package LBJ2.learn; import java.io.PrintStream; import java.util.Collection; import java.util.Iterator; import LBJ2.classify.Classifier; import LBJ2.classify.DiscretePrimitiveStringFeature; import LBJ2.classify.Feature; import LBJ2.classify.FeatureVector; import LBJ2.classify.ScoreSet; import LBJ2.util.ExceptionlessInputStream; import LBJ2.util.ExceptionlessOutputStream; import LBJ2.util.OVector; /** * A <code>SparseNetworkLearner</code> uses multiple * {@link LinearThresholdUnit}s to make a multi-class classification. * Any {@link LinearThresholdUnit} may be used, so long as it implements its * <code>clone()</code> method and a public constructor that takes no * arguments. * * <p> It is assumed that a single discrete label feature will be produced in * association with each example object. A feature taking one of the values * observed in that label feature will be produced by the learned classifier. * * <p> This algorithm's user-configurable parameters are stored in member * fields of this class. They may be set via either a constructor that names * each parameter explicitly or a constructor that takes an instance of * {@link LBJ2.learn.SparseNetworkLearner.Parameters Parameters} as input. * The documentation in each member field in this class indicates the default * value of the associated parameter when using the former type of * constructor. The documentation of the associated member field in the * {@link LBJ2.learn.SparseNetworkLearner.Parameters Parameters} class * indicates the default value of the parameter when using the latter type of * constructor. * * @author Nick Rizzolo **/ public class SparseNetworkLearner extends Learner { /** Default for {@link #baseLTU}. */ public static final LinearThresholdUnit defaultBaseLTU = new SparseAveragedPerceptron(); /** * The underlying algorithm used to learn each class separately as a binary * classifier; default {@link #defaultBaseLTU}. **/ protected LinearThresholdUnit baseLTU; /** * A collection of the linear threshold units used to learn each label, * indexed by the label. **/ protected OVector network; /** The total number of examples in the training data, or 0 if unknown. */ protected int numExamples; /** * The total number of distinct features in the training data, or 0 if * unknown. **/ protected int numFeatures; /** Whether or not this learner's labeler produces conjunctive features. */ protected boolean conjunctiveLabels; /** * Instantiates this multi-class learner with the default learning * algorithm: {@link #defaultBaseLTU}. **/ public SparseNetworkLearner() { this(""); } /** * Instantiates this multi-class learner using the specified algorithm to * learn each class separately as a binary classifier. This constructor * will normally only be called by the compiler. * * @param ltu The linear threshold unit used to learn binary classifiers. **/ public SparseNetworkLearner(LinearThresholdUnit ltu) { this("", ltu); } /** * Initializing constructor. Sets all member variables to their associated * settings in the {@link SparseNetworkLearner.Parameters} object. * * @param p The settings of all parameters. **/ public SparseNetworkLearner(Parameters p) { this("", p); } /** * Instantiates this multi-class learner with the default learning * algorithm: {@link #defaultBaseLTU}. * * @param n The name of the classifier. **/ public SparseNetworkLearner(String n) { this(n, new Parameters()); } /** * Instantiates this multi-class learner using the specified algorithm to * learn each class separately as a binary classifier. * * @param n The name of the classifier. * @param ltu The linear threshold unit used to learn binary classifiers. **/ public SparseNetworkLearner(String n, LinearThresholdUnit ltu) { super(n); Parameters p = new Parameters(); p.baseLTU = ltu; setParameters(p); network = new OVector(); } /** * Initializing constructor. Sets all member variables to their associated * settings in the {@link SparseNetworkLearner.Parameters} object. * * @param n The name of the classifier. * @param p The settings of all parameters. **/ public SparseNetworkLearner(String n, Parameters p) { super(n); setParameters(p); network = new OVector(); } /** * Sets the values of parameters that control the behavior of this learning * algorithm. * * @param p The parameters. **/ public void setParameters(Parameters p) { if (!p.baseLTU.getOutputType().equals("discrete")) { System.err.println( "LBJ WARNING: SparseNetworkLearner will only work with a " + "LinearThresholdUnit that returns discrete."); System.err.println( " The given LTU, " + p.baseLTU.getClass().getName() + ", returns " + p.baseLTU.getOutputType() + "."); } setLTU(p.baseLTU); } /** * Retrieves the parameters that are set in this learner. * * @return An object containing all the values of the parameters that * control the behavior of this learning algorithm. **/ public Learner.Parameters getParameters() { Parameters p = new Parameters(super.getParameters()); p.baseLTU = baseLTU; return p; } /** * Sets the <code>baseLTU</code> variable. This method will <i>not</i> * have any effect on the LTUs that already exist in the network. However, * new LTUs created after this method is executed will be of the same type * as the object specified. * * @param ltu The new LTU. **/ public void setLTU(LinearThresholdUnit ltu) { baseLTU = ltu; baseLTU.name = name + "$baseLTU"; } /** * Sets the labeler. * * @param l A labeling classifier. **/ public void setLabeler(Classifier l) { if (getClass().getName().indexOf("SparseNetworkLearner") != -1 && !l.getOutputType().equals("discrete")) { System.err.println( "LBJ WARNING: SparseNetworkLearner will only work with a " + "label classifier that returns discrete."); System.err.println( " The given label classifier, " + l.getClass().getName() + ", returns " + l.getOutputType() + "."); } super.setLabeler(l); } /** * Sets the extractor. * * @param e A feature extracting classifier. **/ public void setExtractor(Classifier e) { super.setExtractor(e); baseLTU.setExtractor(e); int N = network.size(); for (int i = 0; i < N; ++i) ((LinearThresholdUnit) network.get(i)).setExtractor(e); } /** * Each example is treated as a positive example for the linear threshold * unit associated with the label's value that is active for the example * and as a negative example for all other linear threshold units in the * network. * * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @param exampleLabels The example's label(s). * @param labelValues The labels' values. **/ public void learn(int[] exampleFeatures, double[] exampleValues, int[] exampleLabels, double[] labelValues) { int label = exampleLabels[0]; int N = network.size(); if (label >= N || network.get(label) == null) { conjunctiveLabels |= labelLexicon.lookupKey(label).isConjunctive(); LinearThresholdUnit ltu = (LinearThresholdUnit) baseLTU.clone(); ltu.initialize(numExamples, numFeatures); network.set(label, ltu); N = label + 1; } int[] l = new int[1]; for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null) continue; l[0] = (i == label) ? 1 : 0; ltu.learn(exampleFeatures, exampleValues, l, labelValues); } } /** Simply calls <code>doneLearning()</code> on every LTU in the network. */ public void doneLearning() { super.doneLearning(); int N = network.size(); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null) continue; ltu.doneLearning(); } } /** Sets the number of examples and features. */ public void initialize(int ne, int nf) { numExamples = ne; numFeatures = nf; } /** Simply calls {@link LinearThresholdUnit#doneWithRound()} on every LTU in the network. */ public void doneWithRound() { super.doneWithRound(); int N = network.size(); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null) continue; ltu.doneWithRound(); } } /** Clears the network. */ public void forget() { super.forget(); network = new OVector(); } /** * Returns scores for only those labels in the given collection. If the * given collection is empty, scores for all labels will be returned. If * there is no {@link LinearThresholdUnit} associated with a given label * from the collection, that label's score in the returned {@link ScoreSet} * will be set to <code>Double.NEGATIVE_INFINITY</code>. * * <p> The elements of <code>candidates</code> must all be * <code>String</code>s. * * @param example The example object. * @param candidates A list of the only labels the example may take. * @return Scores for only those labels in <code>candidates</code>. **/ public ScoreSet scores(Object example, Collection candidates) { Object[] exampleArray = getExampleArray(example, false); return scores((int[]) exampleArray[0], (double[]) exampleArray[1], candidates); } /** * Returns scores for only those labels in the given collection. If the * given collection is empty, scores for all labels will be returned. If * there is no {@link LinearThresholdUnit} associated with a given label * from the collection, that label's score in the returned {@link ScoreSet} * will be set to <code>Double.NEGATIVE_INFINITY</code>. * * <p> The elements of <code>candidates</code> must all be * <code>String</code>s. * * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @param candidates A list of the only labels the example may take. * @return Scores for only those labels in <code>candidates</code>. **/ public ScoreSet scores(int[] exampleFeatures, double[] exampleValues, Collection candidates) { ScoreSet result = new ScoreSet(); Iterator I = candidates.iterator(); if (I.hasNext()) { if (conjunctiveLabels) return conjunctiveScores(exampleFeatures, exampleValues, I); while (I.hasNext()) { String label = (String) I.next(); Feature f = new DiscretePrimitiveStringFeature( labeler.containingPackage, labeler.name, "", label, labeler.valueIndexOf(label), (short) labeler.allowableValues().length); if (labelLexicon.contains(f)) { int key = labelLexicon.lookup(f); LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(key); if (ltu != null) result.put(label.toString(), ltu.score(exampleFeatures, exampleValues) - ltu.getThreshold()); } } } else { int N = network.size(); for (int l = 0; l < N; ++l) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(l); if (ltu == null) continue; result.put(labelLexicon.lookupKey(l).getStringValue(), ltu.score(exampleFeatures, exampleValues) - ltu.getThreshold()); } } return result; } /** * This method is a surrogate for * {@link #scores(int[],double[],Collection)} when the labeler is known to * produce conjunctive features. It is necessary because when given a * string label from the collection, we will not know how to construct the * appropriate conjunctive feature key for lookup in the label lexicon. * So, we must go through each feature in the label lexicon and use * {@link LBJ2.classify.Feature#valueEquals(String)}. * * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @param I An iterator over the set of labels to choose * from. * @return The label chosen by this classifier or <code>null</code> if the * network did not contain any of the specified labels. **/ protected ScoreSet conjunctiveScores(int[] exampleFeatures, double[] exampleValues, Iterator I) { ScoreSet result = new ScoreSet(); int N = network.size(); while (I.hasNext()) { String label = (String) I.next(); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null || !labelLexicon.lookupKey(i).valueEquals(label)) continue; double score = ltu.score(exampleFeatures, exampleValues); result.put(label.toString(), score); break; } } return result; } /** * Produces a set of scores indicating the degree to which each possible * discrete classification value is associated with the given example * object. These scores are just the scores of each LTU's positive * classification as produced by * <code>LinearThresholdUnit.scores(Object)</code>. * * @see LinearThresholdUnit#scores(Object) * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @return The set of scores produced by the LTUs **/ public ScoreSet scores(int[] exampleFeatures, double[] exampleValues) { ScoreSet result = new ScoreSet(); int N = network.size(); for (int l = 0; l < N; l++) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(l); if (ltu == null) continue; result.put(labelLexicon.lookupKey(l).getStringValue(), ltu.score(exampleFeatures, exampleValues) - ltu.getThreshold()); } return result; } /** * Returns the classification of the given example as a single feature * instead of a {@link FeatureVector}. * * @param f The features array. * @param v The values array. * @return The classification of the example as a feature. **/ public Feature featureValue(int[] f, double[] v) { double bestScore = Double.NEGATIVE_INFINITY; int bestValue = -1; int N = network.size(); for (int l = 0; l < N; l++) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(l); if (ltu == null) continue; double score = ltu.score(f, v); if (score > bestScore) { bestValue = l; bestScore = score; } } return bestValue == -1 ? null : predictions.get(bestValue); } /** * This implementation uses a winner-take-all comparison of the outputs * from the individual linear threshold units' score methods. * * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @return A single value with the winning linear threshold unit's * associated value. **/ public String discreteValue(int[] exampleFeatures, double[] exampleValues) { return featureValue(exampleFeatures, exampleValues).getStringValue(); } /** * This implementation uses a winner-take-all comparison of the outputs * from the individual linear threshold units' score methods. * * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @return A single feature with the winning linear threshold unit's * associated value. **/ public FeatureVector classify(int[] exampleFeatures, double[] exampleValues) { return new FeatureVector(featureValue(exampleFeatures, exampleValues)); } /** * Using this method, the winner-take-all competition is narrowed to * involve only those labels contained in the specified list. The list * must contain only <code>String</code>s. * * @param example The example object. * @param candidates A list of the only labels the example may take. * @return The prediction as a feature or <code>null</code> if the network * did not contain any of the specified labels. **/ public Feature valueOf(Object example, Collection candidates) { Object[] exampleArray = getExampleArray(example, false); return valueOf((int[]) exampleArray[0], (double[]) exampleArray[1], candidates); } /** * Using this method, the winner-take-all competition is narrowed to * involve only those labels contained in the specified list. The list * must contain only <code>String</code>s. * * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @param candidates A list of the only labels the example may take. * @return The prediction as a feature or <code>null</code> if the network * did not contain any of the specified labels. **/ public Feature valueOf(int[] exampleFeatures, double[] exampleValues, Collection candidates) { double bestScore = Double.NEGATIVE_INFINITY; int bestValue = -1; Iterator cI = candidates.iterator(); if (cI.hasNext()) { if (conjunctiveLabels) return conjunctiveValueOf(exampleFeatures, exampleValues, cI); while (cI.hasNext()) { double score = Double.NEGATIVE_INFINITY; String label = (String) cI.next(); Feature f = new DiscretePrimitiveStringFeature( labeler.containingPackage, labeler.name, "", label, labeler.valueIndexOf(label), (short) labeler.allowableValues().length); int key = -1; if (labelLexicon.contains(f)) { key = labelLexicon.lookup(f); LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(key); if (ltu != null) score = ltu.score(exampleFeatures, exampleValues); } if (score > bestScore) { bestValue = key; bestScore = score; } } } else { int N = network.size(); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null) continue; double score = ltu.score(exampleFeatures, exampleValues); if (score > bestScore) { bestValue = i; bestScore = score; } } } return predictions.get(bestValue); } /** * This method is a surrogate for * {@link #valueOf(int[],double[],Collection)} when the labeler is known to * produce conjunctive features. It is necessary because when given a * string label from the collection, we will not know how to construct the * appropriate conjunctive feature key for lookup in the label lexicon. * So, we must go through each feature in the label lexicon and use * {@link LBJ2.classify.Feature#valueEquals(String)}. * * @param exampleFeatures The example's array of feature indices. * @param exampleValues The example's array of feature values. * @param I An iterator over the set of labels to choose * from. * @return The label chosen by this classifier or <code>null</code> if the * network did not contain any of the specified labels. **/ protected Feature conjunctiveValueOf( int[] exampleFeatures, double[] exampleValues, Iterator I) { double bestScore = Double.NEGATIVE_INFINITY; int bestValue = -1; int N = network.size(); while (I.hasNext()) { String label = (String) I.next(); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null || !labelLexicon.lookupKey(i).valueEquals(label)) continue; double score = ltu.score(exampleFeatures, exampleValues); if (score > bestScore) { bestScore = score; bestValue = i; } break; } } return predictions.get(bestValue); } /** * Writes the algorithm's internal representation as text. * * @param out The output stream. **/ public void write(PrintStream out) { out.println(baseLTU.getClass().getName()); baseLTU.write(out); int N = network.size(); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null) continue; out.println("label: " + labelLexicon.lookupKey(i).getStringValue()); ltu.setLexicon(lexicon); ltu.write(out); ltu.setLexicon(null); } out.println("End of SparseNetworkLearner"); } /** * Writes the learned function's internal representation in binary form. * * @param out The output stream. **/ public void write(ExceptionlessOutputStream out) { super.write(out); baseLTU.write(out); out.writeBoolean(conjunctiveLabels); int N = network.size(); out.writeInt(N); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null) out.writeString(null); else ltu.write(out); } } /** * Reads the binary representation of a learner with this object's run-time * type, overwriting any and all learned or manually specified parameters * as well as the label lexicon but without modifying the feature lexicon. * * @param in The input stream. **/ public void read(ExceptionlessInputStream in) { super.read(in); baseLTU = (LinearThresholdUnit) Learner.readLearner(in); conjunctiveLabels = in.readBoolean(); int N = in.readInt(); network = new OVector(N); for (int i = 0; i < N; ++i) network.add(Learner.readLearner(in)); } /** Returns a deep clone of this learning algorithm. */ public Object clone() { SparseNetworkLearner clone = null; try { clone = (SparseNetworkLearner) super.clone(); } catch (Exception e) { System.err.println("Error cloning SparseNetworkLearner: " + e); e.printStackTrace(); System.exit(1); } clone.baseLTU = (LinearThresholdUnit) baseLTU.clone(); int N = network.size(); clone.network = new OVector(N); for (int i = 0; i < N; ++i) { LinearThresholdUnit ltu = (LinearThresholdUnit) network.get(i); if (ltu == null) clone.network.add(null); else clone.network.add(ltu.clone()); } return clone; } /** * Simply a container for all of {@link SparseNetworkLearner}'s * configurable parameters. Using instances of this class should make code * more readable and constructors less complicated. * * @author Nick Rizzolo **/ public static class Parameters extends Learner.Parameters { /** * The underlying algorithm used to learn each class separately as a * binary classifier; default * {@link SparseNetworkLearner#defaultBaseLTU}. **/ public LinearThresholdUnit baseLTU; /** Sets all the default values. */ public Parameters() { baseLTU = (LinearThresholdUnit) defaultBaseLTU.clone(); } /** * Sets the parameters from the parent's parameters object, giving * defaults to all parameters declared in this object. **/ public Parameters(Learner.Parameters p) { super(p); baseLTU = (LinearThresholdUnit) defaultBaseLTU.clone(); } /** Copy constructor. */ public Parameters(Parameters p) { super(p); baseLTU = p.baseLTU; } /** * Calls the appropriate <code>Learner.setParameters(Parameters)</code> * method for this <code>Parameters</code> object. * * @param l The learner whose parameters will be set. **/ public void setParameters(Learner l) { ((SparseNetworkLearner) l).setParameters(this); } /** * Creates a string representation of these parameters in which only * those parameters that differ from their default values are mentioned. **/ public String nonDefaultString() { String name = baseLTU.getClass().getName(); name = name.substring(name.lastIndexOf('.') + 1); return name + ": " + baseLTU.getParameters().nonDefaultString(); } } }