/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.io;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.io.ParseException;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Row;
import org.apache.flink.types.parser.FieldParser;
import java.util.Arrays;
@PublicEvolving
public class RowCsvInputFormat extends CsvInputFormat<Row> implements ResultTypeQueryable<Row> {
private static final long serialVersionUID = 1L;
private int arity;
private TypeInformation[] fieldTypeInfos;
private int[] fieldPosMap;
private boolean emptyColumnAsNull;
public RowCsvInputFormat(Path filePath, TypeInformation[] fieldTypeInfos, String lineDelimiter, String fieldDelimiter, int[] selectedFields, boolean emptyColumnAsNull) {
super(filePath);
this.arity = fieldTypeInfos.length;
if (arity == 0) {
throw new IllegalArgumentException("At least one field must be specified");
}
if (arity != selectedFields.length) {
throw new IllegalArgumentException("Number of field types and selected fields must be the same");
}
this.fieldTypeInfos = fieldTypeInfos;
this.fieldPosMap = toFieldPosMap(selectedFields);
this.emptyColumnAsNull = emptyColumnAsNull;
boolean[] fieldsMask = toFieldMask(selectedFields);
setDelimiter(lineDelimiter);
setFieldDelimiter(fieldDelimiter);
setFieldsGeneric(fieldsMask, extractTypeClasses(fieldTypeInfos));
}
public RowCsvInputFormat(Path filePath, TypeInformation[] fieldTypes, String lineDelimiter, String fieldDelimiter, int[] selectedFields) {
this(filePath, fieldTypes, lineDelimiter, fieldDelimiter, selectedFields, false);
}
public RowCsvInputFormat(Path filePath, TypeInformation[] fieldTypes, String lineDelimiter, String fieldDelimiter) {
this(filePath, fieldTypes, lineDelimiter, fieldDelimiter, sequentialScanOrder(fieldTypes.length));
}
public RowCsvInputFormat(Path filePath, TypeInformation[] fieldTypes, int[] selectedFields) {
this(filePath, fieldTypes, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, selectedFields);
}
public RowCsvInputFormat(Path filePath, TypeInformation[] fieldTypes, boolean emptyColumnAsNull) {
this(filePath, fieldTypes, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, sequentialScanOrder(fieldTypes.length), emptyColumnAsNull);
}
public RowCsvInputFormat(Path filePath, TypeInformation[] fieldTypes) {
this(filePath, fieldTypes, false);
}
private static Class<?>[] extractTypeClasses(TypeInformation[] fieldTypes) {
Class<?>[] classes = new Class<?>[fieldTypes.length];
for (int i = 0; i < fieldTypes.length; i++) {
classes[i] = fieldTypes[i].getTypeClass();
}
return classes;
}
private static int[] sequentialScanOrder(int arity) {
int[] sequentialOrder = new int[arity];
for (int i = 0; i < arity; i++) {
sequentialOrder[i] = i;
}
return sequentialOrder;
}
private static boolean[] toFieldMask(int[] selectedFields) {
int maxField = 0;
for (int selectedField : selectedFields) {
maxField = Math.max(maxField, selectedField);
}
boolean[] mask = new boolean[maxField + 1];
Arrays.fill(mask, false);
for (int selectedField : selectedFields) {
mask[selectedField] = true;
}
return mask;
}
private static int[] toFieldPosMap(int[] selectedFields) {
int[] fieldIdxs = Arrays.copyOf(selectedFields, selectedFields.length);
Arrays.sort(fieldIdxs);
int[] fieldPosMap = new int[selectedFields.length];
for (int i = 0; i < selectedFields.length; i++) {
int pos = Arrays.binarySearch(fieldIdxs, selectedFields[i]);
fieldPosMap[pos] = i;
}
return fieldPosMap;
}
@Override
protected Row fillRecord(Row reuse, Object[] parsedValues) {
Row reuseRow;
if (reuse == null) {
reuseRow = new Row(arity);
} else {
reuseRow = reuse;
}
for (int i = 0; i < parsedValues.length; i++) {
reuseRow.setField(i, parsedValues[i]);
}
return reuseRow;
}
@Override
protected boolean parseRecord(Object[] holders, byte[] bytes, int offset, int numBytes) throws ParseException {
byte[] fieldDelimiter = this.getFieldDelimiter();
boolean[] fieldIncluded = this.fieldIncluded;
int startPos = offset;
int limit = offset + numBytes;
int field = 0;
int output = 0;
while (field < fieldIncluded.length) {
// check valid start position
if (startPos > limit || (startPos == limit && field != fieldIncluded.length - 1)) {
if (isLenient()) {
return false;
} else {
throw new ParseException("Row too short: " + new String(bytes, offset, numBytes, getCharset()));
}
}
if (fieldIncluded[field]) {
// parse field
FieldParser<Object> parser = (FieldParser<Object>) this.getFieldParsers()[fieldPosMap[output]];
int latestValidPos = startPos;
startPos = parser.resetErrorStateAndParse(
bytes,
startPos,
limit,
fieldDelimiter,
holders[fieldPosMap[output]]);
if (!isLenient() && (parser.getErrorState() != FieldParser.ParseErrorState.NONE)) {
// the error state EMPTY_COLUMN is ignored
if (parser.getErrorState() != FieldParser.ParseErrorState.EMPTY_COLUMN) {
throw new ParseException(String.format("Parsing error for column %1$s of row '%2$s' originated by %3$s: %4$s.",
field, new String(bytes, offset, numBytes), parser.getClass().getSimpleName(), parser.getErrorState()));
}
}
holders[fieldPosMap[output]] = parser.getLastResult();
// check parse result:
// the result is null if it is invalid
// or empty with emptyColumnAsNull enabled
if (startPos < 0 ||
(emptyColumnAsNull && (parser.getErrorState().equals(FieldParser.ParseErrorState.EMPTY_COLUMN)))) {
holders[fieldPosMap[output]] = null;
startPos = skipFields(bytes, latestValidPos, limit, fieldDelimiter);
}
output++;
} else {
// skip field
startPos = skipFields(bytes, startPos, limit, fieldDelimiter);
}
// check if something went wrong
if (startPos < 0) {
throw new ParseException(String.format("Unexpected parser position for column %1$s of row '%2$s'",
field, new String(bytes, offset, numBytes)));
}
else if (startPos == limit
&& field != fieldIncluded.length - 1
&& !FieldParser.endsWithDelimiter(bytes, startPos - 1, fieldDelimiter)) {
// We are at the end of the record, but not all fields have been read
// and the end is not a field delimiter indicating an empty last field.
if (isLenient()) {
return false;
} else {
throw new ParseException("Row too short: " + new String(bytes, offset, numBytes));
}
}
field++;
}
return true;
}
@Override
public TypeInformation<Row> getProducedType() {
return new RowTypeInfo(this.fieldTypeInfos);
}
}