package experimental.analyzer.simple;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import marmot.util.Converter;
import marmot.util.Encoder;
import marmot.util.FeatUtil;
import marmot.util.FeatureTable;
import marmot.util.SymbolTable;
import experimental.analyzer.AnalyzerInstance;
import experimental.analyzer.AnalyzerReading;
import experimental.analyzer.AnalyzerTag;
public class SimpleAnalyzerModel implements Serializable {
private static final long serialVersionUID = 1L;
private SymbolTable<AnalyzerTag> tag_table_;
private SymbolTable<String> pos_table_;
private SymbolTable<String> morph_table_;
private SymbolTable<Character> char_table_;
private Set<String> vocab_;
private List<List<Integer>> tag_to_sub_;
private static enum Features {
affix_feature, signature_feature, dict_feature
}
private static final int feature_bits_ = Encoder.bitsNeeded(Features
.values().length - 1);
private int char_bits_;
private int sig_bits_;
private Encoder encoder_;
private FeatureTable feature_table_;
transient private Context context_;
private FloatDict dict_;
private boolean special_signature_ = true;
private int max_affix_length_ = 10;
private boolean use_hash_table_ = false;
private static class Context {
public List<Integer> list;
public boolean insert;
public Context() {
list = new ArrayList<>();
}
}
private double[] weights_;
private long feat_length_;
private int dict_bits_;
public void init(Collection<SimpleAnalyzerInstance> instances, String dictfile) {
tag_table_ = new SymbolTable<>(true);
pos_table_ = new SymbolTable<>();
morph_table_ = new SymbolTable<>();
tag_to_sub_ = new ArrayList<>();
char_table_ = new SymbolTable<>();
vocab_ = new HashSet<>();
Logger logger = Logger.getLogger(getClass().getName());
if (dictfile != null && !dictfile.isEmpty()) {
dict_ = FloatDict.fromFile(dictfile);
dict_bits_ = Encoder.bitsNeeded(dict_.getDimension());
logger.info(String.format(
"read dict with dimension %d and %d entries.",
dict_.getDimension(), dict_.numEntries()));
}
for (SimpleAnalyzerInstance instance : instances) {
init(instance, true);
}
// logger.info(String.format("tags: %d %s", tag_table_.size(),
// tag_table_));
// logger.info(String.format("pos tags: %d %s", pos_table_.size(),
// pos_table_));
// logger.info(String.format("morph tags: %d %s", morph_table_.size(),
// morph_table_));
sig_bits_ = Encoder.bitsNeeded(FeatUtil
.getMaxSignature(special_signature_));
char_bits_ = Encoder.bitsNeeded(char_table_.size());
encoder_ = new Encoder(6);
feature_table_ = FeatureTable.StaticMethods.create(use_hash_table_);
for (SimpleAnalyzerInstance instance : instances) {
add_features(instance, true);
}
feat_length_ = feature_table_.size();
weights_ = new double[10_000_000];
}
private void add_features(SimpleAnalyzerInstance instance, boolean insert) {
if (context_ == null) {
context_ = new Context();
}
context_.insert = insert;
context_.list.clear();
addAffixIndexes(instance.getFormChars());
encoder_.append(Features.signature_feature.ordinal(), feature_bits_);
encoder_.append(instance.getSignature(), sig_bits_);
addFeature(true);
instance.setFeatureIndexes(Converter.toIntArray(context_.list));
context_.list.clear();
FloatDict.Vector vector = instance.getVector();
if (vector != null) {
for (int index : vector.getIndexes()) {
encoder_.append(Features.dict_feature.ordinal(), feature_bits_);
encoder_.append(index, dict_bits_);
addFeature(true);
}
instance.setFloatFeatIndexes(Converter.toIntArray(context_.list));
}
}
private void init(SimpleAnalyzerInstance instance, boolean insert) {
List<Integer> tag_indexes = new LinkedList<>();
if (insert) {
vocab_.add(instance.getInstance().getForm());
}
for (AnalyzerTag tag : instance.getTags()) {
int tag_index = tag_table_.toIndex(tag, -1, insert);
if (tag_index >= 0) {
tag_indexes.add(tag_index);
if (tag_to_sub_.size() <= tag_index && insert) {
assert insert;
assert tag_to_sub_.size() == tag_index;
int pos_index = pos_table_.toIndex(tag.getPosTag(), true);
int morph_index = morph_table_.toIndex(tag.getMorphTag(),
true);
tag_to_sub_.add(Arrays.asList(pos_index, morph_index));
}
}
}
instance.setTagIndexes(tag_indexes);
AnalyzerInstance a_instance = instance.getInstance();
String form = a_instance.getForm();
int signature = FeatUtil.getSignature(form, special_signature_);
instance.setSignature(signature);
String lower_form = form.toLowerCase();
short[] form_chars = FeatUtil.getCharIndexes(lower_form, char_table_,
insert);
instance.setFormChars(form_chars);
if (dict_ != null) {
instance.setVector(dict_.getVector(form));
}
}
private void addPrefixFeatures(short[] chars) {
encoder_.append(false);
for (int i = 0; i < Math.min(chars.length, max_affix_length_); i++) {
int c = chars[i];
if (c < 0)
return;
encoder_.append(c, char_bits_);
addFeature(false);
}
}
private void addSuffixFeatures(short[] chars) {
encoder_.append(true);
for (int i = chars.length - 1; i >= Math.max(0, chars.length
- max_affix_length_); i--) {
int c = chars[i];
if (c < 0)
return;
encoder_.append(c, char_bits_);
addFeature(false);
}
}
private void addAffixIndexes(short[] lemma_chars) {
encoder_.append(Features.affix_feature.ordinal(), feature_bits_);
addPrefixFeatures(lemma_chars);
encoder_.reset();
encoder_.append(Features.affix_feature.ordinal(), feature_bits_);
addSuffixFeatures(lemma_chars);
encoder_.reset();
}
private void addFeature(boolean reset) {
int index = feature_table_.getFeatureIndex(encoder_, context_.insert);
if (index >= 0) {
context_.list.add(index);
}
if (reset)
encoder_.reset();
}
public double[] getWeights() {
return weights_;
}
public void setWeights(double[] weights) {
weights_ = weights;
}
public int getNumTags() {
return tag_table_.size();
}
private void updateScore(SimpleAnalyzerInstance instance, double[] scores,
double[] updates) {
for (int tag_index = 0; tag_index < getNumTags(); tag_index++) {
for (int feat_index : instance.getFeatIndexes()) {
updateScores(feat_index, scores, updates, tag_index, 1.0);
}
int[] float_feat_indexes = instance.getFloatFeatIndexes();
if (float_feat_indexes != null) {
double[] values = instance.getFloatValues();
for (int i = 0; i < float_feat_indexes.length; i++) {
double value = values[i];
updateScores(float_feat_indexes[i], scores, updates,
tag_index, value);
}
}
}
}
private void updateScores(int feat_index, double[] scores,
double[] updates, int tag_index, double value) {
long index;
List<Integer> sub_indexes = tag_to_sub_.get(tag_index);
long pos_index = sub_indexes.get(0);
index = (long) feat_index + feat_length_ * pos_index;
updateScore(index, scores, updates, tag_index, value);
long morph_index = sub_indexes.get(1);
index = (long) feat_index + feat_length_
* (morph_index + (long) pos_table_.size());
updateScore(index, scores, updates, tag_index, value);
index = (long) feat_index
+ feat_length_
* ((long) tag_index + (long) pos_table_.size() + (long) morph_table_
.size());
updateScore(index, scores, updates, tag_index, value);
}
private void updateScore(long index, double[] scores, double[] updates,
int tag_index, double value) {
int int_index = (int) (index % (long) weights_.length);
if (updates != null) {
weights_[int_index] += updates[tag_index] * value;
}
if (scores != null) {
scores[tag_index] += weights_[int_index] * value;
}
}
public void score(SimpleAnalyzerInstance instance, double[] scores) {
Arrays.fill(scores, 0.0);
updateScore(instance, scores, null);
}
public void update(SimpleAnalyzerInstance instance, double[] updates) {
updateScore(instance, null, updates);
}
public SimpleAnalyzerInstance getInstance(AnalyzerInstance instance) {
Collection<AnalyzerTag> tags = AnalyzerReading.toTags(instance
.getReadings());
SimpleAnalyzerInstance simple_instance = new SimpleAnalyzerInstance(
instance, tags);
init(simple_instance, false);
add_features(simple_instance, false);
return simple_instance;
}
public SymbolTable<AnalyzerTag> getTagTable() {
return tag_table_;
}
public boolean isUnknown(AnalyzerInstance instance) {
return !vocab_.contains(instance.getForm());
}
}