package org.maltparser.ml.lib;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.LinkedHashMap;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.feature.FeatureVector;
import org.maltparser.core.helper.NoPrintStream;
import org.maltparser.core.helper.Util;
import org.maltparser.parser.guide.instance.InstanceModel;
public class LibLinear extends Lib {
public LibLinear(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
super(owner, learnerMode, "liblinear");
if (learnerMode == CLASSIFY) {
try {
ObjectInputStream input = new ObjectInputStream(getInputStreamFromConfigFileEntry(".moo"));
try {
model = (MaltLibModel)input.readObject();
} finally {
input.close();
}
} catch (ClassNotFoundException e) {
throw new LibException("Couldn't load the liblinear model", e);
} catch (Exception e) {
throw new LibException("Couldn't load the liblinear model", e);
}
}
}
protected void trainInternal(FeatureVector featureVector) throws MaltChainedException {
try {
if (configLogger.isInfoEnabled()) {
configLogger.info("Creating Liblinear model "+getFile(".moo").getName()+"\n");
}
Problem problem = readProblem(getInstanceInputStreamReader(".ins"));
final PrintStream out = System.out;
final PrintStream err = System.err;
System.setOut(NoPrintStream.NO_PRINTSTREAM);
System.setErr(NoPrintStream.NO_PRINTSTREAM);
Parameter parameter = getLiblinearParameters();
Model model = Linear.train(problem, parameter);
System.setOut(err);
System.setOut(out);
// System.out.println(" model.getNrFeature():" + model.getNrFeature());
// System.out.println(" model.getFeatureWeights().length:" + model.getFeatureWeights().length);
if (configLogger.isInfoEnabled()) {
configLogger.info("Optimize memory usage for the Liblinear model "+getFile(".moo").getName()+"\n");
}
double[][] wmatrix = convert(model.getFeatureWeights(), model.getNrClass(), model.getNrFeature());
MaltLiblinearModel xmodel = new MaltLiblinearModel(model.getLabels(), model.getNrClass(), wmatrix.length, wmatrix, parameter.getSolverType());
if (configLogger.isInfoEnabled()) {
configLogger.info("Save the Liblinear model "+getFile(".moo").getName()+"\n");
}
ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
try{
output.writeObject(xmodel);
} finally {
output.close();
}
if (!saveInstanceFiles) {
getFile(".ins").delete();
}
} catch (OutOfMemoryError e) {
throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
} catch (IllegalArgumentException e) {
throw new LibException("The Liblinear learner was not able to redirect Standard Error stream. ", e);
} catch (SecurityException e) {
throw new LibException("The Liblinear learner cannot remove the instance file. ", e);
} catch (IOException e) {
throw new LibException("The Liblinear learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
}
}
private double[][] convert(double[] w, int nr_class, int nr_feature) {
double[][] wmatrix = new double[nr_feature][];
double[] wsignature = new double[nr_feature];
boolean reuse = false;
int ne = 0;
// int nr = 0;
// int no = 0;
// int n = 0;
Long[] reverseMap = featureMap.reverseMap();
for (int i = 0; i < nr_feature; i++) {
reuse = false;
int k = nr_class;
for (int t = i * nr_class; (t + (k - 1)) >= t; k--) {
if (w[t + k - 1] != 0.0) {
break;
}
}
double[] copy = new double[k];
System.arraycopy(w, i * nr_class, copy, 0,k);
if (eliminate(copy)) {
ne++;
featureMap.removeIndex(reverseMap[i + 1]);
reverseMap[i + 1] = null;
wmatrix[i] = null;
} else {
featureMap.setIndex(reverseMap[i + 1], i + 1 - ne);
for (int j=0; j<copy.length; j++) wsignature[i] += copy[j];
for (int j = 0; j < i; j++) {
if (wsignature[j] == wsignature[i]) {
if (Util.equals(copy, wmatrix[j])) {
wmatrix[i] = wmatrix[j];
reuse = true;
// nr++;
break;
}
}
}
if (reuse == false) {
// no++;
wmatrix[i] = copy;
}
}
// n++;
}
featureMap.setFeatureCounter(featureMap.getFeatureCounter()- ne);
double[][] wmatrix_reduced = new double[nr_feature-ne][];
for (int i = 0, j = 0; i < wmatrix.length; i++) {
if (wmatrix[i] != null) {
wmatrix_reduced[j++] = wmatrix[i];
}
}
// System.out.println("NE:"+ne);
// System.out.println("NR:"+nr);
// System.out.println("NO:"+no);
// System.out.println("N :"+n);
return wmatrix_reduced;
}
public static boolean eliminate(double[] a) {
if (a.length == 0) {
return true;
}
for (int i = 1; i < a.length; i++) {
if (a[i] != a[i-1]) {
return false;
}
}
return true;
}
protected void trainExternal(FeatureVector featureVector) throws MaltChainedException {
try {
if (configLogger.isInfoEnabled()) {
owner.getGuide().getConfiguration().getConfigLogger().info("Creating liblinear model (external) "+getFile(".mod").getName());
}
binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
final String[] params = getLibParamStringArray();
String[] arrayCommands = new String[params.length+3];
int i = 0;
arrayCommands[i++] = pathExternalTrain;
for (; i <= params.length; i++) {
arrayCommands[i] = params[i-1];
}
arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
arrayCommands[i++] = getFile(".mod").getAbsolutePath();
if (verbosity == Verbostity.ALL) {
owner.getGuide().getConfiguration().getConfigLogger().info('\n');
}
final Process child = Runtime.getRuntime().exec(arrayCommands);
final InputStream in = child.getInputStream();
final InputStream err = child.getErrorStream();
int c;
while ((c = in.read()) != -1){
if (verbosity == Verbostity.ALL) {
owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
}
}
while ((c = err.read()) != -1){
if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
owner.getGuide().getConfiguration().getConfigLogger().info((char)c);
}
}
if (child.waitFor() != 0) {
owner.getGuide().getConfiguration().getConfigLogger().info(" FAILED ("+child.exitValue()+")");
}
in.close();
err.close();
if (configLogger.isInfoEnabled()) {
configLogger.info("\nSaving Liblinear model "+getFile(".moo").getName()+"\n");
}
MaltLiblinearModel xmodel = new MaltLiblinearModel(getFile(".mod"));
ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
try{
output.writeObject(xmodel);
} finally {
output.close();
}
if (!saveInstanceFiles) {
getFile(".ins").delete();
getFile(".mod").delete();
getFile(".ins.tmp").delete();
}
if (configLogger.isInfoEnabled()) {
configLogger.info('\n');
}
} catch (InterruptedException e) {
throw new LibException("Learner is interrupted. ", e);
} catch (IllegalArgumentException e) {
throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
} catch (SecurityException e) {
throw new LibException("The learner cannot remove the instance file. ", e);
} catch (IOException e) {
throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
} catch (OutOfMemoryError e) {
throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
}
}
public void terminate() throws MaltChainedException {
super.terminate();
}
public void initLibOptions() {
libOptions = new LinkedHashMap<String, String>();
libOptions.put("s", "4"); // type = SolverType.L2LOSS_SVM_DUAL (default)
libOptions.put("c", "0.1"); // cost = 1 (default)
libOptions.put("e", "0.1"); // epsilon = 0.1 (default)
libOptions.put("B", "-1"); // bias = -1 (default)
}
public void initAllowedLibOptionFlags() {
allowedLibOptionFlags = "sceB";
}
private Problem readProblem(InputStreamReader isr) throws MaltChainedException {
Problem problem = new Problem();
final FeatureList featureList = new FeatureList();
try {
final BufferedReader fp = new BufferedReader(isr);
problem.bias = -1;
problem.l = getNumberOfInstances();
problem.x = new FeatureNode[problem.l][];
problem.y = new int[problem.l];
int i = 0;
while(true) {
String line = fp.readLine();
if(line == null) break;
int y = binariesInstance(line, featureList);
if (y == -1) {
continue;
}
try {
problem.y[i] = y;
problem.x[i] = new FeatureNode[featureList.size()];
int p = 0;
for (int k=0; k < featureList.size(); k++) {
MaltFeatureNode x = featureList.get(k);
problem.x[i][p++] = new FeatureNode(x.getIndex(), x.getValue());
}
i++;
} catch (ArrayIndexOutOfBoundsException e) {
throw new LibException("Couldn't read liblinear problem from the instance file. ", e);
}
}
fp.close();
problem.n = featureMap.size();
} catch (IOException e) {
throw new LibException("Cannot read from the instance file. ", e);
}
return problem;
}
private Parameter getLiblinearParameters() throws MaltChainedException {
Parameter param = new Parameter(SolverType.MCSVM_CS, 0.1, 0.1);
String type = libOptions.get("s");
if (type.equals("0")) {
param.setSolverType(SolverType.L2R_LR);
} else if (type.equals("1")) {
param.setSolverType(SolverType.L2R_L2LOSS_SVC_DUAL);
} else if (type.equals("2")) {
param.setSolverType(SolverType.L2R_L2LOSS_SVC);
} else if (type.equals("3")) {
param.setSolverType(SolverType.L2R_L1LOSS_SVC_DUAL);
} else if (type.equals("4")) {
param.setSolverType(SolverType.MCSVM_CS);
} else if (type.equals("5")) {
param.setSolverType(SolverType.L1R_L2LOSS_SVC);
} else if (type.equals("6")) {
param.setSolverType(SolverType.L1R_LR);
} else if (type.equals("7")) {
param.setSolverType(SolverType.L2R_LR_DUAL);
} else {
throw new LibException("The liblinear type (-s) is not an integer value between 0 and 4. ");
}
try {
param.setC(Double.valueOf(libOptions.get("c")).doubleValue());
} catch (NumberFormatException e) {
throw new LibException("The liblinear cost (-c) value is not numerical value. ", e);
}
try {
param.setEps(Double.valueOf(libOptions.get("e")).doubleValue());
} catch (NumberFormatException e) {
throw new LibException("The liblinear epsilon (-e) value is not numerical value. ", e);
}
return param;
}
}