/*******************************************************************************
* Copyright 2012 University of Southern California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* This code was developed by the Information Integration Group as part
* of the Karma project at the Information Sciences Institute of the
* University of Southern California. For more information, publications,
* and related projects, please see: http://www.isi.edu/integration
******************************************************************************/
package edu.isi.karma.modeling.semantictypes.mycrf.graph ;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import edu.isi.karma.modeling.semantictypes.mycrf.common.Constants;
import edu.isi.karma.modeling.semantictypes.mycrf.common.Node;
import edu.isi.karma.modeling.semantictypes.mycrf.fieldonly.LblFtrPair;
import edu.isi.karma.modeling.semantictypes.mycrf.globaldata.GlobalDataFieldOnly;
import edu.isi.karma.modeling.semantictypes.mycrf.math.LargeNumber;
import edu.isi.karma.modeling.semantictypes.myutils.Prnt;
/**
* This class represents a single node graph.
* It also stores its computed partition function value,
* potential, and
* marginals.
*
* @author amangoel
*
*/
public class GraphFieldOnly implements GraphInterface {
GlobalDataFieldOnly globalData ;
public Node node ;
LargeNumber Z ;
LargeNumber graphPotential ;
double[] nodeMarginals ;
public GraphFieldOnly(String file, boolean labeled, GlobalDataFieldOnly globalData) {
this.globalData = globalData ;
BufferedReader br = null ;
String line = null ;
try {
br = new BufferedReader(new FileReader(file)) ;
line = br.readLine() ;
br.close() ;
}
catch(Exception e) {
e.printStackTrace() ;
Prnt.endIt("Quiting") ;
}
String[] tokens = line.split("\\s+") ;
node = new Node(Constants.FIELD_TYPE, 0, 0) ;
node.string = tokens[0] ;
for(int i=1;i<tokens.length-(labeled?1:0);i++) {
node.features.add(tokens[i]) ;
}
if(labeled){
node.labelIndex = globalData.labels.indexOf(tokens[tokens.length-1]) ;
}
else {
node.labelIndex = -1 ;
}
nodeMarginals = new double[globalData.labels.size()] ;
}
public GraphFieldOnly(String fieldString, String label, ArrayList<String> features, GlobalDataFieldOnly globalData) {
if (features == null || features.size() == 0) {
Prnt.endIt("GraphFieldOnly constructor called with empty feature list.") ;
}
this.globalData = globalData ;
this.node = new Node(Constants.FIELD_TYPE, 0, 0) ;
node.string = fieldString ;
if (label == null) {
node.labelIndex = -1 ;
}
else {
label = label.trim() ;
node.labelIndex = globalData.labels.indexOf(label) ;
if (node.labelIndex == -1) {
Prnt.endIt("Label not found in globalData.labels") ;
}
}
node.features = new ArrayList<String>(features) ;
nodeMarginals = new double[globalData.labels.size()] ;
}
public void compute_Z() {
LargeNumber tmp = new LargeNumber(0.0, 0) ;
for(int i=0;i<globalData.labels.size();i++) {
tmp.plusEquals(LargeNumber.makeLargeNumberUsingExponent(potentialExpForLabelIndex(i))) ;
}
Z = tmp ;
}
public void compute_graphPotential() {
this.graphPotential = LargeNumber.makeLargeNumberUsingExponent(potentialExpForLabelIndex(node.labelIndex)) ;
}
public void computeNodeMarginals() {
if (globalData.labels.size() != nodeMarginals.length) {
nodeMarginals = new double[globalData.labels.size()] ;
}
for(int i=0;i<globalData.labels.size();i++) {
nodeMarginals[i] = LargeNumber.divide(LargeNumber.makeLargeNumberUsingExponent(potentialExpForLabelIndex(i)), this.Z) ;
}
}
public double potentialExpForLabelIndex(int labelIndex) {
double exp = 0.0 ;
for(int f=0;f<globalData.crfModel.ffs.size();f++) {
LblFtrPair ff = globalData.crfModel.ffs.get(f) ;
if (ff.labelIndex == labelIndex && this.node.features.contains(ff.feature)) {
exp+=globalData.crfModel.weights[f] ;
}
}
return exp ;
}
public double logLikelihood() {
double likelihood = LargeNumber.divide(graphPotential, Z) ;
return Math.log(likelihood);
}
public void logLikelihoodGradient(double[] gradient) {
double lhs = 0, rhs = 0 ;
for(int f=0;f<globalData.crfModel.ffs.size();f++) {
LblFtrPair ff = globalData.crfModel.ffs.get(f) ;
lhs = (this.node.labelIndex == ff.labelIndex && this.node.features.contains(ff.feature)) ? 1.0 : 0.0 ;
rhs = (node.features.contains(ff.feature)) ? this.nodeMarginals[ff.labelIndex] : 0.0 ;
gradient[f] = lhs - rhs ;
}
}
public void computeGraphPotentialAndZ() {
compute_graphPotential() ;
compute_Z() ;
}
public void computeGraphPotentialAndZAndMarginals() {
computeGraphPotentialAndZ() ;
computeNodeMarginals() ;
}
}