package joshua.discriminative.training.learning_algorithm;
import java.util.HashMap;
import java.util.Iterator;
/*Zhifei Li, <zhifei.work@gmail.com>
* Johns Hopkins University
*/
/*Cleasses extend this should include
* (1) process_one_sent: get the reranked 1-best; get the feature counts
* (2) rerank the hypothesis
* (3) feature extraction from 1-best and oracle
* */
public class DefaultPerceptron extends GradientBasedOptimizer {
HashMap g_tbl_sum_model = null; //key: feat str; val: model paramemter
HashMap g_tbl_avg_model = null;//key: feat str; val: (1) last avg-model paramemter, (2) last iter-id; (3) the last sum-model paramemter
public DefaultPerceptron(HashMap sum_model, HashMap avg_model,int train_size, int batch_update_size, int converge_pass, double init_gain, double sigma, boolean is_minimize_score){
super(train_size, batch_update_size, converge_pass, init_gain, sigma, is_minimize_score);
g_tbl_sum_model = sum_model;
g_tbl_avg_model = avg_model;
if(g_tbl_sum_model==null || g_tbl_avg_model==null){System.out.println("model table is null"); System.exit(0);}
}
public void initModel(double min_value, double max_value){
//TODO do nothing
}
// update tbl_sum_model and tbl_avg_model inside
public void updateModel(HashMap tbl_feats_empirical, HashMap tbl_feats_model){
numModelChanges++;
System.out.println("######## update the perceptron model ############### " + numModelChanges);
HashMap gradient = getGradient(tbl_feats_empirical, tbl_feats_model);
//Support.print_hash_tbl(gradient);
double update_gain = computeGain(numModelChanges);
System.out.println("update gain is " + update_gain + "; gradident table size " + gradient.size());
update_sum_model(g_tbl_sum_model, gradient, update_gain);
update_avg_model(g_tbl_sum_model, g_tbl_avg_model, gradient, numModelChanges);
}
//update tbl_sum_model inside
protected void update_sum_model(HashMap tbl_sum_model, HashMap gradient, double update_gain){
for(Iterator it =gradient.keySet().iterator(); it.hasNext();){
String key = (String)it.next();
Double old_v = (Double)tbl_sum_model.get(key);
if(old_v!=null)
tbl_sum_model.put(key, old_v + update_gain*(Double)gradient.get(key));
else
tbl_sum_model.put(key, update_gain*(Double)gradient.get(key)); //incrementally add feature
}
}
// key: feat str; val: (1) last avg-model paramemter, (2) last iter-id; (3) the last sum-model paramemter
//update tbl_avg_model inside
protected void update_avg_model(HashMap tbl_sum_model, HashMap tbl_avg_model, HashMap feature_set, int cur_iter_id){//feature_set: the features need to be updated
for(Iterator it =feature_set.keySet().iterator(); it.hasNext();){
String key = (String)it.next();
update_avg_model_one_feature(tbl_sum_model, tbl_avg_model, key, cur_iter_id);
}
}
//tbl_sum_model has already been updated
// key: feat str; val: (1) last avg-model paramemter, (2) last iter-id; (3) the last sum-model paramemter
// update tbl_avg_model inside
protected void update_avg_model_one_feature(HashMap tbl_sum_model, HashMap tbl_avg_model, String feat_key, int cur_iter_id){
Double[] old_v = (Double[])tbl_avg_model.get(feat_key);
Double[] new_v = new Double[3];
new_v[1] = new Double(cur_iter_id);//iter id
new_v[2] = (Double)tbl_sum_model.get(feat_key);//sum model para
if(old_v!=null)
new_v[0] = ( old_v[0]*old_v[1] + old_v[2]*(cur_iter_id-old_v[1]-1) + new_v[2] )/cur_iter_id;//avg
else//incrementally add feature
new_v[0] = new_v[2]/cur_iter_id;//avg
tbl_avg_model.put(feat_key, new_v);
}
//force update the whole avg model (for each feature, it will automatically handle case where feature already updated)
public void force_update_avg_model(){
System.out.println("force avg update is called");
update_avg_model(g_tbl_sum_model, g_tbl_avg_model, g_tbl_sum_model, numModelChanges); //update all features
}
public HashMap getAvgModel() {
return g_tbl_avg_model;
}
public HashMap getSumModel() {
return g_tbl_sum_model;
}
public void setFeatureWeight(String feat, double weight) {
g_tbl_sum_model.put(feat, weight);
Double[] vals = new Double[3];
vals[0]=weight;
vals[1]=1.0;//TODO
vals[2]=0.0;//TODO
g_tbl_avg_model.put(feat, vals);
}
}