package com.thinkbiganalytics.discovery.parsers.hadoop;
/*-
* #%L
* thinkbig-schema-discovery-default
* %%
* Copyright (C) 2017 ThinkBig Analytics
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import com.thinkbiganalytics.discovery.model.DefaultQueryResult;
import com.thinkbiganalytics.discovery.model.DefaultQueryResultColumn;
import com.thinkbiganalytics.discovery.schema.Field;
import com.thinkbiganalytics.discovery.schema.QueryResult;
import com.thinkbiganalytics.discovery.schema.QueryResultColumn;
import com.thinkbiganalytics.discovery.schema.Schema;
import com.thinkbiganalytics.discovery.util.TableSchemaType;
import com.thinkbiganalytics.spark.rest.model.TransformResponse;
import com.thinkbiganalytics.spark.shell.SparkShellProcessManager;
import com.thinkbiganalytics.spark.shell.SparkShellRestClient;
import org.apache.commons.lang3.StringUtils;
import org.mockito.Mockito;
import org.mockito.internal.util.reflection.Whitebox;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
/**
* Test Parsing of schema from the SparkSchemaParser Service
*/
public class SparkFileSchemaParserServiceTest {
private List<QueryResultColumn> decimalColumns() {
List<QueryResultColumn> columns = new ArrayList<>();
columns.add(newColumn("decimalColumn", "decimal(17,12)"));
columns.add(newColumn("stringColumn", "string"));
return columns;
}
private List<QueryResultColumn> nonDecimalColumns() {
List<QueryResultColumn> columns = new ArrayList<>();
columns.add(newColumn("intColumn", "int"));
columns.add(newColumn("stringColumn", "string"));
return columns;
}
private QueryResultColumn newColumn(String name, String dataType) {
QueryResultColumn column = new DefaultQueryResultColumn();
column.setField(name);
column.setDisplayName(name);
column.setTableName("table");
column.setDataType(dataType);
column.setDatabaseName("database");
return column;
}
private TransformResponse transformResponse(List<QueryResultColumn> columns) {
TransformResponse transformResponse = new TransformResponse();
transformResponse.setStatus(TransformResponse.Status.SUCCESS);
QueryResult result = new DefaultQueryResult("query");
result.setColumns(columns);
transformResponse.setResults(result);
return transformResponse;
}
private Schema parseQueryResult(List<QueryResultColumn> columns, SparkFileSchemaParserService.SparkFileType sparkFileType, TableSchemaType tableSchemaType) throws Exception {
final SparkShellRestClient restClient = Mockito.mock(SparkShellRestClient.class);
final SparkShellProcessManager sparkShellProcessManager = Mockito.mock(SparkShellProcessManager.class);
SparkFileSchemaParserService service = Mockito.mock(SparkFileSchemaParserService.class);
Whitebox.setInternalState(service, "shellProcessManager", sparkShellProcessManager);
Whitebox.setInternalState(service, "restClient", restClient);
Mockito.when(service.doParse(Mockito.any(InputStream.class), Mockito.any(), Mockito.any())).thenCallRealMethod();
Mockito.when(sparkShellProcessManager.getSystemProcess()).thenReturn(null);
Mockito.when(restClient.transform(Mockito.any(), Mockito.any())).thenReturn(transformResponse(columns));
byte[] b = new byte[]{};
InputStream inputStream = new ByteArrayInputStream(b);
Schema schema = service.doParse(inputStream, sparkFileType, tableSchemaType);
return schema;
}
/**
* Test to ensure the column types that have precision,scale get parsed correctly to the field.precisionScale property
*/
@org.junit.Test
public void testDecimalParsing() {
try {
Schema decimalSchema = parseQueryResult(decimalColumns(), SparkFileSchemaParserService.SparkFileType.PARQUET, TableSchemaType.HIVE);
assertNotNull(decimalSchema);
Field decimalField = decimalSchema.getFields().stream().filter(field -> field.getName().equalsIgnoreCase("decimalColumn")).findFirst().orElse(null);
assertNotNull(decimalField);
assertEquals("decimal", decimalField.getDerivedDataType());
assertEquals("17,12", decimalField.getPrecisionScale());
assertEquals("decimal(17,12)", decimalField.getDataTypeWithPrecisionAndScale());
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* Test to ensure standard columns without precision work
*/
@org.junit.Test
public void testParsing() {
try {
Schema schema = parseQueryResult(nonDecimalColumns(), SparkFileSchemaParserService.SparkFileType.AVRO, TableSchemaType.HIVE);
assertNotNull(schema);
schema.getFields().stream().forEach(field -> {
assertNotNull(field);
assertTrue(StringUtils.isBlank(field.getPrecisionScale()));
});
} catch (Exception e) {
e.printStackTrace();
}
}
}