/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.component.configuration;
import java.io.InputStream;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.trainer.AbstractAdaGrad;
import edu.emory.clir.clearnlp.classification.trainer.AbstractLiblinear;
import edu.emory.clir.clearnlp.classification.trainer.AbstractTrainer;
import edu.emory.clir.clearnlp.classification.trainer.AdaGradLR;
import edu.emory.clir.clearnlp.classification.trainer.AdaGradSVM;
import edu.emory.clir.clearnlp.classification.trainer.LiblinearL2LR;
import edu.emory.clir.clearnlp.classification.trainer.LiblinearL2SVM;
import edu.emory.clir.clearnlp.collection.map.ObjectIntHashMap;
import edu.emory.clir.clearnlp.component.utils.NLPMode;
import edu.emory.clir.clearnlp.reader.AbstractReader;
import edu.emory.clir.clearnlp.reader.LineReader;
import edu.emory.clir.clearnlp.reader.RawReader;
import edu.emory.clir.clearnlp.reader.TReader;
import edu.emory.clir.clearnlp.reader.TSVReader;
import edu.emory.clir.clearnlp.util.XmlUtils;
import edu.emory.clir.clearnlp.util.lang.TLanguage;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public class AbstractConfiguration implements ConfigurationXML
{
private AbstractReader<?> d_reader;
private NLPMode n_mode;
private Element x_top;
// =================================== CONSTRUCTORS ===================================
public AbstractConfiguration(NLPMode mode)
{
setMode(mode);
}
public AbstractConfiguration(InputStream in)
{
init(in);
}
public AbstractConfiguration(NLPMode mode, InputStream in)
{
this(mode);
init(in);
}
private void init(InputStream in)
{
x_top = XmlUtils.getDocumentElement(in);
d_reader = initReader();
}
private AbstractReader<?> initReader()
{
Element eReader = getFirstElement(E_READER);
TReader type = TReader.getType(XmlUtils.getTrimmedAttribute(eReader, A_TYPE));
if (type == TReader.RAW)
return new RawReader();
else if (type == TReader.LINE)
return new LineReader();
else
{
ObjectIntHashMap<String> map = getFieldMap(eReader);
int iID = map.get(FIELD_ID) - 1;
int iForm = map.get(FIELD_FORM) - 1;
int iLemma = map.get(FIELD_LEMMA) - 1;
int iPOSTag = map.get(FIELD_POS) - 1;
int iNament = map.get(FIELD_NAMENT) - 1;
int iFeats = map.get(FIELD_FEATS) - 1;
int iHeadID = map.get(FIELD_HEADID) - 1;
int iDeprel = map.get(FIELD_DEPREL) - 1;
int iXHeads = map.get(FIELD_XHEADS) - 1;
int iSHeads = map.get(FIELD_SHEADS) - 1;
return new TSVReader(iID, iForm, iLemma, iPOSTag, iNament, iFeats, iHeadID, iDeprel, iXHeads, iSHeads);
}
}
/** Called by {@link #initReader()}. */
private ObjectIntHashMap<String> getFieldMap(Element eReader)
{
NodeList list = eReader.getElementsByTagName(E_COLUMN);
int i, index, size = list.getLength();
Element element;
String field;
ObjectIntHashMap<String> map = new ObjectIntHashMap<String>();
for (i=0; i<size; i++)
{
element = (Element)list.item(i);
field = XmlUtils.getTrimmedAttribute(element, A_FIELD);
index = XmlUtils.getIntegerAttribute(element, A_INDEX);
map.put(field, index);
}
return map;
}
// =================================== GETTERS ===================================
public AbstractReader<?> getReader()
{
return d_reader;
}
public TLanguage getLanguage()
{
String language = XmlUtils.getTrimmedTextContent(getFirstElement(E_LANGUAGE));
return TLanguage.getType(language);
}
public int getThreadSize()
{
return XmlUtils.getIntegerTextContent(getFirstElement(E_THREAD_SIZE));
}
// =================================== ELEMENT ===================================
protected Element getFirstElement(String tag)
{
return XmlUtils.getFirstElementByTagName(x_top, tag);
}
protected Element getModeElement(NLPMode mode)
{
NodeList list = x_top.getChildNodes();
int i, len = list.getLength();
Node node;
for (i=0; i<len; i++)
{
node = list.item(i);
if (node.getNodeName().equals(mode.toString()))
return (Element)node;
}
return null;
}
// =================================== MODE ===================================
public NLPMode getMode()
{
return n_mode;
}
public void setMode(NLPMode mode)
{
n_mode = mode;
}
protected Element getModeElement()
{
return getModeElement(n_mode);
}
// =================================== TRAINER ===================================
public boolean isBootstrap()
{
Element eMode = getModeElement();
Element eBootstrap = XmlUtils.getFirstElementByTagName(eMode, E_BOOTSTRAPS);
return (eBootstrap != null) ? Boolean.parseBoolean(XmlUtils.getTrimmedTextContent(eBootstrap)) : false;
}
public AbstractTrainer[] getTrainers(StringModel[] models)
{
return getTrainers(models, true);
}
public AbstractTrainer[] getTrainers(StringModel[] models, boolean reset)
{
AbstractTrainer[] trainers = new AbstractTrainer[models.length];
Element eMode = getModeElement();
for (int i=0; i<models.length; i++)
trainers[i] = getTrainer(eMode, models, i, reset);
return trainers;
}
private AbstractTrainer getTrainer(Element eMode, StringModel[] models, int index, boolean reset)
{
Element eTrainer = XmlUtils.getElementByTagName(eMode, E_TRAINER, index);
String algorithm = XmlUtils.getTrimmedAttribute(eTrainer, A_ALGORITHM);
StringModel model = models[index];
if (reset) model.reset();
switch (algorithm)
{
case ALG_ADAGRAD : return getTrainerAdaGrad (eTrainer, model);
case ALG_LIBLINEAR: return getTrainerLiblinear(eTrainer, model);
}
throw new IllegalArgumentException(algorithm+" is not a valid algorithm name.");
}
private AbstractAdaGrad getTrainerAdaGrad(Element eTrainer, StringModel model)
{
int labelCutoff = XmlUtils.getIntegerAttribute(eTrainer, A_LABEL_CUTOFF);
int featureCutoff = XmlUtils.getIntegerAttribute(eTrainer, A_FEATURE_CUTOFF);
String type = XmlUtils.getTrimmedAttribute(eTrainer, A_TYPE);
boolean average = XmlUtils.getBooleanAttribute(eTrainer, "average");
double alpha = XmlUtils.getDoubleAttribute (eTrainer, "alpha");
double rho = XmlUtils.getDoubleAttribute (eTrainer, "rho");
double bias = XmlUtils.getDoubleAttribute (eTrainer, "bias");
switch (type)
{
case V_SUPPORT_VECTOR_MACHINE: return new AdaGradSVM(model, labelCutoff, featureCutoff, average, alpha, rho, bias);
case V_LOGISTIC_REGRESSION : return new AdaGradLR (model, labelCutoff, featureCutoff, average, alpha, rho, bias);
}
throw new IllegalArgumentException(type+" is not a valid algorithm type.");
}
private AbstractLiblinear getTrainerLiblinear(Element eTrainer, StringModel model)
{
int labelCutoff = XmlUtils.getIntegerAttribute(eTrainer, A_LABEL_CUTOFF);
int featureCutoff = XmlUtils.getIntegerAttribute(eTrainer, A_FEATURE_CUTOFF);
int numThreads = XmlUtils.getIntegerAttribute(eTrainer, A_NUMBER_OF_THREADS);
String type = XmlUtils.getTrimmedAttribute(eTrainer, A_TYPE);
double cost = XmlUtils.getDoubleAttribute(eTrainer, "cost");
double eps = XmlUtils.getDoubleAttribute(eTrainer, "eps");
double bias = XmlUtils.getDoubleAttribute(eTrainer, "bias");
switch (type)
{
case V_SUPPORT_VECTOR_MACHINE: return new LiblinearL2SVM(model, labelCutoff, featureCutoff, numThreads, cost, eps, bias);
case V_LOGISTIC_REGRESSION : return new LiblinearL2LR (model, labelCutoff, featureCutoff, numThreads, cost, eps, bias);
}
throw new IllegalArgumentException(type+" is not a valid algorithm type.");
}
// =================================== TEXT CONTENTS ===================================
public double getDoubleTextContent(Element eMode, String tagName)
{
return XmlUtils.getDoubleTextContent(XmlUtils.getFirstElementByTagName(eMode, tagName));
}
public int getIntegerTextContent(Element eMode, String tagName)
{
return XmlUtils.getIntegerTextContent(XmlUtils.getFirstElementByTagName(eMode, tagName));
}
public String getTextContent(Element eMode, String tagName)
{
return XmlUtils.getTrimmedTextContent(XmlUtils.getFirstElementByTagName(eMode, tagName));
}
// =================================== BEAM ===================================
public int getBeamSize(NLPMode mode)
{
Element eMode = getModeElement();
return XmlUtils.getIntegerTextContent(XmlUtils.getFirstElementByTagName(eMode, E_BEAM_SIZE));
}
}