/*******************************************************************************
* Copyright 2012 University of Southern California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* This code was developed by the Information Integration Group as part
* of the Karma project at the Information Sciences Institute of the
* University of Southern California. For more information, publications,
* and related projects, please see: http://www.isi.edu/integration
******************************************************************************/
package edu.isi.karma.modeling.semantictypes.mycrf.globaldata ;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import edu.isi.karma.modeling.semantictypes.mycrf.common.Constants;
import edu.isi.karma.modeling.semantictypes.mycrf.crfmodel.CRFModelFieldOnly;
import edu.isi.karma.modeling.semantictypes.mycrf.graph.GraphFieldOnly;
import edu.isi.karma.modeling.semantictypes.mycrf.graph.GraphInterface;
import edu.isi.karma.modeling.semantictypes.mycrf.math.Matrix;
import edu.isi.karma.modeling.semantictypes.myutils.Prnt;
/**
* This class represents global information used while creating graphs or training the CRF model
*
* @author amangoel
*
*/
public class GlobalDataFieldOnly extends GlobalDataAbstract {
public ArrayList<String> labels ;
public CRFModelFieldOnly crfModel ;
public GlobalDataFieldOnly() {
labels = new ArrayList<String>() ;
}
public void collectAllLabels(ArrayList<String> files) {
// This method collects all labels and puts them into an ArrayList
BufferedReader br = null ;
String line = null ;
for(String file : files) {
try {
br = new BufferedReader(new FileReader(file)) ;
line = br.readLine() ;
br.close() ;
}
catch(Exception e) {
e.printStackTrace() ;
Prnt.endIt("Error, quiting.") ;
}
String[] tokens = line.split("\\s+");
String label = tokens[tokens.length-1] ;
if (!labels.contains(label)) {
labels.add(label) ;
}
}
}
public void errorGradient(double[] gradient) {
double invSD = 1.0 / (Constants.STANDARD_DEVIATION * Constants.STANDARD_DEVIATION) ;
for(int i=0;i<gradient.length;i++) {
gradient[i] = 0.0 ;
}
double[] tmpGradient = new double[gradient.length] ;
for(GraphInterface graphI : trainingGraphs) {
GraphFieldOnly graph = (GraphFieldOnly) graphI ;
graph.logLikelihoodGradient(tmpGradient) ;
Matrix.plusEquals(gradient, tmpGradient, 1) ;
}
for(int i=0;i<gradient.length;i++) {
gradient[i] = -gradient[i] ;
}
Matrix.plusEquals(gradient, crfModel.weights, invSD) ;
}
public double errorValue() {
double error = 0 ;
for(GraphInterface graphI : trainingGraphs) {
GraphFieldOnly graph = (GraphFieldOnly) graphI ;
error+=graph.logLikelihood() ;
}
error = - error + Matrix.dotProduct(crfModel.weights, crfModel.weights) / (2 * Constants.STANDARD_DEVIATION * Constants.STANDARD_DEVIATION) ;
return error ;
}
}