/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.component.mode.dep;
import java.io.ObjectInputStream;
import java.util.List;
import edu.emory.clir.clearnlp.classification.instance.StringInstance;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.prediction.StringPrediction;
import edu.emory.clir.clearnlp.classification.vector.StringFeatureVector;
import edu.emory.clir.clearnlp.collection.pair.ObjectIntPair;
import edu.emory.clir.clearnlp.component.AbstractStatisticalComponent;
import edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState;
import edu.emory.clir.clearnlp.component.mode.dep.state.DEPStateBranch;
import edu.emory.clir.clearnlp.dependency.DEPNode;
import edu.emory.clir.clearnlp.dependency.DEPTree;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public abstract class AbstractDEPParser extends AbstractStatisticalComponent<DEPLabel, AbstractDEPState, DEPEval, DEPFeatureExtractor, DEPConfiguration> implements DEPTransition
{
private int[][] label_indices;
/** Creates a dependency parser for train. */
public AbstractDEPParser(DEPConfiguration configuration, DEPFeatureExtractor[] extractors, Object lexicons)
{
super(configuration, extractors, lexicons, false, 1);
init();
}
/** Creates a dependency parser for bootstrap or evaluate. */
public AbstractDEPParser(DEPConfiguration configuration, DEPFeatureExtractor[] extractors, Object lexicons, StringModel[] models, boolean bootstrap)
{
super(configuration, extractors, lexicons, models, bootstrap);
init();
}
/** Creates a dependency parser for decode. */
public AbstractDEPParser(DEPConfiguration configuration, ObjectInputStream in)
{
super(configuration, in);
init();
}
/** Creates a dependency parser for decode. */
public AbstractDEPParser(DEPConfiguration configuration, byte[] models)
{
super(configuration, models);
init();
}
private void init()
{
label_indices = AbstractDEPState.initLabelIndices(s_models[0].getLabels());
}
// ====================================== LEXICONS ======================================
@Override
public Object getLexicons() {return null;}
@Override
public void setLexicons(Object lexicons) {}
// ====================================== EVAL ======================================
protected void initEval()
{
c_eval = new DEPEval(t_configuration.evaluatePunctuation());
}
// ====================================== PROCESS ======================================
@Override
public void process(DEPTree tree)
{
AbstractDEPState state = new DEPStateBranch(tree, c_flag, t_configuration);
List<StringInstance> instances = process(state);
if (state.startBranching())
{
while (state.nextBranch()) state.saveBest(process(state));
List<StringInstance> tmp = state.setBest();
if (tmp != null) instances.addAll(tmp);
}
if (isTrainOrBootstrap())
s_models[0].addInstances(instances);
else
{
processHeadless(state);
if (isEvaluate()) c_eval.countCorrect(tree, state.getOracle());
}
}
@Override
protected StringFeatureVector createStringFeatureVector(AbstractDEPState state)
{
return f_extractors[0].createStringFeatureVector(state);
}
@Override
protected DEPLabel getAutoLabel(AbstractDEPState state, StringFeatureVector vector)
{
StringPrediction[] ps = getPredictions(state, vector);
DEPLabel autoLabel = new DEPLabel(ps[0]);
if (autoLabel.isArc(ARC_NO)) state.save2ndHead(ps);
state.saveBranch(ps);
return autoLabel;
}
protected StringPrediction[] getPredictions(AbstractDEPState state, StringFeatureVector vector)
{
int[] indices = state.getLabelIndices(label_indices);
StringPrediction[] ps = (indices != null) ? s_models[0].predictTop2(vector, indices) : s_models[0].predictTop2(vector);
for (StringPrediction p : ps) p.setScore(1/(1+Math.exp(-p.getScore())));
return ps;
}
// ====================================== POST-PROCESS ======================================
private void processHeadless(AbstractDEPState state)
{
ObjectIntPair<StringPrediction> max;
int i, size = state.getTreeSize();
DEPNode node;
for (i=1; i<size; i++)
{
node = state.getNode(i);
if (!node.hasHead() && !state.find2ndHead(node))
{
max = new ObjectIntPair<StringPrediction>(null, -1000);
processHeadlessAll(state, node, max, label_indices[AbstractDEPState.RIGHT_ARC], -1);
processHeadlessAll(state, node, max, label_indices[AbstractDEPState. LEFT_ARC] , 1);
if (max.o == null)
node.setHead(state.getNode(0), t_configuration.getRootLabel());
else
node.setHead(state.getNode(max.i), new DEPLabel(max.o).getDeprel());
}
}
}
private void processHeadlessAll(AbstractDEPState state, DEPNode node, ObjectIntPair<StringPrediction> max, int[] indices, int dir)
{
int i, currID = node.getID(), size = state.getTreeSize();
StringFeatureVector vector;
StringPrediction p;
DEPNode head;
for (i=currID+dir; 0 <= i&&i < size; i+=dir)
{
head = state.getNode(i);
if (!head.isDescendantOf(node))
{
if (dir < 0) state.reset(i, currID);
else state.reset(currID, i);
vector = createStringFeatureVector(state);
p = s_models[0].predictBest(vector, indices);
if (max.o == null || max.o.compareTo(p) < 0) max.set(p, i);
}
}
}
// ====================================== ONLINE TRAIN ======================================
@Override
public void onlineTrain(List<DEPTree> trees)
{
onlineTrainSingleAdaGrad(trees);
}
@Override
protected void onlineLexicons(DEPTree tree)
{
}
}