/*******************************************************************************
* Copyright 2016 Observational Health Data Sciences and Informatics
*
* This file is part of WhiteRabbit
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package org.ohdsi.rabbitInAHat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.ohdsi.rabbitInAHat.dataModel.Database;
import org.ohdsi.rabbitInAHat.dataModel.ETL;
import org.ohdsi.rabbitInAHat.dataModel.Field;
import org.ohdsi.rabbitInAHat.dataModel.Table;
import org.ohdsi.utilities.StringUtilities;
import org.ohdsi.utilities.files.WriteTextFile;
public class ETLPackageTestFrameWorkGenerator {
public static String[] keywords = new String[] { "ADD", "ALL", "ALTER", "AND", "ANY", "AS", "ASC", "AUTHORIZATION", "BACKUP", "BEGIN", "BETWEEN",
"BREAK", "BROWSE", "BULK", "BY", "CASCADE", "CASE", "CHECK", "CHECKPOINT", "CLOSE", "CLUSTERED", "COALESCE", "COLLATE", "COLUMN", "COMMIT",
"COMPUTE", "CONSTRAINT", "CONTAINS", "CONTAINSTABLE", "CONTINUE", "CONVERT", "CREATE", "CROSS", "CURRENT", "CURRENT_DATE", "CURRENT_TIME",
"CURRENT_TIMESTAMP", "CURRENT_USER", "CURSOR", "DATABASE", "DBCC", "DEALLOCATE", "DECLARE", "DEFAULT", "DELETE", "DENY", "DESC", "DISK",
"DISTINCT", "DISTRIBUTED", "DOUBLE", "DROP", "DUMP", "ELSE", "END", "ERRLVL", "ESCAPE", "EXCEPT", "EXEC", "EXECUTE", "EXISTS", "EXIT", "EXTERNAL",
"FETCH", "FILE", "FILLFACTOR", "FOR", "FOREIGN", "FREETEXT", "FREETEXTTABLE", "FROM", "FULL", "FUNCTION", "GOTO", "GRANT", "GROUP", "HAVING",
"HOLDLOCK", "IDENTITY", "IDENTITY_INSERT", "IDENTITYCOL", "IF", "IN", "INDEX", "INNER", "INSERT", "INTERSECT", "INTO", "IS", "JOIN", "KEY", "KILL",
"LEFT", "LIKE", "LINENO", "LOAD", "MERGE", "NATIONAL", "NOCHECK", "NONCLUSTERED", "NOT", "NULL", "NULLIF", "OF", "OFF", "OFFSETS", "ON", "OPEN",
"OPENDATASOURCE", "OPENQUERY", "OPENROWSET", "OPENXML", "OPTION", "OR", "ORDER", "OUTER", "OVER", "PERCENT", "PIVOT", "PLAN", "PRECISION",
"PRIMARY", "PRINT", "PROC", "PROCEDURE", "PUBLIC", "RAISERROR", "READ", "READTEXT", "RECONFIGURE", "REFERENCES", "REPLICATION", "RESTORE",
"RESTRICT", "RETURN", "REVERT", "REVOKE", "RIGHT", "ROLLBACK", "ROWCOUNT", "ROWGUIDCOL", "RULE", "SAVE", "SCHEMA", "SECURITYAUDIT", "SELECT",
"SEMANTICKEYPHRASETABLE", "SEMANTICSIMILARITYDETAILSTABLE", "SEMANTICSIMILARITYTABLE", "SESSION_USER", "SET", "SETUSER", "SHUTDOWN", "SOME",
"STATISTICS", "SYSTEM_USER", "TABLE", "TABLESAMPLE", "TEXTSIZE", "THEN", "TO", "TOP", "TRAN", "TRANSACTION", "TRIGGER", "TRUNCATE", "TRY_CONVERT",
"TSEQUAL", "UNION", "UNIQUE", "UNPIVOT", "UPDATE", "UPDATETEXT", "USE", "USER", "VALUES", "VARYING", "VIEW", "WAITFOR", "WHEN", "WHERE", "WHILE",
"WITH", "WITHIN GROUP", "WRITETEXT" };
private static Set<String> keywordSet;
private static int DEFAULT = 0;
private static int NEGATE = 1;
private static int COUNT = 2;
public static void generate(ETL etl, String filename) {
keywordSet = new HashSet<String>();
for (String keyword : keywords)
keywordSet.add(keyword);
List<String> r = generateRScript(etl);
WriteTextFile out = new WriteTextFile(filename);
for (String line : r)
out.writeln(line);
out.close();
}
private static List<String> generateRScript(ETL etl) {
List<String> r = new ArrayList<String>();
createInitFunction(r, etl.getSourceDatabase());
createDeclareTestFunction(r);
createSetDefaultFunctions(r, etl.getSourceDatabase());
createGetDefaultFunctions(r, etl.getSourceDatabase());
createAddFunctions(r, etl.getSourceDatabase());
createExpectFunctions(r, DEFAULT, etl.getTargetDatabase());
createExpectFunctions(r, NEGATE, etl.getTargetDatabase());
createExpectFunctions(r, COUNT, etl.getTargetDatabase());
createLookupFunctions(r, etl.getTargetDatabase());
return r;
}
private static void createDeclareTestFunction(List<String> r) {
r.add("declareTestGroup <- function(groupName) {");
r.add(" frameworkContext$groupIndex <- frameworkContext$groupIndex + 1 ;");
r.add(" frameworkContext$currentGroup <- {}");
r.add("");
r.add(" frameworkContext$currentGroup$groupName <- groupName;");
r.add(" frameworkContext$currentGroup$groupItemIndex <- -1;");
r.add(" sql <- c(\"\",paste0(\"-- \", frameworkContext$groupIndex, \". \", groupName));");
r.add(" frameworkContext$insertSql = c(frameworkContext$insertSql, sql);");
r.add(" frameworkContext$testSql = c(frameworkContext$testSql, sql);");
r.add("}");
r.add("");
r.add("declareTest <- function(description, source_pid = NULL, cdm_pid = NULL) {");
r.add(" frameworkContext$testId = frameworkContext$testId + 1;");
r.add(" frameworkContext$testDescription = description;");
r.add(" frameworkContext$patient$source_pid = source_pid;");
r.add(" frameworkContext$patient$cdm_pid = cdm_pid;");
r.add(" if (is.null(frameworkContext$currentGroup)) { ");
r.add(" sql <- c(paste0(\"-- Test \", frameworkContext$testId, \": \", frameworkContext$testDescription));");
r.add(" } else {");
r.add(" frameworkContext$currentGroup$groupItemIndex = frameworkContext$currentGroup$groupItemIndex + 1;");
r.add(" sql <- c(paste0(\"-- \", frameworkContext$groupIndex, \".\", frameworkContext$currentGroup$groupItemIndex, \" \", frameworkContext$testDescription, \" [Test ID: \", frameworkContext$testId, \"]\"));");
r.add(" }");
r.add(" frameworkContext$insertSql = c(frameworkContext$insertSql, \"--\",sql,\"--\");");
r.add(" frameworkContext$testSql = c(frameworkContext$testSql, \"--\",sql,\"--\");");
r.add("}");
r.add("");
}
private static void createExpectFunctions(List<String> r, int type, Database database) {
for (Table table : database.getTables()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
String sqlTableName = convertToSqlName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> testDefs = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = convertToRName(field.getName());
String sqlFieldName = convertToSqlName(field.getName());
argDefs.add(rFieldName);
testDefs.add(" if (!missing(" + rFieldName + ")) {");
testDefs.add(" if (is.null(" + rFieldName + ")) {");
testDefs.add(" whereClauses <- c(whereClauses, \"" + sqlFieldName + " IS NULL\")");
testDefs.add(" } else if (is(" + rFieldName + ", \"subQuery\")){");
testDefs.add(" whereClauses <- c(whereClauses, paste0(\"" + sqlFieldName + " = (\", as.character(" + rFieldName + "), \")\"))");
testDefs.add(" } else {");
testDefs.add(" whereClauses <- c(whereClauses, paste0(\"" + sqlFieldName + " = '\", " + rFieldName + ",\"'\"))");
testDefs.add(" }");
testDefs.add(" }");
testDefs.add("");
}
if (type == DEFAULT)
line.append("expect_" + rTableName + " <- function(");
else if (type == NEGATE)
line.append("expect_no_" + rTableName + " <- function(");
else
line.append("expect_count_" + rTableName + " <- function(rowCount, ");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
r.add("");
r.add(" if (is.null(frameworkContext$currentGroup)) {");
r.add(" testName <- frameworkContext$testDescription;");
r.add(" } else {");
r.add(" testName <- paste0(frameworkContext$groupIndex, \".\", frameworkContext$currentGroup$groupItemIndex, \" \", frameworkContext$testDescription);");
r.add(" }");
r.add("");
r.add(" source_pid <- frameworkContext$patient$source_pid;");
r.add(" if (is.null(source_pid)) {");
r.add(" source_pid <- \"NULL\";");
r.add(" } else {");
r.add(" source_pid <- paste0(\"'\", as.character(source_pid), \"'\");");
r.add(" }");
r.add("");
r.add(" cdm_pid <- frameworkContext$patient$cdm_pid;");
r.add(" if (is.null(cdm_pid)) {");
r.add(" cdm_pid <- \"NULL\"");
r.add(" }");
r.add("");
line = new StringBuilder();
line.append(" statement <- paste0(\"INSERT INTO @cdm_schema.test_results (id, description, test, source_pid, cdm_pid, status) SELECT ");
line.append("\", frameworkContext$testId, \" AS id, ");
line.append("'\", testName, \"' AS description, ");
line.append("'Expect " + table.getName() + "' AS test, ");
line.append("\", source_pid, \" as source_pid, ");
line.append("\", cdm_pid, \" as cdm_pid, ");
line.append("CASE WHEN(SELECT COUNT(*) FROM @cdm_schema." + sqlTableName + " WHERE \")");
r.add(line.toString());
r.add(" whereClauses = NULL;");
r.addAll(testDefs);
r.add(" statement <- paste0(statement, paste0(whereClauses, collapse=\" AND \"));");
if (type == DEFAULT)
r.add(" statement <- paste0(statement, \") = 0 THEN 'FAIL' ELSE 'PASS' END AS status;\")");
else if (type == NEGATE)
r.add(" statement <- paste0(statement, \") != 0 THEN 'FAIL' ELSE 'PASS' END AS status;\")");
else
r.add(" statement <- paste0(statement, \") != \",rowCount ,\" THEN 'FAIL' ELSE 'PASS' END AS status;\")");
r.add(" frameworkContext$testSql = c(frameworkContext$testSql, statement);");
r.add(" invisible(statement)");
r.add("}");
r.add("");
}
}
private static void createLookupFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
String sqlTableName = convertToSqlName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> testDefs = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = convertToRName(field.getName());
String sqlFieldName = convertToSqlName(field.getName());
argDefs.add(rFieldName);
testDefs.add(" if (!missing(" + rFieldName + ")) {");
testDefs.add(" if (is.null(" + rFieldName + ")) {");
testDefs.add(" whereClauses <- c(whereClauses, \"" + sqlFieldName + " IS NULL\")");
testDefs.add(" } else {");
testDefs.add(" whereClauses <- c(whereClauses, paste0(\"" + sqlFieldName + " = '\", " + rFieldName + ",\"'\"))");
testDefs.add(" }");
testDefs.add(" }");
testDefs.add("");
}
line.append("lookup_" + rTableName + " <- function(fetchField, ");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
r.add(" whereClauses = NULL;");
line = new StringBuilder();
line.append(" statement <- paste0(\"SELECT \", fetchField , \" FROM @cdm_schema.");
line.append(sqlTableName);
line.append(" WHERE \")");
r.add(line.toString());
r.addAll(testDefs);
r.add(" statement <- paste0(statement, paste0(whereClauses, collapse=\" AND \"));");
r.add(" class(statement) <- \"subQuery\"");
r.add(" return(statement)");
r.add("}");
r.add("");
}
}
private static String convertToSqlName(String name) {
if (name.contains(" ") || name.contains(".") || keywordSet.contains(name.toUpperCase()))
name = "[" + name + "]";
return name;
}
private static void createInitFunction(List<String> r, Database database) {
r.add("frameworkContext <- new.env(parent = emptyenv());");
r.add("initFramework <- function() {");
r.add(" frameworkContext$groupIndex <- 0;");
r.add(" insertSql <- c()");
for (Table table : database.getTables()) {
String sqlTableName = convertToSqlName(table.getName());
r.add(" insertSql <- c(insertSql, \"TRUNCATE TABLE @source_schema." + sqlTableName + ";\")");
}
r.add(" frameworkContext$insertSql <- insertSql;");
r.add(" testSql <- c()");
r.add(" testSql <- c(testSql, \"IF OBJECT_ID('@cdm_schema.test_results', 'U') IS NOT NULL\")");
r.add(" testSql <- c(testSql, \" DROP TABLE @cdm_schema.test_results;\")");
r.add(" testSql <- c(testSql, \"\")");
r.add(" testSql <- c(testSql, \"CREATE TABLE @cdm_schema.test_results (id INT, description VARCHAR(512), test VARCHAR(256), source_pid VARCHAR(50), cdm_pid int, status VARCHAR(5));\")");
r.add(" testSql <- c(testSql, \"\")");
r.add(" frameworkContext$testSql <- testSql;");
r.add(" frameworkContext$testId = 0;");
r.add(" frameworkContext$testDescription = \"\";");
r.add("");
r.add(" patient <- {}");
r.add(" patient$source_pid <- NULL");
r.add(" patient$cdm_pid <- NULL");
r.add(" frameworkContext$patient = patient;");
r.add("");
r.add(" frameworkContext$defaultValues =new.env(parent = emptyenv());");
for (Table table : database.getTables()) {
String rTableName = convertToRName(table.getName());
r.add("");
r.add(" defaults <- new.env(parent = emptyenv())");
for (Field field : table.getFields()) {
String rFieldName = field.getName().replaceAll(" ", "_").replaceAll("-", "_");
String defaultValue;
if (field.getValueCounts().length == 0 || field.getValueCounts()[0][0].equalsIgnoreCase("List truncated..."))
defaultValue = "";
else
defaultValue = field.getValueCounts()[0][0];
if (!defaultValue.equals(""))
r.add(" defaults$" + rFieldName + " <- \"" + defaultValue + "\"");
}
r.add(" frameworkContext$defaultValues$" + rTableName + " = defaults;");
}
r.add("}");
r.add("");
}
private static void createAddFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
String sqlTableName = convertToSqlName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> insertLines = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = field.getName().replaceAll(" ", "_").replaceAll("-", "_");
String sqlFieldName = convertToSqlName(field.getName());
argDefs.add(rFieldName);
insertLines.add(" if (missing(" + rFieldName + ")) {");
insertLines.add(" " + rFieldName + " <- defaults$" + rFieldName);
insertLines.add(" }");
insertLines.add(" if (!is.null(" + rFieldName + ")) {");
insertLines.add(" insertFields <- c(insertFields, \"" + sqlFieldName + "\")");
insertLines.add(" insertValues <- c(insertValues, " + rFieldName + ")");
insertLines.add(" }");
insertLines.add("");
}
line.append("add_" + rTableName + " <- function(");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
r.add(" defaults <- frameworkContext$defaultValues$"+ rTableName + ";");
r.add(" insertFields <- c()");
r.add(" insertValues <- c()");
r.addAll(insertLines);
line = new StringBuilder();
line.append(" statement <- paste0(\"INSERT INTO @source_schema." + sqlTableName + " (\", ");
line.append("paste(insertFields, collapse = \", \"), ");
line.append("\") VALUES ('\", ");
line.append("paste(insertValues, collapse = \"', '\"), ");
line.append("\"');\")");
r.add(line.toString());
r.add(" frameworkContext$insertSql = c(frameworkContext$insertSql, statement);");
r.add(" invisible(statement);");
r.add("}");
r.add("");
}
}
private static void createSetDefaultFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
StringBuilder line = new StringBuilder();
String rTableName = convertToRName(table.getName());
List<String> argDefs = new ArrayList<String>();
List<String> insertLines = new ArrayList<String>();
for (Field field : table.getFields()) {
String rFieldName = field.getName().replaceAll(" ", "_").replaceAll("-", "_");
argDefs.add(rFieldName);
insertLines.add(" if (!missing(" + rFieldName + ")) {");
insertLines.add(" defaults$" + rFieldName + " <- " + rFieldName);
insertLines.add(" }");
}
line.append("set_defaults_" + rTableName + " <- function(");
line.append(StringUtilities.join(argDefs, ", "));
line.append(") {");
r.add(line.toString());
r.add(" defaults <- frameworkContext$defaultValues$" + rTableName + ";");
r.addAll(insertLines);
r.add(" invisible(defaults)");
r.add("}");
r.add("");
}
}
private static void createGetDefaultFunctions(List<String> r, Database database) {
for (Table table : database.getTables()) {
String rTableName = convertToRName(table.getName());
r.add("get_defaults_" + rTableName + " <- function() {");
r.add(" return(frameworkContext$defaultValues)");
r.add("}");
r.add("");
}
}
private static String convertToRName(String name) {
if (name.startsWith("_") )
name = "U_" + name.substring(1);
name = name.replaceAll(" ", "_").replaceAll("-", "_");
return name;
}
}