import java.io.*;
import java.util.Arrays;
import java.util.Vector;
/**
* N-gram Punctuation Prediction. Requires a corpus with the punctuation symbols
* tokenized as , --> ,COMMA . --> .PERIOD ? --> ?QMARK
*
* @author joakimlilja
*
*/
public class PunctuationPredicter {
public static final String CORPUS_TEST_PATH = "ppSentenses.txt";
public static NGramWrapper nGramWrapper;
/**
* Contructor
*
* @param nGramLength
* - length of n-gram
* @param corpusPath
* - path to corpus
*/
public PunctuationPredicter(int nGramLength, String corpusPath) {
nGramWrapper = new NGramWrapper(nGramLength);
if (!corpusPath.equals("")) {
nGramWrapper.readFile(new File(corpusPath));
} else {
nGramWrapper.readFile(new File(CORPUS_TEST_PATH));
}
}
/**
* Predicts the most likely sentence with punctuation symbols inserted given
* the input
*
* @param input
* - string from where the prediction is to be made
* @return the predicted sentence
*/
public String predictPunctuation(String input) {
//System.err.println("-----------------------PREDICTION---------------------------");
// Split into words
String[] words = input.split(" ");
// Generate all possible punctuation combinations
HyperStringFSA3 hypString = new HyperStringFSA3(words, nGramWrapper);
// For each combination get it's count (last index)
/*
String prediction = "";
double maxCount = 0;
for (String[] s : hypString.getOutputs()) {
//System.err.println(Arrays.toString(s));
double count = Double.parseDouble(s[s.length - 1]);
//System.err.println(count);
if (count > maxCount) {
prediction = "";
maxCount = count;
for (String w : s) {
prediction += w + " ";
}
}
}
*/
String returnValue = hypString.getOptimalString();
hypString=null;
return returnValue;
}
// Test method to run input from command line
private void handleInput(int nGramLength) {
try {
BufferedReader br = new BufferedReader(new InputStreamReader(
System.in));
String input = br.readLine();
while (input != null) {
System.out.println(predictPunctuation(input));
input = br.readLine();
}
} catch (IOException e) {
e.printStackTrace();
}
}
private double getCostOfString(String word) {
String ngram[] = word.split(" ");
double value = 1.0D;
for(int i = ngram.length; i >= nGramWrapper.getNGramLength(); i--) {
String[] argument = new String[nGramWrapper.getNGramLength()];
System.arraycopy(ngram, i-nGramWrapper.getNGramLength(), argument, 0, nGramWrapper.getNGramLength());
value *= nGramWrapper.getCostOfNGram(argument);
}
return value;
}
// Test
public static void main(String[] args) {
int nGramLength = 4;
for (int i = 0; i < args.length; i += 2) {
if (args[i].equals("n-gram")) {
nGramLength = Integer.parseInt(args[i + 1]);
}
}
for(int n = 3; n <= 7; n++) {
System.gc();
PunctuationPredicter pI = new PunctuationPredicter(n, "ppCorpus.txt");
for (int i = 0; i < 1; i++) {
String evaluate = "testSentences" + i +".txt";
String answers = "testSentencesAnswers" + i +"ngram"+n+ ".txt";
String correct = "testSentencesCorrection" + i +".txt";
try {
//BufferedReader br = new BufferedReader(new FileReader(evaluate));
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(evaluate), "UTF-16BE"));
BufferedReader correction = new BufferedReader(new InputStreamReader(new FileInputStream(correct), "UTF-16BE"));
//PrintWriter pw = new PrintWriter(answers);
OutputStreamWriter pw = new OutputStreamWriter(new FileOutputStream(answers), "UTF-16BE");
PrintWriter printOOV = new PrintWriter("testSentencesAnswers" + i + "OOV.txt");
int counter = Integer.MAX_VALUE;
//int counter = 3;
int counting = 1;
while ((counter > 0) && br.ready()) { //Risky?
//while(false) {
long time = System.currentTimeMillis();
String fix = br.readLine();
//System.err.println("---------------------");
//System.err.println(fix);
fix = fix.trim().replaceAll("( )+", " ");
//if(fix.split(" ").length<9) {
if (true) {
//System.err.println("-----------------------------------------------------");
//System.err.println(fix);
//System.out.println(pI.predictPunctuation(fix));
String answer = pI.predictPunctuation(fix);
pI.nGramWrapper.updateOOV(fix.split(" "));
pI.nGramWrapper.updateCoverage(fix.split(" "));
pw.write(answer);
pw.write('\n');
//String correctional = correction.readLine().trim().replaceAll("( )+", " ").replaceAll("(.PERIOD )+", ".PERIOD ");
//System.err.println(correctional + "\t" + pI.getCostOfString(correctional));
//System.err.println(answer + "\t" + pI.getCostOfString(answer));
//System.err.println(answer);
time = System.currentTimeMillis() - time;
time = time / 1000;
System.err.println(counting);
counting++;
//System.err.println("Spent " + time + " s calculating sentence.");
}
counter--;
}
System.out.println(pI.nGramWrapper.getOOV());
System.out.println("In vocabulary = " + pI.nGramWrapper.numberOfTokensInVocabulary);
System.out.println("Out of vocabulary = " + pI.nGramWrapper.numberOfTokensOutOfVocabulary);
System.out.println("Coverage = " + pI.nGramWrapper.getCoverage());
printOOV.println("In vocabulary = " + pI.nGramWrapper.numberOfTokensInVocabulary);
printOOV.println("Out of vocabulary = " + pI.nGramWrapper.numberOfTokensOutOfVocabulary);
//printOOV.println("Coverage = "+pI.nGramWrapper.getCoverage());
printOOV.println(pI.nGramWrapper.getOOV());
pI.nGramWrapper.resetOOV();
pI.nGramWrapper.resetCoverage();
br.close();
pw.close();
printOOV.close();
/*
br = new BufferedReader(new FileReader("testdata.txt"));
while(br.ready()) {
String input = br.readLine();
System.err.println(input+"\t"+pI.getCostOfString(input));
}
*/
} catch (IOException e) {
e.printStackTrace();
}
}
}
//System.out.println("Ready for prediction");
//pI.handleInput(nGramLength);
}
}