/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.java.io; import org.apache.flink.annotation.Internal; import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.core.fs.FileInputSplit; import org.apache.flink.core.fs.Path; import org.apache.flink.util.Preconditions; import java.io.IOException; import java.lang.reflect.Field; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @Internal public class PojoCsvInputFormat<OUT> extends CsvInputFormat<OUT> { private static final long serialVersionUID = 1L; private Class<OUT> pojoTypeClass; private String[] pojoFieldNames; private transient PojoTypeInfo<OUT> pojoTypeInfo; private transient Field[] pojoFields; public PojoCsvInputFormat(Path filePath, PojoTypeInfo<OUT> pojoTypeInfo) { this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, pojoTypeInfo); } public PojoCsvInputFormat(Path filePath, PojoTypeInfo<OUT> pojoTypeInfo, String[] fieldNames) { this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, pojoTypeInfo, fieldNames, createDefaultMask(pojoTypeInfo.getArity())); } public PojoCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, PojoTypeInfo<OUT> pojoTypeInfo) { this(filePath, lineDelimiter, fieldDelimiter, pojoTypeInfo, pojoTypeInfo.getFieldNames(), createDefaultMask(pojoTypeInfo.getArity())); } public PojoCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, PojoTypeInfo<OUT> pojoTypeInfo, String[] fieldNames) { this(filePath, lineDelimiter, fieldDelimiter, pojoTypeInfo, fieldNames, createDefaultMask(fieldNames.length)); } public PojoCsvInputFormat(Path filePath, PojoTypeInfo<OUT> pojoTypeInfo, int[] includedFieldsMask) { this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, pojoTypeInfo, pojoTypeInfo.getFieldNames(), toBooleanMask(includedFieldsMask)); } public PojoCsvInputFormat(Path filePath, PojoTypeInfo<OUT> pojoTypeInfo, String[] fieldNames, int[] includedFieldsMask) { this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, pojoTypeInfo, fieldNames, includedFieldsMask); } public PojoCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, PojoTypeInfo<OUT> pojoTypeInfo, int[] includedFieldsMask) { this(filePath, lineDelimiter, fieldDelimiter, pojoTypeInfo, pojoTypeInfo.getFieldNames(), includedFieldsMask); } public PojoCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, PojoTypeInfo<OUT> pojoTypeInfo, String[] fieldNames, int[] includedFieldsMask) { super(filePath); boolean[] mask = (includedFieldsMask == null) ? createDefaultMask(fieldNames.length) : toBooleanMask(includedFieldsMask); configure(lineDelimiter, fieldDelimiter, pojoTypeInfo, fieldNames, mask); } public PojoCsvInputFormat(Path filePath, PojoTypeInfo<OUT> pojoTypeInfo, boolean[] includedFieldsMask) { this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, pojoTypeInfo, pojoTypeInfo.getFieldNames(), includedFieldsMask); } public PojoCsvInputFormat(Path filePath, PojoTypeInfo<OUT> pojoTypeInfo, String[] fieldNames, boolean[] includedFieldsMask) { this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, pojoTypeInfo, fieldNames, includedFieldsMask); } public PojoCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, PojoTypeInfo<OUT> pojoTypeInfo, boolean[] includedFieldsMask) { this(filePath, lineDelimiter, fieldDelimiter, pojoTypeInfo, pojoTypeInfo.getFieldNames(), includedFieldsMask); } public PojoCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, PojoTypeInfo<OUT> pojoTypeInfo, String[] fieldNames, boolean[] includedFieldsMask) { super(filePath); configure(lineDelimiter, fieldDelimiter, pojoTypeInfo, fieldNames, includedFieldsMask); } private void configure(String lineDelimiter, String fieldDelimiter, PojoTypeInfo<OUT> pojoTypeInfo, String[] fieldNames, boolean[] includedFieldsMask) { if (includedFieldsMask == null) { includedFieldsMask = createDefaultMask(fieldNames.length); } for (String name : fieldNames) { if (name == null) { throw new NullPointerException("Field name must not be null."); } if (pojoTypeInfo.getFieldIndex(name) < 0) { throw new IllegalArgumentException("Field \"" + name + "\" not part of POJO type " + pojoTypeInfo.getTypeClass().getCanonicalName()); } } setDelimiter(lineDelimiter); setFieldDelimiter(fieldDelimiter); Class<?>[] classes = new Class<?>[fieldNames.length]; for (int i = 0; i < fieldNames.length; i++) { try { classes[i] = pojoTypeInfo.getTypeAt(pojoTypeInfo.getFieldIndex(fieldNames[i])).getTypeClass(); } catch (IndexOutOfBoundsException e) { throw new IllegalArgumentException("Invalid field name: " + fieldNames[i]); } } this.pojoTypeClass = pojoTypeInfo.getTypeClass(); this.pojoTypeInfo = pojoTypeInfo; setFieldsGeneric(includedFieldsMask, classes); setOrderOfPOJOFields(fieldNames); } private void setOrderOfPOJOFields(String[] fieldNames) { Preconditions.checkNotNull(fieldNames); int includedCount = 0; for (boolean isIncluded : fieldIncluded) { if (isIncluded) { includedCount++; } } Preconditions.checkArgument(includedCount == fieldNames.length, includedCount + " CSV fields and " + fieldNames.length + " POJO fields selected. The number of selected CSV and POJO fields must be equal."); for (String field : fieldNames) { Preconditions.checkNotNull(field, "The field name cannot be null."); Preconditions.checkArgument(pojoTypeInfo.getFieldIndex(field) != -1, "Field \"" + field + "\" is not a member of POJO class " + pojoTypeClass.getName()); } pojoFieldNames = Arrays.copyOfRange(fieldNames, 0, fieldNames.length); } @Override public void open(FileInputSplit split) throws IOException { super.open(split); pojoFields = new Field[pojoFieldNames.length]; Map<String, Field> allFields = new HashMap<String, Field>(); findAllFields(pojoTypeClass, allFields); for (int i = 0; i < pojoFieldNames.length; i++) { pojoFields[i] = allFields.get(pojoFieldNames[i]); if (pojoFields[i] != null) { pojoFields[i].setAccessible(true); } else { throw new RuntimeException("There is no field called \"" + pojoFieldNames[i] + "\" in " + pojoTypeClass.getName()); } } } /** * Finds all declared fields in a class and all its super classes. * * @param clazz Class for which all declared fields are found * @param allFields Map containing all found fields so far */ private void findAllFields(Class<?> clazz, Map<String, Field> allFields) { for (Field field : clazz.getDeclaredFields()) { allFields.put(field.getName(), field); } if (clazz.getSuperclass() != null) { findAllFields(clazz.getSuperclass(), allFields); } } @Override public OUT fillRecord(OUT reuse, Object[] parsedValues) { for (int i = 0; i < parsedValues.length; i++) { try { pojoFields[i].set(reuse, parsedValues[i]); } catch (IllegalAccessException e) { throw new RuntimeException("Parsed value could not be set in POJO field \"" + pojoFieldNames[i] + "\"", e); } } return reuse; } }