package liquibase.sqlgenerator;
import liquibase.database.Database;
import liquibase.database.structure.DatabaseObject;
import liquibase.exception.ValidationErrors;
import liquibase.exception.Warnings;
import liquibase.servicelocator.ServiceLocator;
import liquibase.sql.Sql;
import liquibase.sqlgenerator.core.AbstractSqlGenerator;
import liquibase.statement.SqlStatement;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.util.*;
/**
* SqlGeneratorFactory is a singleton registry of SqlGenerators.
* Use the register(SqlGenerator) method to add custom SqlGenerators,
* and the getBestGenerator() method to retrieve the SqlGenerator that should be used for a given SqlStatement.
*/
public class SqlGeneratorFactory {
private static SqlGeneratorFactory instance;
private List<SqlGenerator> generators = new ArrayList<SqlGenerator>();
private SqlGeneratorFactory() {
Class[] classes;
try {
classes = ServiceLocator.getInstance().findClasses(SqlGenerator.class);
for (Class clazz : classes) {
register((SqlGenerator) clazz.getConstructor().newInstance());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* Return singleton SqlGeneratorFactory
*/
public static SqlGeneratorFactory getInstance() {
if (instance == null) {
instance = new SqlGeneratorFactory();
}
return instance;
}
public static void reset() {
instance = new SqlGeneratorFactory();
}
public void register(SqlGenerator generator) {
generators.add(generator);
}
public void unregister(SqlGenerator generator) {
generators.remove(generator);
}
public void unregister(Class generatorClass) {
SqlGenerator toRemove = null;
for (SqlGenerator existingGenerator : generators) {
if (existingGenerator.getClass().equals(generatorClass)) {
toRemove = existingGenerator;
}
}
unregister(toRemove);
}
protected Collection<SqlGenerator> getGenerators() {
return generators;
}
protected SortedSet<SqlGenerator> getGenerators(SqlStatement statement, Database database) {
SortedSet<SqlGenerator> validGenerators = new TreeSet<SqlGenerator>(new SqlGeneratorComparator());
for (SqlGenerator generator : getGenerators()) {
Class clazz = generator.getClass();
Type classType = null;
while (clazz != null) {
if (classType instanceof ParameterizedType) {
checkType(classType, statement, generator, database, validGenerators);
}
for (Type type : clazz.getGenericInterfaces()) {
if (type instanceof ParameterizedType) {
checkType(type, statement, generator, database, validGenerators);
} else if (isTypeEqual( type, SqlGenerator.class)) {
//noinspection unchecked
if (generator.supports(statement, database)) {
validGenerators.add(generator);
}
}
}
classType = clazz.getGenericSuperclass();
clazz = clazz.getSuperclass();
}
}
return validGenerators;
}
private boolean isTypeEqual(Type aType, Class aClass) {
if (aType instanceof Class) {
return ((Class) aType).getName().equals(aClass.getName());
}
return aType.equals(aClass);
}
private void checkType(Type type, SqlStatement statement, SqlGenerator generator, Database database, SortedSet<SqlGenerator> validGenerators) {
for (Type typeClass : ((ParameterizedType) type).getActualTypeArguments()) {
if (typeClass instanceof TypeVariable) {
typeClass = ((TypeVariable) typeClass).getBounds()[0];
}
if (isTypeEqual( typeClass, SqlStatement.class)) {
return;
}
if (((Class) typeClass).isAssignableFrom(statement.getClass())) {
if (generator.supports(statement, database)) {
validGenerators.add(generator);
}
}
}
}
private SqlGeneratorChain createGeneratorChain(SqlStatement statement, Database database) {
SortedSet<SqlGenerator> sqlGenerators = getGenerators(statement, database);
if (sqlGenerators == null || sqlGenerators.size() == 0) {
return null;
}
//noinspection unchecked
return new SqlGeneratorChain(sqlGenerators);
}
public Sql[] generateSql(SqlStatement statement, Database database) {
SqlGeneratorChain generatorChain = createGeneratorChain(statement, database);
if (generatorChain == null) {
throw new IllegalStateException("Cannot find generators for database " + database.getClass() + ", statement: " + statement);
}
return generatorChain.generateSql(statement, database);
}
public boolean requiresCurrentDatabaseMetadata(SqlStatement statement, Database database) {
for (SqlGenerator generator : getGenerators(statement, database)) {
if (generator.requiresUpdatedDatabaseMetadata(database)) {
return true;
}
}
return false;
}
public boolean supports(SqlStatement statement, Database database) {
return getGenerators(statement, database).size() > 0;
}
public ValidationErrors validate(SqlStatement statement, Database database) {
//noinspection unchecked
return createGeneratorChain(statement, database).validate(statement, database);
}
public Warnings warn(SqlStatement statement, Database database) {
//noinspection unchecked
return createGeneratorChain(statement, database).warn(statement, database);
}
public Set<DatabaseObject> getAffectedDatabaseObjects(SqlStatement statement, Database database) {
Set<DatabaseObject> affectedObjects = new HashSet<DatabaseObject>();
SqlGeneratorChain sqlGeneratorChain = createGeneratorChain(statement, database);
if (sqlGeneratorChain != null) {
//noinspection unchecked
Sql[] sqls = sqlGeneratorChain.generateSql(statement, database);
if (sqls != null) {
for (Sql sql : sqls) {
affectedObjects.addAll(sql.getAffectedDatabaseObjects());
}
}
}
return affectedObjects;
}
}