package LBJ2; import java.io.BufferedReader; import java.io.File; import java.io.InputStreamReader; import java.io.PrintStream; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.Map; import java.util.TreeSet; import LBJ2.IR.AST; import LBJ2.IR.ClassifierAssignment; import LBJ2.IR.CodedClassifier; import LBJ2.IR.Constant; import LBJ2.IR.ConstraintDeclaration; import LBJ2.IR.InferenceDeclaration; import LBJ2.IR.LearningClassifierExpression; import LBJ2.IR.ParameterSet; import LBJ2.learn.Accuracy; import LBJ2.learn.BatchTrainer; import LBJ2.learn.Learner; import LBJ2.learn.Lexicon; import LBJ2.learn.TestingMetric; import LBJ2.parse.ArrayFileParser; import LBJ2.parse.Parser; import LBJ2.util.ClassUtils; /** * After code has been generated with {@link TranslateToJava}, this pass * trains any classifiers for which training was indicated. * * @see LBJ2.TranslateToJava * @author Nick Rizzolo **/ public class Train extends Pass { /** <!-- stackTrace(Throwable) --> * Generates a <code>String</code> containing the name of the specified * <code>Throwable</code> and its stack trace. * * @param t <code>Throwable</code>. * @return The generated message. **/ private static String stackTrace(Throwable t) { String message = " " + t + "\n"; StackTraceElement[] elements = t.getStackTrace(); if (elements.length == 0) message += " no stack trace available\n"; for (int i = 0; i < elements.length; ++i) message += " " + elements[i] + "\n"; return message; } /** * Remembers which files have been compiled via {@link #runJavac(String)}. **/ private static final TreeSet compiledFiles = new TreeSet(); /** <!-- runJavac(String) --> * Run the <code>javac</code> compiler with the specified arguments in * addition to those specified on the command line. * * @param arguments The arguments to send to <code>javac</code>. * @return <code>true</code> iff errors were encountered. **/ public static boolean runJavac(String arguments) { String[] files = arguments.split("\\s+"); arguments = ""; for (int i = 0; i < files.length; ++i) if (compiledFiles.add(files[i])) arguments += " " + files[i]; if (arguments.length() == 0) return false; Process javac = null; String pathArguments = "-classpath " + Main.classPath + " -sourcepath " + Main.sourcePath; if (Main.generatedSourceDirectory != null) { String gsd = Main.generatedSourceDirectory; int packageIndex = -1; if (AST.globalSymbolTable.getPackage().length() != 0) packageIndex = gsd.lastIndexOf(File.separator + AST.globalSymbolTable.getPackage() .replace('.', File.separatorChar)); if (packageIndex != -1) gsd = gsd.substring(0, packageIndex); pathArguments += File.pathSeparator + gsd; } if (Main.classPackageDirectory != null) pathArguments += " -d " + Main.classPackageDirectory; String command = Configuration.javac + " " + Main.javacArguments + " " + pathArguments + arguments; try { javac = Runtime.getRuntime().exec(command); } catch (Exception e) { System.err.println("Failed to execute 'javac': " + e); System.exit(1); } BufferedReader error = new BufferedReader(new InputStreamReader(javac.getErrorStream())); try { for (String line = error.readLine(); line != null; line = error.readLine()) System.out.println(line); } catch (Exception e) { System.err.println("Error reading STDERR from 'javac': " + e); System.exit(1); } int exit = 0; try { exit = javac.waitFor(); } catch (Exception e) { System.err.println("Error waiting for 'javac' to terminate: " + e); System.exit(1); } return exit != 0; } // Member variables. /** * Progress output will be printed every <code>progressOutput</code> * examples. **/ protected int progressOutput; /** * Set to <code>true</code> iff there existed a * {@link LearningClassifierExpression} for which new code was generated. **/ protected boolean newCode; /** * An array of the training threads, which is never modified after it is * constructed. **/ protected TrainingThread[] threads; /** A map of all the training threads indexed by the name of the learner. */ protected HashMap threadMap; /** * The keys of this map are the names of learners; the values are * <code>LinkedList</code>s of the names of the learners that the learner * named by the key depends on. **/ protected HashMap learnerDependencies; // Constructor. /** * Instantiates a pass that runs on an entire {@link AST}. * * @param ast The program to run this pass on. * @param output Progress output will be printed every <code>output</code> * examples. **/ public Train(AST ast, int output) { super(ast); progressOutput = output; } // Methods related to learnerDependencies. /** <!-- addDependency(String, String) --> * Adds an edge from dependor to dependency in the * {@link #learnerDependencies} graph. If <code>dependency</code> is * <code>null</code>, no new list item is added, but the * <code>HashSet</code> associated with <code>dependor</code> is still * created if it didn't already exist. * * @param dependor The name of the node doing the depending. * @param dependency The name of the node depended on. **/ private void addDependency(String dependor, String dependency) { HashSet dependencies = (HashSet) learnerDependencies.get(dependor); if (dependencies == null) { dependencies = new HashSet(); learnerDependencies.put(dependor, dependencies); } if (dependency != null) dependencies.add(dependency); } /** <!-- fillLearnerDependorsDAG() --> * This method initializes the {@link #learnerDependencies} graph such * that the entry for each learner contains the names of all learners that * depend on it, except that cycles are broken by preferring that learners * appearing earlier in the source get trained first. **/ protected void fillLearnerDependorsDAG() { threads = (TrainingThread[]) threadMap.values().toArray(new TrainingThread[0]); Arrays.sort(threads, new Comparator() { public int compare(Object o1, Object o2) { TrainingThread t1 = (TrainingThread) o1; TrainingThread t2 = (TrainingThread) o2; return t2.byteOffset - t1.byteOffset; } }); for (int i = 0; i < threads.length - 1; ++i) for (int j = i + 1; j < threads.length; ++j) { if (SemanticAnalysis.isDependentOn(threads[i].getName(), threads[j].getName())) addDependency(threads[i].getName(), threads[j].getName()); else if (SemanticAnalysis.isDependentOn(threads[j].getName(), threads[i].getName())) addDependency(threads[j].getName(), threads[i].getName()); } } /** <!-- executeReadyThreads(String) --> * This method updates the {@link #learnerDependencies} graph by removing * the specified name from every dependencies list, and then starts every * thread that has no more dependencies. * * @param name The name of a learner whose training has completed. **/ protected void executeReadyThreads(String name) { LinkedList ready = new LinkedList(); synchronized (learnerDependencies) { for (Iterator I = learnerDependencies.entrySet().iterator(); I.hasNext(); ) { Map.Entry e = (Map.Entry) I.next(); HashSet dependencies = (HashSet) e.getValue(); dependencies.remove(name); if (dependencies.size() == 0) ready.add(e.getKey()); } } for (Iterator I = ready.iterator(); I.hasNext(); ) { TrainingThread thread = null; synchronized (threadMap) { thread = (TrainingThread) threadMap.remove(I.next()); } if (thread != null) { thread.start(); if (!Main.concurrentTraining) { try { thread.join(); } catch (InterruptedException e) { System.err.println("LBJ ERROR: Training of " + thread.getName() + " has been interrupted."); fatalError = true; } } } } } /** <!-- run(AST) --> * Runs this pass on all nodes of the indicated type. * * @param ast The node to process. **/ public void run(AST ast) { if (RevisionAnalysis.noChanges) return; threadMap = new HashMap(); learnerDependencies = new HashMap(); if (Main.fileNames.size() > 0) { String files = ""; for (Iterator I = Main.fileNames.iterator(); I.hasNext(); ) files += " " + I.next(); System.out.println("Compiling generated code"); if (runJavac(files)) return; } Main.fileNames.clear(); runOnChildren(ast); fillLearnerDependorsDAG(); executeReadyThreads(null); for (int i = 0; i < threads.length; ++i) { try { threads[i].join(); } catch (InterruptedException e) { System.err.println("LBJ ERROR: Training of " + threads[i].getName() + " has been interrupted."); fatalError = true; } } if (!fatalError && newCode) { String files = ""; for (Iterator I = Main.fileNames.iterator(); I.hasNext(); ) files += " " + I.next(); System.out.println("Compiling generated code"); compiledFiles.clear(); runJavac(files); } } /** <!-- run(LearningClassifierExpression) --> * Runs this pass on all nodes of the indicated type. * * @param lce The node to process. **/ public void run(LearningClassifierExpression lce) { runOnChildren(lce); String lceName = lce.name.toString(); if (lce.parser == null ? !RevisionAnalysis.revisionStatus.get(lceName) .equals(RevisionAnalysis.REVISED) : lce.learningStatus.equals(RevisionAnalysis.UNAFFECTED) && !lce.onlyCodeGeneration) return; newCode |= true; TrainingThread thread = new TrainingThread(lceName, lce.byteOffset, lce); threadMap.put(lceName, thread); addDependency(lceName, null); } // The following three methods are here to stop AST traversal. /** <!-- run(CodedClassifier) --> * Runs this pass on all nodes of the indicated type. There's no reason to * traverse children of {@link CodedClassifier}s, so this method exists * simply to stop that from happening. * * @param cc The node to process. **/ public void run(CodedClassifier cc) { } /** <!-- run(ConstraintDeclaration) --> * Runs this pass on all nodes of the indicated type. There's no reason to * traverse children of {@link ConstraintDeclaration}s, so this method * exists simply to stop that from happening. * * @param cd The node to process. **/ public void run(ConstraintDeclaration cd) { } /** <!-- run(InferenceDeclaration) --> * Runs this pass on all nodes of the indicated type. There's no reason to * traverse children of {@link InferenceDeclaration}s, so this method * exists simply to stop that from happening. * * @param id The node to process. **/ public void run(InferenceDeclaration id) { } /** <!-- increment(int[],int[]) --> * Helps the {@link TrainingThread#getParameterCombinations()} method * iterate through all combinations and permutations of integers such that * each integer is at least 0 and less than the corresponding element of * <code>maxes</code>. * * @param I The current array of integers. * @param maxes The maximums for each element of <code>I</code>. **/ private static boolean increment(int[] I, int[] maxes) { int i = 0; while (i < I.length && ++I[i] == maxes[i]) I[i++] = 0; return i < I.length; } /** * This class contains the code that trains a learning classifier. It is a * subclass of <code>Thread</code> so that it may be executed concurrently. * * @author Nick Rizzolo **/ protected class TrainingThread extends Thread { // Member variables. /** The byte offset at which the learner appeared. */ public int byteOffset; /** The expression that specified the learner. */ protected LearningClassifierExpression lce; /** The learning classifier being trained. */ protected Learner learner; /** The class of {@link #learner}. */ protected Class learnerClass; /** {@link #learner}'s <code>Parameters</code> class. */ protected Class parametersClass; /** The file into which training examples are extracted. */ protected String exFilePath; /** The file into which testing examples are extracted. */ protected String testExFilePath; /** The directory into which class files, model files, etc are written. */ protected String classDir; /** Whether or not example vectors should be pre-extracted. */ protected boolean preExtract; /** Whether or not pre-extracted example files should be compressed. */ protected boolean preExtractZip; /** Actually does the training. */ protected BatchTrainer trainer; /** The parser from which testing objects are obtained. */ protected Parser testParser; /** * The metric with which to measure the learner's performance on a test * set. **/ protected TestingMetric testingMetric; // Constructor. /** * Initializing constructor. * * @param n The name of the learner. * @param b The byte offset at which the learner appeared. * @param lce The expression that specified the learner. **/ public TrainingThread(String n, int b, LearningClassifierExpression lce) { super(n); byteOffset = b; this.lce = lce; if (lce.onlyCodeGeneration) return; classDir = Main.classDirectory == null ? "" : Main.classDirectory + File.separator; learner = getLearner(classDir); preExtract = lce.preExtract != null && !lce.preExtract.value.equals("false") && !lce.preExtract.value.equals("\"false\"") && !lce.preExtract.value.equals("\"none\""); boolean preExtractToDisk = preExtract && !lce.preExtract.value.startsWith("\"mem"); preExtractZip = preExtract && lce.preExtract.value.endsWith("Zip\""); if (preExtractToDisk) { exFilePath = getName() + ".ex"; testExFilePath = getName() + ".test.ex"; if (Main.generatedSourceDirectory != null) { exFilePath = Main.generatedSourceDirectory + File.separator + exFilePath; testExFilePath = Main.generatedSourceDirectory + File.separator + testExFilePath; } } Parser parser = null; if (lce.parser != null) { if (lce.featuresStatus == RevisionAnalysis.UNAFFECTED) { // Implies preExtractToDisk is true because of RevisionAnalysis; // therefore, exFilePath != null parser = new ArrayFileParser(exFilePath, lce.preExtract.value.endsWith("Zip\"")); if (lce.pruneStatus != RevisionAnalysis.UNAFFECTED) learner.readLexiconOnDemand(classDir + getName() + ".lex"); } else parser = getParser("getParser"); } if (lce.testParser != null) { if (lce.pruneStatus == RevisionAnalysis.UNAFFECTED && new File(testExFilePath).exists()) // If pruneStatus is affected, pruning will rearrange our lexicon, // so we must re-extract the test set from the original parser. In // addition, pruneStatus == UNAFFECTED implies featuresStatus == // UNAFFECTED. So, like above, as soon as we know pruneStatus == // UNAFFECTED, we know testExFilePath != null testParser = new ArrayFileParser(testExFilePath, lce.preExtract.value.endsWith("Zip\"")); else testParser = getParser("getTestParser"); } testingMetric = getTestingMetric(); if (lce.progressOutput != null) progressOutput = Integer.parseInt(lce.progressOutput.value); trainer = new BatchTrainer(learner, parser, progressOutput); } /** <!-- getLearner(String) --> * Obtain an instance of the learner appropriate for the revision status * of the source file. This method also fills in the * {@link #learnerClass} and {@link #parametersClass} fields. * * <p> If the only change between the last run of the compiler and this * run is that more training rounds were added, the entire model file can * be loaded from disk. Failing that, if features are unaffected * according to {@link RevisionAnalysis}, it means only the label lexicon * should be read. Otherwise, we just start with a fresh instance of the * learner via its static <code>getInstance()</code> method. In any * case, the learner is initialized so that it will write its model * and/or lexicon files to the specified directory as necessary. * * @param dir The directory in which the model and lexicon are written. * @return An instance of the learner. **/ private Learner getLearner(String dir) { String fullyQualified = AST.globalSymbolTable.getPackage(); if (fullyQualified.length() > 0) fullyQualified += "."; fullyQualified += getName(); learnerClass = ClassUtils.getClass(fullyQualified, true); Class[] declaredClasses = learnerClass.getDeclaredClasses(); int c = 0; while (c < declaredClasses.length && !declaredClasses[c].getName() .endsWith(getName() + "$Parameters")) ++c; if (c == declaredClasses.length) { System.err.println( "LBJ ERROR: Expected to find a single member class inside " + getName() + " named 'Parameters'."); for (int i = 0; i < declaredClasses.length; ++i) System.err.println(i + ": " + declaredClasses[i].getName()); System.exit(1); } parametersClass = declaredClasses[c]; Learner l = null; if (lce.startingRound > 1) { // In the condition above, note that before setting // lce.startingRound > 1, RevisionAnalysis ensures that the lce is // unaffected other than the number of rounds and that there will be // no parameter tuning or cross validation. l = Learner.readLearner(dir + getName() + ".lc"); l.setLexiconLocation(dir + getName() + ".lex"); } else if (lce.featuresStatus == RevisionAnalysis.UNAFFECTED) { Constructor noArg = null; try { noArg = parametersClass.getConstructor(new Class[0]); } catch (Exception e) { System.err.println( "LBJ ERROR: Can't find a no-argument constructor for " + getName() + ".Parameters."); System.exit(1); } Learner.Parameters p = null; try { p = (Learner.Parameters) noArg.newInstance(new Object[0]); } catch (Exception e) { System.err.println( "LBJ ERROR: Can't instantiate " + getName() + ".Parameters:"); e.printStackTrace(); System.exit(1); } l = Learner.readLearner(dir + getName() + ".lc", false); l.setParameters(p); l.setLexiconLocation(dir + getName() + ".lex"); } else { Method getInstance = null; try { getInstance = learnerClass.getDeclaredMethod("getInstance", new Class[0]); } catch (Exception e) { System.err.println("LBJ ERROR: Could not access method '" + fullyQualified + ".getInstance()':"); System.exit(1); } try { l = (Learner) getInstance.invoke(null, null); } catch (Exception e) { System.err.println("LBJ ERROR: Could not get unique instance of '" + fullyQualified + "': " + e); e.getCause().printStackTrace(); System.exit(1); } if (l == null) { System.err.println("LBJ ERROR: Could not get unique instance of '" + fullyQualified + "'."); System.exit(1); } l.setModelLocation(dir + getName() + ".lc"); l.setLexiconLocation(dir + getName() + ".lex"); } return l; } /** <!-- getParser(String) --> * Call the specified method of {@link #learnerClass}, and return the * <code>Parser</code> returned by that method. * * @param name The name of the method. * @return The parser returned by the named method. **/ private Parser getParser(String name) { Method m = null; try { m = learnerClass.getDeclaredMethod(name, new Class[0]); } catch (Exception e) { reportError(lce.line, "Could not access method '" + lce.name + "." + name + "()': " + e); return null; } Parser result = null; try { result = (Parser) m.invoke(null, null); } catch (Exception e) { System.err.println( "Could not instantiate parser '" + lce.parser.name + "': " + e + ", caused by"); Throwable cause = e.getCause(); System.err.print(stackTrace(cause)); if (cause instanceof ExceptionInInitializerError) { System.err.println("... caused by"); System.err.print( stackTrace(((ExceptionInInitializerError) cause).getCause())); } return null; } return result; } /** <!-- getTestingMetric() --> * Call the <code>getTestingMetric()</code> method of * {@link #learnerClass} and return the testing metric it returns. **/ private TestingMetric getTestingMetric() { TestingMetric testingMetric = null; if (lce.testingMetric != null) { Method getTestingMetric = null; try { getTestingMetric = learnerClass.getDeclaredMethod("getTestingMetric", new Class[0]); } catch (Exception e) { reportError(lce.line, "Could not access method'" + getName() + ".getTestingMetric()': " + e); return null; } try { testingMetric = (TestingMetric) getTestingMetric.invoke(null, null); } catch (Exception e) { System.err.println( "Could not instantiate testing metric '" + lce.parser.name + "': " + e + ", caused by"); System.err.print(stackTrace(e.getCause())); return null; } } else testingMetric = new Accuracy(); return testingMetric; } /** <!-- preExtractAndPrune() --> * Handles feature pre-extraction and dataset pruning under the * assumption that pre-extraction has been called for by the source code. * The two go hand-in-hand, as we only need to compute and store feature * counts during pre-extraction if we are pruning. **/ private void preExtractAndPrune() { Lexicon.PruningPolicy pruningPolicy = new Lexicon.PruningPolicy(); Lexicon.CountPolicy countPolicy = Lexicon.CountPolicy.none; if (lce.pruneCountType != null) { pruningPolicy = lce.pruneThresholdType.value.equals("\"count\"") ? new Lexicon.PruningPolicy( Integer.parseInt(lce.pruneThreshold.value)) : new Lexicon.PruningPolicy( Double.parseDouble(lce.pruneThreshold.value)); countPolicy = lce.pruneCountType.value.equals("\"global\"") ? Lexicon.CountPolicy.global : Lexicon.CountPolicy.perClass; } Learner preExtractLearner = learner; // Needed in case we're pruning. Lexicon lexicon = null; // Needed for pre-extracting the test set. // As seen below, we can always read the lexicon off disk just before // pre-extracting the test set, but if one of the operations between // now and then obtains the lexicon incidentally, we'll keep it here // to avoid reading it from disk again. if (pruningPolicy.isNone()) { if (lce.featuresStatus != RevisionAnalysis.UNAFFECTED) lexicon = trainer.preExtract(exFilePath, preExtractZip); else if (lce.pruneStatus != RevisionAnalysis.UNAFFECTED) lexicon = learner.getLexiconDiscardCounts(); else trainer.fillInSizes(); } else if (lce.featuresStatus != RevisionAnalysis.UNAFFECTED || lce.pruneStatus != RevisionAnalysis.UNAFFECTED && lce.previousPruneCountType == null) preExtractLearner = trainer.preExtract(exFilePath, preExtractZip, countPolicy); else if (lce.previousPruneCountType != null && !lce.previousPruneCountType.equals(lce.pruneCountType)) { if (lce.previousPruneCountType.value.equals("\"global\"")) // implies lce.pruneCountType.equals("\"perClass\"") preExtractLearner = trainer.preExtract(exFilePath, preExtractZip, countPolicy); else // lce.previousPruneCountType.value.equals("\"perClass\"") learner.getLexicon().perClassToGlobalCounts(); } // else pruneThresholdType or pruneThreshold may have changed, but // that does not require recounting of features. if (lce.featuresStatus == RevisionAnalysis.UNAFFECTED ? lce.pruneStatus != RevisionAnalysis.UNAFFECTED : !pruningPolicy.isNone()) { trainer.pruneDataset(exFilePath, preExtractZip, pruningPolicy, preExtractLearner); lexicon = preExtractLearner.getLexicon(); if (preExtractLearner == learner) learner.setLexicon(null); } if (testParser != null && (lce.pruneStatus != RevisionAnalysis.UNAFFECTED || !(new File(testExFilePath).exists()))) { if (lexicon == null) learner.readLexiconOnDemand(classDir + getName() + ".lex"); else { learner.setLexicon(lexicon); lexicon = null; // See comment below } BatchTrainer preExtractor = new BatchTrainer(learner, testParser, trainer.getProgressOutput(), "test set: "); preExtractor.preExtract(testExFilePath, preExtractZip, Lexicon.CountPolicy.none); testParser = preExtractor.getParser(); } // At this point, it should be the case that (lexicon == null) implies // that the lexicon is not in memory. Above, we intentionally discard // the lexicon when pre-extracting the test set, since that process will // add unwanted features (since pre-extraction always happens under the // assumption that we are training). // Given the above comment, we now ensure that when this learning classifier (ie, // the one whose feature vectors we have just pre-extracted) is // called as a feature for some other learning classifier defined in the // same sourcefile, it will be prepared take a raw example object as // input. String name = getName(); HashSet dependors = (HashSet) SemanticAnalysis.dependorGraph.get(name); if (lexicon != null && dependors.size() > 0) learner.setLexicon(lexicon); else learner.readLexiconOnDemand(classDir + name + ".lex"); } /** <!-- getParameterCombinations() --> * Uses the various {@link LBJ2.IR.ParameterSet ParameterSet}s in the AST * to generate an array of parameter combinations representing the cross * product of all {@link LBJ2.IR.ParameterSet ParameterSet}s except the * one in the {@link LearningClassifierExpression#rounds} field, if any. **/ private Learner.Parameters[] getParameterCombinations() { Class[] paramTypes = new Class[lce.parameterSets.size()]; Object[][] arguments = new Object[paramTypes.length][]; int[] lengths = new int[paramTypes.length]; int totalCombinations = 1; Iterator iterator = lce.parameterSets.iterator(); for (int i = 0; i < paramTypes.length; i++) { ParameterSet ps = (ParameterSet) iterator.next(); paramTypes[i] = ps.type.typeClass(); arguments[i] = ps.toStringArray(); lengths[i] = arguments[i].length; totalCombinations *= lengths[i]; } for (int i = 0; i < arguments.length; i++) { Class t = paramTypes[i]; if (t.isPrimitive()) { if (t.getName().equals("int")) for (int j = 0; j < lengths[i]; ++j) arguments[i][j] = new Integer((String) arguments[i][j]); else if (t.getName().equals("long")) for (int j = 0; j < lengths[i]; ++j) arguments[i][j] = new Long((String) arguments[i][j]); else if (t.getName().equals("short")) for (int j = 0; j < lengths[i]; ++j) arguments[i][j] = new Short((String) arguments[i][j]); else if (t.getName().equals("double")) for (int j = 0; j < lengths[i]; ++j) arguments[i][j] = new Double((String) arguments[i][j]); else if (t.getName().equals("float")) for (int j = 0; j < lengths[i]; ++j) arguments[i][j] = new Float((String) arguments[i][j]); else if (t.getName().equals("boolean")) for (int j = 0; j < lengths[i]; ++j) arguments[i][j] = new Boolean((String) arguments[i][j]); } } Constructor c = null; try { c = parametersClass.getConstructor(paramTypes); } catch (Exception e) { System.err.println( "LBJ ERROR: Can't find a parameter tuning constructor for " + getName() + ".Parameters."); e.printStackTrace(); System.exit(1); } Learner.Parameters[] result = new Learner.Parameters[totalCombinations]; int[] I = new int[paramTypes.length]; Object[] a = new Object[paramTypes.length]; int i = 0; do { for (int j = 0; j < a.length; ++j) a[j] = arguments[j][I[j]]; try { result[i++] = (Learner.Parameters) c.newInstance(a); } catch (Exception e) { System.err.println( "LBJ ERROR: Can't instantiate " + getName() + ".Parameters:"); e.printStackTrace(); System.exit(1); } } while (increment(I, lengths)); return result; } /** <!-- tune() --> * Determines the best parameters to use when training the learner, * under the assumption that {@link LBJ2.IR.ParameterSet}s were present. * Here, "best" means the parameters that did the best out of some small * set of particular parameter settings. **/ private Learner.Parameters tune() { Learner.Parameters[] parameterCombinations = getParameterCombinations(); int[] rounds = null; if (lce.rounds != null) { if (lce.rounds instanceof ParameterSet) rounds = ((ParameterSet) lce.rounds).toSortedIntArray(); else rounds = new int[]{ Integer.parseInt(((Constant) lce.rounds).value) }; } else rounds = new int[]{ 1 }; if (lce.K != null) { int k = Integer.parseInt(lce.K.value); double alpha = Double.parseDouble(lce.alpha.value); return trainer.tune(parameterCombinations, rounds, k, lce.splitPolicy, alpha, testingMetric); } return trainer.tune(parameterCombinations, rounds, testParser, testingMetric); } /** Performs the training and then generates the new code. */ public void run() { boolean tuningParameters = lce.parameterSets.size() > 0 || lce.rounds != null && lce.rounds instanceof ParameterSet; if (!lce.onlyCodeGeneration) { // If there's a "from" clause, train. try { if (lce.parser != null) { System.out.println("Training " + getName()); if (preExtract) { preExtractAndPrune(); System.gc(); } else learner.saveLexicon(); int trainingRounds = 1; if (tuningParameters) { String parametersPath = getName(); if (Main.classDirectory != null) parametersPath = Main.classDirectory + File.separator + parametersPath; parametersPath += ".p"; Learner.Parameters bestParameters = tune(); trainingRounds = bestParameters.rounds; Learner.writeParameters(bestParameters, parametersPath); System.out.println(" " + getName() + ": Training on entire training set"); } else { if (lce.rounds != null) trainingRounds = Integer.parseInt(((Constant) lce.rounds).value); if (lce.K != null) { int[] rounds = { trainingRounds }; int k = Integer.parseInt(lce.K.value); double alpha = Double.parseDouble(lce.alpha.value); trainer.crossValidation(rounds, k, lce.splitPolicy, alpha, testingMetric, true); System.out.println(" " + getName() + ": Training on entire training set"); } } trainer.train(lce.startingRound, trainingRounds); if (testParser != null) { System.out.println("Testing " + getName()); new Accuracy(true).test(learner, learner.getLabeler(), testParser); } System.out.println("Writing " + getName()); } else learner.saveLexicon(); // Writes .lex even if lexicon is empty. learner.save(); // Doesn't write .lex if lexicon is empty. } catch (Exception e) { System.err.println( "LBJ ERROR: Exception while training " + getName() + ":"); e.printStackTrace(); fatalError = true; return; } // Set learner's static instance field to the newly learned instance. Field field = null; try { field = learnerClass.getField("instance"); } catch (Exception e) { System.err.println("Can't access " + learnerClass + "'s 'instance' field: " + e); System.exit(1); } try { field.set(null, learner); } catch (Exception e) { System.err.println("Can't set " + learnerClass + "'s 'instance' field: " + e); System.exit(1); } } else System.out.println("Generating code for " + lce.name); // Write the new code. PrintStream out = TranslateToJava.open(lce); if (out == null) return; out.println(TranslateToJava.disclaimer); out.print("// "); TranslateToJava.compressAndPrint(lce.shallow(), out); out.println("\n"); ast.symbolTable.generateHeader(out); if (lce.cacheIn != null) { String f = lce.cacheIn.toString(); boolean cachedInMap = f.equals(ClassifierAssignment.mapCache); if (cachedInMap) out.println("import java.util.WeakHashMap;"); } out.println("\n"); if (lce.comment != null) out.println(lce.comment); out.println("\n\npublic class " + getName() + " extends " + lce.learnerName); out.println("{"); out.println(" private static java.net.URL _lcFilePath;"); out.println(" private static java.net.URL _lexFilePath;"); if (tuningParameters) out.println(" private static java.net.URL parametersPath;"); out.println(); out.println(" static"); out.println(" {"); out.println(" _lcFilePath = " + getName() + ".class.getResource(\"" + getName() + ".lc\");\n"); out.println(" if (_lcFilePath == null)"); out.println(" {"); out.println(" System.err.println(\"ERROR: Can't locate " + getName() + ".lc in the class path.\");"); out.println(" System.exit(1);"); out.println(" }\n"); out.println(" _lexFilePath = " + getName() + ".class.getResource(\"" + getName() + ".lex\");\n"); out.println(" if (_lexFilePath == null)"); out.println(" {"); out.println(" System.err.println(\"ERROR: Can't locate " + getName() + ".lex in the class path.\");"); out.println(" System.exit(1);"); out.println(" }"); if (tuningParameters) { out.println( "\n parametersPath = " + getName() + ".class.getResource(\"" + getName() + ".p\");\n"); out.println(" if (parametersPath == null)"); out.println(" {"); out.println(" System.err.println(\"ERROR: Can't locate " + getName() + ".p in the class path.\");"); out.println(" System.exit(1);"); out.println(" }"); } out.println(" }\n"); out.println(" private static void loadInstance()"); out.println(" {"); out.println(" if (instance == null)"); out.println(" {"); out.println(" instance = (" + getName() + ") Learner.readLearner(_lcFilePath);"); out.println(" instance.readLexiconOnDemand(_lexFilePath);"); out.println(" }"); out.println(" }\n"); if (tuningParameters) { out.println(" private static " + lce.learnerName + ".Parameters bestParameters;\n"); out.println(" public static " + lce.learnerName + ".Parameters getBestParameters()"); out.println(" {"); out.println(" if (bestParameters == null)"); out.println(" bestParameters = (" + lce.learnerName + ".Parameters) Learner.readParameters(parametersPath);"); out.println(" return bestParameters;"); out.println(" }\n"); } if (exFilePath != null && lce.featuresStatus != RevisionAnalysis.UNAFFECTED && new File(exFilePath).exists()) out.println( " public static Parser getParser() { return new " + "LBJ2.parse.ArrayFileParser(\"" + new File(exFilePath).getAbsolutePath() + "\"); }"); else out.println(" public static Parser getParser() { return " + lce.parser + "; }"); if (testExFilePath != null && lce.featuresStatus != RevisionAnalysis.UNAFFECTED && new File(testExFilePath).exists()) out.println( " public static Parser getTestParser() { return new " + "LBJ2.parse.ArrayFileParser(\"" + new File(testExFilePath).getAbsolutePath() + "\"); }"); else out.println(" public static Parser getTestParser() { return " + lce.testParser + "; }\n"); TranslateToJava.generateLearnerBody(out, lce); if (lce.parameterSets.size() > 0) { out.println(); out.println(" public static class Parameters extends " + lce.learnerName + ".Parameters"); out.println(" {"); out.println( " public Parameters() { super(getBestParameters()); }"); out.println(" }"); } out.println("}\n"); out.close(); executeReadyThreads(getName()); } } }