package com.formulasearchengine.mathosphere.mathpd; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVParser; import org.apache.commons.csv.CSVPrinter; import org.apache.commons.csv.CSVRecord; import org.apache.commons.io.FileUtils; import org.apache.flink.api.java.tuple.Tuple2; import java.io.*; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; /** * Quickly hacked tool to convert pairs with distances in rows to matrix representation * Felix Hamborg */ public class ConverterPairCSVToMatrix { private static final CSVFormat CSV_FORMAT = CSVFormat.RFC4180.withSkipHeaderRecord(); private static final boolean DEBUG = false; private static Tuple2<String, String> getDocumentIDsFromRow(CSVRecord row) { return new Tuple2<>(row.get(0), row.get(1)); } private static double getDistanceFromRow(CSVRecord row, int distanceIndex) { return Double.valueOf(row.get(2 + distanceIndex)); } private static List<Double> getOrderedRow(final HashMap<Tuple2<String, String>, Double> matrix, final List<String> orderedDimensionValues, String rowName) { final List<Double> orderedCellValues = new ArrayList<>(); for (String dimensionValue : orderedDimensionValues) { double value = matrix.getOrDefault(new Tuple2<String, String>(rowName, dimensionValue), -10000.0); orderedCellValues.add(value); } return orderedCellValues; } private static List<Object> getRowWithDescriptionInFirstCol(String rowName, List<Double> values) { List<Object> entries = new ArrayList<>(); entries.add(rowName); entries.addAll(values); return entries; } private static List<String>[] mergeKeysIntoRowsAndCols(HashMap<Tuple2<String, String>, Double> matrix) { final List<String> orderedRowValues = new ArrayList<>(); final List<String> orderedColValues = new ArrayList<>(); if (DEBUG) { try { System.out.println(new File("matrixKeySet").getAbsolutePath()); FileUtils.writeLines(new File("matrixKeySet"), matrix.keySet()); } catch (IOException e) { e.printStackTrace(); } } System.out.println("merging " + matrix.size() + " keys (matrix size in cells: " + (matrix.size() * matrix.size()) + ")"); int tmpCounter = 0; for (Tuple2<String, String> key : matrix.keySet()) { if (!orderedRowValues.contains(key.f0)) orderedRowValues.add(key.f0); if (!orderedColValues.contains(key.f1)) orderedColValues.add(key.f1); if (++tmpCounter % 100000 == 0) { System.out.println("merged " + tmpCounter + " keys (" + (tmpCounter / (float) matrix.size()) + ")"); } } // sort System.out.println("sorting rows"); Collections.sort(orderedRowValues); System.out.println("sorting columns"); Collections.sort(orderedColValues); if (DEBUG) { try { FileUtils.writeLines(new File("orderedRowValues"), orderedRowValues); FileUtils.writeLines(new File("orderedColValues"), orderedColValues); } catch (IOException e) { e.printStackTrace(); } } return new List[]{orderedRowValues, orderedColValues}; } private static void writeOrderedMatrix(HashMap<Tuple2<String, String>, Double> matrix, String filepath, List<String> orderedRowValues, List<String> orderedColValues) throws Exception { // print to disk final CSVPrinter printer = new CSVPrinter(new FileWriter(filepath), CSV_FORMAT); // write first row (header) List<String> tmpHeader = new ArrayList<>(); tmpHeader.add(""); tmpHeader.addAll(orderedColValues); printer.printRecord(tmpHeader); System.out.println("writing " + orderedRowValues.size() + " records"); for (String rowName : orderedRowValues) { printer.printRecord( getRowWithDescriptionInFirstCol( rowName, getOrderedRow(matrix, orderedColValues, rowName))); } printer.close(); } public static void main(String[] args) throws Exception { System.out.println("number of args given = " + args.length); String in = "/home/felix/170113run"; if (args.length == 0) { System.out.println("input file name? "); BufferedReader buffer = new BufferedReader(new InputStreamReader(System.in)); in = buffer.readLine().trim(); } else { in = args[0]; } final String outbase = in + "_out_"; List<String> rows = null; List<String> cols = null; for (int i = 0; i < 5; i++) { System.out.println("parsing file to CSV: " + in); final CSVParser parser = CSVParser.parse(new File(in), Charset.defaultCharset(), CSV_FORMAT); System.out.println("finished parsing"); System.out.println("creating matrix (" + i + ")"); final HashMap<Tuple2<String, String>, Double> matrix = new HashMap<>(); for (CSVRecord row : parser) { Tuple2<String, String> key = getDocumentIDsFromRow(row); if (matrix.containsKey(key)) { throw new RuntimeException("matrix already contains key: " + key); } matrix.put(key, getDistanceFromRow(row, i)); } System.out.println("finished creating matrix (" + i + ")"); if (rows == null && cols == null) { final List<String>[] rowsAndCols = mergeKeysIntoRowsAndCols(matrix); rows = rowsAndCols[0]; cols = rowsAndCols[1]; } else { System.out.println("reusing previously created rows and columns"); } System.out.println("writing matrix"); writeOrderedMatrix(matrix, outbase + i + ".csv", rows, cols); System.out.println("finished writing matrix"); // reset parser parser.close(); } } }