package bots.mctsbot.ai.opponentmodels.weka;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.Writer;
import weka.classifiers.Classifier;
import weka.classifiers.trees.M5P;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
import bots.mctsbot.ai.opponentmodels.weka.instances.InstancesBuilder;
public class ARFFFile {
private final String nl = InstancesBuilder.nl;
private File arffFile;
private File path;
private Writer file;
private double count;
private boolean modelReady;
public ARFFFile(String path, Object player, String name, String attributes, boolean overwrite) throws IOException {
String playerPath = player.toString().replace("\\", "").replace("/", "");
this.path = new File(path, playerPath);
arffFile = new File(this.path, "/arff/" + name);
if (!arffFile.getParentFile().exists()) {
arffFile.getParentFile().mkdirs();
}
boolean fileExists = arffFile.exists();
file = new BufferedWriter(new FileWriter(arffFile, !overwrite));
if (overwrite || !fileExists) {
file.write(attributes);
file.flush();
}
count = countDataLines();
if (count >= WekaOptions.getModelCreationTreshold())
modelReady = true;
}
private double countDataLines() throws IOException {
InputStream is = new BufferedInputStream(new FileInputStream(arffFile));
byte[] c = new byte[1024];
int count = 0;
int readChars = 0;
boolean startReading = false;
while ((readChars = is.read(c)) != -1) {
for (int i = 0; i < readChars; ++i) {
if (c[i] == '\n' && startReading)
++count;
else if (!startReading && i >= 4 && c[i - 4] == '@' && c[i - 3] == 'd' && c[i - 2] == 'a' && c[i - 1] == 't' && c[i] == 'a')
startReading = true;
}
}
is.close();
return count + (count > 0 ? -1 : 0);
}
public void close() throws IOException {
file.close();
}
public void write(Instance instance) {
try {
file.write(instance.toString() + nl);
file.flush();
count++;
if (count >= WekaOptions.getModelCreationTreshold())
modelReady = true;
} catch (IOException e) {
e.printStackTrace();
throw new IllegalStateException(e);
} catch (ArrayIndexOutOfBoundsException e) {
e.printStackTrace();
throw new IllegalStateException(e);
}
}
public boolean isModelReady() {
return modelReady;
}
public Classifier createModel(String fileName, String attribute, String[] rmAttributes) throws Exception {
if (!modelReady)
throw new IllegalStateException("Model didn't reach threshold for creation!");
DataSource source = new DataSource(new FileInputStream(arffFile));
Instances data = source.getDataSet();
if (rmAttributes.length > 0) {
String[] optionsDel = new String[2];
optionsDel[0] = "-R";
optionsDel[1] = "";
for (int i = 0; i < rmAttributes.length; i++)
optionsDel[1] += (1 + data.attribute(rmAttributes[i]).index()) + ",";
Remove remove = new Remove();
remove.setOptions(optionsDel);
remove.setInputFormat(data);
data = Filter.useFilter(data, remove);
}
// setting class attribute if the data format does not provide this information
// E.g., the XRFF format saves the class attribute information as well
if (data.classIndex() == -1)
data.setClass(data.attribute(attribute));
// train M5P
M5P cl = new M5P();
cl.setBuildRegressionTree(true);
cl.setUnpruned(false);
cl.setUseUnsmoothed(false);
// further options...
cl.buildClassifier(data);
// save model + header
if (WekaOptions.isModelPersistency()) {
File modelFile = new File(this.path, "/model/" + fileName + ".model");
if (!modelFile.getParentFile().exists()) {
modelFile.getParentFile().mkdirs();
}
SerializationHelper.write(new FileOutputStream(modelFile), cl);
}
return cl;
}
}