package edu.stanford.nlp.classify; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.util.Index; import java.util.Collection; import java.util.List; import java.util.Random; /** * @author Galen Andrew * @author Sarah Spikes (sdspikes@cs.stanford.edu) (Templatization) */ public class WeightedDataset<L, F> extends Dataset<L, F> { private static final long serialVersionUID = -5435125789127705430L; protected float[] weights; public WeightedDataset(Index<L> labelIndex, int[] labels, Index<F> featureIndex, int[][] data, int size, float[] weights) { super(labelIndex, labels, featureIndex, data, size); this.weights = weights; } public WeightedDataset() { this(10); } public WeightedDataset(int initSize) { super(initSize); weights = new float[initSize]; } private float[] trimToSize(float[] i) { float[] newI = new float[size]; synchronized (System.class) { System.arraycopy(i, 0, newI, 0, size); } return newI; } public float[] getWeights() { weights = trimToSize(weights); return weights; } @Override public float[] getFeatureCounts() { float[] counts = new float[featureIndex.size()]; for (int i = 0, m = size; i < m; i++) { for (int j = 0, n = data[i].length; j < n; j++) { counts[data[i][j]] += weights[i]; } } return counts; } @Override public void add(Datum<L, F> d) { add(d, 1.0f); } @Override public void add(Collection<F> features, L label) { add(features, label, 1.0f); } public void add(Datum<L, F> d, float weight) { add(d.asFeatures(), d.label(), weight); } @Override protected void ensureSize() { super.ensureSize(); if (weights.length == size) { float[] newWeights = new float[size * 2]; synchronized (System.class) { System.arraycopy(weights, 0, newWeights, 0, size); } weights = newWeights; } } public void add(Collection<F> features, L label, float weight) { ensureSize(); addLabel(label); addFeatures(features); weights[size++] = weight; } /** * Set the weight of datum i. * @param i The index of the datum to change the weight of. * @param weight The weight to set */ public void setWeight(int i, float weight) { weights[i] = weight; } /** * Randomizes (shuffles) the data array in place. * Needs to be redefined here because we need to randomize the weights as well. */ @Override public void randomize(long randomSeed) { Random rand = new Random(randomSeed); for(int j = size - 1; j > 0; j --){ int randIndex = rand.nextInt(j); int [] tmp = data[randIndex]; data[randIndex] = data[j]; data[j] = tmp; int tmpL = labels[randIndex]; labels[randIndex] = labels[j]; labels[j] = tmpL; float tmpW = weights[randIndex]; weights[randIndex] = weights[j]; weights[j] = tmpW; } } /** * Randomizes (shuffles) the data array in place. * Needs to be redefined here because we need to randomize the weights as well. */ @Override public <E> void shuffleWithSideInformation(long randomSeed, List<E> sideInformation) { if (size != sideInformation.size()) { throw new IllegalArgumentException("shuffleWithSideInformation: sideInformation not of same size as Dataset"); } Random rand = new Random(randomSeed); for(int j = size - 1; j > 0; j --){ int randIndex = rand.nextInt(j); int [] tmp = data[randIndex]; data[randIndex] = data[j]; data[j] = tmp; int tmpL = labels[randIndex]; labels[randIndex] = labels[j]; labels[j] = tmpL; float tmpW = weights[randIndex]; weights[randIndex] = weights[j]; weights[j] = tmpW; E tmpE = sideInformation.get(randIndex); sideInformation.set(randIndex, sideInformation.get(j)); sideInformation.set(j, tmpE); } } }