/* * Copyright 2011 Red Hat Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.persistence.marshalling.util; import bitronix.tm.resource.jdbc.PoolingDataSource; import javax.persistence.EntityManagerFactory; import javax.persistence.Persistence; import javax.persistence.Table; import java.io.File; import java.io.FilenameFilter; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.net.URL; import java.sql.Connection; import java.sql.DriverManager; import java.sql.Statement; import java.util.ArrayList; import java.util.HashMap; import java.util.Properties; import static org.drools.persistence.util.DroolsPersistenceUtil.DATASOURCE; import static org.drools.persistence.util.DroolsPersistenceUtil.getDatasourceProperties; import static org.drools.persistence.util.DroolsPersistenceUtil.setupPoolingDataSource; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.kie.api.runtime.EnvironmentName.ENTITY_MANAGER_FACTORY; public class MarshallingDBUtil { protected static String MARSHALLING_TEST_DB = "testData"; protected static final String MARSHALLING_BASE_DB = "baseData-current"; protected static boolean clearMarshallingTestDb = true; /** * This method is necessary in order to setup the infrastructure to save, retrieve and compare * the marshalled data generated by Drools/jBPM. * <p/> * This method does the following:<ul> * <li>Find the (absolute) path of the marshalling test database (which stores marshalled data generated during tests).</li> * <li>If this is the first time the test db is being used, delete and recreate the test db for this test run.</li> * </ul> * <i>Note</i>: we find the database in src/test/resources -- NOT in target/test-classes/.. or whichever * folder your IDE/build system might copy the database to. * <p/> * @param jdbcProps The JDBC (database) properties. * @param testClass The class of the test being run (that will generate marshalled data). * @return A Sting containing the URL (in src/test/resources) of the database. */ public static String initializeTestDb(Properties jdbcProps, Class<?> testClass) { Object makeBaseDb = jdbcProps.getProperty("makeBaseDb"); if( "true".equals(makeBaseDb) ) { MARSHALLING_TEST_DB = MARSHALLING_BASE_DB; clearMarshallingTestDb = false; } String dbPath = generatePathToTestDb(testClass); if( clearMarshallingTestDb ) { clearMarshallingTestDb = false; URL dbUrl = Object.class.getResource("/" + MARSHALLING_TEST_DB + ".h2.db"); deleteTestDatabase(dbUrl, dbPath); createMarshallingTestDatabase(dbPath, jdbcProps.getProperty("driverClassName")); } String jdbcURLBase = jdbcProps.getProperty("url"); return jdbcURLBase + dbPath; } /** * This method constructs the path to the database and ensures that the * file that the path refers to exists. * @param testClass The class of the test doing this, in order to access the classLoader/resources. * @return A String containg the absolute URL/path of the test DB. */ protected static String generatePathToTestDb(Class<?> testClass) { URL classUrl = testClass.getResource(testClass.getSimpleName() + ".class"); String projectPath = classUrl.getPath().replaceFirst("target.*", ""); String resourcesPath = projectPath + "target/test-classes/marshalling/"; new File(resourcesPath).mkdirs(); String dbPath = resourcesPath + MARSHALLING_TEST_DB; return dbPath; } /** * This method deletes the test database file. * @param dbUrl * @param dbPath */ private static void deleteTestDatabase(URL dbUrl, String dbPath) { if( dbUrl != null ) { new File(dbUrl.getPath()).delete(); } new File(dbPath + ".h2.db").delete(); } /** * This method quickly creates a H2 database: a direct JDBC connection is used for this. * <p/> * @param dbPath The path to the database. * @param driverClass The name of the JDBC driver class. */ private static void createMarshallingTestDatabase(String dbPath, String driverClass) { try { Class.forName(driverClass); Connection conn = DriverManager.getConnection("jdbc:h2:"+dbPath); conn.setAutoCommit(true); Statement stat = conn.createStatement(); String dropTableQuery = "drop table if exists " + getTableName(MarshalledData.class); stat.executeUpdate(dropTableQuery); conn.close(); } catch (Exception e) { e.printStackTrace(); fail( "Unable to create marshalling database: " + dbPath); } } /** * This method uses reflection to get the name of the table used for an entity. * @param dataClass The class for which we want the table name. * @return A String containing the name of the table for the given class. */ private static String getTableName(Class<?> dataClass) { String tableName = null; Annotation [] anno = dataClass.getDeclaredAnnotations(); for( int i = 0; i < anno.length; ++i ) { Class<?> annoClass = anno[i].annotationType(); if( annoClass.equals(Table.class) ) { Method [] annoMethod = annoClass.getMethods(); int a = 0; while( a < annoMethod.length && ! annoMethod[a].getName().equals("name") ) { ++a; } try { tableName = (String) annoMethod[a].invoke(anno[i]); } catch (Exception e) { e.printStackTrace(); fail( "Unable to generate HQL query - could not get table name: " + e.getMessage() ); } } } if( tableName == null ) { tableName = dataClass.getName(); tableName = tableName.substring(tableName.lastIndexOf('.')+1).toLowerCase(); } return tableName; } public static HashMap<String, Object> initializeMarshalledDataEMF(String persistenceUnitName, Class<?> testClass, boolean useBaseDb) { return initializeMarshalledDataEMF(persistenceUnitName, testClass, useBaseDb, null ); } /** * This method initializes an EntityManagerFactory with a connection to the base (marshalled) data DB. * This database stores the marshalled data that we expect (for a given drools/jbpm version). * @param persistenceUnitName The persistence unit name. * @param testClass The class of the (local) unit test running. * @return A HashMap<String, Object> containg the datasource and entity manager factory. */ public static HashMap<String, Object> initializeMarshalledDataEMF(String persistenceUnitName, Class<?> testClass, boolean useBaseDb, String baseDbVer ) { HashMap<String, Object> context = new HashMap<String, Object>(); Properties dsProps = getDatasourceProperties(); String driverClass = dsProps.getProperty("driverClassName"); if ( ! driverClass.startsWith("org.h2")) { return null; } String dbFilePath = generatePathToTestDb(testClass); if( useBaseDb ) { dbFilePath = dbFilePath.replace(MARSHALLING_TEST_DB, MARSHALLING_BASE_DB); if( baseDbVer != null && baseDbVer.length() > 0) { dbFilePath = dbFilePath.replace("current", baseDbVer); } } String jdbcURLBase = dsProps.getProperty("url"); // trace level file = 0 means that modifying the inode of the db file will _not_ cause a "corrupted" exception String jdbcUrl = jdbcURLBase + dbFilePath; // Setup the datasource PoolingDataSource ds1 = setupPoolingDataSource(dsProps); ds1.getDriverProperties().setProperty("url", jdbcUrl ); ds1.init(); context.put(DATASOURCE, ds1); // Setup persistence Properties overrideProperties = new Properties(); overrideProperties.setProperty("hibernate.connection.url", jdbcUrl); EntityManagerFactory emf = Persistence.createEntityManagerFactory(persistenceUnitName, overrideProperties); context.put(ENTITY_MANAGER_FACTORY, emf); return context; } protected static String [] getListOfBaseDbVers(Class<?> testClass) { String [] versions; ArrayList<String> versionList = new ArrayList<String>(); String path = generatePathToTestDb(testClass); path = path.replace("target" + File.separatorChar + "test-classes" + File.separatorChar + "marshalling" + File.separatorChar + "testData", "src" + File.separatorChar + "test" + File.separatorChar + "resources" + File.separatorChar + "marshalling" + File.separatorChar ); File marshallingDir = new File(path); FilenameFilter baseDatafilter = new FilenameFilter() { public boolean accept(File dir, String name) { return name.startsWith("baseData"); } }; String [] dbFiles = marshallingDir.list(baseDatafilter); assertTrue("No files found in marshalling directory [" + marshallingDir + "]!", dbFiles != null && dbFiles.length > 0 ); for(int i = 0; i < dbFiles.length; ++i ) { String ver = dbFiles[i]; ver = ver.replace(".h2.db", ""); ver = ver.replace("baseData", ""); ver = ver.replace("-", ""); versionList.add(ver); } versions = new String [versionList.size()]; for( int v = 0; v < versions.length; ++v ) { versions[v] = versionList.get(v); } return versions; } }