package edu.stanford.nlp.pipeline;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.util.ArrayMap;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Generics;
import java.util.*;
/**
* Functions for aggregating token attributes.
*
* @author Angel Chang
*/
public abstract class CoreMapAttributeAggregator
{
public static Map<Class, CoreMapAttributeAggregator> getDefaultAggregators()
{
return DEFAULT_AGGREGATORS;
}
public static CoreMapAttributeAggregator getAggregator(String str)
{
return AGGREGATOR_LOOKUP.get(str);
}
public abstract Object aggregate(Class key, List<? extends CoreMap> in);
public static final CoreMapAttributeAggregator FIRST_NON_NIL = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
for (CoreMap cm:in) {
Object obj = cm.get(key);
if (obj != null) {
return obj;
}
}
return null;
}
};
public static final CoreMapAttributeAggregator FIRST = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
for (CoreMap cm:in) {
Object obj = cm.get(key);
return obj;
}
return null;
}
};
public static final CoreMapAttributeAggregator LAST_NON_NIL = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
for (int i = in.size()-1; i >= 0; i--) {
CoreMap cm = in.get(i);
Object obj = cm.get(key);
if (obj != null) {
return obj;
}
}
return null;
}
};
public static final CoreMapAttributeAggregator LAST = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
for (int i = in.size()-1; i >= 0; i--) {
CoreMap cm = in.get(i);
return cm.get(key);
}
return null;
}
};
public static final class ConcatListAggregator<T> extends CoreMapAttributeAggregator {
public ConcatListAggregator()
{
}
@Override
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
List<T> res = new ArrayList<>();
for (CoreMap cm:in) {
Object obj = cm.get(key);
if (obj != null) {
if (obj instanceof List) {
res.addAll( (List<T>) obj);
}
}
}
return res;
}
}
public static final class ConcatCoreMapListAggregator<T extends CoreMap> extends CoreMapAttributeAggregator {
boolean concatSelf = false;
public ConcatCoreMapListAggregator()
{
}
public ConcatCoreMapListAggregator(boolean concatSelf)
{
this.concatSelf = concatSelf;
}
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
List<T> res = new ArrayList<>();
for (CoreMap cm:in) {
Object obj = cm.get(key);
boolean added = false;
if (obj != null) {
if (obj instanceof List) {
res.addAll( (List<T>) obj);
added = true;
}
}
if (!added && concatSelf) {
res.add((T) cm);
}
}
return res;
}
}
public static final ConcatCoreMapListAggregator<CoreLabel> CONCAT_TOKENS = new ConcatCoreMapListAggregator<>(true);
public static final ConcatCoreMapListAggregator<CoreMap> CONCAT_COREMAP = new ConcatCoreMapListAggregator<>(true);
public static final class ConcatAggregator extends CoreMapAttributeAggregator {
String delimiter;
public ConcatAggregator(String delimiter)
{
this.delimiter = delimiter;
}
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
StringBuilder sb = new StringBuilder();
for (CoreMap cm:in) {
Object obj = cm.get(key);
if (obj != null) {
if (sb.length() > 0) {
sb.append(delimiter);
}
sb.append(obj);
}
}
return sb.toString();
}
}
public static final class ConcatTextAggregator extends CoreMapAttributeAggregator {
String delimiter;
public ConcatTextAggregator(String delimiter)
{
this.delimiter = delimiter;
}
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
String text = ChunkAnnotationUtils.getTokenText(in, key);
return text;
}
}
public static final CoreMapAttributeAggregator CONCAT = new ConcatAggregator(" ");
public static final CoreMapAttributeAggregator CONCAT_TEXT = new ConcatTextAggregator(" ");
public static final CoreMapAttributeAggregator COUNT = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
return in.size();
}
};
public static final CoreMapAttributeAggregator SUM = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
double sum = 0;
for (CoreMap cm:in) {
Object obj = cm.get(key);
if (obj != null) {
if (obj instanceof Number) {
sum += ((Number) obj).doubleValue();
} else if (obj instanceof String) {
sum += Double.parseDouble((String) obj);
} else {
throw new RuntimeException("Cannot sum attribute " + key + ", object of type: " + obj.getClass());
}
}
}
return sum;
}
};
public static final CoreMapAttributeAggregator MIN = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
Comparable min = null;
for (CoreMap cm:in) {
Object obj = cm.get(key);
if (obj != null) {
if (obj instanceof Comparable) {
Comparable c = (Comparable) obj;
if (min == null) {
min = c;
} else if (c.compareTo(min) < 0) {
min = c;
}
} else {
throw new RuntimeException("Cannot get min of attribute " + key + ", object of type: " + obj.getClass());
}
}
}
return min;
}
};
public static final CoreMapAttributeAggregator MAX = new CoreMapAttributeAggregator() {
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
Comparable max = null;
for (CoreMap cm:in) {
Object obj = cm.get(key);
if (obj != null) {
if (obj instanceof Comparable) {
Comparable c = (Comparable) obj;
if (max == null) {
max = c;
} else if (c.compareTo(max) > 0) {
max = c;
}
} else {
throw new RuntimeException("Cannot get max of attribute " + key + ", object of type: " + obj.getClass());
}
}
}
return max;
}
};
public static final class MostFreqAggregator extends CoreMapAttributeAggregator {
Set<Object> ignoreSet;
public MostFreqAggregator()
{
}
public MostFreqAggregator(Set<Object> set)
{
ignoreSet = set;
}
public Object aggregate(Class key, List<? extends CoreMap> in) {
if (in == null) return null;
IntCounter<Object> counter = new IntCounter<>();
for (CoreMap cm:in) {
Object obj = cm.get(key);
if (obj != null && (ignoreSet == null || !ignoreSet.contains(obj))) {
counter.incrementCount(obj);
}
}
if (counter.size() > 0) {
return counter.argmax();
} else {
return null;
}
}
}
public static final CoreMapAttributeAggregator MOST_FREQ = new MostFreqAggregator();
private static final Map<String, CoreMapAttributeAggregator> AGGREGATOR_LOOKUP = Generics.newHashMap();
static {
AGGREGATOR_LOOKUP.put("FIRST", FIRST);
AGGREGATOR_LOOKUP.put("FIRST_NON_NIL", FIRST_NON_NIL);
AGGREGATOR_LOOKUP.put("LAST", LAST);
AGGREGATOR_LOOKUP.put("LAST_NON_NIL", LAST_NON_NIL);
AGGREGATOR_LOOKUP.put("MIN", MIN);
AGGREGATOR_LOOKUP.put("MAX", MAX);
AGGREGATOR_LOOKUP.put("COUNT", COUNT);
AGGREGATOR_LOOKUP.put("SUM", SUM);
AGGREGATOR_LOOKUP.put("CONCAT", CONCAT);
AGGREGATOR_LOOKUP.put("CONCAT_TEXT", CONCAT_TEXT);
AGGREGATOR_LOOKUP.put("CONCAT_TOKENS", CONCAT_TOKENS);
AGGREGATOR_LOOKUP.put("MOST_FREQ", MOST_FREQ);
}
public static final Map<Class, CoreMapAttributeAggregator> DEFAULT_AGGREGATORS;
public static final Map<Class, CoreMapAttributeAggregator> DEFAULT_NUMERIC_AGGREGATORS;
public static final Map<Class, CoreMapAttributeAggregator> DEFAULT_NUMERIC_TOKENS_AGGREGATORS;
static {
Map<Class, CoreMapAttributeAggregator> defaultAggr = new ArrayMap<>();
defaultAggr.put(CoreAnnotations.TextAnnotation.class, CoreMapAttributeAggregator.CONCAT_TEXT);
defaultAggr.put(CoreAnnotations.CharacterOffsetBeginAnnotation.class, CoreMapAttributeAggregator.FIRST);
defaultAggr.put(CoreAnnotations.CharacterOffsetEndAnnotation.class, CoreMapAttributeAggregator.LAST);
defaultAggr.put(CoreAnnotations.TokenBeginAnnotation.class, CoreMapAttributeAggregator.FIRST);
defaultAggr.put(CoreAnnotations.TokenEndAnnotation.class, CoreMapAttributeAggregator.LAST);
defaultAggr.put(CoreAnnotations.TokensAnnotation.class, CoreMapAttributeAggregator.CONCAT_TOKENS);
defaultAggr.put(CoreAnnotations.BeforeAnnotation.class, CoreMapAttributeAggregator.FIRST);
defaultAggr.put(CoreAnnotations.AfterAnnotation.class, CoreMapAttributeAggregator.LAST);
DEFAULT_AGGREGATORS = Collections.unmodifiableMap(defaultAggr);
Map<Class, CoreMapAttributeAggregator> defaultNumericAggr = new ArrayMap<>(DEFAULT_AGGREGATORS);
defaultNumericAggr.put(CoreAnnotations.NumericCompositeTypeAnnotation.class, CoreMapAttributeAggregator.FIRST_NON_NIL);
defaultNumericAggr.put(CoreAnnotations.NumericCompositeValueAnnotation.class, CoreMapAttributeAggregator.FIRST_NON_NIL);
defaultNumericAggr.put(CoreAnnotations.NamedEntityTagAnnotation.class, CoreMapAttributeAggregator.FIRST_NON_NIL);
defaultNumericAggr.put(CoreAnnotations.NormalizedNamedEntityTagAnnotation.class, CoreMapAttributeAggregator.FIRST_NON_NIL);
DEFAULT_NUMERIC_AGGREGATORS = Collections.unmodifiableMap(defaultNumericAggr);
Map<Class, CoreMapAttributeAggregator> defaultNumericTokensAggr = new ArrayMap<>(DEFAULT_NUMERIC_AGGREGATORS);
defaultNumericTokensAggr.put(CoreAnnotations.NumerizedTokensAnnotation.class, CoreMapAttributeAggregator.CONCAT_COREMAP);
DEFAULT_NUMERIC_TOKENS_AGGREGATORS = Collections.unmodifiableMap(defaultNumericTokensAggr);
}
}