package edu.stanford.nlp.classify;
import java.util.Arrays;
import java.util.Collection;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Index;
/**
* A central place for utility functions used when training robust logistic models.
* @author jtibs
*/
public class LogisticUtils {
public static int[][] identityMatrix(int n) {
int[][] result = new int[n][1];
for (int i = 0; i < n; i++)
result[i][0] = i;
return result;
}
public static double[] flatten(double[][] input) {
int length = 0;
for (double[] array : input)
length += array.length;
double[] result = new double[length];
int count = 0;
for (double[] array : input) {
for (double value : array)
result[count++] = value;
}
return result;
}
public static void unflatten(double[] input, double[][] output) {
int count = 0;
for (int i = 0; i < output.length; i++) {
for (int j = 0; j < output[i].length; j++) {
output[i][j] = input[count++];
}
}
}
public static double dotProduct(double[] array, int[] indices, double[] values) {
double result = 0;
for (int i = 0; i < indices.length; i++) {
if (indices[i] == -1) continue;
result += array[indices[i]] * values[i];
}
return result;
}
public static double[][] initializeDataValues(int[][] data) {
double[][] result = new double[data.length][];
for (int i = 0; i < data.length; i++) {
result[i] = new double[data[i].length];
Arrays.fill(result[i], 1.0);
}
return result;
}
public static <T> int[] indicesOf(Collection<T> input, Index<T> index) {
int[] result = new int[input.size()];
int count = 0;
for (T element : input)
result[count++] = index.indexOf(element);
return result;
}
public static double[] convertToArray(Collection<Double> input) {
double[] result = new double[input.size()];
int count = 0;
for (double d : input) {
result[count++] = d;
}
return result;
}
public static double[] calculateSums(double[][] weights, int[] featureIndices,
double[] featureValues) {
int numClasses = weights.length + 1;
double[] result = new double[numClasses];
result[0] = 0.0;
for (int c = 1; c < numClasses; c++) {
result[c] = -dotProduct(weights[c - 1], featureIndices, featureValues);
}
double total = ArrayMath.logSum(result);
for (int c = 0; c < numClasses; c++) {
result[c] -= total;
}
return result;
}
public static double[] calculateSums(double[][] weights, int[] featureIndices,
double[] featureValues, double[] intercepts) {
int numClasses = weights.length + 1;
double[] result = new double[numClasses];
result[0] = 0.0;
for (int c = 1; c < numClasses; c++) {
result[c] = -dotProduct(weights[c - 1], featureIndices, featureValues) - intercepts[c - 1];
}
double total = ArrayMath.logSum(result);
for (int c = 0; c < numClasses; c++) {
result[c] -= total;
}
return result;
}
public static double[] calculateSigmoids(double[][] weights, int[] featureIndices,
double[] featureValues) {
return ArrayMath.exp(calculateSums(weights, featureIndices, featureValues));
}
public static double getValue(double[][] weights, LogPrior prior) {
double[] flatWeights = flatten(weights);
return prior.compute(flatWeights, new double[flatWeights.length]);
}
public static int sample(double[] sigmoids) {
double probability = Math.random();
System.out.println("sigmoids: " + Arrays.toString(sigmoids));
System.out.println("probability: " + probability);
double offset = 0.0;
for (int c = 0; c < sigmoids.length; c++) {
if (probability - offset <= sigmoids[c])
return c;
offset += sigmoids[c];
}
return sigmoids.length - 1; // should never be reached
}
public static void prettyPrint(double[][] gammas, double[][] thetas, double[][] zprobs) {
prettyPrint("GAMMAS", gammas);
prettyPrint("THETAS", thetas);
prettyPrint("ZPROBS", zprobs);
}
public static void prettyPrint(String name, double[][] matrix) {
prettyPrint(name, matrix, matrix.length);
}
public static void prettyPrint(String name, double[][] matrix, int maxCount) {
System.out.println(name + ": ");
for (double[] array : matrix) {
System.out.println(Arrays.toString(array));
if (maxCount-- < 0) break;
}
System.out.println();
}
}