// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.morph;
import java.util.Arrays;
import java.util.Collection;
import lemming.lemma.ranker.RankerModel;
import marmot.core.ArrayFloatFeatureVector;
import marmot.core.FeatureVector;
import marmot.core.FloatFeatureVector;
import marmot.core.FloatWeights;
import marmot.core.Model;
import marmot.core.Sequence;
import marmot.core.State;
import marmot.core.WeightVector;
import marmot.util.Encoder;
import marmot.util.FeatureTable;
public class MorphWeightVector implements WeightVector, FloatWeights {
private static final long serialVersionUID = 1L;
private int max_affix_length_;
private int num_state_features_;
private static final int ENCODER_CAPACITY_ = 10;
private boolean use_hash_vector;
private transient Encoder encoder_;
private double accumulated_penalty_;
private transient double[] accumulated_penalties_;
private transient double[] accumulated_float_penalties_;
private double[] weights_;
private double[] float_weights_;
private boolean extend_feature_set_;
private MorphModel model_;
private FeatureTable feature_table_;
private int simple_sub_morph_start_index_;
private int signature_bits_;
private int word_bits_;
private int[] tag_bits_;
private int state_feature_bits_;
private int char_bits_;
private int shape_bits_;
private int order_bits_;
private int[] num_tags_;
private int total_num_tags_;
private int level_bits_;
private int max_level_;
private double scale_factor_;
private boolean shape_;
private int initial_vector_size_;
private int token_feature_bits_;
private int max_transition_feature_level_;
private MorphDictionary mdict_;
private FloatHashDictionary fdict_;
private int mdict_bits_;
private boolean use_state_features_;
private boolean use_form_feature_;
private boolean use_rare_feature_;
private boolean use_lexical_context_feature_;
private boolean use_affix_features_;
private boolean use_signature_features_;
private boolean use_infix_features_;
private boolean use_bigrams_;
private boolean use_hash_feature_table_;
public MorphWeightVector(MorphOptions options) {
shape_ = options.getShape();
max_transition_feature_level_ = options.getMaxTransitionFeatureLevel();
initial_vector_size_ = options.getInitialVectorSize();
use_state_features_ = options.getUseDefaultFeatures();
use_hash_vector = options.getUseHashVector();
max_affix_length_ = options.getMaxAffixLength();
use_hash_feature_table_ = options.getUseHashFeatureTable();
use_form_feature_ = true;
use_rare_feature_ = true;
use_lexical_context_feature_ = true;
use_affix_features_ = true;
use_signature_features_ = true;
use_infix_features_ = false;
use_bigrams_ = true;
String feature_template = options.getFeatureTemplates();
if (feature_template != null) {
use_form_feature_ = false;
use_rare_feature_ = false;
use_lexical_context_feature_ = false;
use_affix_features_ = false;
use_signature_features_ = false;
use_bigrams_ = false;
for (String feat : feature_template.toLowerCase().split(",")) {
if (feat.equals("form")) {
use_form_feature_ = true;
} else if (feat.equals("affix")) {
use_affix_features_ = true;
} else if (feat.equals("rare")) {
use_rare_feature_ = true;
} else if (feat.equals("context")) {
use_lexical_context_feature_ = true;
} else if (feat.equals("sig")) {
use_signature_features_ = true;
} else if (feat.equals("bigrams")) {
use_bigrams_ = true;
} else if (feat.equals("infix")) {
use_infix_features_ = true;
} else {
throw new RuntimeException("Unknown value: " + feat);
}
}
}
if (!use_state_features_) {
use_form_feature_ = false;
use_rare_feature_ = false;
use_lexical_context_feature_ = false;
use_affix_features_ = false;
use_signature_features_ = false;
use_bigrams_ = false;
}
if (!options.getMorphDict().isEmpty()) {
mdict_ = MorphDictionary.create(options.getMorphDict());
mdict_bits_ = Encoder.bitsNeeded(mdict_.numTags());
}
num_state_features_ = 0;
if (use_form_feature_) {
num_state_features_ += 1;
}
if (use_rare_feature_) {
num_state_features_ += 1;
}
if (use_affix_features_) {
num_state_features_ += 2;
}
if (use_lexical_context_feature_) {
num_state_features_ += 2;
}
if (use_signature_features_) {
num_state_features_ += 1;
}
if (shape_) {
num_state_features_ += 3;
}
if (mdict_ != null) {
num_state_features_ += 1;
}
if (use_infix_features_) {
num_state_features_ += 1;
}
// For the token features
num_state_features_ += 1;
// num_state_features_ = 3 + 3;
if (!options.getFloatTypeDict().isEmpty()) {
fdict_ = new FloatHashDictionary();
MorphDictionaryOptions opt = MorphDictionaryOptions.parse(
options.getFloatTypeDict(), false);
if (opt.getIndexes() == null) {
int[] indexes = { 0 };
opt.setIndexes(indexes);
}
fdict_.init(opt);
}
}
@Override
public void setExtendFeatureSet(boolean flag) {
extend_feature_set_ = flag;
}
@Override
public void setPenalty(boolean penalize, double accumulated_penalty) {
if (!penalize) {
accumulated_penalties_ = null;
accumulated_float_penalties_ = null;
accumulated_penalty_ = 0.0;
} else {
accumulated_penalty_ = (double) (accumulated_penalty / scale_factor_);
if (accumulated_penalties_ == null) {
accumulated_penalties_ = new double[weights_.length];
}
if (accumulated_float_penalties_ == null && float_weights_ != null) {
accumulated_float_penalties_ = new double[float_weights_.length];
}
}
RankerModel model = model_.getLemmaModel();
if (model != null) {
model.setPenalty(penalize, accumulated_penalty);
}
}
@Override
public FeatureVector extractStateFeatures(State state) {
prepareEncoder();
MorphFeatureVector new_vector = new MorphFeatureVector(
1 + state.getLevel(), state.getVector());
int fc = 0;
encoder_.append(0, order_bits_);
encoder_.append(state.getLevel() + 1, level_bits_);
encoder_.append(fc, 2);
State run = state.getZeroOrderState();
while (run != null) {
encoder_.append(run.getLevel(), level_bits_);
encoder_.append(run.getIndex(), tag_bits_[run.getLevel()]);
addFeature(new_vector);
run = run.getSubLevelState();
}
encoder_.reset();
fc++;
new_vector.setIsState(true);
new_vector.setWordIndex(((MorphFeatureVector) state.getVector())
.getWordIndex());
return new_vector;
}
@Override
public FeatureVector extractStateFeatures(Sequence sequence, int token_index) {
prepareEncoder();
Word word = (Word) sequence.get(token_index);
int[] mdict_indexes = null;
if (mdict_ != null) {
mdict_indexes = mdict_.getIndexes(word.getWordForm());
}
short[] chars = word.getCharIndexes();
assert chars != null;
int form_index = word.getWordFormIndex();
boolean is_rare = model_.isRare(form_index);
MorphFeatureVector features = new MorphFeatureVector(20);
int fc = 0;
if (use_state_features_) {
if (use_form_feature_) {
if (form_index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(form_index, word_bits_);
addFeature(features);
encoder_.reset();
}
fc++;
}
if (use_rare_feature_) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(is_rare);
addFeature(features);
encoder_.reset();
fc++;
}
if (shape_) {
int shape_index = -1;
shape_index = word.getWordShapeIndex();
if (is_rare && shape_index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(shape_index, shape_bits_);
addFeature(features);
encoder_.reset();
}
fc++;
}
if (token_index - 1 >= 0) {
int pform_index = ((Word) sequence.get(token_index - 1))
.getWordFormIndex();
if (use_lexical_context_feature_) {
if (pform_index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(pform_index, word_bits_);
addFeature(features);
if (form_index >= 0 && use_bigrams_) {
encoder_.append(form_index, word_bits_);
addFeature(features);
}
encoder_.reset();
}
}
int pshape_index = -1;
if (shape_) {
pshape_index = ((Word) sequence.get(token_index - 1))
.getWordShapeIndex();
}
if (pshape_index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc + 1, state_feature_bits_);
encoder_.append(pshape_index, shape_bits_);
if (model_.isRare(pform_index)) {
addFeature(features);
}
encoder_.reset();
}
}
if (use_lexical_context_feature_) {
fc++;
}
if (shape_) {
fc++;
}
if (token_index + 1 < sequence.size()) {
int nform_index = ((Word) sequence.get(token_index + 1))
.getWordFormIndex();
if (use_lexical_context_feature_) {
if (nform_index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(nform_index, word_bits_);
addFeature(features);
if (form_index >= 0 && use_bigrams_) {
encoder_.append(form_index, word_bits_);
addFeature(features);
}
encoder_.reset();
}
}
int nshape_index = -1;
if (shape_) {
nshape_index = ((Word) sequence.get(token_index + 1))
.getWordShapeIndex();
}
if (nshape_index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc + 1, state_feature_bits_);
encoder_.append(nshape_index, shape_bits_);
if (model_.isRare(nform_index)) {
addFeature(features);
}
encoder_.reset();
}
}
if (use_lexical_context_feature_) {
fc++;
}
if (shape_) {
fc++;
}
if (use_signature_features_) {
if (is_rare) {
int signature = word.getWordSignature();
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(signature, signature_bits_);
addFeature(features);
encoder_.reset();
}
fc++;
}
// Infix feature
if (use_infix_features_) {
if (is_rare) {
assert chars != null;
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
for (int position = 0; position < chars.length; position++) {
for (int length = 0; length < max_affix_length_; length++) {
int end_position = position + length;
if (end_position >= chars.length) {
break;
}
short c = chars[end_position];
if (c < 0) {
break;
}
encoder_.append(c, char_bits_);
addFeature(features);
}
encoder_.reset();
}
}
fc++;
}
// Prefix feature
if (use_affix_features_) {
if (is_rare) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
for (int position = 0; position < Math.min(chars.length,
max_affix_length_); position++) {
assert chars != null;
short c = chars[position];
if (c < 0) {
// Unknown character!
break;
}
encoder_.append(c, char_bits_);
addFeature(features);
}
encoder_.reset();
}
fc++;
}
// Suffix feature
if (use_affix_features_) {
if (is_rare) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
for (int position = 0; position < Math.min(chars.length,
max_affix_length_); position++) {
short c = chars[chars.length - position - 1];
if (c < 0) {
// Unknown character!
break;
}
encoder_.append(c, char_bits_);
addFeature(features);
}
encoder_.reset();
}
fc++;
}
}
int[] token_feature_indexes = word.getTokenFeatureIndexes();
if (token_feature_indexes != null) {
for (int token_feature_index : token_feature_indexes) {
if (token_feature_index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(token_feature_index, token_feature_bits_);
addFeature(features);
encoder_.reset();
}
}
fc++;
}
if (fdict_ != null) {
FloatFeatureVector vector = extractFloatFeatures(sequence,
token_index);
features.setFloatVector(vector);
} else {
token_feature_indexes = word.getWeightedTokenFeatureIndexes();
if (token_feature_indexes != null) {
features.setFloatVector(new ArrayFloatFeatureVector(
token_feature_indexes, word
.getWeightedTokenFeatureWeights(), model_
.getWeightedTokenFeatureTable().size()));
}
}
if (mdict_ != null) {
if (mdict_indexes != null) {
for (int index : mdict_indexes) {
if (index >= 0) {
encoder_.append(0, order_bits_);
encoder_.append(0, level_bits_);
encoder_.append(fc, state_feature_bits_);
encoder_.append(index, mdict_bits_);
addFeature(features);
encoder_.reset();
}
}
}
fc++;
}
features.setIsState(true);
features.setWordIndex(form_index);
assert fc == num_state_features_ || fc + 1 == num_state_features_ : String
.format("%d != %d", fc, num_state_features_);
return features;
}
private void addFeature(FeatureVector features) {
int index = feature_table_.getFeatureIndex(encoder_,
extend_feature_set_);
if (index >= 0)
features.add(index);
}
private FloatFeatureVector extractFloatFeatures(Sequence sentence,
int token_index) {
FloatFeatureVector current_vector = null;
if (token_index >= 0 && token_index < sentence.size()) {
String form = ((Word) sentence.get(token_index)).getWordForm();
current_vector = fdict_.getVector(form);
}
return current_vector;
}
@Override
public FeatureVector extractTransitionFeatures(State state) {
prepareEncoder();
int max_level = state.getLevel();
int order = state.getOrder();
FeatureVector features = new FeatureVector(max_level + 1
+ model_.getNumSubTags());
for (int depth = 0; depth <= max_level; depth++) {
int level = max_level - depth;
if (max_transition_feature_level_ >= 0
&& level > max_transition_feature_level_) {
continue;
}
encoder_.append(order, order_bits_);
encoder_.append(level, level_bits_);
encoder_.append(0, 1);
State run = state;
while (run != null) {
State sub_state = run.getSubLevel(depth);
int index = sub_state.getIndex();
encoder_.append(index, tag_bits_[level]);
run = run.getPreviousSubOrderState();
}
addFeature(features);
encoder_.reset();
}
return features;
}
protected double getWeight(int index) {
return weights_[index];
}
@Override
public double dotProduct(State state, FeatureVector vector) {
assert vector != null;
State zero_order_state = state.getZeroOrderState();
int tag_index = getUniversalIndex(zero_order_state);
double score = 0.0;
for (int findex = 0; findex < vector.size(); findex++) {
int feature = vector.get(findex);
int index = getIndex(feature, tag_index);
score += getWeight(index);
}
FloatFeatureVector float_vector = vector.getFloatVector();
if (float_vector != null) {
score += float_vector.getDotProduct(this, tag_index, 0);
}
score += dotProductSubTags(zero_order_state, vector);
if (vector.getIsState()) {
int index = getObservedIndex((MorphFeatureVector) vector, state);
if (index >= 0) {
score += getWeight(index);
}
}
return score * scale_factor_;
}
public double getFloatWeight(int index) {
return float_weights_[index];
}
public int getFloatIndex(int feature, int tag_index) {
return feature * total_num_tags_ + tag_index;
}
private double dotProductSubTags(State state, FeatureVector vector) {
int level = state.getLevel();
if (level >= model_.getTagToSubTags().length) {
return 0.0;
}
int[][] tag_to_subtag = model_.getTagToSubTags()[level];
if (tag_to_subtag == null) {
return 0.0;
}
int[] indexes = tag_to_subtag[state.getIndex()];
if (indexes == null) {
return 0.0;
}
double score = 0.0;
for (int index : indexes) {
int simple_index = getSimpleSubMorphIndex(index);
for (int findex = 0; findex < vector.size(); findex++) {
int feature = vector.get(findex);
int f_index = getIndex(feature, simple_index);
score += getWeight(f_index);
}
FloatFeatureVector float_vector = vector.getFloatVector();
if (float_vector != null) {
score += float_vector.getDotProduct(this, simple_index, 0);
}
}
return score;
}
private int getProductIndex(State state) {
if (state.getLevel() == 0) {
return state.getIndex();
}
int size = num_tags_[state.getLevel()];
return getProductIndex(state.getSubLevelState()) * size
+ state.getIndex();
}
private int getUniversalIndex(int tag_index, int level) {
for (int clevel = 0; clevel < level; clevel++) {
tag_index += num_tags_[clevel];
}
return tag_index;
}
private int getUniversalIndex(State zero_order_state) {
return getUniversalIndex(zero_order_state.getIndex(),
zero_order_state.getLevel());
}
private int getIndex(int feature, int tag_index) {
int index = feature * total_num_tags_ + tag_index;
int capacity = weights_.length - 2 * max_level_;
int h = index;
if (use_hash_vector) {
h ^= (h >>> 20) ^ (h >>> 12);
h = h ^ (h >>> 7) ^ (h >>> 4);
h = h & (capacity - 1);
} else {
if (index >= capacity) {
int old_capacity = capacity;
capacity = (3 * (index + 1)) / 2;
int length = capacity + 2 * max_level_;
weights_ = Arrays.copyOf(weights_, length);
if (accumulated_penalties_ != null) {
accumulated_penalties_ = Arrays.copyOf(
accumulated_penalties_, length);
}
for (int i = 0; i < 2 * max_level_; i++) {
weights_[capacity + i] = weights_[old_capacity + i];
weights_[old_capacity + i] = 0.0;
if (accumulated_penalties_ != null) {
accumulated_penalties_[capacity + i] = accumulated_penalties_[old_capacity
+ i];
accumulated_penalties_[old_capacity + i] = 0.0;
}
}
}
}
assert h >= 0 : String.format("H: %d", h);
assert h < capacity : String.format("H: %d Capacity: %d", h, capacity);
return h;
}
@Override
public void init(Model model, Collection<Sequence> sequences) {
int max_level = model.getTagTables().size();
feature_table_ = FeatureTable.StaticMethods
.create(use_hash_feature_table_);
model_ = (MorphModel) model;
max_level_ = max_level;
num_tags_ = new int[max_level];
total_num_tags_ = 0;
tag_bits_ = new int[max_level];
for (int level = 0; level < Math.min(model.getTagTables().size(),
max_level); level++) {
num_tags_[level] = model.getTagTables().get(level).size();
tag_bits_[level] = Encoder.bitsNeeded(num_tags_[level]);
total_num_tags_ += num_tags_[level];
}
simple_sub_morph_start_index_ = total_num_tags_;
total_num_tags_ += model_.getNumSubTags();
word_bits_ = Encoder.bitsNeeded(model_.getWordTable().size());
state_feature_bits_ = Encoder.bitsNeeded(num_state_features_);
char_bits_ = Encoder.bitsNeeded(model_.getCharTable().size());
if (shape_)
shape_bits_ = Encoder.bitsNeeded(model_.getNumShapes());
order_bits_ = Encoder.bitsNeeded(model.getOrder());
level_bits_ = Encoder.bitsNeeded(max_level);
signature_bits_ = Encoder.bitsNeeded(model_.getMaxSignature());
token_feature_bits_ = Encoder.bitsNeeded(model_.getTokenFeatureTable()
.size());
if (fdict_ != null) {
float_weights_ = new double[fdict_.getDimension() * total_num_tags_];
} else {
float_weights_ = new double[model_.getWeightedTokenFeatureTable()
.size() * total_num_tags_];
}
extend_feature_set_ = true;
scale_factor_ = 1.;
int capacity = 1;
int initial_size = initial_vector_size_;
while (capacity < initial_size)
capacity <<= 1;
weights_ = new double[capacity + 2 * max_level];
}
private void update(State state, double value) {
FeatureVector vector = state.getVector();
if (vector != null) {
State run = state.getZeroOrderState();
while (run != null) {
int tag_index = getUniversalIndex(run);
for (int findex = 0; findex < vector.size(); findex++) {
int feature = vector.get(findex);
int index = getIndex(feature, tag_index);
updateWeight(index, value);
}
FloatFeatureVector float_vector = vector.getFloatVector();
if (float_vector != null) {
float_vector.updateFloatWeight(this, tag_index, 0, value);
}
updateSubTags(run, vector, value);
if (state.getOrder() == 1) {
run = null;
State sub_level_state = state.getSubLevelState();
if (sub_level_state != null)
update(sub_level_state, value);
if (vector.getIsState()) {
int index = getObservedIndex(
(MorphFeatureVector) vector, state);
if (index >= 0) {
updateWeight(index, value);
}
}
} else {
run = run.getSubLevelState();
}
}
}
}
private void updateSubTags(State state, FeatureVector vector, double value) {
int level = state.getLevel();
if (level >= model_.getTagToSubTags().length) {
return;
}
int[][] tag_to_subtag = model_.getTagToSubTags()[level];
if (tag_to_subtag == null) {
return;
}
int[] indexes = tag_to_subtag[state.getIndex()];
if (indexes == null) {
return;
}
for (int index : indexes) {
int simple_index = getSimpleSubMorphIndex(index);
for (int findex = 0; findex < vector.size(); findex++) {
int feature = vector.get(findex);
int f_index = getIndex(feature, simple_index);
updateWeight(f_index, value);
}
FloatFeatureVector float_vector = vector.getFloatVector();
if (float_vector != null) {
float_vector.updateFloatWeight(this, simple_index, 0, value);
}
}
}
protected int getSimpleSubMorphIndex(int sub_morph_index) {
return simple_sub_morph_start_index_ + sub_morph_index;
}
protected void updateWeight(int index, double value) {
weights_[index] += value;
if (accumulated_penalties_ != null) {
weights_[index] = applyPenalty(index, weights_[index],
accumulated_penalties_);
}
}
public void updateFloatWeight(int index, double value) {
float_weights_[index] += value;
if (accumulated_penalties_ != null) {
float_weights_[index] = applyPenalty(index, float_weights_[index],
accumulated_float_penalties_);
}
}
protected int getObservedIndex(MorphFeatureVector vector, State state) {
int word_index = vector.getWordIndex();
int level = state.getLevel();
int product_index = getProductIndex(state);
int feature = model_.hasBeenObserved(word_index, level, product_index) ? 0
: 1;
int start_index = weights_.length - max_level_ * 2;
int index = start_index + level * 2 + feature;
return index;
}
protected double applyPenalty(int index, double weight,
double[] accumulated_penalty) {
double z = weight;
if (z - 1e-10 > 0.) {
weight = Math.max(0, z
- (accumulated_penalty_ + accumulated_penalty[index]));
} else if (z + 1e-10 < 0.) {
weight = Math.min(0, z
+ (accumulated_penalty_ - accumulated_penalty[index]));
}
accumulated_penalty[index] += weight - z;
return weight;
}
protected void prepareEncoder() {
if (encoder_ == null) {
encoder_ = new Encoder(ENCODER_CAPACITY_);
}
encoder_.reset();
}
@Override
public void updateWeights(State state, double value, boolean is_transition) {
value /= scale_factor_;
update(state, value);
if (!is_transition) {
while ((state = state.getSubOrderState()) != null) {
update(state, value);
}
}
}
@Override
public void scaleBy(double scale_factor) {
accumulated_penalty_ /= scale_factor;
scale_factor_ *= scale_factor;
}
@Override
public double[] getWeights() {
return weights_;
}
@Override
public void setWeights(double[] weights) {
weights_ = weights;
}
public MorphDictionary getMorphDict() {
return mdict_;
}
@Override
public double[] getFloatWeights() {
return float_weights_;
}
@Override
public void setFloatWeights(double[] weights) {
float_weights_ = weights;
}
public MorphModel getModel() {
return model_;
}
public FeatureTable getFeatureTable() {
return feature_table_;
}
}