/*******************************************************************************
* Copyright 2016 Observational Health Data Sciences and Informatics
*
* This file is part of WhiteRabbit
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package org.ohdsi.rabbitInAHat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.ohdsi.rabbitInAHat.dataModel.Database;
import org.ohdsi.rabbitInAHat.dataModel.ETL;
import org.ohdsi.rabbitInAHat.dataModel.Field;
import org.ohdsi.rabbitInAHat.dataModel.Table;
import org.ohdsi.utilities.StringUtilities;
import org.ohdsi.utilities.files.WriteTextFile;
public class ETLTestFrameWorkGenerator {
public static String[] keywords = new String[] { "ADD", "ALL", "ALTER", "AND", "ANY", "AS", "ASC", "AUTHORIZATION", "BACKUP", "BEGIN", "BETWEEN",
"BREAK", "BROWSE", "BULK", "BY", "CASCADE", "CASE", "CHECK", "CHECKPOINT", "CLOSE", "CLUSTERED", "COALESCE", "COLLATE", "COLUMN", "COMMIT",
"COMPUTE", "CONSTRAINT", "CONTAINS", "CONTAINSTABLE", "CONTINUE", "CONVERT", "CREATE", "CROSS", "CURRENT", "CURRENT_DATE", "CURRENT_TIME",
"CURRENT_TIMESTAMP", "CURRENT_USER", "CURSOR", "DATABASE", "DBCC", "DEALLOCATE", "DECLARE", "DEFAULT", "DELETE", "DENY", "DESC", "DISK",
"DISTINCT", "DISTRIBUTED", "DOUBLE", "DROP", "DUMP", "ELSE", "END", "ERRLVL", "ESCAPE", "EXCEPT", "EXEC", "EXECUTE", "EXISTS", "EXIT", "EXTERNAL",
"FETCH", "FILE", "FILLFACTOR", "FOR", "FOREIGN", "FREETEXT", "FREETEXTTABLE", "FROM", "FULL", "FUNCTION", "GOTO", "GRANT", "GROUP", "HAVING",
"HOLDLOCK", "IDENTITY", "IDENTITY_INSERT", "IDENTITYCOL", "IF", "IN", "INDEX", "INNER", "INSERT", "INTERSECT", "INTO", "IS", "JOIN", "KEY", "KILL",
"LEFT", "LIKE", "LINENO", "LOAD", "MERGE", "NATIONAL", "NOCHECK", "NONCLUSTERED", "NOT", "NULL", "NULLIF", "OF", "OFF", "OFFSETS", "ON", "OPEN",
"OPENDATASOURCE", "OPENQUERY", "OPENROWSET", "OPENXML", "OPTION", "OR", "ORDER", "OUTER", "OVER", "PERCENT", "PIVOT", "PLAN", "PRECISION",
"PRIMARY", "PRINT", "PROC", "PROCEDURE", "PUBLIC", "RAISERROR", "READ", "READTEXT", "RECONFIGURE", "REFERENCES", "REPLICATION", "RESTORE",
"RESTRICT", "RETURN", "REVERT", "REVOKE", "RIGHT", "ROLLBACK", "ROWCOUNT", "ROWGUIDCOL", "RULE", "SAVE", "SCHEMA", "SECURITYAUDIT", "SELECT",
"SEMANTICKEYPHRASETABLE", "SEMANTICSIMILARITYDETAILSTABLE", "SEMANTICSIMILARITYTABLE", "SESSION_USER", "SET", "SETUSER", "SHUTDOWN", "SOME",
"STATISTICS", "SYSTEM_USER", "TABLE", "TABLESAMPLE", "TEXTSIZE", "THEN", "TO", "TOP", "TRAN", "TRANSACTION", "TRIGGER", "TRUNCATE", "TRY_CONVERT",
"TSEQUAL", "UNION", "UNIQUE", "UNPIVOT", "UPDATE", "UPDATETEXT", "USE", "USER", "VALUES", "VARYING", "VIEW", "WAITFOR", "WHEN", "WHERE", "WHILE",
"WITH", "WITHIN GROUP", "WRITETEXT" };
private static Set<String> keywordSet;
private static int DEFAULT = 0;
private static int NEGATE = 1;
private static int COUNT = 2;
public static void generate(ETL etl, String filename) {
keywordSet = new HashSet<String>();
for (String keyword : keywords)
keywordSet.add(keyword);
List<String> r = generateRScript(etl);
WriteTextFile out = new WriteTextFile(filename);
for (String line : r)
out.writeln(line);
out.close();
}
private static List<String> generateRScript(ETL etl) {
List<String> r = new ArrayList<String>();
createInitFunction(r, etl.getSourceDatabase());
createDeclareTestFunction(r);
createSetDefaultFunctions(r, etl.getSourceDatabase());
createGetDefaultFunctions(r, etl.getSourceDatabase());
createAddFunctions(r, etl.getSourceDatabase());
createExpectFunctions(r, DEFAULT, etl.getTargetDatabase());
createExpectFunctions(r, NEGATE, etl.getTargetDatabase());
createExpectFunctions(r, COUNT, etl.getTargetDatabase());
createLookupFunctions(r, etl.getTargetDatabase());
return r;
}
private static void createDeclareTestFunction(List<String> r) {
r.add("declareTest <- function(id, description) {");
r.add(" assign(\"testId\", id, envir = globalenv()) ");
r.add(" assign(\"testDescription\", description, envir = globalenv()) ");
r.add(" sql <- c(\"\", paste0(\"-- \", id, \": \", description))");
r.add(" assign(\"insertSql\", c(get(\"insertSql\", envir = globalenv()), sql), envir = globalenv())");
r.add(" assign(\"testSql\", c(get(\"testSql\", envir = globalenv()), sql), envir = globalenv())");
r.add("}");
r.add("");
}
private static void createExpectFunctions(List<String> r, int type, Database database) {
for (Table table : database.getTables()) {
if (!table.isStem()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
String sqlTableName = convertToSqlName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> testDefs = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = convertToRName(field.getName());
String sqlFieldName = convertToSqlName(field.getName());
argDefs.add(rFieldName);
testDefs.add(" if (!missing(" + rFieldName + ")) {");
testDefs.add(" if (first) {");
testDefs.add(" first <- FALSE");
testDefs.add(" } else {");
testDefs.add(" statement <- paste0(statement, \" AND\")");
testDefs.add(" }");
testDefs.add(" if (is.null(" + rFieldName + ")) {");
testDefs.add(" statement <- paste0(statement, \" " + sqlFieldName + " IS NULL\")");
testDefs.add(" } else if (is(" + rFieldName + ", \"subQuery\")){");
testDefs.add(" statement <- paste0(statement, \" " + sqlFieldName + " = (\", as.character(" + rFieldName + "), \")\")");
testDefs.add(" } else {");
testDefs.add(" statement <- paste0(statement, \" " + sqlFieldName + " = '\", " + rFieldName + ",\"'\")");
testDefs.add(" }");
testDefs.add(" }");
testDefs.add("");
}
if (type == DEFAULT)
line.append("expect_" + rTableName + " <- function(");
else if (type == NEGATE)
line.append("expect_no_" + rTableName + " <- function(");
else
line.append("expect_count_" + rTableName + " <- function(rowCount, ");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
line = new StringBuilder();
line.append(" statement <- paste0(\"INSERT INTO test_results SELECT ");
line.append("\", get(\"testId\", envir = globalenv()), \" AS id, ");
line.append("'\", get(\"testDescription\", envir = globalenv()), \"' AS description, ");
line.append("'Expect " + table.getName() + "' AS test, ");
line.append("CASE WHEN(SELECT COUNT(*) FROM " + sqlTableName + " WHERE\")");
r.add(line.toString());
r.add(" first <- TRUE");
r.addAll(testDefs);
if (type == DEFAULT)
r.add(" statement <- paste0(statement, \") = 0 THEN 'FAIL' ELSE 'PASS' END AS status;\")");
else if (type == NEGATE)
r.add(" statement <- paste0(statement, \") != 0 THEN 'FAIL' ELSE 'PASS' END AS status;\")");
else
r.add(" statement <- paste0(statement, \") != \",rowCount ,\" THEN 'FAIL' ELSE 'PASS' END AS status;\")");
r.add(" assign(\"testSql\", c(get(\"testSql\", envir = globalenv()), statement), envir = globalenv())");
r.add(" invisible(statement)");
r.add("}");
r.add("");
}
}
}
private static void createLookupFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
if (!table.isStem()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
String sqlTableName = convertToSqlName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> testDefs = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = convertToRName(field.getName());
String sqlFieldName = convertToSqlName(field.getName());
argDefs.add(rFieldName);
testDefs.add(" if (!missing(" + rFieldName + ")) {");
testDefs.add(" if (first) {");
testDefs.add(" first <- FALSE");
testDefs.add(" } else {");
testDefs.add(" statement <- paste0(statement, \" AND\")");
testDefs.add(" }");
testDefs.add(" if (is.null(" + rFieldName + ")) {");
testDefs.add(" statement <- paste0(statement, \" " + sqlFieldName + " IS NULL\")");
testDefs.add(" } else if (is(" + rFieldName + ", \"subQuery\")){");
testDefs.add(" statement <- paste0(statement, \" " + sqlFieldName + " = (\", as.character(" + rFieldName + "), \")\")");
testDefs.add(" } else {");
testDefs.add(" statement <- paste0(statement, \" " + sqlFieldName + " = '\", " + rFieldName + ",\"'\")");
testDefs.add(" }");
testDefs.add(" }");
testDefs.add("");
}
line.append("lookup_" + rTableName + " <- function(fetchField, ");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
line = new StringBuilder();
line.append(" statement <- paste0(\"SELECT \", fetchField , \" FROM ");
line.append(sqlTableName);
line.append(" WHERE\")");
r.add(line.toString());
r.add(" first <- TRUE");
r.addAll(testDefs);
r.add(" class(statement) <- \"subQuery\"");
r.add(" return(statement)");
r.add("}");
r.add("");
}
}
}
private static String convertToSqlName(String name) {
if (name.contains(" ") || name.contains(".") || keywordSet.contains(name.toUpperCase()))
name = "[" + name + "]";
return name;
}
private static void createInitFunction(List<String> r, Database database) {
r.add("initFramework <- function() {");
r.add(" insertSql <- c()");
for (Table table : database.getTables()) {
String sqlTableName = convertToSqlName(table.getName());
r.add(" insertSql <- c(insertSql, \"TRUNCATE TABLE " + sqlTableName + ";\")");
}
r.add(" assign(\"insertSql\", insertSql, envir = globalenv())");
r.add(" testSql <- c()");
r.add(" testSql <- c(testSql, \"IF OBJECT_ID('test_results', 'U') IS NOT NULL\")");
r.add(" testSql <- c(testSql, \" DROP TABLE test_results;\")");
r.add(" testSql <- c(testSql, \"\")");
r.add(" testSql <- c(testSql, \"CREATE TABLE test_results (id INT, description VARCHAR(512), test VARCHAR(256), status VARCHAR(5));\")");
r.add(" testSql <- c(testSql, \"\")");
r.add(" assign(\"testSql\", testSql, envir = globalenv())");
r.add(" assign(\"testId\", 1, envir = globalenv())");
r.add(" assign(\"testDescription\", \"\", envir = globalenv())");
r.add("");
r.add(" defaultValues <- new.env(parent = globalenv())");
r.add(" assign(\"defaultValues\", defaultValues, envir = globalenv())");
for (Table table : database.getTables()) {
if (!table.isStem()) {
String rTableName = convertToRName(table.getName());
r.add("");
r.add(" defaults <- list()");
for (Field field : table.getFields()) {
String rFieldName = field.getName().replaceAll(" ", "_").replaceAll("-", "_");
String defaultValue;
if (field.getValueCounts().length == 0)
defaultValue = "";
else
defaultValue = field.getValueCounts()[0][0];
if (!defaultValue.equals(""))
r.add(" defaults$" + rFieldName + " <- \"" + defaultValue + "\"");
}
r.add(" assign(\"" + rTableName + "\", defaults, envir = defaultValues)");
}
}
r.add("}");
r.add("");
r.add("initFramework()");
r.add("");
}
private static void createAddFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
if (!table.isStem()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
String sqlTableName = convertToSqlName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> insertLines = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = field.getName().replaceAll(" ", "_").replaceAll("-", "_");
String sqlFieldName = convertToSqlName(field.getName());
argDefs.add(rFieldName);
insertLines.add(" if (missing(" + rFieldName + ")) {");
insertLines.add(" " + rFieldName + " <- defaults$" + rFieldName);
insertLines.add(" }");
insertLines.add(" if (!is.null(" + rFieldName + ")) {");
insertLines.add(" insertFields <- c(insertFields, \"" + sqlFieldName + "\")");
insertLines.add(" insertValues <- c(insertValues, " + rFieldName + ")");
insertLines.add(" }");
insertLines.add("");
}
line.append("add_" + rTableName + " <- function(");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
r.add(" defaults <- get(\"" + rTableName + "\", envir = defaultValues)");
r.add(" insertFields <- c()");
r.add(" insertValues <- c()");
r.addAll(insertLines);
line = new StringBuilder();
line.append(" statement <- paste0(\"INSERT INTO " + sqlTableName + " (\", ");
line.append("paste(insertFields, collapse = \", \"), ");
line.append("\") VALUES ('\", ");
line.append("paste(insertValues, collapse = \"', '\"), ");
line.append("\"');\")");
r.add(line.toString());
r.add(" assign(\"insertSql\", c(get(\"insertSql\", envir = globalenv()), statement), envir = globalenv())");
r.add(" invisible(statement)");
r.add("}");
r.add("");
}
}
}
private static void createSetDefaultFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
if (!table.isStem()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> insertLines = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = field.getName().replaceAll(" ", "_").replaceAll("-", "_");
argDefs.add(rFieldName);
insertLines.add(" if (!missing(" + rFieldName + ")) {");
insertLines.add(" defaults$" + rFieldName + " <- " + rFieldName);
insertLines.add(" }");
}
line.append("set_defaults_" + rTableName + " <- function(");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
r.add(" defaults <- get(\"" + rTableName + "\", envir = defaultValues)");
r.addAll(insertLines);
r.add(" assign(\"" + rTableName + "\", defaults, envir = defaultValues)");
r.add(" invisible(defaults)");
r.add("}");
r.add("");
}
}
}
private static void createGetDefaultFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
String rTableName = convertToRName(table.getName());
r.add("get_defaults_" + rTableName + " <- function() {");
r.add(" defaults <- get(\"" + rTableName + "\", envir = defaultValues)");
r.add(" return(defaults)");
r.add("}");
r.add("");
}
}
private static String convertToRName(String name) {
name = name.replaceAll(" ", "_").replaceAll("-", "_");
return name;
}
}