package joshua.discriminative.training.risk_annealer; import java.io.BufferedReader; import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; import joshua.discriminative.FileUtilityOld; public abstract class AbstractMinRiskMERT { protected String configFile; protected double[] lastWeightVector; //== annealer specific protected DeterministicAnnealer annealer; protected String[] referenceFiles; protected int numPara; protected int numTrainingSentence; //protected double[] linearCorpusGainThetas; //======================================= private static final Logger logger = Logger.getLogger(AbstractMinRiskMERT.class.getSimpleName()); public AbstractMinRiskMERT(String configFile, int numTrainingSentence, String[] referenceFiles) { this.configFile = configFile; this.referenceFiles = referenceFiles; this.numTrainingSentence = numTrainingSentence; } //this function should have an option for not annealing public abstract void mainLoop(); public abstract void decodingTestSet(double[] weights, String nbestFile); protected List<Double> readBaselineFeatureWeights(String configFile){ //== get the weights List<Double> weights = new ArrayList<Double>(); BufferedReader configReader = FileUtilityOld.getReadFileStream(configFile); String line; while ((line = FileUtilityOld.readLineLzf(configReader)) != null) { line = line.trim(); if (line.matches("^\\s*\\#.*$") || line.matches("^\\s*$")) { continue; }else if (line.indexOf("=") != -1) { // parameters continue; }else{//models String[] fds = line.split("\\s+"); double weight = new Double(fds[fds.length-1].trim()); weights.add(weight); } } FileUtilityOld.closeReadFile(configReader); return weights; } protected Integer inferOracleFeatureID(String configFile){ //== get the weights BufferedReader configReader = FileUtilityOld.getReadFileStream(configFile); String line; int id = 0; Integer oracleFeatureID = null; while ((line = FileUtilityOld.readLineLzf(configReader)) != null) { line = line.trim(); if (line.matches("^\\s*\\#.*$") || line.matches("^\\s*$")) { continue; }else if (line.indexOf("=") != -1) { // parameters continue; }else{//models String[] fds = line.split("\\s+"); if("oracle".equals(fds[0])){ if(oracleFeatureID==null) oracleFeatureID = id; else{ logger.severe("more than one oralce model, must be wrong"); System.exit(1); } } id++; } } FileUtilityOld.closeReadFile(configReader); return oracleFeatureID; } protected void normalizeWeightsByFirstFeature(double[] weightVector, int featID){ double weight = weightVector[featID]; if(weight<=0){ logger.warning("first weight is negative"); //System.exit(0); } for(int i=0; i<weightVector.length; i++) weightVector[i] /= Math.abs( weight ); } }