/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sequencelearning.hmm;
import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import com.google.common.base.Charsets;
import com.google.common.collect.Maps;
import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import org.apache.mahout.math.Matrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This class implements a sample program that uses a pre-tagged training data
* set to train an HMM model as a POS tagger. The training data is automatically
* downloaded from the following URL:
* http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then
* trains an HMM Model using supervised learning and tests the model on the
* following test data set:
* http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
* details regarding the data files can be found at
* http://flexcrfs.sourceforge.net/#Case_Study
*/
public final class PosTagger {
private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
private static final Pattern SPACE = Pattern.compile(" ");
private static final Pattern SPACES = Pattern.compile("[ ]+");
/**
* No public constructors for utility classes.
*/
private PosTagger() {
// nothing to do here really.
}
/**
* Model trained in the example.
*/
private static HmmModel taggingModel;
/**
* Map for storing the IDs for the POS tags (hidden states)
*/
private static Map<String, Integer> tagIDs;
/**
* Counter for the next assigned POS tag ID The value of 0 is reserved for
* "unknown POS tag"
*/
private static int nextTagId;
/**
* Map for storing the IDs for observed words (observed states)
*/
private static Map<String, Integer> wordIDs;
/**
* Counter for the next assigned word ID The value of 0 is reserved for
* "unknown word"
*/
private static int nextWordId = 1; // 0 is reserved for "unknown word"
/**
* Used for storing a list of POS tags of read sentences.
*/
private static List<int[]> hiddenSequences;
/**
* Used for storing a list of word tags of read sentences.
*/
private static List<int[]> observedSequences;
/**
* number of read lines
*/
private static int readLines;
/**
* Given an URL, this function fetches the data file, parses it, assigns POS
* Tag/word IDs and fills the hiddenSequences/observedSequences lists with
* data from those files. The data is expected to be in the following format
* (one word per line): word pos-tag np-tag sentences are closed with the .
* pos tag
*
* @param url Where the data file is stored
* @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for
* training data, not needed for test data)
* @throws IOException in case data file cannot be read.
*/
private static void readFromURL(String url, boolean assignIDs) throws IOException {
// initialize the data structure
hiddenSequences = new LinkedList<int[]>();
observedSequences = new LinkedList<int[]>();
readLines = 0;
// now read line by line of the input file
List<Integer> observedSequence = new LinkedList<Integer>();
List<Integer> hiddenSequence = new LinkedList<Integer>();
for (String line : CharStreams.readLines(Resources.newReaderSupplier(new URL(url), Charsets.UTF_8))) {
if (line.isEmpty()) {
// new sentence starts
int[] observedSequenceArray = new int[observedSequence.size()];
int[] hiddenSequenceArray = new int[hiddenSequence.size()];
for (int i = 0; i < observedSequence.size(); ++i) {
observedSequenceArray[i] = observedSequence.get(i);
hiddenSequenceArray[i] = hiddenSequence.get(i);
}
// now register those arrays
hiddenSequences.add(hiddenSequenceArray);
observedSequences.add(observedSequenceArray);
// and reset the linked lists
observedSequence.clear();
hiddenSequence.clear();
continue;
}
readLines++;
// we expect the format [word] [POS tag] [NP tag]
String[] tags = SPACE.split(line);
// when analyzing the training set, assign IDs
if (assignIDs) {
if (!wordIDs.containsKey(tags[0])) {
wordIDs.put(tags[0], nextWordId++);
}
if (!tagIDs.containsKey(tags[1])) {
tagIDs.put(tags[1], nextTagId++);
}
}
// determine the IDs
Integer wordID = wordIDs.get(tags[0]);
Integer tagID = tagIDs.get(tags[1]);
// handle unknown values
wordID = wordID == null ? 0 : wordID;
tagID = tagID == null ? 0 : tagID;
// now construct the current sequence
observedSequence.add(wordID);
hiddenSequence.add(tagID);
}
// if there is still something in the pipe, register it
if (!observedSequence.isEmpty()) {
int[] observedSequenceArray = new int[observedSequence.size()];
int[] hiddenSequenceArray = new int[hiddenSequence.size()];
for (int i = 0; i < observedSequence.size(); ++i) {
observedSequenceArray[i] = observedSequence.get(i);
hiddenSequenceArray[i] = hiddenSequence.get(i);
}
// now register those arrays
hiddenSequences.add(hiddenSequenceArray);
observedSequences.add(observedSequenceArray);
}
}
private static void trainModel(String trainingURL) throws IOException {
tagIDs = Maps.newHashMapWithExpectedSize(44); // we expect 44 distinct tags
wordIDs = Maps.newHashMapWithExpectedSize(19122); // we expect 19122
// distinct words
log.info("Reading and parsing training data file from URL: {}", trainingURL);
long start = System.currentTimeMillis();
readFromURL(trainingURL, true);
long end = System.currentTimeMillis();
double duration = (end - start) / 1000.0;
log.info("Parsing done in {} seconds!", duration);
log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.",
new Object[] {readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1});
start = System.currentTimeMillis();
taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
hiddenSequences, observedSequences, 0.05);
// we have to adjust the model a bit,
// since we assume a higher probability that a given unknown word is NNP
// than anything else
Matrix emissions = taggingModel.getEmissionMatrix();
for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) {
emissions.setQuick(i, 0, 0.1 / (double) taggingModel.getNrOfHiddenStates());
}
int nnptag = tagIDs.get("NNP");
emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates());
// re-normalize the emission probabilities
HmmUtils.normalizeModel(taggingModel);
// now register the names
taggingModel.registerHiddenStateNames(tagIDs);
taggingModel.registerOutputStateNames(wordIDs);
end = System.currentTimeMillis();
duration = (end - start) / 1000.0;
log.info("Trained HMM models in {} seconds!", duration);
}
private static void testModel(String testingURL) throws IOException {
log.info("Reading and parsing test data file from URL: {}", testingURL);
long start = System.currentTimeMillis();
readFromURL(testingURL, false);
long end = System.currentTimeMillis();
double duration = (end - start) / 1000.0;
log.info("Parsing done in {} seconds!", duration);
log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size());
start = System.currentTimeMillis();
int errorCount = 0;
int totalCount = 0;
for (int i = 0; i < observedSequences.size(); ++i) {
// fetch the viterbi path as the POS tag for this observed sequence
int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false);
// compare with the expected
int[] posExpected = hiddenSequences.get(i);
for (int j = 0; j < posExpected.length; ++j) {
totalCount++;
if (posEstimate[j] != posExpected[j]) {
errorCount++;
}
}
}
end = System.currentTimeMillis();
duration = (end - start) / 1000.0;
log.info("POS tagged test file in {} seconds!", duration);
double errorRate = (double) errorCount / (double) totalCount;
log.info("Tagged the test file with an error rate of: {}", errorRate);
}
private static List<String> tagSentence(String sentence) {
// first, we need to isolate all punctuation characters, so that they
// can be recognized
sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
sentence = sentence.replaceAll("''", " '' ");
// now we tokenize the sentence
String[] tokens = SPACES.split(sentence);
// now generate the observed sequence
int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0);
// POS tag this observedSequence
int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false);
// and now decode the tag names
return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null);
}
public static void main(String[] args) throws IOException {
// generate the model from URL
trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
// tag an exemplary sentence
String test = "McDonalds is a huge company with many employees .";
String[] testWords = SPACE.split(test);
List<String> posTags = tagSentence(test);
for (int i = 0; i < posTags.size(); ++i) {
log.info("{}[{}]", testWords[i], posTags.get(i));
}
}
}