package com.lordofthejars.nosqlunit.neo4j.extension.springtemplate; import static ch.lambdaj.Lambda.selectFirst; import static ch.lambdaj.Lambda.having; import static org.hamcrest.CoreMatchers.equalTo; import java.io.InputStream; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Set; import org.neo4j.graphdb.GraphDatabaseService; import org.springframework.data.neo4j.config.JtaTransactionManagerFactoryBean; import org.springframework.data.neo4j.conversion.Result; import org.springframework.data.neo4j.support.Neo4jTemplate; import org.springframework.transaction.TransactionStatus; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionTemplate; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Multimap; import com.lordofthejars.nosqlunit.core.NoSqlAssertionError; import com.lordofthejars.nosqlunit.neo4j.Neo4jComparisonStrategy; import com.lordofthejars.nosqlunit.neo4j.Neo4jConnectionCallback; public class SpringTemplateComparisonStrategy implements Neo4jComparisonStrategy { @Override public boolean compare(Neo4jConnectionCallback connection, InputStream dataset) throws NoSqlAssertionError, Throwable { DataParser dataParser = new DataParser(); List<Object> expectedObjects = dataParser.readValues(dataset); Multimap<Class<?>, Object> expectedGroupByClass = groupByClass(expectedObjects); Set<Class<?>> expectedClasses = expectedGroupByClass.keySet(); for (Class<?> expectedClass : expectedClasses) { Collection<Object> expectedObjectsByClass = expectedGroupByClass.get(expectedClass); List<Object> insertedObjects = findAndFetchAllEntitiesByClass(neo4jTemplate(connection), expectedClass); for (Object expectedObject : expectedObjectsByClass) { Object selectFirst = selectFirst(insertedObjects, equalTo(expectedObject)); if(selectFirst == null) { throw new NoSqlAssertionError(String.format("Object %s is not found in graph.", expectedObject.toString())); } } } return true; } @Override public void setIgnoreProperties(String[] ignoreProperties) { } private List<Object> findAndFetchAllEntitiesByClass(final Neo4jTemplate neo4jTemplate, final Class<?> entityClass) { TransactionTemplate transactionalTemplate = transactionalTemplate(neo4jTemplate.getGraphDatabaseService()); return transactionalTemplate.execute(new TransactionCallback<List<Object>>() { @Override public List<Object> doInTransaction(TransactionStatus status) { Result<?> allEntities = neo4jTemplate.findAll(entityClass); final List<Object> fetchedData = fetchData(neo4jTemplate, allEntities); return fetchedData; } private List<Object> fetchData(final Neo4jTemplate neo4jTemplate, Result<?> allEntities) { final List<Object> fetchedData = new ArrayList<Object>(); Iterator<?> iterator = allEntities.iterator(); while (iterator.hasNext()) { Object entity = iterator.next(); fetchedData.add(neo4jTemplate.fetch(entity)); } return fetchedData; } }); } private TransactionTemplate transactionalTemplate(GraphDatabaseService graphDatabaseService) { try { JtaTransactionManagerFactoryBean jtaTransactionManagerFactoryBean = new JtaTransactionManagerFactoryBean( graphDatabaseService); return new TransactionTemplate(jtaTransactionManagerFactoryBean.getObject()); } catch (Exception e) { throw new IllegalArgumentException(e); } } private Multimap<Class<?>, Object> groupByClass(List<Object> objects) { Multimap<Class<?>, Object> groupByClass = ArrayListMultimap.create(); for (Object object : objects) { groupByClass.put(object.getClass(), object); } return groupByClass; } private Neo4jTemplate neo4jTemplate(Neo4jConnectionCallback connection) { GraphDatabaseService graphDatabaseService = connection.graphDatabaseService(); return new Neo4jTemplate(graphDatabaseService); } }