package edu.stanford.nlp.sentiment; import java.io.Serializable; public class RNNTrainOptions implements Serializable { public int batchSize = 27; /** Number of times through all the trees */ public int epochs = 400; public int debugOutputEpochs = 8; public int maxTrainTimeSeconds = 60 * 60 * 24; public double learningRate = 0.01; public double scalingForInit = 1.0; private double[] classWeights = null; /** * The classWeights can be passed in as a comma separated list of * weights using the -classWeights flag. If the classWeights are * not specified, the value is assumed to be 1.0. classWeights only * apply at train time; we do not weight the classes at all during * test time. */ public double getClassWeight(int i) { if (classWeights == null) { return 1.0; } return classWeights[i]; } /** Regularization cost for the transform matrix */ public double regTransformMatrix = 0.001; /** Regularization cost for the classification matrices */ public double regClassification = 0.0001; /** Regularization cost for the word vectors */ public double regWordVector = 0.0001; /** * The value to set the learning rate for each parameter when initializing adagrad. */ public double initialAdagradWeight = 0.0; /** * How many epochs between resets of the adagrad learning rates. * Set to 0 to never reset. */ public int adagradResetFrequency = 1; /** Regularization cost for the transform tensor */ public double regTransformTensor = 0.001; /** * Shuffle matrices when training. Usually should be true. Set to * false to compare training across different implementations, such * as with the original Matlab version */ public boolean shuffleMatrices = true; /** * If set, the initial matrices are logged to this location as a single file * using SentimentModel.toString() */ public String initialMatrixLogPath = null; public int nThreads = 1; @Override public String toString() { StringBuilder result = new StringBuilder(); result.append("TRAIN OPTIONS\n"); result.append("batchSize=" + batchSize + "\n"); result.append("epochs=" + epochs + "\n"); result.append("debugOutputEpochs=" + debugOutputEpochs + "\n"); result.append("maxTrainTimeSeconds=" + maxTrainTimeSeconds + "\n"); result.append("learningRate=" + learningRate + "\n"); result.append("scalingForInit=" + scalingForInit + "\n"); if (classWeights == null) { result.append("classWeights=null\n"); } else { result.append("classWeights="); result.append(classWeights[0]); for (int i = 1; i < classWeights.length; ++i) { result.append("," + classWeights[i]); } result.append("\n"); } result.append("regTransformMatrix=" + regTransformMatrix + "\n"); result.append("regTransformTensor=" + regTransformTensor + "\n"); result.append("regClassification=" + regClassification + "\n"); result.append("regWordVector=" + regWordVector + "\n"); result.append("initialAdagradWeight=" + initialAdagradWeight + "\n"); result.append("adagradResetFrequency=" + adagradResetFrequency + "\n"); result.append("shuffleMatrices=" + shuffleMatrices + "\n"); result.append("initialMatrixLogPath=" + initialMatrixLogPath + "\n"); result.append("nThreads=" + nThreads + "\n"); return result.toString(); } public int setOption(String[] args, int argIndex) { if (args[argIndex].equalsIgnoreCase("-batchSize")) { batchSize = Integer.parseInt(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-epochs")) { epochs = Integer.parseInt(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-debugOutputEpochs")) { debugOutputEpochs = Integer.parseInt(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-maxTrainTimeSeconds")) { maxTrainTimeSeconds = Integer.parseInt(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-learningRate")) { learningRate = Double.parseDouble(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-scalingForInit")) { scalingForInit = Double.parseDouble(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-regTransformMatrix")) { regTransformMatrix = Double.parseDouble(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-regTransformTensor")) { regTransformTensor = Double.parseDouble(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-regClassification")) { regClassification = Double.parseDouble(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-regWordVector")) { regWordVector = Double.parseDouble(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-initialAdagradWeight")) { initialAdagradWeight = Double.parseDouble(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-adagradResetFrequency")) { adagradResetFrequency = Integer.parseInt(args[argIndex + 1]); return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-classWeights")) { String classWeightString = args[argIndex + 1]; String[] pieces = classWeightString.split(","); classWeights = new double[pieces.length]; for (int i = 0; i < pieces.length; ++i) { classWeights[i] = Double.parseDouble(pieces[i]); } return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-shuffleMatrices")) { shuffleMatrices = true; return argIndex + 1; } else if (args[argIndex].equalsIgnoreCase("-noShuffleMatrices")) { shuffleMatrices = false; return argIndex + 1; } else if (args[argIndex].equalsIgnoreCase("-initialMatrixLogPath")) { initialMatrixLogPath = args[argIndex + 1]; return argIndex + 2; } else if (args[argIndex].equalsIgnoreCase("-nThreads") || args[argIndex].equalsIgnoreCase("-numThreads")) { nThreads = Integer.parseInt(args[argIndex + 1]); return argIndex + 2; } else { return argIndex; } } private static final long serialVersionUID = 1; }