package experimental.ising; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedList; import java.util.List; import marmot.util.Numerics; public class Analyzer { private DataReader drTrain; private DataReader drDev; private List<Datum> data; private List<IsingFactorGraph> trainingFactorGraphs; private List<IsingFactorGraph> devFactorGraphs; private UnaryFeatureExtractor ufe; private double[] parameters; private double[] gradient; public Analyzer(DataReader drTrain) { this.drTrain = drTrain; this.drDev = drDev; this.data = new LinkedList<Datum>(); this.trainingFactorGraphs = new LinkedList<IsingFactorGraph>(); this.devFactorGraphs = new LinkedList<IsingFactorGraph>(); this.ufe = new UnaryFeatureExtractor(0,5); System.out.println("...num variables:\t" + drTrain.numVariables); System.out.println("...num pairs:\t" + drTrain.pairs.size()); int counter = 0; int multiple = 0; ufe.setStartFeature(0); ufe.setTotalNumVariables(drTrain.numVariables); for (Datum d : drTrain.data) { ufe.extract(d.getWord()); } System.out.println("...num parameters:\t" + ufe.getNumFeatures()); this.parameters = new double[ufe.getNumFeatures()]; this.gradient = new double[ufe.getNumFeatures()]; int datumI = 0; for (Datum d : drTrain.data) { System.out.println(counter + "\t" + d.getWord() + "\t" + d.getTag().size()); ArrayList<Integer> golden = new ArrayList<Integer>(); for (int i = 0; i < drTrain.numVariables; ++i) { golden.add(0); } for (Integer t : d.getTag()) { golden.set(t, 1); } if (d.getTag().size() > 1) multiple += 1; IsingFactorGraph fg = new IsingFactorGraph(d.getWord(), ufe, 1, drTrain.numVariables, drTrain.pairsLst, golden, drTrain.tagNames); if (datumI < 844) { this.trainingFactorGraphs.add(fg); } else { this.devFactorGraphs.add(fg); } ++counter; ++datumI; } train(50,2.0); System.out.println("...train accuracy:\t" + decode(this.trainingFactorGraphs)); System.out.println("...dev accuracy:\t" + decode(this.devFactorGraphs)); ; System.exit(0); System.out.println("...feature dump:\t"); for (IsingFactorGraph fg : this.trainingFactorGraphs) { for (UnaryFactor uf : fg.getUnaryFactor()) { System.out.println(uf.getTag()); for (Integer feat : uf.getFeaturesPositive()) { System.out.println(ufe.getInt2Feature().get(feat) + "\t" + feat); System.out.println(ufe.getInt2Feature().get(feat + 1) + "\t" + (feat + 1)); } } } System.out.println("...tag names:\t"); System.out.println(drTrain.tagNames); } public void train(int numIterations, double eta) { // train for (int i = 0; i < this.gradient.length; ++i) { this.gradient[i] = 0.0; } for (int iter = 0; iter < numIterations; ++iter) { double likelihood = 0.0; for (IsingFactorGraph ig : this.trainingFactorGraphs) { ig.updatePotentials(parameters); ig.featurizedGradient(gradient, this.trainingFactorGraphs.size()); for (int i = 0; i < this.gradient.length; ++i) { this.parameters[i] += eta * this.gradient[i] ; //- 0.001 * parameters[i]; this.gradient[i] = 0.0; } //likelihood += ig.logLikelihood(); } eta *= .9; System.out.println("...iteration:\t" + iter); System.out.println("...likelihood:\t" + likelihood); // descent } } public double decode(List<IsingFactorGraph> factorGraphs) { double correct = 0.0; int total = 0; for (IsingFactorGraph ig : factorGraphs) { System.out.println(ig); ig.updatePotentials(this.parameters); ig.inference(1, 0.01); List<String> decoded = ig.posteriorDecode(); List<String> golden = new LinkedList<String>(); int counter = 0; for (Integer g : ig.golden) { if (g == 1) { String tag = this.drTrain.integer2Tag.get(counter); golden.add(tag); } ++counter; } Collections.sort(decoded); Collections.sort(golden); if (decoded.equals(golden)) { correct += 1; } else { System.out.println("...word:\t" + ig.getWord()); System.out.println("...predicted:\t" + decoded); System.out.println("...golden:\t" + golden); } total += 1; } return correct / total; } public static void main(String[] args) { //DataReader drTrain = new ThomasReader(args[0]); //DataReader drDev = new ThomasReader(args[1]); DataReader drTrain = new MorphItReader(args[0]); new Analyzer(drTrain); } }