package experimental.analyzer.simple; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; import java.util.logging.Level; import java.util.logging.Logger; import marmot.util.Counter; import marmot.util.Mutable; import org.javatuples.Pair; import cc.mallet.optimize.LimitedMemoryBFGS; import cc.mallet.optimize.Optimizable.ByGradientValue; import cc.mallet.optimize.OptimizationException; import cc.mallet.optimize.Optimizer; import experimental.analyzer.Analyzer; import experimental.analyzer.AnalyzerInstance; import experimental.analyzer.AnalyzerReading; import experimental.analyzer.AnalyzerTag; import experimental.analyzer.AnalyzerTrainer; import experimental.analyzer.simple.SimpleAnalyzer.Mode; public class SimpleAnalyzerTrainer extends AnalyzerTrainer { private Mode train_mode_; private Mode tag_mode_; private double penalty_; public final static String MODE = "mode"; public final static String PENALTY = "penalty"; public final static String PAIR_CONSTRAINT = "pair-constraint"; public final static String PAIR_CONSTRAINT_THRESHOLD = "pair-constraint-threshold"; private boolean optimize_threshold_ = false; private boolean mallet_ = false; enum PairConstraint {simple, weighted, none}; private PairConstraint pair_constraint_ = PairConstraint.weighted; private double pair_constraint_threshold_ = 0.1; private Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts_ = null; @Override public Analyzer train(Collection<AnalyzerInstance> instances) { System.err.format("Num instances: %d\n", instances.size()); boolean use_simple_optimizer = false; boolean couple_tags = false; tag_mode_ = Mode.binary; train_mode_ = Mode.binary; if (options_.containsKey(MODE)) { Mode mode = Mode.valueOf(options_.get(MODE)); tag_mode_ = mode; train_mode_ = mode; } if (options_.containsKey(PAIR_CONSTRAINT)) { pair_constraint_ = PairConstraint.valueOf(options_.get(PAIR_CONSTRAINT)); } if (options_.containsKey(PAIR_CONSTRAINT_THRESHOLD)) { pair_constraint_threshold_ = Double.valueOf(options_.get(PAIR_CONSTRAINT_THRESHOLD)); } System.err.format("Modes: %s / %s\n", tag_mode_, train_mode_); penalty_ = 1.0; if (options_.containsKey(PENALTY)) { penalty_ = Double.valueOf(options_.get(PENALTY)); } System.err.format("Penalty: %g\n", penalty_); Collection<Pair<AnalyzerTag, AnalyzerTag>> coupled = null; if (couple_tags) coupled = getCoupledTags(instances); if (pair_constraint_ != PairConstraint.none) { preparePairConstraints(instances); } Collection<SimpleAnalyzerInstance> simple_instances = new LinkedList<>(); for (AnalyzerInstance instance : instances) { Collection<AnalyzerTag> tags = AnalyzerReading.toTags(instance .getReadings()); simple_instances.add(new SimpleAnalyzerInstance(instance, tags)); } SimpleAnalyzerModel model = new SimpleAnalyzerModel(); String float_dict_file = null; if (options_.containsKey(AnalyzerTrainer.FLOAT_DICT_)) { float_dict_file = options_.get(AnalyzerTrainer.FLOAT_DICT_); } model.init(simple_instances, float_dict_file); if (mallet_) { run_mallet(model, simple_instances); } else { run_sgd(model, simple_instances, 10, true, 0.1); } double best_threshold = 0.01; if (optimize_threshold_) { SimpleThresholdOptimizer opt = new SimpleThresholdOptimizer( use_simple_optimizer); best_threshold = opt.findTreshold(model, instances, tag_mode_); System.err.println("Best threshold on train: " + best_threshold); } SimpleAnalyzer analyzer = new SimpleAnalyzer(model, best_threshold, tag_mode_, coupled); return analyzer; } private void preparePairConstraints(Collection<AnalyzerInstance> instances) { TagStats stats = getTagStates(instances); relative_counts_ = new HashMap<>(); for (Map.Entry<Pair<AnalyzerTag, AnalyzerTag>, Double> entry : stats.tag_tag_counts .entrySet()) { Pair<AnalyzerTag, AnalyzerTag> pair = entry.getKey(); Double count = entry.getValue(); addRelativeProb(pair.getValue0(), pair.getValue1(), count, stats.tag_counts.count(pair.getValue0()), relative_counts_); addRelativeProb(pair.getValue1(), pair.getValue0(), count, stats.tag_counts.count(pair.getValue1()), relative_counts_); } for (Map.Entry<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> entry : relative_counts_ .entrySet()) { AnalyzerTag tag = entry.getKey(); Map<AnalyzerTag, Mutable<Double>> map = entry.getValue(); map.put(tag, new Mutable<Double>(1.0)); double sum = 0.0; for (Mutable<Double> count : map.values()) { sum += count.get(); } for (Mutable<Double> count : map.values()) { count.set(count.get() / sum); } System.err.println(tag + " " + map); } } private void addRelativeProb(AnalyzerTag tag, AnalyzerTag other_tag, Double tag_tag_count, Double tag_count, Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts) { double prob = tag_tag_count / tag_count; if (prob > pair_constraint_threshold_) { Map<AnalyzerTag, Mutable<Double>> map = relative_counts.get(tag); if (map == null) { map = new HashMap<>(); relative_counts.put(tag, map); } assert !map.containsKey(other_tag); map.put(other_tag, new Mutable<Double>(prob)); } } private void run_sgd(SimpleAnalyzerModel model, Collection<SimpleAnalyzerInstance> simple_instances, int steps_, boolean verbose_, double step_width_) { List<SimpleAnalyzerInstance> instances = new LinkedList<>( simple_instances); SimpleAnalyzerObjective objective = new SimpleAnalyzerObjective( penalty_, model, simple_instances, train_mode_, relative_counts_, pair_constraint_); int number = 0; Random random = new Random(42); for (int step = 0; step < steps_; step++) { if (verbose_) System.err.println("step: " + step); Collections.shuffle(instances, random); for (SimpleAnalyzerInstance instance : instances) { double step_width = step_width_ / (1 + (number / (double) instances.size())); objective.update(instance, step_width, true); number++; } } } private void run_mallet(SimpleAnalyzerModel model, Collection<SimpleAnalyzerInstance> simple_instances) { Logger logger = Logger.getLogger(getClass().getName()); logger.info("Start optimization"); ByGradientValue objective = new SimpleAnalyzerObjective(penalty_, model, simple_instances, train_mode_, relative_counts_, pair_constraint_); Optimizer optimizer = new LimitedMemoryBFGS(objective); Logger.getLogger(optimizer.getClass().getName()).setLevel(Level.OFF); objective.setParameters(model.getWeights()); // SimpleAnalyzer analyzer = new SimpleAnalyzer(model, 0.5); try { optimizer.optimize(1); // double memory_usage_during_optimization = // Sys.getUsedMemoryInMegaBytes(); // logger.info(String.format("Memory usage after first iteration: %g / %g MB", // memory_usage_during_optimization, // Sys.getMaxHeapSizeInMegaBytes())); for (int i = 0; i < 200 && !optimizer.isConverged(); i++) { optimizer.optimize(1); logger.info(String.format("Iteration: %3d / %3d: %g", i + 1, 200, objective.getValue())); } } catch (IllegalArgumentException e) { } catch (OptimizationException e) { } } private static class TagStats { Counter<AnalyzerTag> tag_counts = new Counter<>(); Counter<Pair<AnalyzerTag, AnalyzerTag>> tag_tag_counts = new Counter<>(); public TagStats() { tag_counts = new Counter<>(); tag_tag_counts = new Counter<>(); } } private TagStats getTagStates(Collection<AnalyzerInstance> instances) { TagStats stats = new TagStats(); for (AnalyzerInstance instance : instances) { Collection<AnalyzerTag> tags = AnalyzerReading.toTags(instance .getReadings()); for (AnalyzerTag tag : tags) { stats.tag_counts.increment(tag, 1.0); } List<AnalyzerTag> tag_list = new ArrayList<>(tags); for (int i = 0; i < tag_list.size(); i++) { AnalyzerTag tag = tag_list.get(i); for (int j = i + 1; j < tag_list.size(); j++) { AnalyzerTag other_tag = tag_list.get(j); if (tag.hashCode() < other_tag.hashCode()) { stats.tag_tag_counts.increment(new Pair<>(other_tag, tag), 1.0); } else { stats.tag_tag_counts.increment(new Pair<>(tag, other_tag), 1.0); } } } } return stats; } private Collection<Pair<AnalyzerTag, AnalyzerTag>> getCoupledTags( Collection<AnalyzerInstance> instances) { TagStats stats = getTagStates(instances); Collection<Pair<AnalyzerTag, AnalyzerTag>> coupled = new LinkedList<>(); for (Map.Entry<Pair<AnalyzerTag, AnalyzerTag>, Double> entry : stats.tag_tag_counts .entrySet()) { Pair<AnalyzerTag, AnalyzerTag> pair = entry.getKey(); double tag_count = stats.tag_counts.count(pair.getValue0()); assert tag_count < instances.size(); double other_tag_count = stats.tag_counts.count(pair.getValue1()); assert other_tag_count < instances.size(); double joint_count = entry.getValue(); assert joint_count < instances.size(); if (entry.getValue() >= 10) { double pseudo_pmi = joint_count / Math.sqrt(tag_count * other_tag_count); if (pseudo_pmi > 0.99) { coupled.add(pair); } } } System.err.println("|Coupled|: " + coupled.size()); System.err.println("Coupled: " + coupled); return coupled; } }