/* * Microsoft JDBC Driver for SQL Server * * Copyright(c) Microsoft Corporation All rights reserved. * * This program is made available under the terms of the MIT License. See the LICENSE file in the project root for more information. */ package com.microsoft.sqlserver.testframework; import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.ByteArrayInputStream; import java.io.CharArrayReader; import java.net.URI; import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.logging.Level; import java.util.logging.Logger; import com.microsoft.sqlserver.testframework.sqlType.SqlBigInt; import com.microsoft.sqlserver.testframework.sqlType.SqlBinary; import com.microsoft.sqlserver.testframework.sqlType.SqlBit; import com.microsoft.sqlserver.testframework.sqlType.SqlChar; import com.microsoft.sqlserver.testframework.sqlType.SqlDate; import com.microsoft.sqlserver.testframework.sqlType.SqlDateTime; import com.microsoft.sqlserver.testframework.sqlType.SqlDateTime2; import com.microsoft.sqlserver.testframework.sqlType.SqlDateTimeOffset; import com.microsoft.sqlserver.testframework.sqlType.SqlDecimal; import com.microsoft.sqlserver.testframework.sqlType.SqlFloat; import com.microsoft.sqlserver.testframework.sqlType.SqlInt; import com.microsoft.sqlserver.testframework.sqlType.SqlMoney; import com.microsoft.sqlserver.testframework.sqlType.SqlNChar; import com.microsoft.sqlserver.testframework.sqlType.SqlNVarChar; import com.microsoft.sqlserver.testframework.sqlType.SqlNVarCharMax; import com.microsoft.sqlserver.testframework.sqlType.SqlNumeric; import com.microsoft.sqlserver.testframework.sqlType.SqlReal; import com.microsoft.sqlserver.testframework.sqlType.SqlSmallDateTime; import com.microsoft.sqlserver.testframework.sqlType.SqlSmallInt; import com.microsoft.sqlserver.testframework.sqlType.SqlSmallMoney; import com.microsoft.sqlserver.testframework.sqlType.SqlTime; import com.microsoft.sqlserver.testframework.sqlType.SqlTinyInt; import com.microsoft.sqlserver.testframework.sqlType.SqlType; import com.microsoft.sqlserver.testframework.sqlType.SqlVarBinary; import com.microsoft.sqlserver.testframework.sqlType.SqlVarBinaryMax; import com.microsoft.sqlserver.testframework.sqlType.SqlVarChar; import com.microsoft.sqlserver.testframework.sqlType.SqlVarCharMax; /** * Generic Utility class which we can access by test classes. * * @since 6.1.2 */ public class Utils { public static final Logger log = Logger.getLogger("Utils"); // 'SQL' represents SQL Server, while 'SQLAzure' represents SQL Azure. public static final String SERVER_TYPE_SQL_SERVER = "SQL"; public static final String SERVER_TYPE_SQL_AZURE = "SQLAzure"; // private static SqlType types = null; private static ArrayList<SqlType> types = null; /** * Returns serverType * * @return */ public static String getServerType() { String serverType = null; String serverTypeProperty = getConfiguredProperty("server.type"); if (null == serverTypeProperty) { // default to SQL Server serverType = SERVER_TYPE_SQL_SERVER; } else if (serverTypeProperty.equalsIgnoreCase(SERVER_TYPE_SQL_AZURE)) { serverType = SERVER_TYPE_SQL_AZURE; } else if (serverTypeProperty.equalsIgnoreCase(SERVER_TYPE_SQL_SERVER)) { serverType = SERVER_TYPE_SQL_SERVER; } else { if (log.isLoggable(Level.FINE)) { log.fine("Server.type '" + serverTypeProperty + "' is not supported yet. Default to SQL Server"); } serverType = SERVER_TYPE_SQL_SERVER; } return serverType; } /** * Read variable from property files if found null try to read from env. * * @param key * @return Value */ public static String getConfiguredProperty(String key) { String value = System.getProperty(key); if (value == null) { value = System.getenv(key); } return value; } /** * Convenient method for {@link #getConfiguredProperty(String)} * * @param key * @return Value */ public static String getConfiguredProperty(String key, String defaultValue) { String value = getConfiguredProperty(key); if (value == null) { value = defaultValue; } return value; } /** * * @param javatype * @return */ public static SqlType find(Class javatype) { if (null != types) { types(); for (int i = 0; i < types.size(); i++) { SqlType type = types.get(i); if (type.getType() == javatype) return type; } } return null; } /** * * @param name * @return */ public static SqlType find(String name) { if (null == types) types(); if (null != types) { for (int i = 0; i < types.size(); i++) { SqlType type = types.get(i); if (type.getName().equalsIgnoreCase(name)) return type; } } return null; } /** * * @return */ public static ArrayList<SqlType> types() { if (null == types) { types = new ArrayList<SqlType>(); types.add(new SqlInt()); types.add(new SqlSmallInt()); types.add(new SqlTinyInt()); types.add(new SqlBit()); types.add(new SqlDateTime()); types.add(new SqlSmallDateTime()); types.add(new SqlDecimal()); types.add(new SqlNumeric()); types.add(new SqlReal()); types.add(new SqlFloat()); types.add(new SqlMoney()); types.add(new SqlSmallMoney()); types.add(new SqlVarChar()); types.add(new SqlChar()); // types.add(new SqlText()); types.add(new SqlBinary()); types.add(new SqlVarBinary()); // types.add(new SqlImage()); // types.add(new SqlTimestamp()); types.add(new SqlNVarChar()); types.add(new SqlNChar()); // types.add(new SqlNText()); // types.add(new SqlGuid()); types.add(new SqlBigInt()); // types.add(new SqlVariant(this)); // 9.0 types types.add(new SqlVarCharMax()); types.add(new SqlNVarCharMax()); types.add(new SqlVarBinaryMax()); // types.add(new SqlXml()); // 10.0 types types.add(new SqlDate()); types.add(new SqlDateTime2()); types.add(new SqlTime()); types.add(new SqlDateTimeOffset()); } return types; } /** * Wrapper Class for BinaryStream * */ public static class DBBinaryStream extends ByteArrayInputStream { byte[] data; // Constructor public DBBinaryStream(byte[] value) { super(value); data = value; } } /** * Wrapper for CharacterStream * */ public static class DBCharacterStream extends CharArrayReader { String localValue; /** * Constructor * * @param value */ public DBCharacterStream(String value) { super(value.toCharArray()); localValue = value; } } /** * Wrapper for NCharacterStream */ class DBNCharacterStream extends DBCharacterStream { // Constructor public DBNCharacterStream(String value) { super(value); } } /** * * @return location of resource file */ public static String getCurrentClassPath() { try { String className = new Object() { }.getClass().getEnclosingClass().getName(); String location = Class.forName(className).getProtectionDomain().getCodeSource().getLocation().getPath() + "/"; URI uri = new URI(location.toString()); return uri.getPath(); } catch (Exception e) { fail("Failed to get CSV file path. " + e.getMessage()); } return null; } /** * mimic "DROP TABLE IF EXISTS ..." for older versions of SQL Server */ public static void dropTableIfExists(String tableName, java.sql.Statement stmt) throws SQLException { dropObjectIfExists(tableName, "IsTable", stmt); } /** * mimic "DROP PROCEDURE IF EXISTS ..." for older versions of SQL Server */ public static void dropProcedureIfExists(String procName, java.sql.Statement stmt) throws SQLException { dropObjectIfExists(procName, "IsProcedure", stmt); } /** * actually perform the "DROP TABLE / PROCEDURE" */ private static void dropObjectIfExists(String objectName, String objectProperty, java.sql.Statement stmt) throws SQLException { StringBuilder sb = new StringBuilder(); if (!objectName.startsWith("[")) { sb.append("["); } sb.append(objectName); if (!objectName.endsWith("]")) { sb.append("]"); } String bracketedObjectName = sb.toString(); String sql = String.format( "IF EXISTS " + "( " + "SELECT * from sys.objects " + "WHERE object_id = OBJECT_ID(N'%s') AND OBJECTPROPERTY(object_id, N'%s') = 1 " + ") " + "DROP %s %s ", bracketedObjectName, objectProperty, "IsProcedure".equals(objectProperty) ? "PROCEDURE" : "TABLE", bracketedObjectName); stmt.executeUpdate(sql); } public static boolean parseByte(byte[] expectedData, byte[] retrieved) { assertTrue(Arrays.equals(expectedData, Arrays.copyOf(retrieved, expectedData.length)), " unexpected BINARY value, expected"); for (int i = expectedData.length; i < retrieved.length; i++) { assertTrue(0 == retrieved[i], "unexpected data BINARY"); } return true; } }