package com.antbrains.crf;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.io.compress.CompressionCodecFactory;
import com.google.gson.Gson;
import de.ruedigermoeller.serialization.FSTObjectInput;
import de.ruedigermoeller.serialization.FSTObjectOutput;
/*
*
* java port of crfsuite's sgd crf.
* see https://github.com/chokkan/crfsuite/blob/143be5187863091c25de98e0349947ee2f98cd4f/lib/crf/src/train_pegasos.c
*
SGD for L2-regularized MAP estimation.
The iterative algorithm is inspired by Pegasos:
Shai Shalev-Shwartz, Yoram Singer, and Nathan Srebro.
Pegasos: Primal Estimated sub-GrAdient SOlver for SVM.
In Proc. of ICML 2007, pp 807-814, 2007.
The calibration strategy is inspired by the implementation of sgd:
http://leon.bottou.org/projects/sgd
written by Léon Bottou.
The objective function to minimize is:
f(w) = (lambda/2) * ||w||^2 + (1/N) * \sum_i^N log P^i(y|x)
lambda = 2 * C / N
The original version of the Pegasos algorithm.
0) Initialization
t = t0
k = [the batch size]
1) Computing the learning rate (eta).
eta = 1 / (lambda * t)
2) Updating feature weights.
w = (1 - eta * lambda) w - (eta / k) \sum_i (oexp - mexp)
3) Projecting feature weights within an L2-ball.
w = min{1, (1/sqrt(lambda))/||w||} * w
4) Goto 1 until convergence.
A naive implementation requires O(K) computations for steps 2 and 3,
where K is the total number of features. This code implements the procedure
in an efficient way:
0) Initialization
norm2 = 0
decay = 1
proj = 1
1) Computing various factors
eta = 1 / (lambda * t)
decay *= (1 - eta * lambda)
scale = decay * proj
gain = (eta / k) / scale
2) Updating feature weights
Updating feature weights from observation expectation:
delta = gain * (1.0) * f(x,y)
norm2 += delta * (delta + w + w);
w += delta
Updating feature weights from model expectation:
delta = gain * (-P(y|x)) * f(x,y)
norm2 += delta * (delta + w + w);
w += delta
3) Projecting feature weights within an L2-ball
If 1.0 / lambda < norm2 * scale * scale:
proj = 1.0 / (sqrt(norm2 * lambda) * scale)
4) Goto 1 until convergence.
*/
public class SgdCrf {
private static void initWeights(TrainingWeights model) {
Arrays.fill(model.getBosTransitionWeights(), 0);
Arrays.fill(model.getEosTransitionWeights(), 0);
Arrays.fill(model.getTransitionWeights(), 0);
Arrays.fill(model.getAttributeWeights(), 0);
}
private static double[] computeStateScores(Instance instance, boolean exp, int labelNum,
double[] attributeWeights) {
int itemNum = instance.length();
int rowSize = instance.rowSize();
int[] attrIds = instance.getAttrIds();
double[] stateScores = new double[itemNum * labelNum];
for (int itemIndex = 0; itemIndex < itemNum; itemIndex++) {
for (int i = 0; i < rowSize; i++) {
int attributeIndex = attrIds[itemIndex * rowSize + i];
if (attributeIndex >= 0) {
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
stateScores[itemIndex * labelNum + labelIndex] += attributeWeights[attributeIndex
* labelNum + labelIndex];
}
}
}
}
if (exp) {
for (int itemIndex = 0; itemIndex < itemNum; itemIndex++) {
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
double score = stateScores[itemIndex * labelNum + labelIndex];
stateScores[itemIndex * labelNum + labelIndex] = Math.exp(score);
}
}
}
return stateScores;
}
private static double[][] computeForwardScores(Instance instance,
double[] expBosTransitionWeights, double[] expEosTransitionWeights,
double[][] expTransitionWeights, double[] expStatesScores, double[] scaleFactors, int labelNum) {
int itemNum = instance.length();
double[][] forwardScores = new double[itemNum][];
for (int itemIndex = 0; itemIndex < itemNum; itemIndex++) {
forwardScores[itemIndex] = new double[labelNum];
}
double sum = 0;
double[] previousScores;
double[] curScores = forwardScores[0];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
sum += curScores[labelIndex] = expBosTransitionWeights[labelIndex]
* expStatesScores[labelIndex];
}
scaleFactors[0] = sum != 0 ? 1.0 / sum : 1;
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
curScores[labelIndex] *= scaleFactors[0];
}
for (int itemIndex = 1; itemIndex < itemNum; itemIndex++) {
sum = 0;
previousScores = forwardScores[itemIndex - 1];
curScores = forwardScores[itemIndex];
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
double score = 0;
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
score += previousScores[fromLabelIndex]
* expTransitionWeights[fromLabelIndex][toLabelIndex];
}
sum += curScores[toLabelIndex] = score
* expStatesScores[itemIndex * labelNum + toLabelIndex];
}
scaleFactors[itemIndex] = sum != 0 ? 1.0 / sum : 1;
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
curScores[labelIndex] *= scaleFactors[itemIndex];
}
}
sum = 0;
curScores = forwardScores[itemNum - 1];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
sum += curScores[labelIndex] * expEosTransitionWeights[labelIndex];
}
scaleFactors[itemNum] = sum != 0 ? 1.0 / sum : 1;
return forwardScores;
}
private static double computeLogProb(Instance instance, double[] statesScores,
double[][] forwardScores, double[][] backwardScores, double[] scaleFactors, double logNorm,
double[] transitionWeights, int labelNum) {
int itemNum = instance.length();
int[] labels = instance.labelIds();
int lastLabel = labels[0], label;
double[] curScores = forwardScores[0];
double instanceLogProb = Math.log(curScores[lastLabel]) - Math.log(scaleFactors[0]);
for (int itemIndex = 1; itemIndex < itemNum; itemIndex++) {
label = labels[itemIndex];
instanceLogProb += transitionWeights[lastLabel * labelNum + label];
instanceLogProb += Math.log(statesScores[itemIndex * labelNum + label]);
lastLabel = label;
}
curScores = backwardScores[itemNum - 1];
instanceLogProb += Math.log(curScores[lastLabel]) - Math.log(scaleFactors[itemNum - 1]);
instanceLogProb -= logNorm;
return instanceLogProb;
}
private static double computeInitialLoglikelihood(TrainingDataSet dataSet, int samplesNum,
double lambda, double[] attributeWeights, double[] transitionWeights) {
int labelNum = dataSet.getLabelNum();
double[] expBosTransitionWeights = new double[labelNum];
double[] expEosTransitionWeights = new double[labelNum];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
expBosTransitionWeights[labelIndex] = 1;
expEosTransitionWeights[labelIndex] = 1;
}
double[][] expTransitionWeights = new double[labelNum][];
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
expTransitionWeights[fromLabelIndex] = new double[labelNum];
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
expTransitionWeights[fromLabelIndex][toLabelIndex] = 1;
}
}
double logp = 0;
List<Instance> instances = dataSet.getInstances();
for (int i = 0; i < samplesNum; i++) {
Instance instance = instances.get(i);
double[] stateScores = computeStateScores(instance, true, labelNum, attributeWeights);
double[] scaleFactors = new double[instance.length() + 1];
double[][] forwardScores = computeForwardScores(instance, expBosTransitionWeights,
expEosTransitionWeights, expTransitionWeights, stateScores, scaleFactors, labelNum);
double logNorm = 0;
for (int itemIndex = 0; itemIndex <= instance.length(); itemIndex++) {
logNorm -= Math.log(scaleFactors[itemIndex]);
}
double[][] backwardScores = computeBackwardScores(instance, expEosTransitionWeights,
expTransitionWeights, stateScores, scaleFactors, labelNum);
logp += computeLogProb(instance, stateScores, forwardScores, backwardScores, scaleFactors,
logNorm, transitionWeights, labelNum);
}
double norm2 = 0;
for (int attributeIndex = 0; attributeIndex < dataSet.getAttributeNum(); attributeIndex++) {
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
double weight = attributeWeights[attributeIndex * labelNum + labelIndex];
norm2 += weight * weight;
}
}
return logp - 0.5 * lambda * norm2 * samplesNum;
}
private static double[][] computeBackwardScores(Instance instance,
double[] expEosTransitionWeights, double[][] expTransitionScores, double[] expStatesScores,
double[] scaleFactors, int labelNum) {
int itemNum = instance.length();
double[][] backwardScores = new double[itemNum][];
for (int itemIndex = 0; itemIndex < itemNum; itemIndex++) {
backwardScores[itemIndex] = new double[labelNum];
}
double[] curScores = backwardScores[itemNum - 1];
double scale = scaleFactors[itemNum - 1];
double[] nextScores, tranScores;
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
curScores[labelIndex] = expEosTransitionWeights[labelIndex] * scale;
}
for (int itemIndex = itemNum - 2; itemIndex >= 0; itemIndex--) {
curScores = backwardScores[itemIndex];
nextScores = backwardScores[itemIndex + 1];
scale = scaleFactors[itemIndex];
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
double score = 0;
tranScores = expTransitionScores[fromLabelIndex];
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
score += tranScores[toLabelIndex]
* expStatesScores[(itemIndex + 1) * labelNum + toLabelIndex]
* nextScores[toLabelIndex];
}
curScores[fromLabelIndex] = score * scale;
}
}
return backwardScores;
}
private static double calibrateSgd(List<Instance> trainInstances, int seqNum, double t0,
double lambda, int labelNum, double[] bosTransitionWeights, double[] eosTransitionWeights,
double[] transitionWeights, double[] attributeWeights) {
int t = 0;
double decay = 1.0, proj = 1.0;
double scale = 0;
double logp = 0;
double[] expBosTransitionWeights = new double[labelNum];
double[] expEosTransitionWeights = new double[labelNum];
double[][] expTransitionWeights = new double[labelNum][];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
expTransitionWeights[labelIndex] = new double[labelNum];
}
for (int i = 0; i < seqNum; i++) {
Instance instances = trainInstances.get(i);
int itemNum = instances.length();
double eta = 1.0 / (lambda * (t0 + t));
decay *= (1.0 - eta * lambda);
scale = decay * proj;
double gain = eta / scale;
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
expBosTransitionWeights[labelIndex] = Math.exp(bosTransitionWeights[labelIndex]);
expEosTransitionWeights[labelIndex] = Math.exp(eosTransitionWeights[labelIndex]);
}
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
expTransitionWeights[fromLabelIndex][toLabelIndex] = Math
.exp(transitionWeights[fromLabelIndex * labelNum + toLabelIndex]);
}
}
double[] statesScores = computeStateScores(instances, true, labelNum, attributeWeights);
double[] scaleFactors = new double[itemNum + 1];
double[][] forwardScores = computeForwardScores(instances, expBosTransitionWeights,
expEosTransitionWeights, expTransitionWeights, statesScores, scaleFactors, labelNum);
double[][] backwardScores = computeBackwardScores(instances, expEosTransitionWeights,
expTransitionWeights, statesScores, scaleFactors, labelNum);
double logNorm = 0;
for (int itemIndex = 0; itemIndex <= itemNum; itemIndex++) {
logNorm -= Math.log(scaleFactors[itemIndex]);
}
logp += computeLogProb(instances, statesScores, forwardScores, backwardScores, scaleFactors,
logNorm, transitionWeights, labelNum);
updateFeatureWeights(instances, expTransitionWeights, statesScores, forwardScores,
backwardScores, scaleFactors, gain, labelNum, attributeWeights, bosTransitionWeights,
eosTransitionWeights, transitionWeights);
++t;
}
return logp;
}
private static double updateWeight(double[] weights, int index, double diff) {
double oldWeight = weights[index];
weights[index] = oldWeight + diff;
return diff * (diff + oldWeight * 2);
}
private static double updateWeight(double[] weights, int index, double gain, double prob,
boolean isTarget) {
double norm2diff = updateWeight(weights, index, -gain * prob);
if (isTarget) {
norm2diff += updateWeight(weights, index, gain);
}
return norm2diff;
}
private static double updateFeatureWeights(Instance instance, double[][] expTransitionScores,
double[] statesScores, double[][] forwardScores, double[][] backwardScores,
double[] scaleFactors, double gain, int labelNum, double[] attributeWeights,
double[] bosTransitionWeights, double[] eosTransitionWeights, double[] transitionWeights) {
int itemNum = instance.length();
int rowSize = instance.rowSize();
int[] attrIds = instance.getAttrIds();
int[] labelIds = instance.labelIds();
double norm2diff = 0;
double[] fwd = forwardScores[0];
double[] bwd = backwardScores[0];
double coeff = scaleFactors[itemNum] / scaleFactors[0];
double[] probs = new double[labelNum];
for (int i = 0; i < labelNum; i++) {
probs[i] = fwd[i] * bwd[i] * coeff;
}
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
norm2diff += updateWeight(bosTransitionWeights, labelIndex, gain, probs[labelIndex],
labelIndex == labelIds[0]);
}
for (int i = 0; i < rowSize; i++) {
int attributeIndex = attrIds[i];
if (attributeIndex < 0) {
continue;
}
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
norm2diff += updateWeight(attributeWeights, attributeIndex * labelNum + labelIndex, gain,
probs[labelIndex], labelIndex == labelIds[0]);
}
}
fwd = forwardScores[itemNum - 1];
bwd = backwardScores[itemNum - 1];
coeff = scaleFactors[itemNum] / scaleFactors[itemNum - 1];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
probs[labelIndex] = fwd[labelIndex] * bwd[labelIndex] * coeff;
}
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
norm2diff += updateWeight(eosTransitionWeights, labelIndex, gain, probs[labelIndex],
labelIndex == labelIds[itemNum - 1]);
}
for (int i = (itemNum - 1) * rowSize; i < attrIds.length; i++) {
int attributeIndex = attrIds[i];
if (attributeIndex < 0) {
continue;
}
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
norm2diff += updateWeight(attributeWeights, attributeIndex * labelNum + labelIndex, gain,
probs[labelIndex], labelIndex == labelIds[itemNum - 1]);
}
}
for (int itemIndex = 1; itemIndex < itemNum - 1; itemIndex++) {
fwd = forwardScores[itemIndex];
bwd = backwardScores[itemIndex];
coeff = scaleFactors[itemNum] / scaleFactors[itemIndex];
for (int i = 0; i < labelNum; i++) {
probs[i] = -1;
}
for (int i = 0; i < rowSize; i++) {
int attributeIndex = attrIds[itemIndex * rowSize + i];
if (attributeIndex < 0) {
continue;
}
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
if (probs[labelIndex] == -1) {
probs[labelIndex] = fwd[labelIndex] * bwd[labelIndex] * coeff;
}
norm2diff += updateWeight(attributeWeights, attributeIndex * labelNum + labelIndex, gain,
probs[labelIndex], labelIndex == labelIds[itemIndex]);
}
}
}
for (int itemIndex = 0; itemIndex < itemNum - 1; itemIndex++) {
fwd = forwardScores[itemIndex];
bwd = backwardScores[itemIndex + 1];
coeff = scaleFactors[itemNum];
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
double[] edge = expTransitionScores[fromLabelIndex];
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
norm2diff += updateWeight(transitionWeights, fromLabelIndex * labelNum + toLabelIndex,
gain, fwd[fromLabelIndex] * edge[toLabelIndex]
* statesScores[(itemIndex + 1) * labelNum + toLabelIndex] * bwd[toLabelIndex]
* coeff, fromLabelIndex == labelIds[itemIndex]
&& toLabelIndex == labelIds[itemIndex + 1]);
}
}
}
return norm2diff;
}
private static double calibrate(TrainingDataSet dataSet, double lambda, TrainingParams param,
TrainingWeights weights) {
List<Instance> instances = dataSet.getInstances();
int samplesNum = Math.min(instances.size(), param.getSamplesNum());
System.out.println(String.format("sgd.calibration.eta: %f\n", param.getEta()));
System.out.println(String.format("sgd.calibration.rate: %f\n", param.getRate()));
System.out.println(String.format("sgd.calibration.samples: %d\n", samplesNum));
System.out.println(String.format("sgd.calibration.candidates: %d\n", param.getCandidatesNum()));
Collections.shuffle(dataSet.getInstances());
initWeights(weights);
double initialLogProb = computeInitialLoglikelihood(dataSet, samplesNum, lambda,
weights.getAttributeWeights(), weights.getTransitionWeights());
System.out.println(String.format("Initial Log-likelihood: %f\n", initialLogProb));
boolean decrease = false;
int numOfCandidates = param.getCandidatesNum();
int trial = 0;
double bestLogp = -Double.MAX_VALUE;
double initEtaValue = param.getEta();
double bestEtaValue = initEtaValue;
double etaValue = initEtaValue;
System.out.println("calibrating");
while (numOfCandidates > 0 || !decrease) {
System.out.println(String.format("Trial #%d (eta = %f): ", trial + 1, etaValue));
initWeights(weights);
double logp = calibrateSgd(instances, samplesNum, 1.0 / (lambda * etaValue), lambda,
dataSet.getLabelNum(), weights.getBosTransitionWeights(),
weights.getEosTransitionWeights(), weights.getTransitionWeights(),
weights.getAttributeWeights());
boolean ok = !Double.isInfinite(logp) && logp > initialLogProb;
if (ok) {
System.out.println(String.format("%f\n", logp));
} else {
System.out.println(String.format("%f (worse)\n", logp));
}
if (ok) {
numOfCandidates--;
if (logp > bestLogp) {
bestLogp = logp;
bestEtaValue = etaValue;
}
}
if (!decrease) {
if (ok) {
System.out.println(String.format("etaValue(%f)*=rate(%f)", etaValue, param.getRate()));
etaValue *= param.getRate();
System.out.println("etaValue=" + etaValue);
} else {
decrease = true;
System.out.println(String.format("initEtaValue(%f)/=rate(%f)", initEtaValue,
param.getRate()));
etaValue = initEtaValue / param.getRate();
System.out.println("etaValue=" + etaValue);
}
} else {
System.out.println(String.format("etaValue(%f)/=rate(%f)", etaValue, param.getRate()));
etaValue /= param.getRate();
System.out.println("etaValue=" + etaValue);
}
++trial;
}
etaValue = bestEtaValue;
System.out.println(String.format("Best learning rate (eta): %f\n", etaValue));
return 1.0 / (lambda * etaValue);
}
public static void saveModel(TrainingParams params, TrainingWeights weights, String fileName)
throws IOException {
FSTObjectOutput foo = null;
try {
foo = new FSTObjectOutput(new FileOutputStream(fileName));
foo.writeObject(params, TrainingParams.class);
foo.writeObject(weights, TrainingWeights.class);
} finally {
if (foo != null) {
foo.close();
}
}
}
public static CrfModel loadModel(InputStream is) throws Exception {
FSTObjectInput foi = null;
try {
foi = new FSTObjectInput(is);
TrainingParams params = (TrainingParams) foi.readObject(TrainingParams.class);
TrainingWeights weights = (TrainingWeights) foi.readObject(TrainingWeights.class);
return new CrfModel(params, weights);
} finally {
if (foi != null) {
foi.close();
}
}
}
public static CrfModel loadModel(String fileName) throws Exception {
FSTObjectInput foi = null;
try {
foi = new FSTObjectInput(new FileInputStream(fileName));
TrainingParams params = (TrainingParams) foi.readObject(TrainingParams.class);
TrainingWeights weights = (TrainingWeights) foi.readObject(TrainingWeights.class);
return new CrfModel(params, weights);
} finally {
if (foi != null) {
foi.close();
}
}
}
private static Instance readInstance(Gson gson, String line) {
String[] arr = line.split("\t", 2);
return gson.fromJson(arr[1], Instance.class);
}
public static List<Instance> buildSamples4Calibrate(FileSystem hdfsFs, String path, int K,
String charset, long[] totalSample) throws IOException {
Instance[] r = new Instance[K];
Path file = new Path(path);
FileStatus[] status = hdfsFs.listStatus(file);
RandomDataGenerator rndgen = new RandomDataGenerator();
CompressionCodecFactory factory = new CompressionCodecFactory(hdfsFs.getConf());
Gson gson = new Gson();
long lineNum = 0;
for (FileStatus stat : status) {
if (stat.isDir()) {
System.out.println("ignore subdir: " + stat.getPath().toString());
} else {
Path f = stat.getPath();
if (f.getName().startsWith("_")) {
System.out.println("ignore hidden file: " + f.toString());
} else {
System.out.println("process: " + f.toString());
BufferedReader br = null;
try {
CompressionCodec codec = factory.getCodec(f);
InputStream stream = null;
// check if we have a compression codec we need to use
if (codec != null) {
stream = codec.createInputStream(hdfsFs.open(f));
} else {
stream = hdfsFs.open(f);
}
br = new BufferedReader(new InputStreamReader(stream, charset));
String line;
while ((line = br.readLine()) != null) {
Instance instance = null;
// Reservoir sampling
if (lineNum < K) {
instance = readInstance(gson, line);
r[(int) lineNum] = instance;
} else {
long j = rndgen.nextLong(0, lineNum);
if (j < K) {
instance = readInstance(gson, line);
r[(int) j] = instance;
}
}
lineNum++;
if (lineNum % 100000 == 0) {
System.out.println(new java.util.Date() + " buildSamples4Calibrate: " + lineNum);
}
}
} finally {
if (br != null) {
br.close();
}
}
}
}
}
totalSample[0] = lineNum;
List<Instance> instances = new ArrayList<Instance>(r.length);
for (Instance instance : r) {
if (instance != null) {
instances.add(instance);
}
}
return instances;
}
public static OnePassResult firstPassScan(FileSystem hdfsFs, String path, TagConvertor tc,
Template template, int K, String charset, FeatureDictEnum dictType) throws IOException {
OnePassResult opr = new OnePassResult(dictType);
FeatureDict attributeDict = opr.getAttributes();
TObjectIntHashMap<String> labelDict = opr.getLabels();
Path file = new Path(path);
FileStatus[] status = hdfsFs.listStatus(file);
Random rnd = new Random();
Instance[] r = new Instance[K];
long lineNum = 0;
for (FileStatus stat : status) {
if (stat.isDir()) {
System.out.println("ignore subdir: " + stat.getPath().toString());
} else {
Path f = stat.getPath();
if (f.getName().startsWith(".")) {
System.out.println("ignore hidden file: " + f.toString());
} else {
System.out.println("process: " + f.toString());
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(hdfsFs.open(f), charset));
String line;
while ((line = br.readLine()) != null) {
String[] tokens = line.split("\t");
Instance instance = buildInstance(tokens, tc, attributeDict, labelDict, template,
true, true);
// Reservoir sampling
if (lineNum < K) {
r[(int) lineNum] = instance;
} else {
long j = rnd.nextLong() % lineNum;
if (j < K) {
r[(int) j] = instance;
}
}
lineNum++;
if (lineNum % 100000 == 0) {
System.out.println(new java.util.Date() + " firstPass: " + lineNum);
}
}
} finally {
if (br != null) {
br.close();
}
}
}
}
}
opr.setTotalNumber(lineNum);
ArrayList<Instance> instances = opr.getInstances();
instances.ensureCapacity(K);
for (Instance instance : r) {
instances.add(instance);
}
return opr;
}
public static void genFeatureDictAndInstances(FileSystem hdfsFs, String path, int validateNum,
Template template, int iterationNum, TrainingParams param, TagConvertor tc, String charset,
TrainingWeights weights, FeatureDictEnum dictType, String featureFile, String instanceFile)
throws IOException {
int sampleNum = param.getSamplesNum();
OnePassResult opr = firstPassScan(hdfsFs, path, tc, template, sampleNum, charset, dictType);
FeatureDict attributeDict = opr.getAttributes();
System.out.println("save Dict to " + featureFile);
FSTObjectOutput foo = null;
try {
foo = new FSTObjectOutput(new FileOutputStream(featureFile));
foo.writeObject(attributeDict);
} finally {
if (foo != null) {
foo.close();
}
}
System.out.println("save instance to " + instanceFile);
List<Instance> instances = opr.getInstances();
try {
foo = new FSTObjectOutput(new FileOutputStream(instanceFile));
foo.writeObject(instances);
} finally {
if (foo != null) {
foo.close();
}
}
}
public static void train(FileSystem hdfsFs, String path, Template template, int iterationNum,
TrainingParams param, TagConvertor tc, String charset, TrainingWeights weights,
FeatureDict dict) throws IOException {
int sampleNum = param.getSamplesNum();
System.out.println("get instances for calibrate, number=" + sampleNum);
TObjectIntHashMap<String> labelDict = new TObjectIntHashMap<String>();
for (String tag : tc.getTags()) {
labelDict.put(tag, labelDict.size());
}
long[] totalInstances = new long[0];
List<Instance> calibrateInstances = buildSamples4Calibrate(hdfsFs, path, sampleNum, charset,
totalInstances);
System.out.println("totalInstances: " + totalInstances[0]);
TrainingDataSet ds = new TrainingDataSet();
ds.setLabelNum(labelDict.size());
ds.setInstances(calibrateInstances);
System.out.println("start sgd");
int t = 0;
double decay = 1.0, proj = 1.0;
double lambda = 1.0 / (param.getSigma() * param.getSigma() * totalInstances[0]);
int labelNum = labelDict.size();
double[] expBosTransitionWeights = new double[labelNum];
double[] expEosTransitionWeights = new double[labelNum];
double[][] expTransitionWeighs = new double[labelNum][];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
expTransitionWeighs[labelIndex] = new double[labelNum];
}
System.out.println("calibrating");
double t0 = calibrate(ds, lambda, param, weights);
System.out.println(t0);
}
public static void train(TrainingDataSet dataSet, int validateNum, int iterationNum,
TrainingParams param, TrainingWeights weights, TrainingProgress tp) {
List<Instance> instances = dataSet.getInstances();
int labelNum = dataSet.getLabelNum();
Collections.shuffle(instances);
List<Instance> validateInstances = instances.subList(0, validateNum);
List<Instance> trainInstances = instances.subList(validateNum, instances.size());
// System.out.println("start sgd");
tp.startTraining();
boolean validate = validateInstances != null && validateInstances.size() > 0;
int t = 0;
double decay = 1.0, proj = 1.0;
int trainInstanceNum = trainInstances.size();
double lambda = 1.0 / (param.getSigma() * param.getSigma() * trainInstanceNum);
double[] expBosTransitionWeights = new double[labelNum];
double[] expEosTransitionWeights = new double[labelNum];
double[][] expTransitionWeighs = new double[labelNum][];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
expTransitionWeighs[labelIndex] = new double[labelNum];
}
double t0 = 0;
if (param.getT0() == 0) {
// System.out.println("calibrating");
t0 = calibrate(dataSet, lambda, param, weights);
} else {
t0 = param.getT0();
}
double norm2 = 0;
initWeights(weights);
double[] bosTransitionWeights = weights.getBosTransitionWeights();
double[] eosTransitionWeights = weights.getEosTransitionWeights();
double[] transitionWeights = weights.getTransitionWeights();
double[] attributeWeights = weights.getAttributeWeights();
// int attributeNum = attributeWeights.length;
for (int epoch = 1; epoch <= iterationNum; epoch++) {
tp.doIter(epoch);
// System.out.println(new java.util.Date()+" iteration No. " +
// epoch);
Collections.shuffle(trainInstances);
double scale = 0;
for (Instance trainInstance : trainInstances) {
double eta = 1.0 / (lambda * (t0 + t));
decay *= (1.0 - eta * lambda);
scale = decay * proj;
double gain = eta / scale;
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
expBosTransitionWeights[labelIndex] = Math.exp(bosTransitionWeights[labelIndex]);
expEosTransitionWeights[labelIndex] = Math.exp(eosTransitionWeights[labelIndex]);
}
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
expTransitionWeighs[fromLabelIndex][toLabelIndex] = Math
.exp(transitionWeights[fromLabelIndex * labelNum + toLabelIndex]);
}
}
int itemNum = trainInstance.length();
double[] statesScores = computeStateScores(trainInstance, true, labelNum, attributeWeights);
double[] scaleFactors = new double[itemNum + 1];
double[][] forwardScores = computeForwardScores(trainInstance, expBosTransitionWeights,
expEosTransitionWeights, expTransitionWeighs, statesScores, scaleFactors, labelNum);
double[][] backwardScores = computeBackwardScores(trainInstance, expEosTransitionWeights,
expTransitionWeighs, statesScores, scaleFactors, labelNum);
norm2 += updateFeatureWeights(trainInstance, expTransitionWeighs, statesScores,
forwardScores, backwardScores, scaleFactors, gain, labelNum, attributeWeights,
bosTransitionWeights, eosTransitionWeights, transitionWeights);
double boundary = norm2 * scale * scale * lambda;
if (boundary > 1.0) {
proj = 1.0 / Math.sqrt(boundary);
}
++t;
}
if (scale < 1e-20) {
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
bosTransitionWeights[labelIndex] *= scale;
eosTransitionWeights[labelIndex] *= scale;
}
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
transitionWeights[fromLabelIndex * labelNum + toLabelIndex] *= scale;
}
}
// for (int attributeIndex = 0; attributeIndex < attributeNum;
// attributeIndex++) {
// for (int labelIndex = 0; labelIndex < labelNum; labelIndex++)
// {
// int idx=attributeIndex * labelNum + labelIndex;
// if(idx==54655554){
// throw new
// RuntimeException(String.format("idx(%d)=attributeIndex(%d) * labelNum(%d) + labelIndex(%d)",idx,attributeIndex,labelNum,labelIndex));
// }
// attributeWeights[idx] *= scale;
// }
// }
for (int idx = 0; idx < attributeWeights.length; idx++) {
attributeWeights[idx] *= scale;
}
decay = 1.0;
proj = 1.0;
// System.out.println("scale weights and reset decay and proj to 1");
}
if (validate) {
EvaluationResult statsOnValidateData = evaluate(validateInstances, weights);
// System.out.println("statistics on validate data: ");
// System.out.println("\n" + statsOnValidateData.toString() +
// "\n");
tp.doValidate(statsOnValidateData.toString());
}
}
// System.out.println("validate on all data");
EvaluationResult statsOnValidateData = evaluate(instances, weights);
tp.doValidate(statsOnValidateData.toString());
// System.out.println("statistics on validate data: ");
// System.out.println("\n" + statsOnValidateData.toString() + "\n");
}
private static EvaluationResult evaluate(List<Instance> instances, TrainingWeights weights) {
String[] labelTexts = weights.getLabelTexts();
EvaluationResult evaluation = new EvaluationResult(labelTexts);
for (Instance instance : instances) {
int[] tagIds = tagId(instance, weights);
evaluation.totalItemCount += tagIds.length;
boolean hasError = false;
for (int itemIndex = 0; itemIndex < instance.length(); itemIndex++) {
int keyIndex = tagIds[itemIndex];
int answerIndex = instance.labelIds()[itemIndex];
for (int labelIndex = 0; labelIndex < labelTexts.length; labelIndex++) {
int[] counts = evaluation.labelIndex2count[labelIndex];
if (answerIndex == labelIndex) {
if (keyIndex == labelIndex) {
counts[EvaluationResult.TP_INDEX]++;
} else {
counts[EvaluationResult.FP_INDEX]++;
}
} else {
if (keyIndex == labelIndex) {
counts[EvaluationResult.FN_INDEX]++;
} else {
counts[EvaluationResult.TN_INDEX]++;
}
}
}
if (answerIndex == keyIndex) {
evaluation.correctItemCount++;
} else if (!hasError) {
hasError = true;
}
}
evaluation.totalSeqCount++;
if (!hasError) {
evaluation.correctSeqCount++;
}
}
return evaluation;
}
public static double getScore(String[] tokens, TagConvertor tc, CrfModel model){
TrainingWeights weights=model.weights;
Instance instance = buildInstance(tokens, tc, weights.getAttributeDict(),
weights.getLabelDict(), weights.getTemplate(), false, false);
return getScore(weights, instance);
}
public static double getScore(TrainingWeights weights,Instance instance){
int labelNum = weights.getLabelTexts().length;
int[] tagIds=instance.labelIds();
double[] stateScores = computeStateScores(instance, false, labelNum, weights.getAttributeWeights());
double score=0;
int labelIndex=tagIds[0];
score=weights.getBosTransitionWeights()[labelIndex] + stateScores[labelIndex];
for(int i=1;i<tagIds.length;i++){
score+=(weights.getTransitionWeights()[tagIds[i-1]*labelNum+tagIds[i]]+stateScores[i*labelNum+tagIds[i]]);
}
score+=weights.getEosTransitionWeights()[tagIds[tagIds.length-1]];
return score;
}
public static int[] tagId(Instance instance, TrainingWeights weights) {
if (instance == null) {
return new int[0];
}
int itemNum = instance.length();
int labelNum = weights.getLabelTexts().length;
int[] tagIndexes = new int[itemNum];
double[] stateScores = computeStateScores(instance, false, labelNum,
weights.getAttributeWeights());
int[] bestBackIndexes = new int[itemNum * labelNum];
double[] bestScores = new double[itemNum * labelNum];
double[] bosTransitionWeights = weights.getBosTransitionWeights();
double[] transitionWeights = weights.getTransitionWeights();
double[] eosTransitionWeights = weights.getEosTransitionWeights();
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
bestScores[labelIndex] = bosTransitionWeights[labelIndex] + stateScores[labelIndex];
}
for (int itemIndex = 1, itemMulIndex = labelNum; itemIndex < itemNum; itemIndex++, itemMulIndex += labelNum) {
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
double maxScore = bestScores[itemMulIndex - labelNum] + transitionWeights[toLabelIndex];
;
int maxFromLabelIndex = 0;
for (int fromLabelIndex = 1, fromLabelMulIndex = labelNum; fromLabelIndex < labelNum; fromLabelIndex++, fromLabelMulIndex += labelNum) {
double score = bestScores[itemMulIndex - labelNum + fromLabelIndex]
+ transitionWeights[fromLabelMulIndex + toLabelIndex];
if (score > maxScore) {
maxScore = score;
maxFromLabelIndex = fromLabelIndex;
}
}
bestScores[itemMulIndex + toLabelIndex] = maxScore
+ stateScores[itemMulIndex + toLabelIndex];
bestBackIndexes[itemMulIndex + toLabelIndex] = maxFromLabelIndex;
}
}
int itemMulIndex = (itemNum - 1) * labelNum;
double maxScore = bestScores[itemMulIndex] + eosTransitionWeights[0];
int maxFromLabelIndex = 0;
for (int labelIndex = 1; labelIndex < labelNum; labelIndex++) {
double score = bestScores[itemMulIndex + labelIndex] + eosTransitionWeights[labelIndex];
if (score > maxScore) {
maxScore = score;
maxFromLabelIndex = labelIndex;
}
}
tagIndexes[itemNum - 1] = maxFromLabelIndex;
for (int itemIndex = itemNum - 2; itemIndex >= 0; itemIndex--, itemMulIndex -= labelNum) {
maxFromLabelIndex = bestBackIndexes[itemMulIndex + maxFromLabelIndex];
tagIndexes[itemIndex] = maxFromLabelIndex;
}
return tagIndexes;
}
private static int getFeatureIndex(FeatureDict map, String key, boolean isTraining) {
int index = map.get(key, false);
if (index < 0 && isTraining) {
if (key.contains("_B-") || key.contains("_B+")) {
return index;
}
index = map.get(key, true);
}
return index;
}
private static int getLabelIndex(TObjectIntHashMap<String> labelDict, String key,
boolean isTraining) {
int index = labelDict.get(key);
if (index < 0 && isTraining) {
labelDict.put(key, index = labelDict.size());
}
return index;
}
public static Instance buildInstance(String[] tokens, TagConvertor tc, FeatureDict attributeDict,
TObjectIntHashMap<String> labelDict, Template template, boolean addFeatureIfNotExist,
boolean addLableIfNotExist) {
String[] tags = tc.tokens2Tags(tokens);
int itemNum = tags.length;
List<String> attributes = new ArrayList<String>();
for (String token : tokens) {
for (int i = 0; i < token.length(); i++) {
attributes.add(token.substring(i, i + 1));
}
}
if (template != null) {
attributes = template.expandTemplate(attributes, itemNum);
}
int[] attrIds = new int[attributes.size()];
int rowSize = attributes.size() / itemNum;
int[] labelIds = new int[itemNum];
for (int itemIndex = 0, attrIndex = 0; itemIndex < itemNum; itemIndex++) {
labelIds[itemIndex] = getLabelIndex(labelDict, tags[itemIndex], addLableIfNotExist);
// if(labelIds[itemIndex]==-1){
// System.out.println("bug");
// for(String token:tokens){
// System.out.print(token+"\t");
// }
// System.out.println();
// System.out.println("tags["+itemNum+"]="+tags[itemIndex]);
// System.out.println(labelDict.size());
// TObjectIntIterator<String> iter=labelDict.iterator();
// while(iter.hasNext()){
// iter.advance();
// System.out.println(iter.key()+"\t"+iter.value());
// }
// System.exit(-1);
// }
for (int rowIndex = 0; rowIndex < rowSize; rowIndex++) {
attrIds[attrIndex] = getFeatureIndex(attributeDict, attributes.get(attrIndex),
addFeatureIfNotExist);
attrIndex++;
}
}
return new Instance(attrIds, labelIds);
}
private static Instance buildInstance(List<String> attributes, int itemNum, List<String> labels,
TrainingWeights weights, boolean isTraining) {
Template template = weights.getTemplate();
TObjectIntHashMap<String> labelDict = weights.getLabelDict();
FeatureDict attributeDict = weights.getAttributeDict();
if (template != null) {
attributes = template.expandTemplate(attributes, itemNum);
}
int[] attrIds = new int[attributes.size()];
int rowSize = attributes.size() / itemNum;
int[] labelIds = null;
boolean containsLabels = labels != null;
if (containsLabels) {
labelIds = new int[itemNum];
}
for (int itemIndex = 0, attrIndex = 0; itemIndex < itemNum; itemIndex++) {
if (containsLabels) {
labelIds[itemIndex] = getLabelIndex(labelDict, labels.get(itemIndex), isTraining);
}
for (int rowIndex = 0; rowIndex < rowSize; rowIndex++) {
attrIds[attrIndex] = getFeatureIndex(attributeDict, attributes.get(attrIndex), isTraining);
attrIndex++;
}
}
return containsLabels ? new Instance(attrIds, labelIds) : new Instance(attrIds, itemNum);
}
public static List<Instance> readTestData(String filename, String charset, TrainingWeights weights)
throws IOException {
return getInstances(filename, charset, true, weights, false);
}
public static List<Instance> readTestData2(String filename, String charset,
TrainingWeights weights, TagConvertor tc) throws IOException {
return getInstances2(filename, charset, true, weights, false, tc);
}
public static EvaluationResult readAndEvaluate(String filename, String charset,
TrainingWeights weights, TagConvertor tc) throws IOException {
BufferedReader br = null;
String[] labelTexts = weights.getLabelTexts();
EvaluationResult evaluation = new EvaluationResult(labelTexts);
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(filename), charset));
String line;
int lineNumber = 0;
while ((line = br.readLine()) != null) {
if (line.trim().length() == 0) {
continue;
}
lineNumber++;
Instance instance = buildInstance(line.split("\t"), tc, weights.getAttributeDict(),
weights.getLabelDict(), weights.getTemplate(), false, false);
if (lineNumber % 10000 == 0) {
System.out.println(lineNumber + " lines evaluated");
}
int[] tagIds = tagId(instance, weights);
evaluation.totalItemCount += tagIds.length;
boolean hasError = false;
for (int itemIndex = 0; itemIndex < instance.length(); itemIndex++) {
int keyIndex = tagIds[itemIndex];
int answerIndex = instance.labelIds()[itemIndex];
for (int labelIndex = 0; labelIndex < labelTexts.length; labelIndex++) {
int[] counts = evaluation.labelIndex2count[labelIndex];
if (answerIndex == labelIndex) {
if (keyIndex == labelIndex) {
counts[EvaluationResult.TP_INDEX]++;
} else {
counts[EvaluationResult.FP_INDEX]++;
}
} else {
if (keyIndex == labelIndex) {
counts[EvaluationResult.FN_INDEX]++;
} else {
counts[EvaluationResult.TN_INDEX]++;
}
}
}
if (answerIndex == keyIndex) {
evaluation.correctItemCount++;
} else if (!hasError) {
hasError = true;
}
}
evaluation.totalSeqCount++;
if (!hasError) {
evaluation.correctSeqCount++;
}
}
} finally {
if (br != null) {
br.close();
}
}
return evaluation;
}
private static TrainingDataSet shrinkAndInit(int minFeatureFreq, List<Instance> instances,
TrainingWeights weights) {
if (minFeatureFreq > 1) {
SgdCrf.shrinkAttributeDict(instances, minFeatureFreq, weights.getAttributeDict());
}
TrainingDataSet dataSet = new TrainingDataSet();
dataSet.setInstances(instances);
int attrNum = weights.getAttributeDict().size();
int labelNum = weights.getLabelDict().size();
dataSet.setAttributeNum(attrNum);
dataSet.setLabelNum(labelNum);
weights.setAttributeWeights(new double[labelNum * attrNum]);
weights.setTransitionWeights(new double[labelNum * labelNum]);
weights.setBosTransitionWeights(new double[labelNum]);
weights.setEosTransitionWeights(new double[labelNum]);
final String[] labelTexts = new String[labelNum];
weights.getLabelDict().forEachEntry(new gnu.trove.procedure.TObjectIntProcedure<String>() {
@Override
public boolean execute(String text, int index) {
labelTexts[index] = text;
return true;
}
});
weights.setLabelTexts(labelTexts);
return dataSet;
}
public static TrainingDataSet readTrainingData(String filename, String charset,
TrainingWeights weights, int minFeatureFreq) throws IOException {
List<Instance> instances = getInstances(filename, charset, true, weights, true);
return shrinkAndInit(minFeatureFreq, instances, weights);
}
public static TrainingDataSet readTrainingData2(String filename, String charset,
TrainingWeights weights, int minFeatureFreq, TagConvertor tc) throws IOException {
TObjectIntHashMap<String> labeldict = weights.getLabelDict();
for (String tag : tc.getTags()) {
labeldict.put(tag, labeldict.size());
}
List<Instance> instances = getInstances2(filename, charset, true, weights, true, tc);
return shrinkAndInit(minFeatureFreq, instances, weights);
}
private static List<Instance> getInstances2(String filename, String charset,
boolean containsLabels, TrainingWeights weights, boolean isTraining, TagConvertor tc)
throws IOException {
BufferedReader br = null;
List<Instance> instances = new ArrayList<Instance>();
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(filename), charset));
String line;
int lineNumber = 0;
while ((line = br.readLine()) != null) {
if (line.trim().length() == 0) {
continue;
}
lineNumber++;
Instance instance = buildInstance(line.split("\t"), tc, weights.getAttributeDict(),
weights.getLabelDict(), weights.getTemplate(), isTraining, false);
instances.add(instance);
if (lineNumber % 10000 == 0) {
System.out.println(lineNumber + " lines read");
}
}
} finally {
if (br != null) {
br.close();
}
}
return instances;
}
private static List<Instance> getInstances(String filename, String charset,
boolean containsLabels, TrainingWeights weights, boolean isTraining) throws IOException {
BufferedReader br = null;
List<Instance> instances = new ArrayList<Instance>();
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(filename), charset));
String line;
List<String> itemList = new ArrayList<String>();
List<String> labelList = containsLabels ? new ArrayList<String>() : null;
int seqCount = 0, itemCount = 0;
System.out.println("extracting instances");
int fieldNum = -1;
while ((line = br.readLine()) != null) {
if (line.trim().length() == 0) {
if (itemList.size() > 0) {
Instance instance = buildInstance(itemList, labelList.size(), labelList, weights,
isTraining);
instances.add(instance);
seqCount++;
if (seqCount % 10000 == 0) {
System.out.println(seqCount + " lines read");
}
itemList.clear();
labelList.clear();
}
} else {
String[] fields = line.split("\\s+");
int thisFieldNum = containsLabels ? fields.length - 1 : fields.length;
if (fieldNum < 0) {
fieldNum = thisFieldNum;
} else {
if (fieldNum != thisFieldNum) {
throw new IllegalStateException("inconsistent input format: " + line);
}
}
if (containsLabels) {
for (int i = 0; i < fields.length - 1; i++) {
itemList.add(fields[i]);
}
labelList.add(fields[fields.length - 1]);
} else {
for (String field : fields) {
itemList.add(field);
}
}
itemCount++;
}
}
if (itemList.size() > 0) {
Instance instance = buildInstance(itemList, labelList.size(), labelList, weights,
isTraining);
instances.add(instance);
}
System.out.println("found " + seqCount + " instances and " + itemCount + " items");
} finally {
if (br != null) {
br.close();
}
}
return instances;
}
public static void showUsageAndExit() {
System.err.println("Usage:");
System.err.println("\t" + "SgdCrf help");
System.err
.println("\t"
+ "SgdCrf train <CRF++_format_train_file> <model_file> <crf_train_properties_file> [encoding]");
System.err
.println("\t"
+ "SgdCrf train2 <tab_sep_text_train_file> <model_file> <crf_train_properties_file> [encoding]");
System.err
.println("\t"
+ "SgdCrf hdfs-train <hdfs_dir> <model_file> <crf_train_properties_file> <feature_dict> [encoding] [hdfsconf1] [hdfsconf2] ...");
System.err.println("\t" + "SgdCrf test <test_file> <model_file> [encoding]");
System.err.println("\t" + "SgdCrf test2 <test_file> <model_file> [encoding]");
System.err.println("\t" + "SgdCrf tag <model_file> [nBest] [encoding]");
System.exit(1);
}
public static TrainingParams loadParams(String configFile) throws IOException {
TrainingParams params = new TrainingParams();
Properties props = new Properties();
props.load(new FileInputStream(new File(configFile)));
params.setMinFeatureFreq(getIntParam(props, "mininumFeatureFrequency", 1));
params.setEta(getDoubleParam(props, "eta", .1));
params.setSigma(getDoubleParam(props, "sigma", 10.0));
params.setRate(getDoubleParam(props, "rate", 2));
params.setIterationNum(getIntParam(props, "iterateCount", 100));
params.setCandidatesNum(getIntParam(props, "candidatesNum", 10));
params.setSamplesNum(getIntParam(props, "samplesNum", 1000));
params.setT0(getDoubleParam(props, "t0", 0));
String templateFile = props.getProperty("templateFile");
params.setTemplates(readTemplates(templateFile));
return params;
}
public static List<String> readTemplates(String path) throws IOException {
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
List<String> lines = new ArrayList<String>();
String line;
while ((line = br.readLine()) != null) {
line = line.trim();
if (line.startsWith("#") || line.equals("")) {
continue;
}
lines.add(line);
}
return lines;
} finally {
if (br != null) {
br.close();
}
}
}
private static int shrinkAttributeDict(List<Instance> instances, int freqThreshold,
FeatureDict attributeDict) {
final int[] counter = new int[attributeDict.size()];
for (Instance instance : instances) {
int[] attrIds = instance.getAttrIds();
for (int attrId : attrIds) {
if (attrId >= 0) {
counter[attrId]++;
}
}
}
int newNextAttrId = 0;
for (int oldAttrId = 0; oldAttrId < counter.length; oldAttrId++) {
if (counter[oldAttrId] > freqThreshold) {
counter[oldAttrId] = newNextAttrId++;
} else {
counter[oldAttrId] = -1;
}
}
TObjectIntIterator<String> iter = attributeDict.iterator();
int removeNum = 0;
while (iter.hasNext()) {
iter.advance();
int oldAttrId = iter.value();
int newAttrId = counter[oldAttrId];
if (newAttrId < 0) {
iter.remove();
removeNum++;
} else {
iter.setValue(newAttrId);
}
}
for (Instance instance : instances) {
int[] oldAttrIds = instance.getAttrIds();
for (int i = 0; i < oldAttrIds.length; ++i) {
int oldAttrId = oldAttrIds[i];
if (oldAttrId >= 0) {
oldAttrIds[i] = counter[oldAttrId];
}
}
}
return removeNum;
}
private static int getIntParam(Properties props, String key, int defaultValue) {
if (props.containsKey(key)) {
return Integer.valueOf(props.getProperty(key));
} else {
return defaultValue;
}
}
private static double getDoubleParam(Properties props, String key, double defaultValue) {
if (props.containsKey(key)) {
return Double.valueOf(props.getProperty(key));
} else {
return defaultValue;
}
}
private static Instance buildInstance4Explanation(List<String> attributes, int itemNum,
List<String> labels, Map<Integer, String> featureMap, Template template,
FeatureDict attributeDict, TObjectIntHashMap<String> labelDict) {
if (template != null) {
attributes = template.expandTemplate(attributes, itemNum);
}
int[] attrIds = new int[attributes.size()];
int rowSize = attributes.size() / itemNum;
int[] labelIds = null;
boolean containsLabels = labels != null;
if (containsLabels) {
labelIds = new int[itemNum];
}
for (int itemIndex = 0, attrIndex = 0; itemIndex < itemNum; itemIndex++) {
if (containsLabels) {
labelIds[itemIndex] = getLabelIndex(labelDict, labels.get(itemIndex), false);
}
for (int rowIndex = 0; rowIndex < rowSize; rowIndex++) {
attrIds[attrIndex] = getFeatureIndex(attributeDict, attributes.get(attrIndex), false);
featureMap.put(attrIds[attrIndex], attributes.get(attrIndex));
attrIndex++;
}
}
return containsLabels ? new Instance(attrIds, labelIds) : new Instance(attrIds, itemNum);
}
private static double[] computeStateScores4Explanation(Instance instance, boolean exp,
FeatureWeightScore[][] details, Map<Integer, String> featureMap, int labelNum,
double[] attributeWeights) {
int itemNum = instance.length();
int rowSize = instance.rowSize();
int[] attrIds = instance.getAttrIds();
double[] stateScores = new double[itemNum * labelNum];
for (int itemIndex = 0; itemIndex < itemNum; itemIndex++) {
for (int i = 0; i < rowSize; i++) {
int attributeIndex = attrIds[itemIndex * rowSize + i];
if (attributeIndex >= 0) {
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
stateScores[itemIndex * labelNum + labelIndex] += attributeWeights[attributeIndex
* labelNum + labelIndex];
String feature = featureMap.get(attributeIndex);
details[itemIndex][labelIndex].features.add(feature);
details[itemIndex][labelIndex].weights.add(attributeWeights[attributeIndex * labelNum
+ labelIndex]);
}
}
}
}
if (exp) {
for (int itemIndex = 0; itemIndex < itemNum; itemIndex++) {
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
double score = stateScores[itemIndex * labelNum + labelIndex];
stateScores[itemIndex * labelNum + labelIndex] = Math.exp(score);
}
}
}
for (int itemIndex = 0; itemIndex < itemNum; itemIndex++) {
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
details[itemIndex][labelIndex].score = stateScores[itemIndex * labelNum + labelIndex];
List<String> featureList = details[itemIndex][labelIndex].features;
List<Double> weightList = details[itemIndex][labelIndex].weights;
List<Object[]> sortHelper = new ArrayList<Object[]>(featureList.size());
for (int i = 0; i < featureList.size(); i++) {
sortHelper.add(new Object[] { featureList.get(i), weightList.get(i) });
}
Collections.sort(sortHelper, new Comparator<Object[]>() {
@Override
public int compare(Object[] arg0, Object[] arg1) {
double w1 = (Double) arg0[1];
double w2 = (Double) arg1[1];
w1 = Math.abs(w1);
w2 = Math.abs(w2);
if (w1 >= w2)
return -1;
else
return 1;
}
});
featureList.clear();
weightList.clear();
for (Object[] arr : sortHelper) {
featureList.add((String) arr[0]);
weightList.add((Double) arr[1]);
}
}
}
return stateScores;
}
public static Explanation explain(String sen, CrfModel model) {
List<String> features = new ArrayList<String>(sen.length());
for (int i = 0; i < sen.length(); i++) {
features.add(sen.charAt(i) + "");
}
Explanation explanation = tagAndExplain(features, sen.length(), model);
explanation.tokens = features;
return explanation;
}
public static Explanation tagAndExplain(List<String> features, int itemNum, CrfModel model) {
double[] bosTransitionWeights = model.weights.getBosTransitionWeights();
double[] transitionWeights = model.weights.getTransitionWeights();
double[] eosTransitionWeights = model.weights.getEosTransitionWeights();
double[] attributeWeights = model.weights.getAttributeWeights();
Template template = model.weights.getTemplate();
FeatureDict attributeDict = model.weights.getAttributeDict();
TObjectIntHashMap<String> labelDict = model.weights.getLabelDict();
Explanation result = new Explanation();
Map<Integer, String> featureMap = new HashMap<Integer, String>();
Instance instance = buildInstance4Explanation(features, itemNum, null, featureMap, template,
attributeDict, labelDict);
// codes copied from tagId(Instance instance)
if (instance == null) {
result.bestTagIds = new int[0];
return result;
}
int[] tagIndexes = new int[itemNum];
int labelNum = model.weights.getLabelDict().size();
FeatureWeightScore[][] details = new FeatureWeightScore[itemNum][];
for (int i = 0; i < details.length; i++) {
details[i] = new FeatureWeightScore[labelNum];
for (int j = 0; j < details[i].length; j++) {
details[i][j] = new FeatureWeightScore();
}
}
result.details = details;
double[] stateScores = computeStateScores4Explanation(instance, false, details, featureMap,
labelNum, attributeWeights);
int[] bestBackIndexes = new int[itemNum * labelNum];
double[] bestScores = new double[itemNum * labelNum];
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
bestScores[labelIndex] = bosTransitionWeights[labelIndex] + stateScores[labelIndex];
}
for (int itemIndex = 1, itemMulIndex = labelNum; itemIndex < itemNum; itemIndex++, itemMulIndex += labelNum) {
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
double maxScore = bestScores[itemMulIndex - labelNum] + transitionWeights[toLabelIndex];
;
int maxFromLabelIndex = 0;
for (int fromLabelIndex = 1, fromLabelMulIndex = labelNum; fromLabelIndex < labelNum; fromLabelIndex++, fromLabelMulIndex += labelNum) {
double score = bestScores[itemMulIndex - labelNum + fromLabelIndex]
+ transitionWeights[fromLabelMulIndex + toLabelIndex];
if (score > maxScore) {
maxScore = score;
maxFromLabelIndex = fromLabelIndex;
}
}
bestScores[itemMulIndex + toLabelIndex] = maxScore
+ stateScores[itemMulIndex + toLabelIndex];
bestBackIndexes[itemMulIndex + toLabelIndex] = maxFromLabelIndex;
}
}
int itemMulIndex = (itemNum - 1) * labelNum;
double maxScore = bestScores[itemMulIndex] + eosTransitionWeights[0];
int maxFromLabelIndex = 0;
for (int labelIndex = 1; labelIndex < labelNum; labelIndex++) {
double score = bestScores[itemMulIndex + labelIndex] + eosTransitionWeights[labelIndex];
if (score > maxScore) {
maxScore = score;
maxFromLabelIndex = labelIndex;
}
}
tagIndexes[itemNum - 1] = maxFromLabelIndex;
for (int itemIndex = itemNum - 2; itemIndex >= 0; itemIndex--, itemMulIndex -= labelNum) {
maxFromLabelIndex = bestBackIndexes[itemMulIndex + maxFromLabelIndex];
tagIndexes[itemIndex] = maxFromLabelIndex;
}
result.bestTagIds = tagIndexes;
result.bosTransitionWeights = bosTransitionWeights;
result.eosTransitionWeights = eosTransitionWeights;
result.transitionWeights = transitionWeights;
result.labelTexts = model.weights.getLabelTexts();
return result;
}
public static List<String[]> tagNBest(Instance instance, int N, double[] relativeScore,
CrfModel model) {
List<String[]> result = new ArrayList<String[]>(N);
if (instance == null) {
return result;
}
int itemNum = instance.length();
// compute state scores
int labelNum = model.weights.getLabelDict().size();
double[] attributeWeights = model.weights.getAttributeWeights();
double[] bosTransitionWeights = model.weights.getBosTransitionWeights();
double[] transitionWeights = model.weights.getTransitionWeights();
double[] eosTransitionWeights = model.weights.getEosTransitionWeights();
String[] labelTexts = model.weights.getLabelTexts();
double[] stateScores = computeStateScores(instance, false, labelNum, attributeWeights);
int[][][][] backs = new int[itemNum][][][];
double[][][] scores = new double[itemNum][][];
scores[0] = new double[labelNum][];
for (int i = 0; i < labelNum; i++) {
scores[0][i] = new double[N];
}
for (int labelIndex = 0; labelIndex < labelNum; labelIndex++) {
scores[0][labelIndex][0] = stateScores[labelIndex] + bosTransitionWeights[labelIndex];
for (int i = 1; i < N; i++) {
scores[0][labelIndex][i] = -Double.MAX_VALUE;
}
}
for (int itemIndex = 1; itemIndex < itemNum; itemIndex++) {
scores[itemIndex] = new double[labelNum][];
backs[itemIndex] = new int[labelNum][][];
for (int toLabelIndex = 0; toLabelIndex < labelNum; toLabelIndex++) {
// double maxScore = -Double.MAX_VALUE;
// int maxFromLabelIndex = -1;
double[] nBest = new double[N];
int[][] nBestIndex = new int[N][2];
int curCount = 0;
nBest[0] = -Double.MAX_VALUE;
for (int fromLabelIndex = 0; fromLabelIndex < labelNum; fromLabelIndex++) {
for (int i = 0; i < N; i++) {
if (scores[itemIndex - 1][fromLabelIndex][i] == -Double.MAX_VALUE) {
break;
}
double score = scores[itemIndex - 1][fromLabelIndex][i]
+ transitionWeights[fromLabelIndex * labelNum + toLabelIndex];
if (curCount < N || score > nBest[N - 1]) {
int j = 0;
for (; j < N; j++) {
if (score > nBest[j])
break;
}
for (int k = N - 1; k > j; k--) {
nBest[k] = nBest[k - 1];
nBestIndex[k] = nBestIndex[k - 1];
}
nBest[j] = score;
nBestIndex[j] = new int[] { fromLabelIndex, i };
curCount++;
}
}
}
for (int i = 0; i < curCount && i < N; i++) {
nBest[i] += stateScores[itemIndex * labelNum + toLabelIndex];
}
scores[itemIndex][toLabelIndex] = nBest;
backs[itemIndex][toLabelIndex] = nBestIndex;
}
}
List<Object[]> helper = new ArrayList<Object[]>();
for (int i = 0; i < labelNum; i++) {
double[] score = scores[itemNum - 1][i];
for (int j = 0; j < N; j++) {
double s = score[j];
if (s == -Double.MAX_VALUE)
break;
s = s + eosTransitionWeights[i];
Object[] arr = new Object[] { s, i, j };
helper.add(arr);
}
}
Collections.sort(helper, new Comparator<Object[]>() {
@Override
public int compare(Object[] arg0, Object[] arg1) {
double s1 = (Double) arg0[0];
double s2 = (Double) arg1[0];
if (s1 >= s2)
return -1;
else
return 1;
}
});
int[] tmp;
for (int i = 0; i < N; i++) {
String[] tags = new String[itemNum];
Object[] arr = helper.get(i);
int j = (Integer) arr[1];
int k = (Integer) arr[2];
tags[itemNum - 1] = labelTexts[j];
relativeScore[i] = (Double) arr[0];
for (int itemIndex = itemNum - 2; itemIndex >= 0; itemIndex--) {
tmp = backs[itemIndex + 1][j][k];
j = tmp[0];
k = tmp[1];
tags[itemIndex] = labelTexts[j];
}
result.add(tags);
}
return result;
}
public static String[] tagId2Text(int[] tags, CrfModel model) {
String[] labelTexts = model.weights.getLabelTexts();
String[] tagTexts = new String[tags.length];
for (int i = 0; i < tags.length; i++) {
tagTexts[i] = labelTexts[tags[i]];
}
return tagTexts;
}
public static List<String> segment(String sentence, CrfModel model, TagConvertor tagConvertor) {
List<String> attributes = new ArrayList<String>(sentence.length());
for (int i = 0; i < sentence.length(); i++) {
attributes.add(sentence.charAt(i) + "");
}
Instance instance = buildInstance(attributes, attributes.size(), null, model.weights, false);
int[] tags = tagId(instance, model.weights);
return tagConvertor.tags2TokenList(tagId2Text(tags, model), sentence);
}
public static List<String[]> segment(String sentence, CrfModel model, TagConvertor tagConvertor,
int nBest) {
List<String[]> result = new ArrayList<String[]>();
List<String> attributes = new ArrayList<String>(sentence.length());
for (int i = 0; i < sentence.length(); i++) {
attributes.add(sentence.charAt(i) + "");
}
Instance instance = buildInstance(attributes, attributes.size(), null, model.weights, false);
double[] relativeScore = new double[nBest];
List<String[]> tags = tagNBest(instance, nBest, relativeScore, model);
for (String[] tag : tags) {
result.add(tagConvertor.tags2Tokens(tag, sentence));
}
return result;
}
public static void main(String[] args) throws Exception {
if (args.length < 1) {
showUsageAndExit();
}
String command = args[0];
if (command.equals("help")) {
showUsageAndExit();
} else if (command.equals("train")) {
if (args.length != 4 && args.length != 5) {
showUsageAndExit();
}
String trainFilename = args[1];
String modelFilename = args[2];
String configFilename = args[3];
String charset = "UTF8";
if (args.length > 4) {
charset = args[4];
}
TrainingParams params = loadParams(configFilename);
Template template = new Template(params.getTemplates().toArray(new String[0]));
TrainingWeights weights = new TrainingWeights(template, FeatureDictEnum.TROVE_HASHMAP);
TrainingDataSet dataSet = SgdCrf.readTrainingData(trainFilename, charset, weights,
params.getMinFeatureFreq());
SgdCrf.train(dataSet, 0, params.getIterationNum(), params, weights,
new PrintTrainingProgress());
dataSet = null;// free memory
SgdCrf.saveModel(params, weights, modelFilename);
// SgdCrf.train(dataSet, validateNum, iterationNum, param, model);
} else if (command.equals("train2")) {
if (args.length != 4 && args.length != 5) {
showUsageAndExit();
}
String trainFilename = args[1];
String modelFilename = args[2];
String configFilename = args[3];
String charset = "UTF8";
if (args.length > 4) {
charset = args[4];
}
TrainingParams params = loadParams(configFilename);
Template template = new Template(params.getTemplates().toArray(new String[0]));
TrainingWeights weights = new TrainingWeights(template, FeatureDictEnum.TROVE_HASHMAP);
TagConvertor tc = new BESB1B2MTagConvertor();
TrainingDataSet dataSet = SgdCrf.readTrainingData2(trainFilename, charset, weights,
params.getMinFeatureFreq(), tc);
SgdCrf.train(dataSet, 0, params.getIterationNum(), params, weights,
new PrintTrainingProgress());
dataSet = null;
SgdCrf.saveModel(params, weights, modelFilename);
}
// else if(command.equals("train-hdfs")){
// if(!(args.length>=4)){
// showUsageAndExit();
// }
// String trainFilename = args[1];
// String modelFilename = args[2];
// String configFilename = args[3];
//
// String charset = "UTF8";
// if(args.length>4){
// charset=args[4];
// }
//
// Configuration conf = new Configuration();
// for(int i=5;i<args.length;i++){
// System.out.println("add hdfs conf: "+args[i]);
// conf.addResource(new Path(args[i]));
// }
//
// FileSystem fs = FileSystem.get(conf);
// TrainingParams params=loadParams(configFilename);
// Template template=new Template(params.getTemplates().toArray(new
// String[0]));
// TagConvertor tc=new BESB1B2MTagConvertor();
// TrainingWeights weights=new
// TrainingWeights(template,FeatureDictEnum.TROVE_HASHMAP);
// SgdCrf.genFeatureDictAndInstances(fs, trainFilename, 0, template,
// params.getIterationNum(), params, tc, charset, weights,
// FeatureDictEnum.TROVE_HASHMAP, "./dict","./instances");
//
// }
else if (command.equals("train-hdfs")) {
if (!(args.length >= 5)) {
showUsageAndExit();
}
String trainFilename = args[1];
String modelFilename = args[2];
String configFilename = args[3];
String featureFilename = args[4];
String charset = "UTF8";
if (args.length > 5) {
charset = args[5];
}
Configuration conf = new Configuration();
for (int i = 6; i < args.length; i++) {
System.out.println("add hdfs conf: " + args[i]);
conf.addResource(new Path(args[i]));
}
FileSystem fs = FileSystem.get(conf);
TrainingParams params = loadParams(configFilename);
Template template = new Template(params.getTemplates().toArray(new String[0]));
TagConvertor tc = new BESB1B2MTagConvertor();
TrainingWeights weights = new TrainingWeights(template);
// SgdCrf.train(fs, trainFilename, 0, template,
// params.getIterationNum(), params, tc, charset, weights,
// FeatureDictEnum.TROVE_HASHMAP, "./dict","./instances");
FSTObjectInput foi = null;
FeatureDict dict = null;
System.out.println("load featuredict from: " + featureFilename);
try {
foi = new FSTObjectInput(new FileInputStream(featureFilename));
dict = (FeatureDict) foi.readObject();
} finally {
if (foi != null) {
foi.close();
}
}
SgdCrf.train(fs, trainFilename, template, params.getIterationNum(), params, tc, charset,
weights, dict);
}
else if (command.equals("test")) {
if (args.length != 3 && args.length != 4) {
showUsageAndExit();
}
String testFilename = args[1];
String modelFilename = args[2];
String charset = "UTF8";
if (args.length > 3) {
charset = args[3];
}
CrfModel model = SgdCrf.loadModel(modelFilename);
List<Instance> instances = SgdCrf.readTestData(testFilename, charset, model.weights);
EvaluationResult er = SgdCrf.evaluate(instances, model.weights);
System.out.println(er);
} else if (command.equals("test2")) {
if (args.length != 3 && args.length != 4) {
showUsageAndExit();
}
String testFilename = args[1];
String modelFilename = args[2];
String charset = "UTF8";
if (args.length > 3) {
charset = args[3];
}
CrfModel model = SgdCrf.loadModel(modelFilename);
// List<Instance> instances=SgdCrf.readTestData2(testFilename,
// charset, model.weights, new BESB1B2MTagConvertor());
// EvaluationResult er=SgdCrf.evaluate(instances, model.weights);
EvaluationResult er = SgdCrf.readAndEvaluate(testFilename, charset, model.weights,
new BESB1B2MTagConvertor());
System.out.println(er);
} else if (command.equals("seg")) {
if (args.length != 3 && args.length != 2 && args.length != 4) {
showUsageAndExit();
}
String modelFilename = args[1];
String charset = "";
int nBest = 1;
if (args.length > 2) {
nBest = Integer.valueOf(args[2]);
}
if (args.length > 3) {
charset = args[3];
}
CrfModel model = SgdCrf.loadModel(modelFilename);
BufferedReader br;
BufferedWriter bw;
if (charset.equals("")) {
br = new BufferedReader(new InputStreamReader(System.in));
bw = new BufferedWriter(new OutputStreamWriter(System.out));
} else {
br = new BufferedReader(new InputStreamReader(System.in, charset));
bw = new BufferedWriter(new OutputStreamWriter(System.out, charset));
}
String line;
bw.write("Enter Chinese sentences to be segment, enter quit to exit!\n");
bw.flush();
TagConvertor tc = new BESB1B2MTagConvertor();
while ((line = br.readLine()) != null) {
if (line.trim().equals("quit")) {
break;
}
if (line.trim().equals("")) {
continue;
}
bw.write("Input: " + line + "\n");
if (nBest < 2) {
List<String> result = SgdCrf.segment(line, model, tc);
boolean isFirst = true;
for (String word : result) {
if (isFirst) {
isFirst = false;
} else {
bw.write("\t");
}
bw.write(word);
}
bw.write("\n");
bw.write("Enter Chinese sentences to be segment, enter quit to exit!\n");
bw.flush();
} else {
List<String[]> result = SgdCrf.segment(line, model, tc, nBest);
for (String[] tks : result) {
boolean isFirst = true;
for (String word : tks) {
if (isFirst) {
isFirst = false;
} else {
bw.write("\t");
}
bw.write(word);
}
bw.write("\n");
}
bw.write("Enter Chinese sentences to be segment, enter quit to exit!\n");
bw.flush();
}
}
br.close();
bw.close();
} else {
System.err.println("unknown command: " + command);
showUsageAndExit();
}
}
}