import java.io.File;
import java.util.Arrays;
import java.util.PriorityQueue;
import java.util.Vector;
/**
* Hyper String Finite State Automata Object holding all possible combinations
* of first letter capitalization and punctuation symbols .,?
*
* @author joakimlilja
*
*/
public class HyperStringFSA3 {
public static final String EMPTY_PUNCT = "" + ((char) 007) + "EMPTY ";
//public static final String[] TRANSITIONS = { EMPTY_PUNCT, ",COMMA ", ".PERIOD ", "?QMARK ", "!EXCL " };
//public static final String[] TRANSITIONS = { EMPTY_PUNCT, ",COMMA ", ".PERIOD "};
public static final String[] TRANSITIONS = { EMPTY_PUNCT, ".PERIOD "};
public static final String[] POSTPROCESSES = { " ", ", ", ". ", "? ", "! " };
public static final int TRANSITION_COUNT = TRANSITIONS.length;
//public static final int STATES_COUNT = 2;
private final static String END_OF_LINE = "¿EOL";
Vector<String[]> outputs;
NGramWrapper nGram;
private double lowestValue = Double.MAX_VALUE;
private double highestValue = Double.MIN_VALUE;
private boolean optimalNotFound = true;
private Node optimalNode;
/**
* Constructor creating a FSA based on the specified String array consisting
* of words
*
* @param s
* Array of words
*
*/
public HyperStringFSA3(String[] s, NGramWrapper nGram) {
outputs = new Vector<String[]>();
this.nGram = nGram;
constructFSA(s, outputs);
}
/**
* Construct the FSA with all possible outputs with each emission having a
* cost
*
* @param s
* Array of words
*
* @param outputs
* Vector holding all possible outputs
*
*/
private void constructFSA(String[] s, Vector<String[]> outputs) {
//Node root = new Node("", 1.0);
//Node startNode = new Node("START ", 1.0);
//root.children.add(startNode);
//startNode.parent = root;
Long time = System.currentTimeMillis();
//root = generateNodes(s, root);
//It is a valid assumption that start of the line is START.
/*
startNode = generateNodes(Arrays.copyOfRange(s, 1, s.length), startNode, null);
time = System.currentTimeMillis() - time;
System.err.println("Generated nodes in "+time+" msec.");
time = System.currentTimeMillis();
generateOutputs(root, outputs);
time = System.currentTimeMillis() - time;
System.err.println("Generated output in "+time+" msec.");
System.err.println("Highest value = " +highestValue);
System.err.println("Lowest value = "+lowestValue);
highestValue = Integer.MIN_VALUE;
lowestValue = Integer.MAX_VALUE;
*/
//time = System.currentTimeMillis();
/*
startNode = new Node("START ", 1.0);
usePriorityQueue(Arrays.copyOfRange(s, 1, s.length), startNode);
*/
Node startNode = new Node(EMPTY_PUNCT, 1.0);
usePriorityQueue(s, startNode);
//time = System.currentTimeMillis() - time;
//System.err.println("Generated priority queue in "+time+" msec.");
/*
StringBuilder test = new StringBuilder();
Node node = findHighestValuedChild(root);
System.out.println();
backTrackFromChild(node, test);
System.out.println(test.toString());
System.out.println();
*/
}
private void usePriorityQueue(String[] s, Node startNode) {
PriorityQueue<PriorityQueueElement> pq = new PriorityQueue<PriorityQueueElement>();
pq.add(new PriorityQueueElement(s, startNode));
//System.err.println("Printing priority queue");
//Node result = null;
while(optimalNotFound&&(!pq.isEmpty())) {
PriorityQueueElement pqe = pq.poll();
//result = generateNodes(pqe.s, pqe.self, pq);
generateNodes(pqe.s, pqe.self, pq);
}
pq.clear();
//return result;
}
/**
* Generate the outputs using a tree structure
*
* @param node
* @param outputs
*/
private void generateOutputs(Node node, Vector<String[]> outputs) {
if (node.children.size() == 0) {
String s = backTrack(node, "", Integer.MAX_VALUE) + node.cost;
outputs.add(s.split(" "));
} else {
for (int i = 0; i < node.children.size(); i++) {
generateOutputs(node.children.elementAt(i), outputs);
}
}
}
private Node findHighestValuedChild(Node root) {
if(root.children.size()==0) {
return root;
}
double value = Double.NEGATIVE_INFINITY;
int index = -1;
for(int i = 0; i < root.children.size(); i++) {
Node node = findHighestValuedChild(root.children.get(i));
if(node.cost>value) {
value = node.cost;
index = i;
}
}
return root.children.get(index);
}
private void backTrackFromChild(Node n, StringBuilder sb) {
if(n.parent!=null) {
backTrackFromChild(n.parent, sb);
}
if(n.value.equals(EMPTY_PUNCT)) {
//sb.append(' ');
} else {
sb.append(n.toString());
}
}
/**
* Backtrack from end node to generate the output of that path
*
* @param node
* @param s
* @return output
*/
private String backTrack(Node node, String s, int n) {
String value = node.value.equals(EMPTY_PUNCT) ? "" : node.value;
if (node.parent == null || n == 0) {
return value + s;
} else {
return backTrack(node.parent, value + s, n - 1);
}
}
private double getCostOfString(String word) {
String ngram[] = word.split(" ");
double value = 1.0D;
for(int i = ngram.length-1; i >= nGram.getNGramLength(); i--) {
String[] argument = new String[nGram.getNGramLength()];
System.arraycopy(ngram, i-nGram.getNGramLength(), argument, 0, nGram.getNGramLength());
for(int j = 0; j < argument.length; j++) {
System.err.print(argument[j]+" ");
}
System.err.println();
value *= nGram.getCostOfNGram(argument);
}
return value;
}
/**
* Generate children
*
* @param s
* @param parent
* @return
*/
private Node generateNodes(String[] s, Node parent, PriorityQueue<PriorityQueueElement> pq) {
if((pq!=null)&&(s==null)) {
optimalNotFound=false;
optimalNode = parent;
//StringBuilder sb = new StringBuilder();
//backTrackFromChild(parent, sb);
//System.out.println(sb.toString()+"\t"+parent.cost);
//System.err.println(sb.toString()+"\t"+parent.cost);
return parent;
}
String[] ngramHolder = new String[nGram.getNGramLength()];
//Node unCapNode = new Node(unCapWord + " ", parent.cost * getCost(parent, unCapWord));
//System.err.println(backTrack(parent, unCapWord, nGram.getNGramLength() - 2)+"\t"+unCapNode.cost);
//String unCapWord = deCapitalizeWord(s[0]);
//Node unCapNode = new Node(unCapWord+ " ");
//unCapNode.parent = parent;
//unCapNode.cost = getCost2(unCapNode, ngramHolder, ngramHolder.length-1)*parent.cost;
Node unCapNode = new Node(s[0]+" ");
unCapNode.parent = parent;
unCapNode.cost = getCost2(unCapNode, ngramHolder, ngramHolder.length-1)*parent.cost;
/*
for(int i = 0; i < ngramHolder.length; i++) {
System.err.print(ngramHolder[i]+" ");
}
System.err.print("\t"+unCapNode.cost);
System.err.println();
*/
//String capWord = capitalizeWord(s[0]);
//Node capNode = new Node(capWord + " ", parent.cost * getCost(parent, capWord));
//Node capNode = new Node(capWord+" ");
//capNode.parent = parent;
//capNode.cost = getCost2(capNode, ngramHolder, ngramHolder.length-1)*parent.cost;
/*
double valueOfUnCap = unCapNode.cost;
double valueOfCap = capNode.cost;
String[] ngramHolder = new String[nGram.getNGramLength()];
double newUnCapValue = getCost2(unCapNode, ngramHolder, nGram.getNGramLength()-1)*parent.cost;
double newCapValue = getCost2(capNode, ngramHolder, nGram.getNGramLength()-1)*parent.cost;
if(valueOfCap!=newCapValue) {
System.err.println("WTH!");
System.err.println(valueOfCap);
System.err.println(newCapValue);
} else if(valueOfUnCap!=newUnCapValue) {
System.err.println("WTH!");
}
*/
generateTransitions(unCapNode);
//generateTransitions(capNode);
parent.children.add(unCapNode);
//parent.children.add(capNode);
/*
StringBuilder build = new StringBuilder();
backTrackFromChild(parent, build);
System.err.println("Printing from "+build.toString()+"\t"+parent.cost);
for(int i = 0; i < unCapNode.children.size(); i++) {
StringBuilder sb = new StringBuilder();
backTrackFromChild(unCapNode.children.get(i), sb);
System.err.println("\t\t" + sb.toString() + "\t" + unCapNode.children.get(i).cost);
sb =new StringBuilder();
backTrackFromChild(capNode.children.get(i), sb);
System.err.println("\t\t"+sb.toString()+"\t"+capNode.children.get(i).cost);
}
*/
//Ska för all del vara > 0....
if (s.length > 1) { //Borde vara s.length < NGramLength right .... ? nope...
for (int i = 0; i < unCapNode.children.size(); i++) {
if (pq == null) {
unCapNode.children.set(i, generateNodes(Arrays.copyOfRange(s, 1, s.length), unCapNode.children.get(i), null));
//capNode.children.set(i, generateNodes(Arrays.copyOfRange(s, 1, s.length), capNode.children.get(i), null));
} else {
pq.offer(new PriorityQueueElement(Arrays.copyOfRange(s, 1, s.length), unCapNode.children.get(i)));
//pq.offer(new PriorityQueueElement(Arrays.copyOfRange(s, 1, s.length), capNode.children.get(i)));
}
}
} else if(pq!=null) {
for (int i = 0; i < unCapNode.children.size(); i++) {
pq.offer(new PriorityQueueElement(null, unCapNode.children.get(i)));
//pq.offer(new PriorityQueueElement(null, capNode.children.get(i)));
}
} else {
for (int i = 0; i < unCapNode.children.size(); i++) {
double unCapNodeCost = unCapNode.children.get(i).cost;
if(unCapNodeCost>highestValue) {
highestValue=unCapNodeCost;
StringBuilder sb = new StringBuilder();
backTrackFromChild(unCapNode.children.get(i), sb);
System.err.println(sb.toString()+"\t"+unCapNodeCost);
}
if(unCapNodeCost<lowestValue) {
lowestValue=unCapNodeCost;
}
}
/*
for (int i = 0; i < capNode.children.size(); i++) {
double nodeCost = capNode.children.get(i).cost;
if(nodeCost>highestValue) {
highestValue=nodeCost;
StringBuilder sb = new StringBuilder();
backTrackFromChild(capNode.children.get(i), sb);
System.err.println(sb.toString()+"\t"+nodeCost);
}
if(nodeCost<lowestValue) {
lowestValue=nodeCost;
}
}
*/
}
/*
if(s[0].equals(END_OF_LINE)&&(pq!=null)) { //Borde vara s[NGramLength - 1] (typ)
optimalNotFound=false;
StringBuilder test = new StringBuilder();
backTrackFromChild(parent, test);
test.append(END_OF_LINE);
System.out.println(test+"\t"+parent.cost);
return parent;
}
*/
return parent;
}
class PriorityQueueElement implements Comparable<PriorityQueueElement> {
String[] s;
Node self;
public PriorityQueueElement(String[] s, Node self) {
this.s = s;
this.self = self;
}
@Override
public int compareTo(PriorityQueueElement pqe) {
return self.compareTo(pqe.self);
}
}
private double getCost2(Node parent, String[] ngram, int length) {
if(parent.value.equals(EMPTY_PUNCT)) {
if(parent.parent!=null) {
return getCost2(parent.parent, ngram, length);
} else {
return 1.0D;
}
}
ngram[length] = parent.value.trim();
if(length==0) {
return nGram.getCostOfNGram(ngram);
} else if(parent.parent!=null) {
return getCost2(parent.parent, ngram, length-1);
} else {
/*
if(nGram.getNGramLength()>3) {
System.err.println("THIS WILL CRASH!");
throw new IllegalArgumentException();
}
*/
return 1.0D;
//This really should be 1.0D... right ?...
//return nGram.getCostOfNGram(Arrays.copyOfRange(ngram, length, ngram.length-length));
}
}
private double getCost(Node parent, String word) {
String[] ngram = (backTrack(parent, word, nGram.getNGramLength() - 2)
.split(" "));
for(int i = 0; i < ngram.length; i++) {
System.err.print(ngram[i]+" ");
}
System.err.println();
// System.err.println(Arrays.toString(ngram.split(" ")));
double cost = Double.NaN;
if (ngram.length >= 0) { //Varför större än 1 istället för större än 0 ... ?
cost = nGram.getCostOfNGram(ngram);
}
// System.err.println("Cost: " + cost);
return cost;
}
public String getOptimalString() {
StringBuilder sb = new StringBuilder();
backTrackFromChild(optimalNode, sb);
return sb.toString();
}
/**
* Generate the possible punctuation transitions
*
* @param node
*/
private void generateTransitions(Node parent) {
for (int i = 0; i < TRANSITION_COUNT; i++) {
String emission = TRANSITIONS[i];
Node transNode = null;
if (emission.equals(EMPTY_PUNCT)) { //EMPTY_PUNCT should be replaced by null.
//transNode = new Node(emission, parent.cost*0.5); //För att du vill ha en kostnad för att inte ha en punctuation ?
transNode = new Node(emission, parent.cost);
transNode.parent = parent;
parent.children.add(transNode);
} else {
//transNode = new Node(emission, parent.cost*getCost(parent, emission));
String[] ngram = new String[nGram.getNGramLength()];
transNode = new Node(emission);
transNode.parent = parent;
parent.children.add(transNode);
transNode.cost = getCost2(transNode, ngram, ngram.length-1)*parent.cost;
}
}
}
private String capitalizeWord(String input) {
return input.substring(0, 1).toUpperCase() + input.substring(1);
}
private String deCapitalizeWord(String input) {
return input.substring(0, 1).toLowerCase() + input.substring(1);
}
public String toString() {
return outputs.toString();
}
public Vector<String[]> getOutputs() {
return outputs;
}
public static String postProcessing(String input) {
for (int i = 0; i < TRANSITION_COUNT; i++) {
input = input.replaceAll(TRANSITIONS[i], POSTPROCESSES[i]);
}
return input;
}
public static void main(String... args) {
String[] words = { "mars", "scientists" };
NGramWrapper ngw = new NGramWrapper(3);
ngw.readFile(new File("sentences.txt"));
HyperStringFSA3 fsa = new HyperStringFSA3(words, ngw);
for (String[] s : fsa.outputs) {
System.out.println(Arrays.toString(s));
}
}
private class Node implements Comparable<Node> {
String value;
double cost;
Node parent;
Vector<Node> children;
public Node(String value) {
this(value, 0.0D);
}
public Node(String value, double cost) {
children = new Vector<Node>();
this.cost = cost;
this.value = value;
}
@Override
public int compareTo(Node n2) {
if(cost<n2.cost) {
return 1;
} else if(cost>n2.cost) {
return -1;
} else {
return 0;
}
}
public String toString() {
return value;
}
}
}