package liquibase.sqlgenerator;
import liquibase.change.Change;
import liquibase.database.Database;
import liquibase.exception.UnexpectedLiquibaseException;
import liquibase.structure.DatabaseObject;
import liquibase.exception.ValidationErrors;
import liquibase.exception.Warnings;
import liquibase.servicelocator.ServiceLocator;
import liquibase.sql.Sql;
import liquibase.statement.SqlStatement;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.*;
/**
* SqlGeneratorFactory is a singleton registry of SqlGenerators.
* Use the register(SqlGenerator) method to add custom SqlGenerators,
* and the getBestGenerator() method to retrieve the SqlGenerator that should be used for a given SqlStatement.
*/
public class SqlGeneratorFactory {
private static SqlGeneratorFactory instance;
private List<SqlGenerator> generators = new ArrayList<SqlGenerator>();
//caches for expensive reflection based calls that slow down Liquibase initialization: CORE-1207
private final Map<Class<?>, Type[]> genericInterfacesCache = new HashMap<Class<?>, Type[]>();
private final Map<Class<?>, Type> genericSuperClassCache = new HashMap<Class<?>, Type>();
private Map<String, SortedSet<SqlGenerator>> generatorsByKey = new HashMap<String, SortedSet<SqlGenerator>>();
private SqlGeneratorFactory() {
Class[] classes;
try {
classes = ServiceLocator.getInstance().findClasses(SqlGenerator.class);
for (Class clazz : classes) {
register((SqlGenerator) clazz.getConstructor().newInstance());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* Return singleton SqlGeneratorFactory
*/
public static synchronized SqlGeneratorFactory getInstance() {
if (instance == null) {
instance = new SqlGeneratorFactory();
}
return instance;
}
public static synchronized void reset() {
instance = new SqlGeneratorFactory();
}
public void register(SqlGenerator generator) {
generators.add(generator);
}
public void unregister(SqlGenerator generator) {
generators.remove(generator);
}
public void unregister(Class generatorClass) {
SqlGenerator toRemove = null;
for (SqlGenerator existingGenerator : generators) {
if (existingGenerator.getClass().equals(generatorClass)) {
toRemove = existingGenerator;
}
}
unregister(toRemove);
}
protected Collection<SqlGenerator> getGenerators() {
return generators;
}
public SortedSet<SqlGenerator> getGenerators(SqlStatement statement, Database database) {
String databaseName = null;
if (database == null) {
databaseName = "NULL";
} else {
databaseName = database.getShortName();
}
int version;
if (database == null) {
version = 0;
} else {
try {
version = database.getDatabaseMajorVersion();
} catch (Throwable e) {
version = 0;
}
}
String key = statement.getClass().getName()+":"+ databaseName+":"+ version;
if (generatorsByKey.containsKey(key)) {
return generatorsByKey.get(key);
}
SortedSet<SqlGenerator> validGenerators = new TreeSet<SqlGenerator>(new SqlGeneratorComparator());
for (SqlGenerator generator : getGenerators()) {
Class clazz = generator.getClass();
Type classType = null;
while (clazz != null) {
if (classType instanceof ParameterizedType) {
checkType(classType, statement, generator, database, validGenerators);
}
for (Type type : getGenericInterfaces(clazz)) {
if (type instanceof ParameterizedType) {
checkType(type, statement, generator, database, validGenerators);
} else if (isTypeEqual(type, SqlGenerator.class)) {
//noinspection unchecked
if (generator.supports(statement, database)) {
validGenerators.add(generator);
}
}
}
classType = getGenericSuperclass(clazz);
clazz = clazz.getSuperclass();
}
}
generatorsByKey.put(key, validGenerators);
return validGenerators;
}
private Type[] getGenericInterfaces(Class<?> clazz) {
if(genericInterfacesCache.containsKey(clazz)) {
return genericInterfacesCache.get(clazz);
}
Type[] genericInterfaces = clazz.getGenericInterfaces();
genericInterfacesCache.put(clazz, genericInterfaces);
return genericInterfaces;
}
private Type getGenericSuperclass(Class<?> clazz) {
if(genericSuperClassCache.containsKey(clazz)) {
return genericSuperClassCache.get(clazz);
}
Type genericSuperclass = clazz.getGenericSuperclass();
genericSuperClassCache.put(clazz, genericSuperclass);
return genericSuperclass;
}
private boolean isTypeEqual(Type aType, Class aClass) {
if (aType instanceof Class) {
return ((Class) aType).getName().equals(aClass.getName());
}
return aType.equals(aClass);
}
private void checkType(Type type, SqlStatement statement, SqlGenerator generator, Database database, SortedSet<SqlGenerator> validGenerators) {
for (Type typeClass : ((ParameterizedType) type).getActualTypeArguments()) {
if (typeClass instanceof TypeVariable) {
typeClass = ((TypeVariable) typeClass).getBounds()[0];
}
if (isTypeEqual(typeClass, SqlStatement.class)) {
return;
}
if (((Class) typeClass).isAssignableFrom(statement.getClass())) {
if (generator.supports(statement, database)) {
validGenerators.add(generator);
}
}
}
}
private SqlGeneratorChain createGeneratorChain(SqlStatement statement, Database database) {
SortedSet<SqlGenerator> sqlGenerators = getGenerators(statement, database);
if (sqlGenerators == null || sqlGenerators.size() == 0) {
return null;
}
//noinspection unchecked
return new SqlGeneratorChain(sqlGenerators);
}
public Sql[] generateSql(Change change, Database database) {
SqlStatement[] sqlStatements = change.generateStatements(database);
if (sqlStatements == null) {
return new Sql[0];
} else {
return generateSql(sqlStatements, database);
}
}
public Sql[] generateSql(SqlStatement[] statements, Database database) {
List<Sql> returnList = new ArrayList<Sql>();
SqlGeneratorFactory factory = SqlGeneratorFactory.getInstance();
for (SqlStatement statement : statements) {
Sql[] sqlArray = factory.generateSql(statement, database);
if (sqlArray != null && sqlArray.length > 0) {
List<Sql> sqlList = Arrays.asList(sqlArray);
returnList.addAll(sqlList);
}
}
return returnList.toArray(new Sql[returnList.size()]);
}
public Sql[] generateSql(SqlStatement statement, Database database) {
SqlGeneratorChain generatorChain = createGeneratorChain(statement, database);
if (generatorChain == null) {
throw new IllegalStateException("Cannot find generators for database " + database.getClass() + ", statement: " + statement);
}
return generatorChain.generateSql(statement, database);
}
/**
* Return true if the SqlStatement class queries the database in any way to determine Statements to execute.
* If the statement queries the database, it cannot be used in updateSql type operations
*/
public boolean generateStatementsVolatile(SqlStatement statement, Database database) {
for (SqlGenerator generator : getGenerators(statement, database)) {
if (generator.generateStatementsIsVolatile(database)) {
return true;
}
}
return false;
}
public boolean generateRollbackStatementsVolatile(SqlStatement statement, Database database) {
for (SqlGenerator generator : getGenerators(statement, database)) {
if (generator.generateRollbackStatementsIsVolatile(database)) {
return true;
}
}
return false;
}
public boolean supports(SqlStatement statement, Database database) {
return getGenerators(statement, database).size() > 0;
}
public ValidationErrors validate(SqlStatement statement, Database database) {
//noinspection unchecked
SqlGeneratorChain generatorChain = createGeneratorChain(statement, database);
if (generatorChain == null) {
throw new UnexpectedLiquibaseException("Unable to create generator chain for "+statement.getClass().getName()+" on "+database.getShortName());
}
return generatorChain.validate(statement, database);
}
public Warnings warn(SqlStatement statement, Database database) {
//noinspection unchecked
return createGeneratorChain(statement, database).warn(statement, database);
}
public Set<DatabaseObject> getAffectedDatabaseObjects(SqlStatement statement, Database database) {
Set<DatabaseObject> affectedObjects = new HashSet<DatabaseObject>();
SqlGeneratorChain sqlGeneratorChain = createGeneratorChain(statement, database);
if (sqlGeneratorChain != null) {
//noinspection unchecked
Sql[] sqls = sqlGeneratorChain.generateSql(statement, database);
if (sqls != null) {
for (Sql sql : sqls) {
affectedObjects.addAll(sql.getAffectedDatabaseObjects());
}
}
}
return affectedObjects;
}
}