package jef.database.test; import java.io.IOException; import java.io.Reader; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Properties; import jef.common.log.LogUtil; import jef.database.DbCfg; import jef.database.DbClient; import jef.database.DbClientBuilder; import jef.database.DbUtils; import jef.database.datasource.DataSourceInfoImpl; import jef.database.datasource.MapDataSourceInfoLookup; import jef.database.datasource.RoutingDataSource; import jef.database.meta.MetaHolder; import jef.tools.IOUtils; import jef.tools.JefConfiguration; import jef.tools.StringUtils; import jef.tools.reflect.BeanUtils; import jef.tools.reflect.FieldEx; import org.apache.commons.lang.ArrayUtils; import org.junit.internal.runners.statements.ExpectException; import org.junit.runner.Description; import org.junit.runner.Result; import org.junit.runner.manipulation.Filter; import org.junit.runner.manipulation.NoTestsRemainException; import org.junit.runner.notification.Failure; import org.junit.runner.notification.RunListener; import org.junit.runner.notification.RunNotifier; import org.junit.runners.BlockJUnit4ClassRunner; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.InitializationError; import org.junit.runners.model.Statement; /** * jef的单元测试工具 * * @author jiyi * */ public class JefJUnit4DatabaseTestRunner extends BlockJUnit4ClassRunner { final Map<String, Object> connections = new HashMap<String, Object>(); private boolean isRouting; private DbClient routingDbClient; private Properties pro = new Properties(); static class DbConnectionHolder { DbConnectionHolder(DataSource ds, DbClient db) { this.datasource = ds; this.db = db; } DataSource datasource; DbClient db; } private void initContext(DataSourceContext annotation) { if (annotation == null) { throw new IllegalArgumentException("Please assign a @DataSourceContext on this class."); } this.isRouting = annotation.routing(); Reader reader = IOUtils.getReader(this.getClass().getClassLoader().getResource("junit4jef.properties"), "UTF-8"); if (reader != null) { try { pro.load(reader); } catch (IOException e) { e.printStackTrace(); } } for (DataSource ds : annotation.value()) { String url = ds.url(); if (url.isEmpty()) continue; url = apply(url); if (url.isEmpty()) { LogUtil.warn("The case {} with datasource {} was not config.", this.getName(), ds.url()); continue; } connections.put(ds.name(), ds); } } @Override protected List<FrameworkMethod> getChildren() { List<FrameworkMethod> methods = super.getChildren(); if (isRouting || connections.isEmpty()) { return methods; } List<FrameworkMethod> result = new ArrayList<FrameworkMethod>(); for (String s : connections.keySet()) { for (FrameworkMethod me : methods) { IgnoreOn at = me.getMethod().getAnnotation(IgnoreOn.class); if (at == null || isNotIgnore(at, s)) { result.add(new DbFrameworkMethod(s, me.getMethod())); } } } return result; } private boolean isNotIgnore(IgnoreOn at, String s) { if (at.allButExcept().length == 0) { return !ArrayUtils.contains(at.value(), s); } else { return ArrayUtils.contains(at.allButExcept(), s); } } @Override public void filter(final Filter raw) throws NoTestsRemainException { super.filter(new Filter() { @Override public boolean shouldRun(Description description) { String testDisplay = StringUtils.substringBefore(description.getDisplayName(), " "); if (testDisplay != description.getDisplayName()) { description = Description.createTestDescription(description.getTestClass(), testDisplay); } return raw.shouldRun(description); } @Override public String describe() { return raw.describe(); } }); } @Override protected Statement methodInvoker(FrameworkMethod method, Object test) { boolean isNew = false; boolean inject = false; if (method instanceof DbFrameworkMethod) { DbFrameworkMethod dbCase = (DbFrameworkMethod) method; String dbType = dbCase.getDbType(); Object obj = connections.get(dbType); if (obj == null) { throw new IllegalArgumentException("Database " + dbType + " is unknown."); } DbConnectionHolder holder; if (obj instanceof DataSource) { holder = createDbClient((DataSource) obj);// CreateDbClient isNew = true; } else { holder = (DbConnectionHolder) obj; if (holder.db == null) { holder = createDbClient(holder.datasource);// CreateDbClient } } inject(test, holder.db, holder.datasource.field()); inject=true; } else if (isRouting) { if (routingDbClient == null) { MapDataSourceInfoLookup lookup = new MapDataSourceInfoLookup(); for (Entry<String, Object> entry : connections.entrySet()) { if (entry.getValue() instanceof DataSource) { DataSource ds = (DataSource) entry.getValue(); DataSourceInfoImpl dsi = new DataSourceInfoImpl(ds.url()); dsi.setUser(ds.user()); dsi.setPassword(ds.password()); lookup.add(ds.name(), dsi); } } RoutingDataSource rds = new RoutingDataSource(lookup); routingDbClient = new DbClient(rds); isNew = true; } inject(test, routingDbClient, ""); inject=true; } if (isNew) { List<FrameworkMethod> methods = getTestClass().getAnnotatedMethods(DatabaseInit.class); try { for (FrameworkMethod m : methods) { printMethod(m, method); m.getMethod().invoke(test); } } catch (Exception e) { throw DbUtils.toRuntimeException(e); } } printMethod(method, null); if(inject){ return super.methodInvoker(method, test); }else{ System.err.println("数据库未配置,跳过测试:"+super.describeChild(method).getDisplayName()); return new ExpectException(null, NullPointerException.class); } } private void printMethod(FrameworkMethod m, FrameworkMethod parentMethod) { String name = m.getMethod().getDeclaringClass().getName() + "." + m.getName(); if (parentMethod instanceof DbFrameworkMethod) { name = name + "@" + ((DbFrameworkMethod) parentMethod).dbType; } else if (m instanceof DbFrameworkMethod) { name = name + "@" + ((DbFrameworkMethod) m).dbType; } System.out.println("======================== " + name + " =========================="); } private void inject(Object test, DbClient db, String field) { if (StringUtils.isEmpty(field)) { field = "db"; } FieldEx f = BeanUtils.getField(test.getClass(), field); if (f == null) { throw new IllegalStateException("The class " + test.getClass() + " must have a field named '" + field + "'."); } try { f.set(test, db); } catch (Exception e) { throw new IllegalStateException("Cann't inject DbClient into '" + field + "' on " + test); } } private DbConnectionHolder createDbClient(DataSource ds) { MetaHolder.clear(); int max = JefConfiguration.getInt(DbCfg.DB_CONNECTION_POOL_MAX, 50); DbClient db = new DbClientBuilder(apply(ds.url()), apply(ds.user()), apply(ds.password())).setMaxPoolSize(max).build(); DbConnectionHolder holder = new DbConnectionHolder(ds, db); connections.put(ds.name(), holder); return holder; } private String apply(String string) { if (pro.isEmpty()) { return string; } return StringUtils.convertProperty(string, pro); } @Override protected String testName(FrameworkMethod method) { if (method instanceof DbFrameworkMethod) { return method.toString(); } else { return super.testName(method); } } @Override public void run(RunNotifier notifier) { notifier.addListener(new RunListener() { public void testRunFinished(Result result) throws Exception { super.testRunFinished(result); if (routingDbClient != null) { close(routingDbClient, ""); routingDbClient = null; } for (Object obj : connections.values()) { if (obj instanceof DbConnectionHolder) { DbConnectionHolder holder = (DbConnectionHolder) obj; close(holder.db, holder.datasource.field()); holder.db = null; } } } @Override public void testFailure(Failure failure) throws Exception { super.testFailure(failure); System.err.println("!!!Failure!!!"); failure.getException().printStackTrace(); } }); super.run(notifier); } private void close(DbClient db, String field) { List<FrameworkMethod> methods = getTestClass().getAnnotatedMethods(DatabaseDestroy.class); if (!methods.isEmpty()) { try { Object test = super.createTest(); inject(test, db, field); for (FrameworkMethod m : methods) { m.getMethod().invoke(test); } } catch (Exception e) { throw new IllegalStateException(e); } } try { db.close(); } catch (Exception e) { } ; } static class DbFrameworkMethod extends FrameworkMethod { public DbFrameworkMethod(String type, Method method) { super(method); this.dbType = type; } String dbType; public final String getDbType() { return dbType; } public final void setDbType(String dbType) { this.dbType = dbType; } @Override public String toString() { return getMethod().getName() + " @" + dbType; } } public JefJUnit4DatabaseTestRunner(Class<?> klass) throws InitializationError { super(klass); initContext(klass.getAnnotation(DataSourceContext.class)); } }