package water.mojo.glm; import au.com.bytecode.opencsv.CSVReader; import hex.genmodel.ModelMojoReader; import hex.genmodel.MojoModel; import hex.genmodel.MojoReaderBackend; import java.io.*; import java.util.Arrays; import java.util.zip.GZIPInputStream; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; public class GlmMojoBenchHelper { static void readData(File f, int cols, String firstColName, double[][] out, MojoModel mojo) throws IOException { int[] mapping = new int[cols]; for (int i = 0; i < cols; i++) mapping[i] = i; readData(f, mapping, firstColName, out, mojo); } static void readData(File f, int[] mapping, String firstColName, double[][] out, MojoModel mojo) throws IOException { InputStream is = new FileInputStream(f); try { InputStream source; if (f.getName().endsWith(".zip")) { ZipInputStream zis = new ZipInputStream(is); ZipEntry entry = zis.getNextEntry(); if (! entry.getName().endsWith(".csv")) throw new IllegalStateException("CSV file expected, name " + entry.getName()); source = zis; } else { source = new GZIPInputStream(is); } CSVReader r = new CSVReader(new InputStreamReader(source)); if (firstColName != null) { String[] header = r.readNext(); if (header == null) throw new IllegalStateException("File empty"); if (! firstColName.equals(header[0])) throw new IllegalStateException("Header expected"); } int rowIdx = 0; String[] row; while ((rowIdx < out.length) && ((row = r.readNext()) != null)) { double[] outRow = out[rowIdx++]; if (row.length < mapping.length) throw new IllegalStateException("Row too short: " + Arrays.toString(row)); for (int i = 0; i < mapping.length; i++) { int target = mapping[i]; if (target < 0) continue; if ("NA".equals(row[i])) { outRow[target] = Double.NaN; continue; } String[] domain = mojo.getDomainValues(target); if (domain == null) outRow[target] = Double.parseDouble(row[i]); else { outRow[target] = -1; for (int d = 0; d < domain.length; d++) if (domain[d].equals(row[i])) { outRow[target] = d; break; } if (outRow[target] < 0) throw new IllegalStateException("Value " + row[i] + " not found in domain " + Arrays.toString(domain)); } } } } finally { is.close(); } } static MojoModel loadMojo(String dir) throws IOException { return ModelMojoReader.readFrom(new ClasspathReaderBackend(dir)); } private static class ClasspathReaderBackend implements MojoReaderBackend { private final String _dir; public ClasspathReaderBackend(String dir) { _dir = dir; } @Override public BufferedReader getTextFile(String filename) throws IOException { InputStream is = GlmMojoBenchHelper.class.getResourceAsStream(_dir + "/" + filename); return new BufferedReader(new InputStreamReader(is)); } @Override public byte[] getBinaryFile(String filename) throws IOException { throw new UnsupportedOperationException("Unexpected call to getBinaryFile()"); } @Override public boolean exists(String name) { throw new UnsupportedOperationException("Unexpected call to exists()"); } } }