package org.maltparser.parser.guide.decision;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.feature.FeatureModel;
import org.maltparser.core.feature.FeatureVector;
import org.maltparser.core.helper.HashMap;
import org.maltparser.core.syntaxgraph.DependencyStructure;
import org.maltparser.parser.DependencyParserConfig;
import org.maltparser.parser.guide.ClassifierGuide;
import org.maltparser.parser.guide.GuideException;
import org.maltparser.parser.guide.instance.AtomicModel;
import org.maltparser.parser.guide.instance.FeatureDivideModel;
import org.maltparser.parser.guide.instance.InstanceModel;
import org.maltparser.parser.history.action.GuideDecision;
import org.maltparser.parser.history.action.MultipleDecision;
import org.maltparser.parser.history.action.SingleDecision;
import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision;
/**
*
* @author Johan Hall
* @since 1.1
*
*/
public class BranchedDecisionModel implements DecisionModel {
private ClassifierGuide guide;
private String modelName;
private FeatureModel featureModel;
private InstanceModel instanceModel;
private int decisionIndex;
private DecisionModel parentDecisionModel;
private HashMap<Integer, DecisionModel> children;
private String branchedDecisionSymbols;
public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException {
this.branchedDecisionSymbols = "";
setGuide(guide);
setFeatureModel(featureModel);
setDecisionIndex(0);
setModelName("bdm" + decisionIndex);
setParentDecisionModel(null);
}
public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException {
if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) {
this.branchedDecisionSymbols = branchedDecisionSymbol;
} else {
this.branchedDecisionSymbols = "";
}
setGuide(guide);
setParentDecisionModel(parentDecisionModel);
setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1);
setFeatureModel(parentDecisionModel.getFeatureModel());
if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) {
setModelName("bdm" + decisionIndex + branchedDecisionSymbols);
} else {
setModelName("bdm" + decisionIndex);
}
this.parentDecisionModel = parentDecisionModel;
}
public void updateFeatureModel() throws MaltChainedException {
featureModel.update();
}
// public void updateCardinality() throws MaltChainedException {
// featureModel.updateCardinality();
// }
public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
if (instanceModel != null) {
instanceModel.finalizeSentence(dependencyGraph);
}
if (children != null) {
for (DecisionModel child : children.values()) {
child.finalizeSentence(dependencyGraph);
}
}
}
public void noMoreInstances() throws MaltChainedException {
if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
throw new GuideException("The decision model could not create it's model. ");
}
if (instanceModel != null) {
instanceModel.noMoreInstances();
instanceModel.train();
}
if (children != null) {
for (DecisionModel child : children.values()) {
child.noMoreInstances();
}
}
}
public void terminate() throws MaltChainedException {
if (instanceModel != null) {
instanceModel.terminate();
instanceModel = null;
}
if (children != null) {
for (DecisionModel child : children.values()) {
child.terminate();
}
}
}
public void addInstance(GuideDecision decision) throws MaltChainedException {
if (decision instanceof SingleDecision) {
throw new GuideException("A branched decision model expect more than one decisions. ");
}
featureModel.update();
final SingleDecision singleDecision = ((MultipleDecision) decision).getSingleDecision(decisionIndex);
if (instanceModel == null) {
initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
}
instanceModel.addInstance(singleDecision);
if (decisionIndex + 1 < decision.numberOfDecisions()) {
if (singleDecision.continueWithNextDecision()) {
if (children == null) {
children = new HashMap<Integer, DecisionModel>();
}
DecisionModel child = children.get(singleDecision.getDecisionCode());
if (child == null) {
child = initChildDecisionModel(((MultipleDecision) decision).getSingleDecision(decisionIndex + 1),
branchedDecisionSymbols + (branchedDecisionSymbols.length() == 0 ? "" : "_") + singleDecision.getDecisionSymbol());
children.put(singleDecision.getDecisionCode(), child);
}
child.addInstance(decision);
}
}
}
public boolean predict(GuideDecision decision) throws MaltChainedException {
// if (decision instanceof SingleDecision) {
// throw new GuideException("A branched decision model expect more than one decisions. ");
// }
featureModel.update();
final SingleDecision singleDecision = ((MultipleDecision) decision).getSingleDecision(decisionIndex);
if (instanceModel == null) {
initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
}
instanceModel.predict(singleDecision);
if (decisionIndex + 1 < decision.numberOfDecisions()) {
if (singleDecision.continueWithNextDecision()) {
if (children == null) {
children = new HashMap<Integer, DecisionModel>();
}
DecisionModel child = children.get(singleDecision.getDecisionCode());
if (child == null) {
child = initChildDecisionModel(((MultipleDecision) decision).getSingleDecision(decisionIndex + 1),
branchedDecisionSymbols + (branchedDecisionSymbols.length() == 0 ? "" : "_") + singleDecision.getDecisionSymbol());
children.put(singleDecision.getDecisionCode(), child);
}
child.predict(decision);
}
}
return true;
}
public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException {
if (decision instanceof SingleDecision) {
throw new GuideException("A branched decision model expect more than one decisions. ");
}
featureModel.update();
final SingleDecision singleDecision = ((MultipleDecision) decision).getSingleDecision(decisionIndex);
if (instanceModel == null) {
initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
}
FeatureVector fv = instanceModel.predictExtract(singleDecision);
if (decisionIndex + 1 < decision.numberOfDecisions()) {
if (singleDecision.continueWithNextDecision()) {
if (children == null) {
children = new HashMap<Integer, DecisionModel>();
}
DecisionModel child = children.get(singleDecision.getDecisionCode());
if (child == null) {
child = initChildDecisionModel(((MultipleDecision) decision).getSingleDecision(decisionIndex + 1),
branchedDecisionSymbols + (branchedDecisionSymbols.length() == 0 ? "" : "_") + singleDecision.getDecisionSymbol());
children.put(singleDecision.getDecisionCode(), child);
}
child.predictExtract(decision);
}
}
return fv;
}
public FeatureVector extract() throws MaltChainedException {
featureModel.update();
return instanceModel.extract(); // TODO handle many feature vectors
}
public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException {
if (decision instanceof SingleDecision) {
throw new GuideException("A branched decision model expect more than one decisions. ");
}
boolean success = false;
final SingleDecision singleDecision = ((MultipleDecision) decision).getSingleDecision(decisionIndex);
if (decisionIndex + 1 < decision.numberOfDecisions()) {
if (singleDecision.continueWithNextDecision()) {
if (children == null) {
children = new HashMap<Integer, DecisionModel>();
}
DecisionModel child = children.get(singleDecision.getDecisionCode());
if (child != null) {
success = child.predictFromKBestList(decision);
}
}
}
if (!success) {
success = singleDecision.updateFromKBestList();
if (decisionIndex + 1 < decision.numberOfDecisions()) {
if (singleDecision.continueWithNextDecision()) {
if (children == null) {
children = new HashMap<Integer, DecisionModel>();
}
DecisionModel child = children.get(singleDecision.getDecisionCode());
if (child == null) {
child = initChildDecisionModel(((MultipleDecision) decision).getSingleDecision(decisionIndex + 1),
branchedDecisionSymbols + (branchedDecisionSymbols.length() == 0 ? "" : "_") + singleDecision.getDecisionSymbol());
children.put(singleDecision.getDecisionCode(), child);
}
child.predict(decision);
}
}
}
return success;
}
public ClassifierGuide getGuide() {
return guide;
}
public String getModelName() {
return modelName;
}
public FeatureModel getFeatureModel() {
return featureModel;
}
public int getDecisionIndex() {
return decisionIndex;
}
public DecisionModel getParentDecisionModel() {
return parentDecisionModel;
}
private void setFeatureModel(FeatureModel featureModel) {
this.featureModel = featureModel;
}
private void setDecisionIndex(int decisionIndex) {
this.decisionIndex = decisionIndex;
}
private void setParentDecisionModel(DecisionModel parentDecisionModel) {
this.parentDecisionModel = parentDecisionModel;
}
private void setModelName(String modelName) {
this.modelName = modelName;
}
private void setGuide(ClassifierGuide guide) {
this.guide = guide;
}
private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException {
Class<?> decisionModelClass = null;
if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) {
decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class;
} else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) {
decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class;
} else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) {
decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class;
}
if (decisionModelClass == null) {
throw new GuideException("Could not find an appropriate decision model for the relation to the next decision");
}
try {
Class<?>[] argTypes = {org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class,
java.lang.String.class};
Object[] arguments = new Object[3];
arguments[0] = getGuide();
arguments[1] = this;
arguments[2] = branchedDecisionSymbol;
Constructor<?> constructor = decisionModelClass.getConstructor(argTypes);
return (DecisionModel) constructor.newInstance(arguments);
} catch (NoSuchMethodException e) {
throw new GuideException("The decision model class '" + decisionModelClass.getName() + "' cannot be initialized. ", e);
} catch (InstantiationException e) {
throw new GuideException("The decision model class '" + decisionModelClass.getName() + "' cannot be initialized. ", e);
} catch (IllegalAccessException e) {
throw new GuideException("The decision model class '" + decisionModelClass.getName() + "' cannot be initialized. ", e);
} catch (InvocationTargetException e) {
throw new GuideException("The decision model class '" + decisionModelClass.getName() + "' cannot be initialized. ", e);
}
}
private void initInstanceModel(String subModelName) throws MaltChainedException {
FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols + "." + subModelName);
if (fv == null) {
fv = featureModel.getFeatureVector(subModelName);
}
if (fv == null) {
fv = featureModel.getMainFeatureVector();
}
DependencyParserConfig c = guide.getConfiguration();
// if (c.getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes") ||
// (c.getOptionValue("guide", "tree_split_columns")!=null &&
// c.getOptionValue("guide", "tree_split_columns").toString().length() > 0) ||
// (c.getOptionValue("guide", "tree_split_structures")!=null &&
// c.getOptionValue("guide", "tree_split_structures").toString().length() > 0)) {
// instanceModel = new DecisionTreeModel(fv, this);
// }else
if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) {
instanceModel = new AtomicModel(-1, fv, this);
} else {
instanceModel = new FeatureDivideModel(fv, this);
}
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(modelName).append(", ");
for (DecisionModel model : children.values()) {
sb.append(model.toString()).append(", ");
}
return sb.toString();
}
}