/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.avro;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigDecimal;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TimeZone;
import org.apache.avro.Conversions;
import org.apache.avro.LogicalType;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData.Array;
import org.apache.avro.generic.GenericRecord;
import org.apache.nifi.serialization.RecordSetWriter;
import org.apache.nifi.serialization.SimpleRecordSchema;
import org.apache.nifi.serialization.WriteResult;
import org.apache.nifi.serialization.record.DataType;
import org.apache.nifi.serialization.record.MapRecord;
import org.apache.nifi.serialization.record.Record;
import org.apache.nifi.serialization.record.RecordField;
import org.apache.nifi.serialization.record.RecordFieldType;
import org.apache.nifi.serialization.record.RecordSchema;
import org.apache.nifi.serialization.record.RecordSet;
import org.junit.Test;
public abstract class TestWriteAvroResult {
protected abstract RecordSetWriter createWriter(Schema schema, OutputStream out) throws IOException;
protected abstract GenericRecord readRecord(InputStream in, Schema schema) throws IOException;
protected void verify(final WriteResult writeResult) {
}
@Test
public void testLogicalTypes() throws IOException, ParseException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/logical-types.avsc"));
testLogicalTypes(schema);
}
@Test
public void testNullableLogicalTypes() throws IOException, ParseException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/logical-types-nullable.avsc"));
testLogicalTypes(schema);
}
private void testLogicalTypes(Schema schema) throws ParseException, IOException {
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final List<RecordField> fields = new ArrayList<>();
fields.add(new RecordField("timeMillis", RecordFieldType.TIME.getDataType()));
fields.add(new RecordField("timeMicros", RecordFieldType.TIME.getDataType()));
fields.add(new RecordField("timestampMillis", RecordFieldType.TIMESTAMP.getDataType()));
fields.add(new RecordField("timestampMicros", RecordFieldType.TIMESTAMP.getDataType()));
fields.add(new RecordField("date", RecordFieldType.DATE.getDataType()));
// Avro decimal is represented as double in NiFi type system.
fields.add(new RecordField("decimal", RecordFieldType.DOUBLE.getDataType()));
final RecordSchema recordSchema = new SimpleRecordSchema(fields);
final String expectedTime = "2017-04-04 14:20:33.789";
final DateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS");
df.setTimeZone(TimeZone.getTimeZone("gmt"));
final long timeLong = df.parse(expectedTime).getTime();
final Map<String, Object> values = new HashMap<>();
values.put("timeMillis", new Time(timeLong));
values.put("timeMicros", new Time(timeLong));
values.put("timestampMillis", new Timestamp(timeLong));
values.put("timestampMicros", new Timestamp(timeLong));
values.put("date", new Date(timeLong));
// Avro decimal is represented as double in NiFi type system.
final BigDecimal expectedDecimal = new BigDecimal("123.45");
values.put("decimal", expectedDecimal.doubleValue());
final Record record = new MapRecord(recordSchema, values);
try (final RecordSetWriter writer = createWriter(schema, baos)) {
writer.write(RecordSet.of(record.getSchema(), record));
}
final byte[] data = baos.toByteArray();
try (final InputStream in = new ByteArrayInputStream(data)) {
final GenericRecord avroRecord = readRecord(in, schema);
final long secondsSinceMidnight = 33 + (20 * 60) + (14 * 60 * 60);
final long millisSinceMidnight = (secondsSinceMidnight * 1000L) + 789;
assertEquals((int) millisSinceMidnight, avroRecord.get("timeMillis"));
assertEquals(millisSinceMidnight * 1000L, avroRecord.get("timeMicros"));
assertEquals(timeLong, avroRecord.get("timestampMillis"));
assertEquals(timeLong * 1000L, avroRecord.get("timestampMicros"));
assertEquals(17260, avroRecord.get("date"));
// Double value will be converted into logical decimal if Avro schema is defined as logical decimal.
final Schema decimalSchema = schema.getField("decimal").schema();
final LogicalType logicalType = decimalSchema.getLogicalType() != null
? decimalSchema.getLogicalType()
// Union type doesn't return logical type. Find the first logical type defined within the union.
: decimalSchema.getTypes().stream().map(s -> s.getLogicalType()).filter(Objects::nonNull).findFirst().get();
final BigDecimal decimal = new Conversions.DecimalConversion().fromBytes((ByteBuffer) avroRecord.get("decimal"), decimalSchema, logicalType);
assertEquals(expectedDecimal, decimal);
}
}
@Test
public void testDataTypes() throws IOException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/datatypes.avsc"));
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final List<RecordField> subRecordFields = Collections.singletonList(new RecordField("field1", RecordFieldType.STRING.getDataType()));
final RecordSchema subRecordSchema = new SimpleRecordSchema(subRecordFields);
final DataType subRecordDataType = RecordFieldType.RECORD.getRecordDataType(subRecordSchema);
final List<RecordField> fields = new ArrayList<>();
fields.add(new RecordField("string", RecordFieldType.STRING.getDataType()));
fields.add(new RecordField("int", RecordFieldType.INT.getDataType()));
fields.add(new RecordField("long", RecordFieldType.LONG.getDataType()));
fields.add(new RecordField("double", RecordFieldType.DOUBLE.getDataType()));
fields.add(new RecordField("float", RecordFieldType.FLOAT.getDataType()));
fields.add(new RecordField("boolean", RecordFieldType.BOOLEAN.getDataType()));
fields.add(new RecordField("bytes", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType())));
fields.add(new RecordField("nullOrLong", RecordFieldType.LONG.getDataType()));
fields.add(new RecordField("array", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType())));
fields.add(new RecordField("record", subRecordDataType));
fields.add(new RecordField("map", RecordFieldType.MAP.getMapDataType(subRecordDataType)));
final RecordSchema recordSchema = new SimpleRecordSchema(fields);
final Record innerRecord = new MapRecord(subRecordSchema, Collections.singletonMap("field1", "hello"));
final Map<String, Object> innerMap = new HashMap<>();
innerMap.put("key1", innerRecord);
final Map<String, Object> values = new HashMap<>();
values.put("string", "hello");
values.put("int", 8);
values.put("long", 42L);
values.put("double", 3.14159D);
values.put("float", 1.23456F);
values.put("boolean", true);
values.put("bytes", AvroTypeUtil.convertByteArray("hello".getBytes()));
values.put("nullOrLong", null);
values.put("array", new Integer[] {1, 2, 3});
values.put("record", innerRecord);
values.put("map", innerMap);
final Record record = new MapRecord(recordSchema, values);
final WriteResult writeResult;
try (final RecordSetWriter writer = createWriter(schema, baos)) {
writeResult = writer.write(RecordSet.of(record.getSchema(), record));
}
verify(writeResult);
final byte[] data = baos.toByteArray();
try (final InputStream in = new ByteArrayInputStream(data)) {
final GenericRecord avroRecord = readRecord(in, schema);
assertMatch(record, avroRecord);
}
}
protected void assertMatch(final Record record, final GenericRecord avroRecord) {
for (final String fieldName : record.getSchema().getFieldNames()) {
Object avroValue = avroRecord.get(fieldName);
final Object recordValue = record.getValue(fieldName);
if (recordValue instanceof String) {
assertNotNull(fieldName + " should not have been null", avroValue);
avroValue = avroValue.toString();
}
if (recordValue instanceof Object[] && avroValue instanceof ByteBuffer) {
final ByteBuffer bb = (ByteBuffer) avroValue;
final Object[] objectArray = (Object[]) recordValue;
assertEquals("For field " + fieldName + ", byte buffer remaining should have been " + objectArray.length + " but was " + bb.remaining(),
objectArray.length, bb.remaining());
for (int i = 0; i < objectArray.length; i++) {
assertEquals(objectArray[i], bb.get());
}
} else if (recordValue instanceof Object[]) {
assertTrue(fieldName + " should have been instanceof Array", avroValue instanceof Array);
final Array<?> avroArray = (Array<?>) avroValue;
final Object[] recordArray = (Object[]) recordValue;
assertEquals(fieldName + " not equal", recordArray.length, avroArray.size());
for (int i = 0; i < recordArray.length; i++) {
assertEquals(fieldName + "[" + i + "] not equal", recordArray[i], avroArray.get(i));
}
} else if (recordValue instanceof byte[]) {
final ByteBuffer bb = ByteBuffer.wrap((byte[]) recordValue);
assertEquals(fieldName + " not equal", bb, avroValue);
} else if (recordValue instanceof Map) {
assertTrue(fieldName + " should have been instanceof Map", avroValue instanceof Map);
final Map<?, ?> avroMap = (Map<?, ?>) avroValue;
final Map<?, ?> recordMap = (Map<?, ?>) recordValue;
assertEquals(fieldName + " not equal", recordMap.size(), avroMap.size());
for (Object s : avroMap.keySet()) {
assertMatch((Record) recordMap.get(s.toString()), (GenericRecord) avroMap.get(s));
}
} else if (recordValue instanceof Record) {
assertMatch((Record) recordValue, (GenericRecord) avroValue);
} else {
assertEquals(fieldName + " not equal", recordValue, avroValue);
}
}
}
}