package experimental.analyzer.simple; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Map; import marmot.util.Numerics; import marmot.util.SymbolTable; import org.javatuples.Pair; import experimental.analyzer.Analyzer; import experimental.analyzer.AnalyzerInstance; import experimental.analyzer.AnalyzerReading; import experimental.analyzer.AnalyzerTag; public class SimpleAnalyzer implements Analyzer { private static final long serialVersionUID = 1L; public static enum Mode {binary, classifier}; private SimpleAnalyzerModel model_; private double threshold_; private Mode mode_; private List<List<Integer>> pairs_; public SimpleAnalyzer(SimpleAnalyzerModel model, double threshold, Mode mode, Collection<Pair<AnalyzerTag, AnalyzerTag>> coupled) { model_ = model; threshold_ = threshold; mode_ = mode; if (coupled != null) { SymbolTable<AnalyzerTag> table = model_.getTagTable(); pairs_ = new ArrayList<>(table.size()); for (int i=0;i<table.size();i++) { pairs_.add(Collections.<Integer> emptyList()); } for (Pair<AnalyzerTag, AnalyzerTag> pair : coupled) { int index = table.toIndex(pair.getValue0()); int other_index = table.toIndex(pair.getValue1()); assert index != other_index; if (index > other_index) { int temp = other_index; other_index = index; index = temp; } List<Integer> list = pairs_.get(other_index); if (list.isEmpty()) { list = new LinkedList<>(); pairs_.set(other_index, list); } list.add(index); } } } @Override public Collection<AnalyzerReading> analyze(AnalyzerInstance instance) { SimpleAnalyzerInstance simple_instance = model_.getInstance(instance); double[] scores = new double[model_.getNumTags()]; model_.score(simple_instance, scores); Collection<AnalyzerReading> readings = new LinkedList<>(); SymbolTable<AnalyzerTag> tag_table = model_.getTagTable(); tag_table.setBidirectional(true); Collection<Map.Entry<AnalyzerTag, Integer>> tags = tag_table.entrySet(); switch (mode_) { case binary: binaryScore(scores); break; case classifier: classifierScore(scores); break; default: throw new RuntimeException("Unknown mode: " + mode_); } boolean[] activated = new boolean[scores.length]; for (int tag_index = 0; tag_index < activated.length; tag_index ++) { activated[tag_index] = scores[tag_index]> threshold_; if (pairs_ != null) { List<Integer> related_tags = pairs_.get(tag_index); for (int other_index : related_tags) { assert other_index < tag_index; if (activated[other_index] != activated[tag_index]) { double diff = Math.abs(threshold_ - scores[tag_index]); double other_diff = Math.abs(threshold_ - scores[other_index]); if (diff > other_diff) { activated[other_index] = activated[tag_index]; } else { activated[tag_index] = activated[other_index]; } System.err.format("%s: %s %s [%s]\n", instance.getForm(), tag_table.toSymbol(tag_index), tag_table.toSymbol(other_index), activated[tag_index]); } } } } AnalyzerTag max_tag = null; double max_prob = Double.NEGATIVE_INFINITY; for (Map.Entry<AnalyzerTag, Integer> entry : tags) { int tag_index = entry.getValue(); if (activated[tag_index]) { AnalyzerTag tag = entry.getKey(); readings.add(new AnalyzerReading(tag, null)); } double prob = scores[tag_index]; if (prob > max_prob) { max_prob = prob; max_tag = entry.getKey(); } } if (readings.isEmpty()) { readings.add(new AnalyzerReading(max_tag, null)); } return readings; } private void classifierScore(double[] scores) { double sum = Double.NEGATIVE_INFINITY; for (int tag_index=0; tag_index < scores.length; tag_index++) { sum = Numerics.sumLogProb(scores[tag_index], sum); } for (int tag_index=0; tag_index < scores.length; tag_index++) { double score = scores[tag_index]; double prob = Math.exp(score - sum); scores[tag_index] = prob; }; } private void binaryScore(double[] scores) { for (int tag_index=0; tag_index < scores.length; tag_index++) { double score = scores[tag_index]; double prob = Math.exp(score - Numerics.sumLogProb(score, 0)); scores[tag_index] = prob; } } @Override public String represent(AnalyzerInstance instance) { SimpleAnalyzerInstance simple_instance = model_.getInstance(instance); return Arrays.toString(simple_instance.getFeatIndexes()); } @Override public int getNumTags() { return model_.getNumTags(); } @Override public boolean isUnknown(AnalyzerInstance instance) { return model_.isUnknown(instance); } public SimpleAnalyzerModel getModel() { return model_; } public Mode getMode() { return mode_; } }