/* * Copyright 2011 Red Hat Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.persistence.util; import static org.junit.Assert.*; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import java.sql.SQLException; import java.util.Properties; import org.drools.persistence.jta.JtaTransactionManager; import org.drools.runtime.Environment; import org.drools.runtime.EnvironmentName; import org.h2.tools.DeleteDbFiles; import org.h2.tools.Server; import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import bitronix.tm.resource.jdbc.PoolingDataSource; public class PersistenceUtil { private static Logger logger = LoggerFactory.getLogger( PersistenceUtil.class ); public static final String DROOLS_PERSISTENCE_UNIT_NAME = "org.drools.persistence.jpa"; public static final String JBPM_PERSISTENCE_UNIT_NAME = "org.jbpm.persistence.jpa"; protected static final String DATASOURCE_PROPERTIES = "/datasource.properties"; private static TestH2Server h2Server = new TestH2Server(); private static Properties defaultProperties = null; private static Properties getDefaultProperties() { if( defaultProperties == null ) { String [] keyArr = { "serverName", "portNumber", "databaseName", "url", "user", "password", "driverClassName", "className", "maxPoolSize", "allowLocalTransactions" }; String [] defaultPropArr= { "", "", "", "jdbc:h2:tcp://localhost/JPADroolsFlow", "sa", "", "org.h2.Driver", "bitronix.tm.resource.jdbc.lrc.LrcXADataSource", "16", "true" }; Assert.assertTrue("Unequal number of keys for default properties", keyArr.length == defaultPropArr.length); defaultProperties = new Properties(); for( int i = 0; i < keyArr.length; ++i ) { defaultProperties.put(keyArr[i], defaultPropArr[i]); } } return defaultProperties; } public static Properties getDatasourceProperties() { boolean propertiesNotFound = false; InputStream propsInputStream = PersistenceUtil.class.getResourceAsStream(DATASOURCE_PROPERTIES); Properties props = new Properties(); if( propsInputStream != null ) { try { props.load(propsInputStream); } catch (IOException ioe) { propertiesNotFound = true; logger.warn("Unable to find properties, using default H2 properties: " + ioe.getMessage()); ioe.printStackTrace(); } } else { propertiesNotFound = true; } String password = props.getProperty("password"); if( "${maven.jdbc.password}".equals(password) || propertiesNotFound ) { props = getDefaultProperties(); } return props; } public static PoolingDataSource setupPoolingDataSource() { Properties dsProps = getDatasourceProperties(); PoolingDataSource pds = new PoolingDataSource(); // The name must match what's in the persistence.xml! pds.setUniqueName("jdbc/testDS1"); pds.setClassName(dsProps.getProperty("className")); pds.setMaxPoolSize(Integer.parseInt(dsProps.getProperty("maxPoolSize"))); pds.setAllowLocalTransactions(Boolean.parseBoolean(dsProps .getProperty("allowLocalTransactions"))); for (String propertyName : new String[] { "user", "password" }) { pds.getDriverProperties().put(propertyName, dsProps.getProperty(propertyName)); } String driverClass = dsProps.getProperty("driverClassName"); if (driverClass.startsWith("org.h2")) { h2Server.start(); for (String propertyName : new String[] { "url", "driverClassName" }) { pds.getDriverProperties().put(propertyName, dsProps.getProperty(propertyName)); } } else { pds.setClassName(dsProps.getProperty("className")); if( driverClass.startsWith("oracle") ) { pds.getDriverProperties().put("driverType", "thin"); pds.getDriverProperties().put("URL", dsProps.getProperty("url")); } else if( driverClass.startsWith("com.ibm.db2") ) { // placeholder for eventual future modifications } else if( driverClass.startsWith("com.microsoft") ) { for (String propertyName : new String[] { "serverName", "portNumber", "databaseName" }) { pds.getDriverProperties().put(propertyName, dsProps.getProperty(propertyName)); } pds.getDriverProperties().put("URL", dsProps.getProperty("url")); pds.getDriverProperties().put("selectMethod", "cursor"); pds.getDriverProperties().put("InstanceName", "MSSQL01"); // pds.getDriverProperties().put("instanceName", dsProps.getProperty("databaseName")); // do nothing // pds.getDriverProperties().put("instanceName", "mssql"); } else if( driverClass.startsWith("com.mysql") ) { for (String propertyName : new String[] { "databaseName", "serverName", "portNumber", "url" }) { pds.getDriverProperties().put(propertyName, dsProps.getProperty(propertyName)); } } else if( driverClass.startsWith("com.sybase") ) { for (String propertyName : new String[] { "databaseName", "portNumber", "serverName" }) { pds.getDriverProperties().put(propertyName, dsProps.getProperty(propertyName)); } pds.getDriverProperties().put("REQUEST_HA_SESSION", "false"); pds.getDriverProperties().put("networkProtocol", "Tds"); } else if( driverClass.startsWith("org.postgresql") ) { for (String propertyName : new String[] { "databaseName", "portNumber", "serverName" }) { pds.getDriverProperties().put(propertyName, dsProps.getProperty(propertyName)); } } else { throw new RuntimeException("Unknown driver class: " + driverClass); } } return pds; } public static boolean useTransactions() { boolean useTransactions = false; String databaseDriverClassName = getDatasourceProperties().getProperty("driverClassName"); // Postgresql has a "Large Object" api which REQUIRES the use of transactions, since // @Lob/byte array is actually stored in multiple tables. if( databaseDriverClassName.startsWith("org.postgresql") ) { useTransactions = true; } return useTransactions; } private static class TestH2Server { private Server realH2Server; public void start() { if (realH2Server == null || !realH2Server.isRunning(false)) { try { DeleteDbFiles.execute("", "JPADroolsFlow", true); realH2Server = Server.createTcpServer(new String[0]); realH2Server.start(); } catch (SQLException e) { throw new RuntimeException("can't start h2 server db", e); } } } @Override protected void finalize() throws Throwable { if (realH2Server != null) { realH2Server.stop(); } DeleteDbFiles.execute("", "JPADroolsFlow", true); super.finalize(); } } public static Object getValueOfField(String fieldname, Object source) { String sourceClassName = source.getClass().getSimpleName(); Field field = null; try { field = source.getClass().getDeclaredField(fieldname); field.setAccessible(true); } catch (SecurityException e) { fail("Unable to retrieve " + fieldname + " field from " + sourceClassName + ": " + e.getCause()); } catch (NoSuchFieldException e) { fail("Unable to retrieve " + fieldname + " field from " + sourceClassName + ": " + e.getCause()); } assertNotNull("." + fieldname + " field is null!?!", field); Object fieldValue = null; try { fieldValue = field.get(source); } catch (IllegalArgumentException e) { fail("Unable to retrieve value of " + fieldname + " from " + sourceClassName + ": " + e.getCause()); } catch (IllegalAccessException e) { fail("Unable to retrieve value of " + fieldname + " from " + sourceClassName + ": " + e.getCause()); } return fieldValue; } }