package liquibase.sqlgenerator.core; import org.junit.Test; import static org.junit.Assert.assertEquals; import liquibase.database.core.MSSQLDatabase; import liquibase.statement.core.InsertOrUpdateStatement; import liquibase.sql.Sql; import static junit.framework.Assert.assertTrue; import junit.framework.Assert; import java.lang.reflect.Method; import java.lang.reflect.InvocationTargetException; public class InsertOrUpdateGeneratorMSSQLTest { @Test public void getRecordCheck(){ InsertOrUpdateGeneratorMSSQL generator = new InsertOrUpdateGeneratorMSSQL(); MSSQLDatabase database = new MSSQLDatabase(); InsertOrUpdateStatement statement = new InsertOrUpdateStatement("myschema","mytable","pk_col1"); statement.addColumnValue("pk_col1","value1"); statement.addColumnValue("col2","value2"); String where = "1 = 1"; Class c = InsertOrUpdateGenerator.class.getClass(); String recordCheck = (String)invokePrivateMethod(generator,"getRecordCheck", new Object[] {statement,database,where}); Integer lineNumber = 0; String[] lines = recordCheck.split("\n"); assertEquals("DECLARE @reccount integer", lines[lineNumber]); lineNumber++; assertEquals("SELECT @reccount = count(*) FROM [myschema].[mytable] WHERE " + where, lines[lineNumber]); lineNumber++; assertEquals("IF @reccount = 0", lines[lineNumber]); } @Test public void getInsert(){ InsertOrUpdateGeneratorMSSQL generator = new InsertOrUpdateGeneratorMSSQL(); MSSQLDatabase database = new MSSQLDatabase(); InsertOrUpdateStatement statement = new InsertOrUpdateStatement("myschema","mytable","pk_col1"); statement.addColumnValue("pk_col1","value1"); statement.addColumnValue("col2","value2"); String where = "1 = 1"; Class c = InsertOrUpdateGenerator.class.getClass(); //InsertOrUpdateStatement insertOrUpdateStatement, Database database, SqlGeneratorChain sqlGeneratorChain String insertStatement = (String)invokePrivateMethod(generator,"getInsertStatement", new Object[] {statement,database,null}); Integer lineNumber = 0; String[] lines = insertStatement.split("\n"); assertEquals("BEGIN", lines[lineNumber]); lineNumber++; assertTrue(lines[lineNumber].startsWith("INSERT")); lineNumber++; assertEquals("END", lines[lineNumber]); } @Test public void getElse(){ InsertOrUpdateGeneratorMSSQL generator = new InsertOrUpdateGeneratorMSSQL(); MSSQLDatabase database = new MSSQLDatabase(); InsertOrUpdateStatement statement = new InsertOrUpdateStatement("myschema","mytable","pk_col1"); statement.addColumnValue("pk_col1","value1"); statement.addColumnValue("col2","value2"); String where = "1 = 1"; Class c = InsertOrUpdateGenerator.class.getClass(); //InsertOrUpdateStatement insertOrUpdateStatement, Database database, SqlGeneratorChain sqlGeneratorChain String insertStatement = (String)invokePrivateMethod(generator,"getElse", new Object[] {database}); Integer lineNumber = 0; String[] lines = insertStatement.split("\n"); assertEquals("ELSE", lines[lineNumber]); } @Test public void getUpdate(){ InsertOrUpdateGeneratorMSSQL generator = new InsertOrUpdateGeneratorMSSQL(); MSSQLDatabase database = new MSSQLDatabase(); InsertOrUpdateStatement statement = new InsertOrUpdateStatement("myschema","mytable","pk_col1"); statement.addColumnValue("col2","value2"); String where = "1 = 1"; Class c = InsertOrUpdateGenerator.class.getClass(); //InsertOrUpdateStatement insertOrUpdateStatement, Database database, String whereClause, SqlGeneratorChain sqlGeneratorChain String insertStatement = (String)invokePrivateMethod(generator,"getUpdateStatement", new Object[] {statement,database,where,null}); Integer lineNumber = 0; String[] lines = insertStatement.split("\n"); assertEquals("BEGIN", lines[lineNumber]); lineNumber++; assertTrue(lines[lineNumber].startsWith("UPDATE")); lineNumber++; assertEquals("END", lines[lineNumber]); } public static Object invokePrivateMethod (Object o, String methodName, Object[] params) { // Check we have valid arguments... Assert.assertNotNull(o); Assert.assertNotNull(methodName); // Assert.assertNotNull(params); // Go and find the private method... final Method methods[] = o.getClass().getDeclaredMethods(); for (int i = 0; i < methods.length; ++i) { if (methodName.equals(methods[i].getName())) { try { methods[i].setAccessible(true); return methods[i].invoke(o, params); } catch (IllegalAccessException ex) { Assert.fail ("IllegalAccessException accessing " + methodName); } catch (InvocationTargetException ite) { Assert.fail ("InvocationTargetException accessing " + methodName); } } } Assert.fail ("Method '" + methodName +"' not found"); return null; } }