import org.apache.commons.cli.*; import org.apache.commons.io.FilenameUtils; import pviz.Cluster; import pviz.Clusters; import pviz.Plotviz; import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBException; import javax.xml.bind.Marshaller; import javax.xml.bind.Unmarshaller; import java.io.*; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.*; public class LabelApply { private final String fixedClassesFile; private String vectorFolder; private String pointsFolder; private String distFolder; private String originalStockFile; private String sectorFile; private boolean histogram; private String cluserInputFile; private String clusterOutputFile; private NumberFormat intgerFormaterr = new DecimalFormat("#00"); private Map<Integer, String> permNoToSymbol = new HashMap<Integer, String>(); private Map<String, Integer> sectorToClazz = new HashMap<String, Integer>(); private Map<String, String> invertedSectors = new HashMap<String, String>(); private Map<String, Integer> invertedFixedClases = new HashMap<String, Integer>(); private NumberFormat formatter = new DecimalFormat("#0.00"); public static void main(String[] args) { Options options = new Options(); options.addOption("v", true, "Input Vector folder"); // yearly vector data options.addOption("p", true, "Points folder"); // yearly mds (rotate) output options.addOption("d", true, "Destination folder"); options.addOption("o", true, "Original stock file"); // global 10 year stock file options.addOption("s", true, "Sector file"); // If Histogram true then set this as the folder to histogram output options.addOption("h", false, "Gen from histogram"); options.addOption("e", true, "Extra classes file"); // a file containing fixed classes options.addOption("ci", true, "Cluster input file"); options.addOption("co", true, "Cluster output file"); CommandLineParser commandLineParser = new BasicParser(); try { CommandLine cmd = commandLineParser.parse(options, args); String vectorFile = cmd.getOptionValue("v"); String pointsFolder = cmd.getOptionValue("p"); String distFolder = cmd.getOptionValue("d"); String originalStocks = cmd.getOptionValue("o"); String sectorFile = cmd.getOptionValue("s"); boolean histogram = cmd.hasOption("h"); String fixedClasses = cmd.getOptionValue("e"); String clusterInputFile = cmd.getOptionValue("ci"); String clusterOutputFile = cmd.getOptionValue("co"); LabelApply program = new LabelApply(vectorFile, pointsFolder, distFolder, originalStocks, sectorFile, histogram, fixedClasses, clusterInputFile, clusterOutputFile); program.process(); } catch (ParseException e) { e.printStackTrace(); } } public LabelApply(String vectorFolder, String pointsFolder, String distFolder, String originalStockFile, String sectorFile, boolean histogram, String fixedClasses, String clusterInputFile, String clusterOutputFile) { this.vectorFolder = vectorFolder; this.pointsFolder = pointsFolder; this.distFolder = distFolder; this.originalStockFile = originalStockFile; this.histogram = histogram; this.sectorFile = sectorFile; this.fixedClassesFile = fixedClasses; this.clusterOutputFile = clusterOutputFile; this.cluserInputFile = clusterInputFile; init(); } private void init() { permNoToSymbol = Utils.loadMapping(originalStockFile); Map<String, Integer> symbolToPerm = new HashMap<String, Integer>(); for (Map.Entry<Integer, String> entry : permNoToSymbol.entrySet()) { symbolToPerm.put(entry.getValue(), entry.getKey()); } } public void process() { File inFolder = new File(vectorFolder); if (!inFolder.isDirectory()) { System.out.println("In should be a folder: " + vectorFolder); return; } boolean clusterSaved = false; this.invertedFixedClases = loadFixedClasses(fixedClassesFile); if (!histogram) { Map<String, List<String>> sectors = loadStockSectors(sectorFile); sectorToClazz = convertSectorsToClazz(sectors); if (!clusterSaved) { changeClassLabels(); clusterSaved = true; } for (Map.Entry<String, Integer> entry : sectorToClazz.entrySet()) { System.out.println(entry.getKey() + " : " + entry.getValue()); } } for (File inFile : inFolder.listFiles()) { String fileName = inFile.getName(); String fileNameWithOutExt = FilenameUtils.removeExtension(fileName); if (histogram) { sectorToClazz.clear(); invertedSectors.clear(); Map<String, List<String>> sectors = loadHistoSectors(sectorFile + "/" + fileNameWithOutExt + ".csv"); sectorToClazz = convertSectorsToClazz(sectors); if (!clusterSaved) { changeClassLabels(); clusterSaved = true; } for (Map.Entry<String, Integer> entry : sectorToClazz.entrySet()) { System.out.println(entry.getKey() + " : " + entry.getValue()); } } processFile(fileNameWithOutExt); } } private Clusters loadClusters(String cluserInputFile) { Clusters clusters; FileInputStream adrFile = null; try { adrFile = new FileInputStream(cluserInputFile); JAXBContext ctx = JAXBContext.newInstance(Clusters.class); Unmarshaller um = ctx.createUnmarshaller(); clusters = (Clusters) um.unmarshal(adrFile); return clusters; } catch (FileNotFoundException | JAXBException e) { e.printStackTrace(); } finally { if (adrFile != null) { try { adrFile.close(); } catch (IOException ignore) { } } } return null; } public static void saveClusters(String outFileName, Clusters plotviz) throws FileNotFoundException, JAXBException { FileOutputStream fileOutputStream = null; try { fileOutputStream = new FileOutputStream(outFileName); JAXBContext ctx = JAXBContext.newInstance(Clusters.class); Marshaller ma = ctx.createMarshaller(); ma.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true); ma.marshal(plotviz, fileOutputStream); } finally { if (fileOutputStream != null) { try { fileOutputStream.close(); } catch (IOException ignore) { } } } } public void changeClassLabels() { System.out.println("Reading cluster file: " + cluserInputFile); Clusters clusters = loadClusters(cluserInputFile); if (clusters == null) { System.out.println("Not clusters found to change"); return; } Map<Integer, String> classToSector = new HashMap<Integer, String>(); for (Map.Entry<String, Integer> e: sectorToClazz.entrySet()) { classToSector.put(e.getValue(), e.getKey()); } for (Cluster c : clusters.getCluster()) { // find the cluster label // this is the label String key = classToSector.get(c.getKey()); if (key != null ) { System.out.println("Setting label: " + key + " to cluster: " + c.getKey() ); c.setLabel(key); } } try { System.out.println("Writing cluster file: " + clusterOutputFile); saveClusters(clusterOutputFile, clusters); } catch (FileNotFoundException | JAXBException e) { throw new RuntimeException("Failed to write clusters", e); } } private Map<String, Integer> loadFixedClasses(String file) { FileReader input; try { Map<Integer, List<String>> fixedClaszzes = new HashMap<Integer, List<String>>(); Map<String, Integer> invertedFixedClasses = new HashMap<String, Integer>(); File f = new File(file); if (!f.exists()) { System.out.println("Extra classes file doesn't exist: " + fixedClassesFile); return invertedFixedClasses; } input = new FileReader(f); BufferedReader bufRead = new BufferedReader(input); String line; while ((line = bufRead.readLine()) != null) { String parts[] = line.split(","); int clazz = Integer.parseInt(parts[0]); List<String> symbols = new ArrayList<String>(); symbols.addAll(Arrays.asList(parts).subList(1, parts.length)); fixedClaszzes.put(clazz, symbols); } for (Map.Entry<Integer, List<String>> e : fixedClaszzes.entrySet()) { for (String s : e.getValue()) { invertedFixedClasses.put(s, e.getKey()); } } return invertedFixedClasses; } catch (IOException e) { e.printStackTrace(); } return null; } private Map<String, List<String>> loadHistoSectors(String sectorFile) { FileReader input; Map<String, List<String>> sectors = new HashMap<String, List<String>>(); try { input = new FileReader(sectorFile); BufferedReader bufRead = new BufferedReader(input); String line; int i = 1; while ((line = bufRead.readLine()) != null) { Bin sectorRecord = Utils.readBin(line); List<String> stockList = sectorRecord.symbols; String startEnd = formatter.format(sectorRecord.end); String key = intgerFormaterr.format(i) + ":" + startEnd; sectors.put(key, stockList); for (String s : stockList) { invertedSectors.put(s, key); } i++; } } catch (IOException e) { throw new RuntimeException("Failed to load sector file", e); } return sectors; } private Map<String, List<String>> loadStockSectors(String sectorFile) { FileReader input; Map<String, List<String>> sectors = new HashMap<String, List<String>>(); try { input = new FileReader(sectorFile); BufferedReader bufRead = new BufferedReader(input); String line; while ((line = bufRead.readLine()) != null) { SectorRecord sectorRecord = Utils.readSectorRecord(line); List<String> stockList = sectors.get(sectorRecord.getSector()); if (stockList == null) { stockList = new ArrayList<String>(); sectors.put(sectorRecord.getSector(), stockList); } stockList.add(sectorRecord.getSymbol()); invertedSectors.put(sectorRecord.getSymbol(), sectorRecord.getSector()); } } catch (IOException e) { throw new RuntimeException("Failed to load sector file", e); } return sectors; } private Map<String, Integer> convertSectorsToClazz(Map<String, List<String>> sectors) { List<String> sectorNames = new ArrayList<String>(sectors.keySet()); Collections.sort(sectorNames); Map<String, Integer> sectorsToClazz = new HashMap<String, Integer>(); for (int i = 0; i < sectorNames.size(); i++) { sectorsToClazz.put(sectorNames.get(i), i + 1); System.out.println(sectorNames.get(i) + ": " + (i + 1)); } return sectorsToClazz; } private void processFile(String file) { String vectorFile = vectorFolder + "/" + file + ".csv"; String pointsFile = pointsFolder + "/" + file + ".txt"; String pointsOutFile = distFolder + "/" + file + ".txt"; List<String> symbols = loadSymbols(vectorFile); applyLabel(pointsFile, pointsOutFile, symbols); } private void applyLabel(String inPointsFile, String outPointsFile, List<String> symbols) { System.out.println("Applying labels for points file: " + inPointsFile); FileReader input; BufferedWriter bufWriter = null; try { FileOutputStream fos = new FileOutputStream(outPointsFile); bufWriter = new BufferedWriter(new OutputStreamWriter(fos)); File inFile = new File(inPointsFile); if (!inFile.exists()) { System.out.println("ERROR: In file doens't exist"); return; } input = new FileReader(inPointsFile); BufferedReader bufRead = new BufferedReader(input); String inputLine; int index = 0; while ((inputLine = bufRead.readLine()) != null && index < symbols.size()) { Point p = Utils.readPoint(inputLine); String symbol = symbols.get(index); int clazz = 0; if (this.invertedFixedClases.containsKey(symbol)) { clazz = this.invertedFixedClases.get(symbol); } else { // get the corresponding symbol // get the class for this one String sector = invertedSectors.get(symbol); if (sector != null) { clazz = sectorToClazz.get(sector); } else { // System.out.println("No sector: " + symbol); } } p.setClazz(clazz); String s = p.serialize(); bufWriter.write(s); bufWriter.newLine(); index++; } System.out.println("Read lines: " + index); } catch (Exception e) { throw new RuntimeException("Failed to read/write file", e); } finally { if (bufWriter != null) { try { bufWriter.close(); } catch (IOException ignore) { } } } } // load symbols for each point in file private List<String> loadSymbols(String vectorFile) { System.out.println("Loading symbols from vector file: " + vectorFile); File vf = new File(vectorFile); List<VectorPoint> vectorPoints = Utils.readVectors(vf, 0, 7000); List<String> symbols = new ArrayList<String>(); for (int i = 0; i < vectorPoints.size(); i++) { VectorPoint v = vectorPoints.get(i); String symbol = permNoToSymbol.get(v.getKey()); symbols.add(symbol); } System.out.println("No of symbols for point: " + symbols.size()); return symbols; } }