package com.antbrains.crf;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
import gnu.trove.map.hash.TObjectIntHashMap;
public class TrainingWeights implements java.io.Serializable, Writable {
private static final long serialVersionUID = 8928028057374831674L;
public double[] getBosTransitionWeights() {
return bosTransitionWeights;
}
public void setBosTransitionWeights(double[] bosTransitionWeights) {
this.bosTransitionWeights = bosTransitionWeights;
}
public double[] getEosTransitionWeights() {
return eosTransitionWeights;
}
public void setEosTransitionWeights(double[] eosTransitionWeights) {
this.eosTransitionWeights = eosTransitionWeights;
}
public double[] getTransitionWeights() {
return transitionWeights;
}
public void setTransitionWeights(double[] transitionWeights) {
this.transitionWeights = transitionWeights;
}
public double[] getAttributeWeights() {
return attributeWeights;
}
public void setAttributeWeights(double[] attributeWeights) {
this.attributeWeights = attributeWeights;
}
// weights of each label as start state
private double[] bosTransitionWeights;
// weights of each label as end state
private double[] eosTransitionWeights;
// weights from one label to another, to speed up, using 1d arrary to represent 2d array
private double[] transitionWeights;
// weights from label to feature, to speed up, using 1d arrary to represent 2d array
private double[] attributeWeights;
public Template getTemplate() {
return template;
}
public void setTemplate(Template template) {
this.template = template;
}
public static long getSerialversionuid() {
return serialVersionUID;
}
// private TObjectIntHashMap<String> labelDict;
// private TObjectIntHashMap<String> attributeDict;
private TObjectIntHashMap<String> labelDict;
public TObjectIntHashMap<String> getLabelDict() {
return labelDict;
}
public void setLabelDict(TObjectIntHashMap<String> labelDict) {
this.labelDict = labelDict;
}
public FeatureDict getAttributeDict() {
return attributeDict;
}
public void setAttributeDict(FeatureDict attributeDict) {
this.attributeDict = attributeDict;
}
private FeatureDict attributeDict;
private Template template;
private String[] labelTexts;
public String[] getLabelTexts() {
return labelTexts;
}
public void setLabelTexts(String[] labelTexts) {
this.labelTexts = labelTexts;
}
public TrainingWeights(Template template) {
this.template = template;
}
public TrainingWeights(Template template, FeatureDictEnum dictType) {
this.template = template;
//
this.labelDict = new TObjectIntHashMap<String>(10, 0.75f, -1);
// this.attributeDict=new TObjectIntHashMap<String>(10000, 0.75f, -1);
if (dictType == FeatureDictEnum.TROVE_HASHMAP) {
this.attributeDict = new TroveFeatureDict(102400);
} else if (dictType == FeatureDictEnum.DOUBLE_ARRAY_TRIE) {
this.attributeDict = new DATrieFeatureDict();
} else if (dictType == FeatureDictEnum.COMPACT_TROVE_MAP) {
this.attributeDict = new CompactedTroveFeatureDict(102400);
}
}
private void writeDoubleArray(DataOutput out, double[] array) throws IOException {
out.writeInt(array.length);
for (double d : array) {
out.writeDouble(d);
}
}
@Override
public void readFields(DataInput in) throws IOException {
int len = in.readInt();
if (this.attributeWeights == null || this.attributeWeights.length != len) {
this.attributeWeights = new double[len];
}
for (int i = 0; i < len; i++) {
this.attributeWeights[i] = in.readDouble();
}
len = in.readInt();
if (this.bosTransitionWeights == null || this.bosTransitionWeights.length != len) {
this.bosTransitionWeights = new double[len];
}
for (int i = 0; i < len; i++) {
this.bosTransitionWeights[i] = in.readDouble();
}
len = in.readInt();
if (this.eosTransitionWeights == null || this.eosTransitionWeights.length != len) {
this.eosTransitionWeights = new double[len];
}
for (int i = 0; i < len; i++) {
this.eosTransitionWeights[i] = in.readDouble();
}
len = in.readInt();
if (this.transitionWeights == null || this.transitionWeights.length != len) {
this.transitionWeights = new double[len];
}
for (int i = 0; i < len; i++) {
this.transitionWeights[i] = in.readDouble();
}
}
@Override
public void write(DataOutput out) throws IOException {
this.writeDoubleArray(out, attributeWeights);
this.writeDoubleArray(out, bosTransitionWeights);
this.writeDoubleArray(out, eosTransitionWeights);
this.writeDoubleArray(out, transitionWeights);
}
}