package hu.u_szeged.kpe.main;
import hu.u_szeged.kpe.KpeMain;
import hu.u_szeged.kpe.candidates.NGram;
import hu.u_szeged.kpe.candidates.NGramStats;
import hu.u_szeged.kpe.readers.DocumentData;
import hu.u_szeged.kpe.readers.DocumentSet;
import hu.u_szeged.ml.DataMiningException;
import hu.u_szeged.ml.Model;
import hu.u_szeged.ml.mallet.MalletDataHandler;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelector;
import cc.mallet.types.InfoGain;
import edu.stanford.nlp.util.CoreMap;
/**
* Class responsible for the training phrase of keyphrase extraction.
*/
public class ExtractionModelBuilder {
private KPEFilter m_KPEFilter = null;
/** The maximum length of phrases */
private int m_MaxPhraseLength = 5;
/** The minimum length of phrases */
private int m_MinPhraseLength = 1;
/** The minimum number of occurrences of a phrase */
private int m_MinNumOccur = 2;
/** The reader for the processing of the training data */
private DocumentSet docSet;
public void setDocSet(DocumentSet docs) {
docSet = docs;
}
public DocumentSet getDocSet() {
return docSet;
}
/**
* Get the value of MinNumOccur.
*
* @return Value of MinNumOccur.
*/
public int getMinNumOccur() {
return m_MinNumOccur;
}
public KPEFilter getKPEFilter() {
return m_KPEFilter;
}
/**
* Set the value of MinNumOccur.
*
* @param newMinNumOccur
* Value to assign to MinNumOccur.
*/
public void setMinNumOccur(int newMinNumOccur) {
m_MinNumOccur = newMinNumOccur;
}
/**
* Get the value of MaxPhraseLength.
*
* @return Value of MaxPhraseLength.
*/
public int getMaxPhraseLength() {
return m_MaxPhraseLength;
}
/**
* Set the value of MaxPhraseLength.
*
* @param newMaxPhraseLength
* Value to assign to MaxPhraseLength.
*/
public void setMaxPhraseLength(int newMaxPhraseLength) {
m_MaxPhraseLength = newMaxPhraseLength;
}
/**
* Get the value of MinPhraseLength.
*
* @return Value of MinPhraseLength.
*/
public int getMinPhraseLength() {
return m_MinPhraseLength;
}
/**
* Set the value of MinPhraseLength.
*
* @param newMinPhraseLength
* Value to assign to MinPhraseLength.
*/
public void setMinPhraseLength(int newMinPhraseLength) {
m_MinPhraseLength = newMinPhraseLength;
}
public String buildModel(int foldNum, int totalFolds, List<String> features, String classifier,
double commonWordsThreshold, double selectedFeatureRatio, boolean[] employBIESmarkup,
DocumentSet targetDomainDocs, boolean noSWpruning, boolean noPOSpruning, boolean serialize) throws Exception {
m_KPEFilter = new KPEFilter(noSWpruning, noPOSpruning);
m_KPEFilter.setMaxPhraseLength(getMaxPhraseLength());
m_KPEFilter.setMinPhraseLength(getMinPhraseLength());
m_KPEFilter.setMinNumOccur(getMinNumOccur());
String[] BIEScompatibleFeatures = { "PosFeature", "StrangeOrthographyFeature", "SuffixFeature" };
Map<String, Boolean> employBIES = new HashMap<>();
for (int i = 0; i < employBIESmarkup.length; ++i) {
employBIES.put(BIEScompatibleFeatures[i], employBIESmarkup[i]);
}
m_KPEFilter.setNumFeature(features, classifier, employBIES);
// m_KPEFilter.setAcceptSynonyms(useSynonymsForTraining);
List<DocumentData> documents = docSet.determineDocumentSet(foldNum, totalFolds, true, targetDomainDocs);
m_KPEFilter.setDocsNumber(documents.size());
int i = 0;
boolean containsScientific = false;
for (DocumentData doc : documents) {
containsScientific = containsScientific || doc.isScientific();
if (++i % 500 == 0) {
System.err.print(i + "\t");
}
m_KPEFilter.updateGlobalDictionary(doc.getKeyphrases(), doc.getSections(docSet.getReader(), serialize));
}
if (containsScientific) {
m_KPEFilter.setCommonWords(commonWordsThreshold, documents.size());
}
String log = buildClassifier(foldNum, documents, selectedFeatureRatio, serialize);
System.err.println("Classifier built of " + documents.size() + " documents in "
+ ((System.currentTimeMillis() - KpeMain.time) / 1000d) + " seconds.");
return log;
}
/**
* Builds the classifier.
*/
private String buildClassifier(int foldNum, List<DocumentData> docsToLearn, double featureRatio, boolean serialize) {
m_KPEFilter.initializeFeatureFields();
System.err.println("Global dictionaries built...\t" + (System.currentTimeMillis() - KpeMain.time) / 1000d);
Map<String, Object> initClassifier = new HashMap<String, Object>();
initClassifier.put("classifier", m_KPEFilter.getClassifierName());
MalletDataHandler dh = new MalletDataHandler();
try {
dh.initClassifier(initClassifier);
} catch (DataMiningException dme) {
dme.printStackTrace();
}
dh.createNewDataset(null);
int id = 1;
for (int docNum = 0; docNum < docsToLearn.size(); ++docNum) {
boolean belongsToTargetDomain = docNum >= docSet.size();
DocumentData doc = docsToLearn.get(docNum);
doc.setDocId(id++);
Map<String, Map<NGram, NGramStats>> hash = new HashMap<String, Map<NGram, NGramStats>>();
List<int[]> length = new ArrayList<int[]>(1);
// training is based on separate documents always
List<Map<Integer, List<CoreMap>>> grammars = new ArrayList<Map<Integer, List<CoreMap>>>(1);
TreeMap<Integer, List<CoreMap>> sections = doc.getSections(docSet.getReader(), serialize);
grammars.add(sections);
length.add(m_KPEFilter.getPhrases(hash, doc, Integer.MAX_VALUE, sections));
List<Map<String, Map<NGram, NGramStats>>> listOfHashs = new ArrayList<Map<String, Map<NGram, NGramStats>>>(1);
listOfHashs.add(hash);
for (Entry<String, Map<NGram, NGramStats>> phrase : hash.entrySet()) {
m_KPEFilter.addFeats(belongsToTargetDomain, phrase, length, listOfHashs, grammars, dh, true, doc);
}
}
// for (Entry<String, Map<Double, Integer>> feats :
// m_KPEFilter.getFeatures().getFeatureValDistribution().entrySet()) {
// int i = 0, sum = 0;
// if (feats.getKey().contains("TfIdf") && foldNum == 9) {
// for (Entry<Double, Integer> v : feats.getValue().entrySet()) {
// if (++i < 10 || i > 6000)
// System.err.println(i + "\t" + v);
// sum += v.getValue();
// }
// System.err.println(sum + "\t" + i);
// }
// }
String log = "N/A";
try {
int instanceNum = dh.getInstanceCount();
int featureNum = dh.getFeatureCount();
if (featureRatio < 1.0) {
FeatureSelector fs = new FeatureSelector(new InfoGain.Factory(), Double.MIN_VALUE);
fs.selectFeaturesFor(dh.data);
List<String> selection = new LinkedList<String>();
BitSet selectedFeatures = dh.data.getFeatureSelection().getBitSet();
Object[] alphabet = dh.data.getFeatureSelection().getAlphabet().toArray();
for (int j = 0; j < Math.min(featureNum, featureNum * featureRatio); ++j) {
boolean selected = selectedFeatures.get(j);
if (selected) {
selection.add(alphabet[j].toString());
}
}
// System.err.println("Selected features: " + selection);
fs.selectFeaturesFor(dh.data);
}
// dh.removeFeature(m_KPEFilter.getFeatures().getRareFeatures(1));
List<Alphabet> alphabets = new LinkedList<Alphabet>();
alphabets.add(dh.getAlphabet("feature"));
alphabets.add(dh.getAlphabet("label"));
m_KPEFilter.setAlphabets(alphabets);
// let's count the number of positive and negative training instances
int positiveTrainingInstances = 0, negativeTrainingInstances = 0;
for (String s : dh.instanceIds.keySet()) {
if (Boolean.parseBoolean((String) dh.getLabel(s))) {
positiveTrainingInstances++;
} else {
negativeTrainingInstances++;
}
}
log = instanceNum + " inst.\t" + positiveTrainingInstances + " pos + " + negativeTrainingInstances + " neg\t"
+ featureNum + " (" + dh.getFeatureCount() + ") features\t" + dh.getFeatureNames().size()
+ " pruned features\t" + (System.currentTimeMillis() - KpeMain.time) / 1000.0d;
Model learnedModel = dh.trainClassifier();
m_KPEFilter.setModel(learnedModel);
} catch (Exception e) {
e.printStackTrace();
}
System.err.println(log);
return log;
}
}