/**
* Copyright 2015, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.component.mode.ner;
import java.util.List;
import java.util.StringJoiner;
import java.util.function.Function;
import edu.emory.clir.clearnlp.collection.map.IntObjectHashMap;
import edu.emory.clir.clearnlp.collection.pair.ObjectIntPair;
import edu.emory.clir.clearnlp.collection.tree.PrefixNode;
import edu.emory.clir.clearnlp.collection.tree.PrefixTree;
import edu.emory.clir.clearnlp.collection.triple.ObjectIntIntTriple;
import edu.emory.clir.clearnlp.component.state.AbstractTagState;
import edu.emory.clir.clearnlp.component.utils.CFlag;
import edu.emory.clir.clearnlp.dependency.DEPNode;
import edu.emory.clir.clearnlp.dependency.DEPTree;
import edu.emory.clir.clearnlp.ner.BILOU;
import edu.emory.clir.clearnlp.ner.NERInfoSet;
import edu.emory.clir.clearnlp.ner.NERLib;
import edu.emory.clir.clearnlp.util.constant.StringConst;
/**
* @since 3.0.3
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public class NERState extends AbstractTagState
{
/** Information from prefix-tree. */
private List<ObjectIntIntTriple<NERInfoSet>> info_list;
private PrefixTree<String,NERInfoSet> ne_dictionary;
private String[] ambiguity_classes;
// ====================================== INITIALIZATION ======================================
public NERState(DEPTree tree, CFlag flag, PrefixTree<String,NERInfoSet> namedEntityDictionary)
{
super(tree, flag);
init(namedEntityDictionary);
}
public void init(PrefixTree<String,NERInfoSet> namedEntityDictionary)
{
ne_dictionary = namedEntityDictionary;
// info_list = ne_dictionary.getAll(d_tree.toNodeArray(), 1, DEPNode::getWordForm, true, false);
info_list = ne_dictionary.getAll(d_tree.toNodeArray(), 1, DEPNode::getLowerSimplifiedWordForm, true, false);
ambiguity_classes = getAmbiguityClasses();
}
// private void initAmbiguityClasses()
// {
// List<Set<String>> sets = IntStream.range(0, t_size).mapToObj(k -> new HashSet<String>()).collect(Collectors.toList());
// StringJoiner[] joiners = new StringJoiner[t_size];
// int i, j, size = info_list.size();
// ObjectIntIntTriple<NERInfoSet> t;
// String tag;
//
// for (i=1; i<t_size; i++)
// joiners[i] = new StringJoiner("-");
//
// for (i=0; i<size; i++)
// {
// t = info_list.get(i);
// tag = t.o.joinTags(StringConst.COLON);
//
// if (t.i1 == t.i2)
// joiners[t.i1].add(NERLib.toBILOUTag(BILOU.U, tag));
// else
// {
// joiners[t.i1].add(NERLib.toBILOUTag(BILOU.B, tag));
// joiners[t.i2].add(NERLib.toBILOUTag(BILOU.L, tag));
//
// for (j=t.i1+1; j<t.i2; j++)
// joiners[j].add(NERLib.toBILOUTag(BILOU.I, tag));
// }
//
// for (j=t.i1; j<=t.i2; j++)
// sets.get(j).addAll(t.o.getCategorySet());
// }
//
// ambiguity_class_set = new String[t_size][];
// ambiguity_classes = new String[t_size];
//
// for (i=1; i<t_size; i++)
// {
// ambiguity_classes [i] = joiners.length == 0 ? null : joiners[i].toString();
// ambiguity_class_set[i] = DSUtils.toArray(sets.get(i));
// }
// }
private String[] getAmbiguityClasses()
{
StringJoiner[] joiners = new StringJoiner[t_size];
ObjectIntIntTriple<NERInfoSet> t;
int i, j, size = info_list.size();
String tag;
for (i=1; i<t_size; i++)
joiners[i] = new StringJoiner("-");
for (i=0; i<size; i++)
{
t = info_list.get(i);
tag = t.o.joinTags(StringConst.COLON);
if (t.i1 == t.i2)
joiners[t.i1].add(NERLib.toBILOUTag(BILOU.U, tag));
else
{
joiners[t.i1].add(NERLib.toBILOUTag(BILOU.B, tag));
joiners[t.i2].add(NERLib.toBILOUTag(BILOU.L, tag));
for (j=t.i1+1; j<t.i2; j++)
joiners[j].add(NERLib.toBILOUTag(BILOU.I, tag));
}
}
String[] classes = new String[t_size];
for (i=1; i<t_size; i++) classes[i] = joiners.length == 0 ? null : joiners[i].toString();
return classes;
}
// ====================================== ORACLE/LABEL ======================================
@Override
protected String clearOracle(DEPNode node)
{
String tag = node.clearNamedEntityTag();
return tag.startsWith("I") ? "I" : tag;
// return node.clearNamedEntityTag();
}
// ====================================== TRANSITION ======================================
protected void setLabel(DEPNode node, String label)
{
node.setNamedEntityTag(label);
}
// ====================================== FEATURES ======================================
@Override
public String getAmbiguityClass(DEPNode node)
{
return ambiguity_classes[node.getID()];
}
// public String[] getAmbiguityClasses(DEPNode node)
// {
// return ambiguity_class_set[node.getID()];
// }
//
// public String[] getCooccuranceFeatures(DEPNode node)
// {
// String[] categories = {"PER", "LOC", "ORG", "MISC"};
// int[] cooccurrences = new int[categories.length];
// int i;
// List<DEPNode> prevWords = node.getSubNodeList();
// for (DEPNode prevWord : prevWords) {
// for (i=0;i<categories.length;i++) {
// if (prevWord.getNamedEntityTag().equals(categories[i])) {
// cooccurrences[i]++;
// }
// }
// }
// StringJoiner[] joiner= new StringJoiner[categories.length];
// String[] features = new String[categories.length];
// for (i=0;i<categories.length;i++) {
// joiner[i]= new StringJoiner("-");
// joiner[i].add(categories[i])
// .add(Double.toString(cooccurrences[i]/Math.log(prevWords.size())));
// features[i] = joiner[i].toString();
// }
//
// return features;
// }
// ====================================== DICTIONARY ======================================
/** For training. */
public void adjustDictionary()
{
IntObjectHashMap<String> goldMap = collectNamedEntityMap(g_oracle, String::toString);
populateDictionary(goldMap);
}
private IntObjectHashMap<ObjectIntIntTriple<NERInfoSet>> populateDictionary(IntObjectHashMap<String> goldMap)
{
IntObjectHashMap<ObjectIntIntTriple<NERInfoSet>> dictMap = toNERInfoMap();
NERInfoSet list;
int bIdx, eIdx;
// add gold entries to the dictionary
for (ObjectIntPair<String> p : goldMap)
{
dictMap.remove(p.i);
bIdx = p.i / t_size;
eIdx = p.i % t_size;
list = pick(ne_dictionary, p.o, d_tree.toNodeArray(), bIdx, eIdx+1, DEPNode::getWordForm, 1);
list.addCorrectCount(1);
}
for (ObjectIntPair<ObjectIntIntTriple<NERInfoSet>> p : dictMap)
p.o.o.addCorrectCount(-1);
return dictMap;
}
/**
* @param beginIndex inclusive
* @param endIndex exclusive
*/
static public <T>NERInfoSet pick(PrefixTree<String,NERInfoSet> dictionary, String tag, T[] array, int beginIndex, int endIndex, Function<T,String> f, int inc)
{
PrefixNode<String,NERInfoSet> node = dictionary.add(array, beginIndex, endIndex, f);
NERInfoSet set = node.getValue();
if (set == null)
{
set = new NERInfoSet();
node.setValue(set);
}
set.addCategory(tag);
return set;
}
private IntObjectHashMap<ObjectIntIntTriple<NERInfoSet>> toNERInfoMap()
{
IntObjectHashMap<ObjectIntIntTriple<NERInfoSet>> map = new IntObjectHashMap<>();
for (ObjectIntIntTriple<NERInfoSet> t : info_list)
map.put(NERState.getKey(t.i1, t.i2, t_size), t);
return map;
}
static public <T>IntObjectHashMap<String> collectNamedEntityMap(T[] array, Function<T,String> f)
{
IntObjectHashMap<String> map = new IntObjectHashMap<>();
int i, beginIndex = -1, size = array.length;
String tag;
for (i=1; i<size; i++)
{
tag = f.apply(array[i]);
if (tag == null || tag.length() < 3) continue;
switch (NERLib.toBILOU(tag))
{
case U: map.put(getKey(i,i,size), NERLib.toNamedEntity(tag)); beginIndex = -1; break;
case B: beginIndex = i; break;
case L: if (0 < beginIndex&&beginIndex < i) map.put(getKey(beginIndex,i,size), NERLib.toNamedEntity(tag)); beginIndex = -1; break;
case O: beginIndex = -1; break;
case I: break;
}
}
return map;
}
static private int getKey(int beginIndex, int endIndex, int size)
{
return beginIndex * size + endIndex;
}
// ====================================== POST-PROCESS ======================================
public void postProcess()
{
int i, beginIndex = -1;
DEPNode curr;
for (i=1; i<t_size; i++)
{
curr = getNode(i);
switch (NERLib.toBILOU(curr.getNamedEntityTag()))
{
case B: beginIndex = postProcessB(curr); break;
case L: beginIndex = postProcessL(curr, beginIndex, i); break;
case U: postProcessU(curr); break;
case I: postProcessI(curr); break;
case O: beginIndex = -1; break;
}
}
}
private int postProcessB(DEPNode curr)
{
DEPNode next = getNode(curr.getID()+1);
if (next == null)
{
curr.setNamedEntityTag(NERLib.changeChunkType(BILOU.U, curr.getNamedEntityTag()));
return -1;
}
curr.setNamedEntityTag("O");
return curr.getID();
}
private int postProcessL(DEPNode curr, int beginIndex, int endIndex)
{
String tag = NERLib.toNamedEntity(curr.getNamedEntityTag());
if (endIndex == 1)
{
curr.setNamedEntityTag(NERLib.toBILOUTag(BILOU.U, tag));
}
else if (beginIndex > 0)
{
getNode(beginIndex).setNamedEntityTag(NERLib.toBILOUTag(BILOU.B, tag));
for (int i=beginIndex+1; i<endIndex; i++)
getNode(i).setNamedEntityTag(NERLib.toBILOUTag(BILOU.I, tag));
}
else
{
curr.setNamedEntityTag("O");
// DEPNode prev = getNode(endIndex-1);
//
// if (prev.isNamedEntityTag("O"))
// {
// String snd = prev.getFeat(DEPLib.FEAT_NER2);
//
// if (snd != null && snd.endsWith(tag))
// prev.setNamedEntityTag(NERTag.toBILOUTag(BILOU.B, tag));
// }
// else if (prev.isNamedEntityTag(curr.getNamedEntityTag()))
// prev.setNamedEntityTag(NERTag.toBILOUTag(BILOU.B, tag));
}
return -1;
}
private int postProcessU(DEPNode curr)
{
DEPNode next = getNode(curr.getID()+1);
if (next != null && next.getNamedEntityTag().startsWith("L"))
{
curr.setNamedEntityTag("O");
return curr.getID();
}
return -1;
}
private void postProcessI(DEPNode curr)
{
curr.setNamedEntityTag("O");
}
}