/**
* Copyright (c) 2011, Regents of the University of Colorado All rights
* reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer. Redistributions in binary
* form must reproduce the above copyright notice, this list of conditions and
* the following disclaimer in the documentation and/or other materials provided
* with the distribution. Neither the name of the University of Colorado at
* Boulder nor the names of its contributors may be used to endorse or promote
* products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package clear.dep.srl;
import clear.dep.DepNode;
import clear.parse.SRLParser;
import clear.util.cluster.Prob2dMap;
import clear.util.tuple.JObjectObjectTuple;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashSet;
public class SRLProb {
static public final String SYM_PREV = "<";
static public final String SYM_NEXT = ">";
static public final String SYM_ACTIVE = "a";
static public final String SYM_PASSIVE = "p";
static public final String ARG_NONE = "NONE";
static public final String ARG_END = "END";
private Prob2dMap m_prob1a;
private Prob2dMap m_prob2a;
private Prob2dMap m_prob2n;
public SRLProb() {
m_prob1a = new Prob2dMap();
m_prob2a = new Prob2dMap();
m_prob2n = new Prob2dMap();
}
// ============================= Retrieve Key =============================
public String getKey(DepNode pred, byte dir) {
String postfix, feat;
if (dir == SRLParser.DIR_LEFT) {
postfix = SYM_PREV;
} else {
postfix = SYM_NEXT;
}
if ((feat = pred.getFeat("vo")) != null && feat.equals("1")) {
postfix += SYM_PASSIVE;
} else {
postfix += SYM_ACTIVE;
}
return pred.lemma + postfix;
}
public String getKey(DepNode pred, String prevArg, byte dir) {
return getKey(pred, dir) + "|" + prevArg;
}
public boolean isPrevArg(String label) {
return (label.startsWith(SYM_PREV));
}
public boolean isNextArg(String label) {
return (label.startsWith(SYM_NEXT));
}
// ============================= Count 1st-degree =============================
/**
* For training.
*/
public void add1dArgs(DepNode pred, HashSet<String> sArgs) {
HashSet<String> pSet = new HashSet<>();
HashSet<String> nSet = new HashSet<>();
for (String label : sArgs) {
if (isPrevArg(label)) {
pSet.add(label);
} else {
nSet.add(label);
}
}
String pKey = getKey(pred, SRLParser.DIR_LEFT);
String nKey = getKey(pred, SRLParser.DIR_RIGHT);
if (pSet.isEmpty()) {
m_prob1a.increment(pKey, ARG_END);
} else {
m_prob1a.increment(pKey, pSet);
}
if (nSet.isEmpty()) {
m_prob1a.increment(nKey, ARG_END);
} else {
m_prob1a.increment(nKey, nSet);
}
}
// ============================= Count 2nd-degree =============================
/**
* For training.
*/
public void add2dArgs(DepNode pred, ArrayList<SRLArg> lsArgs) {
ArrayList<String> pList = new ArrayList<>();
ArrayList<String> nList = new ArrayList<>();
for (SRLArg arg : lsArgs) {
if (isPrevArg(arg.label)) {
pList.add(arg.label);
} else {
nList.add(arg.label);
}
}
JObjectObjectTuple<String, String> prevArgs = new JObjectObjectTuple<>(ARG_NONE, ARG_NONE);
add2dArgsAux(pList, pred, prevArgs, SRLParser.DIR_LEFT);
add2dArgsAux(nList, pred, prevArgs, SRLParser.DIR_RIGHT);
}
/**
* Called from {@link SRLProb#add1dArgs(DepNode, HashSet)}.
*/
private void add2dArgsAux(ArrayList<String> list, DepNode pred, JObjectObjectTuple<String, String> prevArgs, byte dir) {
for (String currArg : list) {
m_prob2a.increment(getKey(pred, prevArgs.o1, dir), currArg);
m_prob2n.increment(getKey(pred, prevArgs.o2, dir), currArg);
prevArgs.o1 = currArg;
if (currArg.substring(1).matches("A\\d")) {
prevArgs.o2 = currArg;
}
}
m_prob2a.increment(getKey(pred, prevArgs.o1, dir), ARG_END);
m_prob2n.increment(getKey(pred, prevArgs.o2, dir), ARG_END);
}
// ============================= Print =============================
public void printAll(String filename) {
DecimalFormat format = new DecimalFormat("#0.0000");
m_prob1a.print(filename + ".p1a", format);
m_prob2a.print(filename + ".p2a", format);
m_prob2n.print(filename + ".p2n", format);
}
}