/*
* Copyright (C) 2012-2015 DataStax Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.driver.core;
import com.datastax.driver.core.utils.CassandraVersion;
import com.google.common.collect.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.annotations.Test;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
/**
* The goal of this test is to cover the serialization and deserialization of datatypes.
* <p/>
* It creates a table with a column of a given type, inserts a value and then tries to retrieve it.
* There are 3 variants for the insert query: a raw string, a simple statement with a parameter
* (protocol > v2 only) and a prepared statement.
* This is repeated with a large number of datatypes.
*/
public class DataTypeIntegrationTest extends CCMTestsSupport {
private static final Logger logger = LoggerFactory.getLogger(DataTypeIntegrationTest.class);
private Map<DataType, Object> samples;
private List<TestTable> tables;
private VersionNumber cassandraVersion;
enum StatementType {RAW_STRING, SIMPLE_WITH_PARAM, PREPARED}
@Override
public void onTestContextInitialized() {
ProtocolVersion protocolVersion = ccm().getProtocolVersion();
samples = PrimitiveTypeSamples.samples(protocolVersion);
tables = allTables();
Host host = cluster().getMetadata().getAllHosts().iterator().next();
cassandraVersion = host.getCassandraVersion().nextStable();
List<String> statements = Lists.newArrayList();
for (TestTable table : tables) {
if (cassandraVersion.compareTo(table.minCassandraVersion) < 0)
logger.debug("Skipping table because it uses a feature not supported by Cassandra {}: {}",
cassandraVersion, table.createStatement);
else
statements.add(table.createStatement);
}
execute(statements);
}
@Test(groups = "long")
public void should_insert_and_retrieve_data_with_legacy_statements() {
should_insert_and_retrieve_data(StatementType.RAW_STRING);
}
@Test(groups = "long")
public void should_insert_and_retrieve_data_with_prepared_statements() {
should_insert_and_retrieve_data(StatementType.PREPARED);
}
@Test(groups = "long")
@CassandraVersion(value = "2.0", description = "Uses parameterized simple statements, which are only available with protocol v2")
public void should_insert_and_retrieve_data_with_parameterized_simple_statements() {
should_insert_and_retrieve_data(StatementType.SIMPLE_WITH_PARAM);
}
protected void should_insert_and_retrieve_data(StatementType statementType) {
ProtocolVersion protocolVersion = cluster().getConfiguration().getProtocolOptions().getProtocolVersion();
CodecRegistry codecRegistry = cluster().getConfiguration().getCodecRegistry();
for (TestTable table : tables) {
if (cassandraVersion.compareTo(table.minCassandraVersion) < 0)
continue;
TypeCodec<Object> codec = codecRegistry.codecFor(table.testColumnType);
switch (statementType) {
case RAW_STRING:
String formatValue = codec.format(table.sampleValue);
assertThat(formatValue).isNotNull();
String query = table.insertStatement.replace("?", formatValue);
session().execute(query);
break;
case SIMPLE_WITH_PARAM:
SimpleStatement statement = new SimpleStatement(table.insertStatement, table.sampleValue);
checkGetValuesReturnsSerializedValue(protocolVersion, statement, table);
session().execute(statement);
break;
case PREPARED:
PreparedStatement ps = session().prepare(table.insertStatement);
BoundStatement bs = ps.bind(table.sampleValue);
checkGetterReturnsValue(bs, table);
session().execute(bs);
break;
}
Row row = session().execute(table.selectStatement).one();
Object queriedValue = codec.deserialize(row.getBytesUnsafe("v"), protocolVersion);
// Since codec.deserialize will get the unboxed version for primitive check against expected unboxed value.
assertThat(queriedValue)
.as("Test failure on %s statement with table:%n%s;%n" +
"insert statement:%n%s;%n",
statementType,
table.createStatement,
table.insertStatement)
.isEqualTo(table.expectedValue);
// Since calling row.get* will return boxed version for primitives check against expected primitive value.
assertThat(getValue(row, table.testColumnType))
.as("Test failure on %s statement with table:%n%s;%n" +
"insert statement:%n%s;%n",
statementType,
table.createStatement,
table.insertStatement)
.isEqualTo(table.expectedPrimitiveValue);
session().execute(table.truncateStatement);
}
}
private void checkGetterReturnsValue(BoundStatement bs, TestTable table) {
// Driver will not serialize null references in a statement.
Object getterResult = getValue(bs, table.testColumnType);
assertThat(getterResult).as("Expected values to match for " + table.testColumnType).isEqualTo(table.expectedPrimitiveValue);
// Ensure that bs.getObject() also returns the expected value.
assertThat(bs.getObject(0)).as("Expected values to match for " + table.testColumnType).isEqualTo(table.sampleValue);
assertThat(bs.getObject("v")).as("Expected values to match for " + table.testColumnType).isEqualTo(table.sampleValue);
}
public void checkGetValuesReturnsSerializedValue(ProtocolVersion protocolVersion, SimpleStatement statement, TestTable table) {
CodecRegistry codecRegistry = cluster().getConfiguration().getCodecRegistry();
ByteBuffer[] values = statement.getValues(protocolVersion, codecRegistry);
assertThat(values.length).isEqualTo(1);
assertThat(values[0])
.as("Value not serialized as expected for " + table.sampleValue)
.isEqualTo(codecRegistry.codecFor(table.testColumnType).serialize(table.sampleValue, protocolVersion));
}
/**
* Abstracts information about a table (corresponding to a given column type).
*/
static class TestTable {
private static final AtomicInteger counter = new AtomicInteger();
private String tableName = "date_type_test" + counter.incrementAndGet();
final DataType testColumnType;
final Object sampleValue;
final Object expectedValue;
final Object expectedPrimitiveValue;
final String createStatement;
final String insertStatement = String.format("INSERT INTO %s (k, v) VALUES (1, ?)", tableName);
final String selectStatement = String.format("SELECT v FROM %s WHERE k = 1", tableName);
final String truncateStatement = String.format("TRUNCATE %s", tableName);
final VersionNumber minCassandraVersion;
TestTable(DataType testColumnType, Object sampleValue, String minCassandraVersion) {
this(testColumnType, sampleValue, sampleValue, minCassandraVersion);
}
TestTable(DataType testColumnType, Object sampleValue, Object expectedValue, String minCassandraVersion) {
this(testColumnType, sampleValue, expectedValue, expectedValue, minCassandraVersion);
}
TestTable(DataType testColumnType, Object sampleValue, Object expectedValue, Object expectedPrimitiveValue, String minCassandraVersion) {
this.testColumnType = testColumnType;
this.sampleValue = sampleValue;
this.expectedValue = expectedValue;
this.expectedPrimitiveValue = expectedPrimitiveValue;
this.minCassandraVersion = VersionNumber.parse(minCassandraVersion);
this.createStatement = String.format("CREATE TABLE %s (k int PRIMARY KEY, v %s)", tableName, testColumnType);
}
}
private List<TestTable> allTables() {
List<TestTable> tables = Lists.newArrayList();
tables.addAll(tablesWithPrimitives());
tables.addAll(tablesWithPrimitivesNull());
tables.addAll(tablesWithCollectionsOfPrimitives());
tables.addAll(tablesWithMapsOfPrimitives());
tables.addAll(tablesWithNestedCollections());
tables.addAll(tablesWithRandomlyGeneratedNestedCollections());
return ImmutableList.copyOf(tables);
}
private List<TestTable> tablesWithPrimitives() {
List<TestTable> tables = Lists.newArrayList();
for (Map.Entry<DataType, Object> entry : samples.entrySet())
tables.add(new TestTable(entry.getKey(), entry.getValue(), "1.2.0"));
return tables;
}
private List<TestTable> tablesWithPrimitivesNull() {
List<TestTable> tables = Lists.newArrayList();
// Create a test table for each primitive type testing with null values. If the
// type maps to a java primitive type it's value will be the default one specified here instead of null.
for (DataType dataType : TestUtils.allPrimitiveTypes(ccm().getProtocolVersion())) {
Object expectedPrimitiveValue = null;
switch (dataType.getName()) {
case BIGINT:
case TIME:
expectedPrimitiveValue = 0L;
break;
case DOUBLE:
expectedPrimitiveValue = 0.0;
break;
case FLOAT:
expectedPrimitiveValue = 0.0f;
break;
case INT:
expectedPrimitiveValue = 0;
break;
case SMALLINT:
expectedPrimitiveValue = (short) 0;
break;
case TINYINT:
expectedPrimitiveValue = (byte) 0;
break;
case BOOLEAN:
expectedPrimitiveValue = false;
break;
default:
// not a Java primitive type
continue;
}
tables.add(new TestTable(dataType, null, null, expectedPrimitiveValue, "1.2.0"));
}
return tables;
}
private List<TestTable> tablesWithCollectionsOfPrimitives() {
List<TestTable> tables = Lists.newArrayList();
for (Map.Entry<DataType, Object> entry : samples.entrySet()) {
DataType elementType = entry.getKey();
Object elementSample = entry.getValue();
tables.add(new TestTable(DataType.list(elementType), Lists.newArrayList(elementSample, elementSample), "1.2.0"));
// Duration not supported in Set
if (elementType != DataType.duration())
tables.add(new TestTable(DataType.set(elementType), Sets.newHashSet(elementSample), "1.2.0"));
}
return tables;
}
private List<TestTable> tablesWithMapsOfPrimitives() {
List<TestTable> tables = Lists.newArrayList();
for (Map.Entry<DataType, Object> keyEntry : samples.entrySet()) {
// Duration not supported as Map key
DataType keyType = keyEntry.getKey();
if (keyType == DataType.duration())
continue;
Object keySample = keyEntry.getValue();
for (Map.Entry<DataType, Object> valueEntry : samples.entrySet()) {
DataType valueType = valueEntry.getKey();
Object valueSample = valueEntry.getValue();
tables.add(new TestTable(DataType.map(keyType, valueType),
ImmutableMap.builder().put(keySample, valueSample).build(),
"1.2.0"));
}
}
return tables;
}
private Collection<? extends TestTable> tablesWithNestedCollections() {
List<TestTable> tables = Lists.newArrayList();
// To avoid combinatorial explosion, only use int as the primitive type, and two levels of nesting.
// This yields collections like list<frozen<map<int, int>>, map<frozen<set<int>>, frozen<list<int>>>, etc.
// Types and samples for the inner collections like frozen<list<int>>
Map<DataType, Object> childCollectionSamples = ImmutableMap.<DataType, Object>builder()
.put(DataType.frozenList(DataType.cint()), Lists.newArrayList(1, 1))
.put(DataType.frozenSet(DataType.cint()), Sets.newHashSet(1, 2))
.put(DataType.frozenMap(DataType.cint(), DataType.cint()), ImmutableMap.<Integer, Integer>builder().put(1, 2).put(3, 4).build())
.build();
for (Map.Entry<DataType, Object> entry : childCollectionSamples.entrySet()) {
DataType elementType = entry.getKey();
Object elementSample = entry.getValue();
tables.add(new TestTable(DataType.list(elementType), Lists.newArrayList(elementSample, elementSample), "2.1.3"));
tables.add(new TestTable(DataType.set(elementType), Sets.newHashSet(elementSample), "2.1.3"));
for (Map.Entry<DataType, Object> valueEntry : childCollectionSamples.entrySet()) {
DataType valueType = valueEntry.getKey();
Object valueSample = valueEntry.getValue();
tables.add(new TestTable(DataType.map(elementType, valueType),
ImmutableMap.builder().put(elementSample, valueSample).build(), "2.1.3"));
}
}
return tables;
}
private Collection<? extends TestTable> tablesWithRandomlyGeneratedNestedCollections() {
List<TestTable> tables = Lists.newArrayList();
DataType nestedListType = buildNestedType(DataType.Name.LIST, 5);
DataType nestedSetType = buildNestedType(DataType.Name.SET, 5);
DataType nestedMapType = buildNestedType(DataType.Name.MAP, 5);
tables.add(new TestTable(nestedListType, nestedObject(nestedListType), "2.1.3"));
tables.add(new TestTable(nestedSetType, nestedObject(nestedSetType), "2.1.3"));
tables.add(new TestTable(nestedMapType, nestedObject(nestedMapType), "2.1.3"));
return tables;
}
/**
* Populate a nested collection based on the given type and it's arguments.
*/
public Object nestedObject(DataType type) {
int typeIdx = type.getTypeArguments().size() > 1 ? 1 : 0;
DataType argument = type.getTypeArguments().get(typeIdx);
boolean isAtBottom = !(argument instanceof DataType.CollectionType);
if (isAtBottom) {
switch (type.getName()) {
case LIST:
return Lists.newArrayList(1, 2, 3);
case SET:
return Sets.newHashSet(1, 2, 3);
case MAP:
Map<Integer, Integer> map = Maps.newHashMap();
map.put(1, 2);
map.put(3, 4);
map.put(5, 6);
return map;
}
} else {
switch (type.getName()) {
case LIST:
List<Object> l = Lists.newArrayListWithExpectedSize(2);
for (int i = 0; i < 5; i++) {
l.add(nestedObject(argument));
}
return l;
case SET:
Set<Object> s = Sets.newHashSet();
for (int i = 0; i < 5; i++) {
s.add(nestedObject(argument));
}
return s;
case MAP:
Map<Integer, Object> map = Maps.newHashMap();
for (int i = 0; i < 5; i++) {
map.put(i, nestedObject(argument));
}
return map;
}
}
return null;
}
/**
* @param baseType The base type to use, one of SET, MAP, LIST.
* @param depth How many subcollections to generate.
* @return a DataType that is a nested collection with the given baseType with the
* given depth.
*/
public DataType buildNestedType(DataType.Name baseType, int depth) {
Random r = new Random();
DataType t = null;
for (int i = 1; i <= depth; i++) {
int chooser = r.nextInt(3);
if (t == null) {
if (chooser == 0) {
t = DataType.frozenList(DataType.cint());
} else if (chooser == 1) {
t = DataType.frozenSet(DataType.cint());
} else {
t = DataType.frozenMap(DataType.cint(), DataType.cint());
}
} else if (i == depth) {
switch (baseType) {
case LIST:
return DataType.list(t);
case SET:
return DataType.set(t);
case MAP:
return DataType.map(DataType.cint(), t);
}
} else {
if (chooser == 0) {
t = DataType.frozenList(t);
} else if (chooser == 1) {
t = DataType.frozenSet(t);
} else {
t = DataType.frozenMap(DataType.cint(), t);
}
}
}
return null;
}
private Object getValue(GettableByIndexData data, DataType dataType) {
// This is kind of lame, but better than testing all getters manually
CodecRegistry codecRegistry = cluster().getConfiguration().getCodecRegistry();
switch (dataType.getName()) {
case ASCII:
return data.getString(0);
case BIGINT:
return data.getLong(0);
case BLOB:
return data.getBytes(0);
case BOOLEAN:
return data.getBool(0);
case DECIMAL:
return data.getDecimal(0);
case DOUBLE:
return data.getDouble(0);
case FLOAT:
return data.getFloat(0);
case INET:
return data.getInet(0);
case TINYINT:
return data.getByte(0);
case SMALLINT:
return data.getShort(0);
case INT:
return data.getInt(0);
case TEXT:
case VARCHAR:
return data.getString(0);
case TIMESTAMP:
return data.getTimestamp(0);
case DATE:
return data.getDate(0);
case TIME:
return data.getTime(0);
case UUID:
case TIMEUUID:
return data.getUUID(0);
case VARINT:
return data.getVarint(0);
case LIST:
return data.getList(0, codecRegistry.codecFor(dataType.getTypeArguments().get(0)).getJavaType());
case SET:
return data.getSet(0, codecRegistry.codecFor(dataType.getTypeArguments().get(0)).getJavaType());
case MAP:
return data.getMap(0,
codecRegistry.codecFor(dataType.getTypeArguments().get(0)).getJavaType(),
codecRegistry.codecFor(dataType.getTypeArguments().get(1)).getJavaType());
case DURATION:
return data.get(0, Duration.class);
case CUSTOM:
case COUNTER:
default:
fail("Unexpected type in bound statement test: " + dataType);
return null;
}
}
}