/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.pmml; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import ml.shifu.shifu.util.Constants; import org.dmg.pmml.DataType; import org.dmg.pmml.FieldName; import org.jpmml.evaluator.EvaluationException; import org.jpmml.evaluator.Evaluator; import org.jpmml.evaluator.EvaluatorUtil; import org.jpmml.evaluator.FieldValue; public class CsvUtil { private CsvUtil() { } static public Table readTable(File file) throws IOException { return readTable(file, null); } static public Table readTable(File file, String separator) throws IOException { Table table = new Table(); BufferedReader reader = new BufferedReader(new InputStreamReader( new FileInputStream(file), Constants.DEFAULT_CHARSET)); try { while (true) { String line = reader.readLine(); if (line == null) { break; } // End if if ((line.trim()).equals("")) { break; } // End if if (separator == null) { separator = getSeparator(line); } table.add(parseLine(line, separator)); } } finally { reader.close(); } table.setSeparator(separator); return table; } static public void writeTable(Table table, File file) throws IOException { BufferedWriter writer = new BufferedWriter(new OutputStreamWriter( new FileOutputStream(file), Constants.DEFAULT_CHARSET)); try { String terminator = ""; for (List<String> row : table) { StringBuilder sb = new StringBuilder(); sb.append(terminator); terminator = "\n"; String separator = ""; for (int i = 0; i < row.size(); i++) { sb.append(separator); separator = table.getSeparator(); sb.append(row.get(i)); } writer.write(sb.toString()); } writer.flush(); } finally { writer.close(); } } static private String getSeparator(String line) { String[] separators = { "\t", ";", "," }; for (String separator : separators) { String[] cells = line.split(separator); if (cells.length > 1) { return separator; } } throw new IllegalArgumentException(); } static public List<String> parseLine(String line, String separator) { List<String> result = new ArrayList<String>(); String[] cells = line.split(separator); for (String cell : cells) { // Remove quotation marks, if any cell = stripQuotes(cell, "\""); cell = stripQuotes(cell, "\'"); // Standardize decimal marks to Full Stop (US) if (!(",").equals(separator)) { cell = cell.replace(',', '.'); } result.add(cell); } return result; } static private String stripQuotes(String string, String quote) { if (string.startsWith(quote) && string.endsWith(quote)) { string = string.substring(quote.length(), string.length() - quote.length()); } return string; } @SuppressWarnings("unused") static public List<Map<FieldName, FieldValue>> prepareAll(Evaluator evaluator, Table table) { List<FieldName> names = new ArrayList<FieldName>(); List<FieldName> activeFields = evaluator.getActiveFields(); List<FieldName> groupFields = evaluator.getGroupFields(); header: { List<String> headerRow = table.get(0); for (int column = 0; column < headerRow.size(); column++) { FieldName field = FieldName.create(headerRow.get(column)); if (!(activeFields.contains(field) || groupFields.contains(field))) { field = null; } names.add(field); } } List<Map<FieldName, Object>> stringRows = new ArrayList<Map<FieldName, Object>>(); body: for (int row = 1; row < table.size(); row++) { List<String> bodyRow = table.get(row); Map<FieldName, Object> stringRow = new LinkedHashMap<FieldName, Object>(); for (int column = 0; column < bodyRow.size(); column++) { FieldName name = names.get(column); if (name == null) { continue; } String value = bodyRow.get(column); if (("").equals(value) || ("NA").equals(value) || ("N/A").equals(value)) { value = null; } stringRow.put(name, value); } stringRows.add(stringRow); } if (groupFields.size() == 1) { FieldName groupField = groupFields.get(0); stringRows = EvaluatorUtil.groupRows(groupField, stringRows); } else if (groupFields.size() > 1) { throw new EvaluationException(); } List<Map<FieldName, FieldValue>> fieldValueRows = new ArrayList<Map<FieldName, FieldValue>>(); for (Map<FieldName, Object> stringRow : stringRows) { Map<FieldName, FieldValue> fieldValueRow = new LinkedHashMap<FieldName, FieldValue>(); Collection<Map.Entry<FieldName, Object>> entries = stringRow.entrySet(); for (Map.Entry<FieldName, Object> entry : entries) { FieldName name = entry.getKey(); // Pre Data process: for numeric variable convert non-double // value to null. if (evaluator.getDataField(name).getDataType() == DataType.DOUBLE) { try { Double.parseDouble((String) entry.getValue()); } catch (Exception e) { entry.setValue(null); } } FieldValue value = EvaluatorUtil.prepare(evaluator, name, entry.getValue()); fieldValueRow.put(name, value); } fieldValueRows.add(fieldValueRow); } return fieldValueRows; } static public List<Map<FieldName, FieldValue>> load(Evaluator evaluator, String dataPath, String c) throws IOException { Table table = CsvUtil.readTable(new File(dataPath), c); return CsvUtil.prepareAll(evaluator, table); } static public class Table extends ArrayList<List<String>> { private static final long serialVersionUID = -3317839096636490372L; private String separator = null; public String getSeparator() { return this.separator; } public void setSeparator(String separator) { this.separator = separator; } } }