package org.hivedb.hibernate.simplified.session; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.hibernate.Interceptor; import org.hibernate.Session; import org.hibernate.SessionFactory; import org.hibernate.shards.Shard; import org.hibernate.shards.session.ShardedSessionFactory; import org.hibernate.shards.session.ShardedSessionImpl; import org.hivedb.Hive; import org.hivedb.HiveRuntimeException; import org.hivedb.configuration.EntityHiveConfig; import org.hivedb.hibernate.RecordNodeOpenSessionEvent; import org.hivedb.hibernate.simplified.HiveInterceptorDecorator; import org.hivedb.meta.Node; import org.hivedb.util.functional.*; import java.sql.SQLException; import java.util.Collection; import java.util.HashMap; import java.util.Map; // TODO Enforce ReadWrite Constraints public class HiveSessionFactoryImpl implements HiveSessionFactory { private final static Log log = LogFactory.getLog(HiveSessionFactoryImpl.class); private ShardedSessionFactory factory; private EntityHiveConfig config; private Hive hive; private Map<Integer, SessionFactory> factories; public Map<Integer, Node> nodesById; public HiveSessionFactoryImpl(ShardedSessionFactory shardedFactory, Hive hive, EntityHiveConfig hiveConfig) { this.factory = shardedFactory; this.hive = hive; this.config = hiveConfig; this.factories = buildSessionFactoryMap(factory, hive.getNodes()); this.nodesById = buildNodeToIdMap(hive); } public Session openSession() { return openShardedSession(getDefaultInterceptor()); } public Session openSession(Interceptor interceptor) { return openShardedSession(wrapWithHiveInterceptor(interceptor)); } public Session openSession(Object primaryIndexKey) { return openSession(hive.directory().getNodeIdsOfPrimaryIndexKey(primaryIndexKey), getDefaultInterceptor()); } public Session openSession(Object primaryIndexKey, Interceptor interceptor) { return openSession(hive.directory().getNodeIdsOfPrimaryIndexKey(primaryIndexKey), wrapWithHiveInterceptor(interceptor)); } public Session openSession(String resource, Object resourceId) { return openSession(hive.directory().getNodeIdsOfResourceId(resource, resourceId), getDefaultInterceptor()); } public Session openSession(String resource, Object resourceId, Interceptor interceptor) { return openSession(hive.directory().getNodeIdsOfResourceId(resource, resourceId), wrapWithHiveInterceptor(interceptor)); } public Session openSession(String resource, String indexName, Object secondaryIndexKey) { return openSession(hive.directory().getNodeIdsOfSecondaryIndexKey(resource, indexName, secondaryIndexKey), getDefaultInterceptor()); } public Session openSession(String resource, String indexName, Object secondaryIndexKey, Interceptor interceptor) { return openSession(hive.directory().getNodeIdsOfSecondaryIndexKey(resource, indexName, secondaryIndexKey), wrapWithHiveInterceptor(interceptor)); } public Interceptor getDefaultInterceptor() { return new HiveInterceptorDecorator(config, hive); } public Interceptor wrapWithHiveInterceptor(Interceptor interceptor) { return new HiveInterceptorDecorator(interceptor, config, hive); } private Session openShardedSession(Interceptor interceptor) { return addEventsToShardedSession((ShardedSessionImpl) factory.openSession(wrapWithHiveInterceptor(interceptor))); } private Session openSession(Collection<Integer> nodeIds, Interceptor interceptor) { Collection<Node> nodes = getNodesFromIds(nodeIds); if(nodes.size() > 1) throw new IllegalStateException("Record appears to be stored on more than one node. Currently HiveDB Hibernate support only allows records to be stored on a single node."); return openSession(Atom.getFirstOrThrow(nodes), interceptor); } private Session openSession(Node node, Interceptor interceptor) { return addEventsToSession(getSessionFactory(node).openSession(interceptor)); } private SessionFactory getSessionFactory(Node node) { return factories.get(node.getId()); } private Session addEventsToShardedSession(ShardedSessionImpl session) { for (Shard shard : session.getShards()) { shard.addOpenSessionEvent(new RecordNodeOpenSessionEvent()); } return session; } private Session addEventsToSession(Session session) { RecordNodeOpenSessionEvent.setNode(session); return session; } @SuppressWarnings("deprecation") public String extractFactoryURL(SessionFactory factory) { Session session = null; try { session = factory.openSession(); return session.connection().getMetaData().getURL(); } catch (SQLException e) { throw new HiveRuntimeException(e); } finally { if(session != null) session.close(); } } private Collection<Node> getNodesFromIds(Collection<Integer> ids) { return Transform.map(new Unary<Integer,Node>(){ public Node f(Integer item) { return nodesById.get(item); } }, ids); } private Map<Integer, Node> buildNodeToIdMap(Hive hive) { Map<Integer,Node> nodeMap = new HashMap<Integer,Node>(); for(Node node : hive.getNodes()) { nodeMap.put(node.getId(), node); } return nodeMap; } private Map<Integer, SessionFactory> buildSessionFactoryMap(ShardedSessionFactory factory, Collection<Node> nodes) { Map<Integer, SessionFactory> factoryMap = new HashMap<Integer, SessionFactory>(); for(SessionFactory f : factory.getSessionFactories()) { Node node = matchNodeToFactoryUrl(extractFactoryURL(f), nodes); factoryMap.put(node.getId(), f); } return factoryMap; } private Node matchNodeToFactoryUrl(final String url, Collection<Node> nodes) { return Filter.grepSingle(new Predicate<Node>(){ public boolean f(Node item) { return item.getUri().startsWith(url); } }, nodes); } }