/* * Copyright (C) 2012-2016 DuyHai DOAN * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package info.archinnov.achilles.junit; import static info.archinnov.achilles.junit.AchillesTestResource.Steps.BOTH; import static info.archinnov.achilles.validation.Validator.validateTrue; import static java.util.stream.Collectors.toList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; import java.util.stream.Stream; import org.junit.rules.ExternalResource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.datastax.driver.core.Cluster; import com.datastax.driver.core.PreparedStatement; import com.datastax.driver.core.Session; import info.archinnov.achilles.embedded.CassandraEmbeddedServer; import info.archinnov.achilles.embedded.CassandraEmbeddedServerBuilder; import info.archinnov.achilles.internals.cache.StatementsCache; import info.archinnov.achilles.internals.runtime.AbstractManagerFactory; import info.archinnov.achilles.logger.AchillesLoggers; import info.archinnov.achilles.script.ScriptExecutor; import info.archinnov.achilles.type.TypedMap; /** * <strong>WARNING: this AchillesTestResource will use an unsafe Cassandra daemon, it is not suitable for production</strong> * <br/><br/> * Test resource for JUnit. Example of usage: * * <pre class="code"><code class="java"> * * {@literal @}Rule * public AchillesTestResource<ManagerFactory> resource = AchillesTestResourceBuilder * .forJunit() * .withKeyspace("unit_test") // default keyspace = achilles_embedded * .entityClassesToTruncate(SimpleEntity.class) * .truncateBeforeAndAfterTest() * .build((cluster, statementsCache) -> ManagerFactoryBuilder * .builder(cluster) * .doForceSchemaCreation(true) * .withStatementCache(statementsCache) * .withDefaultKeyspaceName(DEFAULT_CASSANDRA_EMBEDDED_KEYSPACE_NAME) * .build() * ); * private Session session = resource.getNativeSession(); * private ScriptExecutor scriptExecutor = resource.getScriptExecutor(); * private SimpleEntity_Manager manager = resource.getManagerFactory().forSimpleEntity(); * {@literal @}Test * public void should_test_xxx() throws Exception { * //Given * final long id = RandomUtils.nextLong(0, Long.MAX_VALUE); * scriptExecutor.executeScriptTemplate("SimpleEntity/insert_single_row.cql", ImmutableMap.of("id", id)); * //When * //Then * Row actual = session.execute("SELECT ....").one(); * assertTrue(row.getString("xxx").equals("yyy")); * ... * } * </code></pre> */ public class AchillesTestResource<T extends AbstractManagerFactory> extends ExternalResource { // Default statement cache for unit testing private static final StatementsCache STATEMENTS_CACHE = new StatementsCache(10000); private static final Logger DML_LOG = LoggerFactory.getLogger(AchillesLoggers.ACHILLES_DML_STATEMENT); private static final Map<String, PreparedStatement> TABLES_TO_TRUNCATE = new ConcurrentHashMap<>(); private final TypedMap cassandraParams; private final Optional<String> keyspaceName; private final List<PreparedStatement> truncateStatements; private final CassandraEmbeddedServer server; private final T managerFactory; private final Session session; private final ScriptExecutor scriptExecutor; private final Steps steps; public AchillesTestResource(BiFunction<Cluster, StatementsCache, T> managerFactoryBuilder, TypedMap cassandraParams, Optional<String> keyspaceName, List<String> tablesToTruncate, List<Class<?>> entityClassesToTruncate) { this(managerFactoryBuilder, cassandraParams, keyspaceName, BOTH, tablesToTruncate, entityClassesToTruncate); } public AchillesTestResource(BiFunction<Cluster, StatementsCache, T> managerFactoryBuilder, TypedMap cassandraParams, Optional<String> keyspaceName, Steps cleanUpSteps, List<String> tablesToTruncate, List<Class<?>> entityClassesToTruncate) { this.cassandraParams = cassandraParams; this.keyspaceName = keyspaceName; this.steps = cleanUpSteps; this.server = buildServer(); this.session = buildSession(this.server); this.scriptExecutor = new ScriptExecutor(this.session); this.managerFactory = buildManagerFactory(this.server, managerFactoryBuilder); this.truncateStatements = determineTableToTruncate(this.managerFactory, this.session, tablesToTruncate, entityClassesToTruncate); } public Session getNativeSession() { return this.session; } public ScriptExecutor getScriptExecutor() { return this.scriptExecutor; } public T getManagerFactory() { return this.managerFactory; } private CassandraEmbeddedServer buildServer() { return CassandraEmbeddedServerBuilder .builder() .withParams(cassandraParams) .buildServer(); } private T buildManagerFactory(CassandraEmbeddedServer server, BiFunction<Cluster, StatementsCache, T> managerFactoryBuilder) { return managerFactoryBuilder.apply(server.getNativeCluster(), STATEMENTS_CACHE); } private Session buildSession(CassandraEmbeddedServer server) { final Session defaultSession = server.getNativeSession(); final Session session = keyspaceName .filter(ks -> !ks.equals(defaultSession.getLoggedKeyspace())) .map(x -> defaultSession.getCluster().connect(x)) .orElse(defaultSession); server.registerSessionForShutdown(session); return session; } private List<PreparedStatement> determineTableToTruncate(T managerFactory, Session session, List<String> tablesToTruncate, List<Class<?>> entityClassesToTruncate) { entityClassesToTruncate .forEach(clazz -> validateTrue(managerFactory.staticTableNameFor(clazz).isPresent(), "Entity class '%s' is not managed by Achilles. Did you forget to add @Table annotation ?", clazz.getCanonicalName())); maybeGenerateTruncateStatement(session, entityClassesToTruncate .stream() .map(clazz -> managerFactory.staticTableNameFor(clazz).get().toLowerCase()) .collect(toList())); maybeGenerateTruncateStatement(session, tablesToTruncate); return Stream.concat(tablesToTruncate.stream(), entityClassesToTruncate.stream().map(clazz -> managerFactory.staticTableNameFor(clazz).get().toLowerCase())) .map(TABLES_TO_TRUNCATE::get) .collect(toList()); } private void maybeGenerateTruncateStatement(Session session, List<String> tablesToTruncate) { tablesToTruncate .stream() .filter(tableName -> !TABLES_TO_TRUNCATE.containsKey(tableName)) .forEach(table -> TABLES_TO_TRUNCATE.put(table, session.prepare("TRUNCATE " + table))); } protected void before() throws Throwable { if (steps.isBefore()) truncateTables(); } protected void after() { if (steps.isAfter()) truncateTables(); } public void truncateTables() { truncateStatements .forEach(statement -> { if (DML_LOG.isDebugEnabled()) { DML_LOG.debug(statement.getQueryString()); } session.execute(statement.bind()); }); } public enum Steps { BEFORE_TEST, AFTER_TEST, BOTH; public boolean isBefore() { return (this == BOTH || this == BEFORE_TEST); } public boolean isAfter() { return (this == BOTH || this == AFTER_TEST); } } }