package bots.mctsbot.ai.opponentmodels.weka; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileWriter; import java.io.IOException; import java.io.Writer; import java.util.ArrayList; import weka.classifiers.Classifier; import weka.classifiers.trees.M5P; import weka.core.Instance; import weka.core.Instances; import weka.core.SerializationHelper; import weka.core.converters.ConverterUtils.DataSource; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; import bots.mctsbot.ai.opponentmodels.weka.instances.InstancesBuilder; public class ARFFFile { private final String nl = InstancesBuilder.nl; private final File path; private final String name; private File arffFile; private Writer file; private long count = 0; private WekaOptions config; private Instances instances; private ArrayList<Prediction> predictions; private M5P cl = null; private boolean echo = false; public ARFFFile(String path, Object player, String name, String attributes, WekaOptions config) throws Exception { // if (name.equals("PreFoldCallRaise.arff")) echo = true; String playerPath = player.toString().replace("\\", "").replace("/", ""); this.path = new File(path, playerPath); this.name = name; this.config = config; // TODO: false => !config.arffOverwrite() arffFile = new File(this.path, "/arff/" + name); if (!arffFile.getParentFile().exists()) { arffFile.getParentFile().mkdirs(); } file = new BufferedWriter(new FileWriter(arffFile, false)); file.write(attributes); file.flush(); DataSource source = new DataSource(new FileInputStream(arffFile)); instances = source.getDataSet(); // make it clean instances.delete(); predictions = new ArrayList<Prediction>(); // initiate accuracies for (int i = 0; i < MAX_DECREASE; i++) { accuracies[i] = -1; } } // private double countDataLines() { // InputStream is; // try { // is = new BufferedInputStream(new FileInputStream(path + player + name)); // byte[] c = new byte[1024]; // int count = 0; // int readChars = 0; // boolean startReading = false; // while ((readChars = is.read(c)) != -1) { // for (int i = 0; i < readChars; ++i) { // if (c[i] == '\n' && startReading) // ++count; // else if (!startReading && i >= 4 && c[i - 4] == '@' // && c[i - 3] == 'd' && c[i - 2] == 'a' // && c[i - 1] == 't' && c[i] == 'a') // startReading = true; // } // } // is.close(); // return count + (count > 0 ? -1 : 0); // } catch (FileNotFoundException e) { // e.printStackTrace(); // } catch (IOException e) { // e.printStackTrace(); // } // return 0; // } // // private boolean fileExists() throws FileNotFoundException { // return new File(path + player + name).exists(); // } public void close() throws IOException { file.close(); } public void write(Instance instance) { // System.out.println("Writing instance " + (count +1) + " in file " + name); try { count++; file.write(instance.toString() + nl); file.flush(); instances.add(instance); adjustWindow(); } catch (IOException e) { throw new IllegalStateException(e); } } public void addPrediction(Prediction p) { // if (echo) System.out.println("Adding " + p); for (int i = 0; i < instances.numInstances() - predictions.size() - 1; i++) predictions.add(null); predictions.add(p); } public double getWindowSize() { return instances.numInstances(); } public double getAccuracy() { if (predictions.isEmpty()) return 0.0; double truePositive = 0.0; double trueNegative = 0.0; double falsePositive = 0.0; double falseNegative = 0.0; for (int i = 0; i < predictions.size(); i++) { Prediction p = predictions.get(i); if (p != null) { truePositive += p.getTruePositive(); trueNegative += p.getTrueNegative(); falsePositive += p.getFalsePositive(); falseNegative += p.getFalseNegative(); } } return (trueNegative + truePositive) / (trueNegative + truePositive + falseNegative + falsePositive); } private final int MAX_DECREASE = 20; private double[] accuracies = new double[MAX_DECREASE]; private int currentDecrease = 0; private boolean decreasingAcc(double accuracy) { currentDecrease++; if (currentDecrease > MAX_DECREASE) { for (int i = 0; i < MAX_DECREASE - 1; i++) { accuracies[i] = accuracies[i + 1]; } accuracies[MAX_DECREASE - 1] = accuracy; } else accuracies[currentDecrease - 1] = accuracy; double slope = calculateLeastSquaresSlope(accuracies); return (slope < 0); } private double calculateLeastSquaresSlope(double[] accuracies) { double n = accuracies.length; double sumY = 0.0; double sumX = 0.0; double sumXY = 0.0; double sumX2 = 0.0; for (int i = 0; i < accuracies.length; i++) { if (echo) System.out.print(accuracies[i] + ", "); if (accuracies[i] != -1) { sumY += accuracies[i]; sumX += i; sumXY += i * accuracies[i]; sumX2 += i * i; } } double slope = ((n * sumXY) - (sumX * sumY)) / ((n * sumX2) - (sumX * sumX)); double intercept = (sumY - (sumX * slope)) / n; if (echo) System.out.print("slope: " + slope + ", intercept: " + intercept); if (echo) System.out.println(""); return slope; } private boolean printed = false; private void adjustWindow() { if (cl == null) return; double windowSize = instances.numInstances(); double coverage = windowSize / cl.measureNumRules(); double accuracy = getAccuracy(); boolean decreasing = decreasingAcc(accuracy); double l; if ((coverage < config.getCdLowCoverage()) || (accuracy < config.getCdAccuracy() && decreasing)) l = Math.round(0.2 * windowSize); else if (coverage > 2 * config.getCdHighCoverage() && accuracy > config.getCdAccuracy()) l = 2; else if (coverage > config.getCdHighCoverage() && accuracy > config.getCdAccuracy()) l = 1; else l = 0; if (echo && !printed) { System.out.println("L \t Accuracy \t Coverage \t Instances \t Decreasing"); printed = true; } if (echo) System.out.println(l + "\t" + accuracy + "\t" + coverage + "\t" + windowSize + "\t" + decreasing); for (int i = 0; i < l; i++) { instances.delete(0); if (!predictions.isEmpty()) predictions.remove(0); } // windowSize = windowSize - l; // System.out.println(name + ", " + windowSize + ", l: " + l + ", acc: " + accuracy + ", coverage: " + coverage); } public boolean isModelReady() { return count > config.getMinimalLearnExamples(); } public long getNrExamples() { return count; } public String getName() { return name; } public Classifier createModel(String fileName, String attribute, String[] rmAttributes) throws Exception { // System.out.println("Creating model for " + player + name); Instances data; if (config.solveConceptDrift()) data = instances; else { DataSource source = new DataSource(new FileInputStream(arffFile)); data = source.getDataSet(); } if (rmAttributes.length > 0) { String[] optionsDel = new String[2]; optionsDel[0] = "-R"; optionsDel[1] = ""; for (int i = 0; i < rmAttributes.length; i++) optionsDel[1] += (1 + data.attribute(rmAttributes[i]).index()) + ","; Remove remove = new Remove(); remove.setOptions(optionsDel); remove.setInputFormat(data); data = Filter.useFilter(data, remove); } // setting class attribute if the data format does not provide this information // E.g., the XRFF format saves the class attribute information as well if (data.classIndex() == -1) data.setClass(data.attribute(attribute)); // train M5P cl = new M5P(); cl.setBuildRegressionTree(true); cl.setUnpruned(false); cl.setUseUnsmoothed(false); // further options... cl.buildClassifier(data); // System.out.println("Number of instances: " + data.numInstances()); // System.out.println("Number of measures: " + cl.measureNumRules()); // System.out.println(cl); // save model + header // save model + header if (config.modelPersistency()) { File modelFile = new File(this.path, "/model/" + fileName + ".model"); if (!modelFile.getParentFile().exists()) { modelFile.getParentFile().mkdirs(); } SerializationHelper.write(new FileOutputStream(modelFile), cl); } return cl; } }