package rainbownlp.analyzer.evaluation.classification;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import rainbownlp.analyzer.evaluation.ICrossfoldValidator;
import rainbownlp.machinelearning.LearnerEngine;
import rainbownlp.machinelearning.MLExample;
import rainbownlp.util.FileUtil;
import rainbownlp.util.ConfigurationUtil;
public class CrossValidation implements ICrossfoldValidator {
LearnerEngine mlModel;
public CrossValidation(LearnerEngine learningEngine)
{
mlModel = learningEngine;
}
public EvaluationResult crossValidation(List<MLExample> examples, int folds) throws Exception
{
int foldCount = examples.size()/folds;
ArrayList<EvaluationResult> results =
new ArrayList<EvaluationResult>();
for(int foldIndex = 0;foldIndex<folds;foldIndex++)
{
ConfigurationUtil.crossFoldCurrent = foldIndex+1;
int start_index = foldIndex*foldCount;
int end_index = (foldIndex+1)*foldCount;
if(end_index>=examples.size())
end_index = examples.size();
// HibernateUtil.startTransaction();
List<MLExample> train_set = new ArrayList<MLExample>();
for(int i=0;i<start_index;i++)
train_set.add(examples.get(i).clone());
for(int i=end_index;i<examples.size();i++)
train_set.add(examples.get(i).clone());
mlModel.train(train_set);
train_set = null;
System.gc();
List<MLExample> test_set = new ArrayList<MLExample>();
for(int i=start_index;i<end_index;i++)
test_set.add(examples.get(i).clone());
mlModel.test(test_set);
// HibernateUtil.endTransaction();
results.add(Evaluator.getEvaluationResult(test_set));
}
HashMap<String, ResultRow>
evaluationAverageResult =
new HashMap<String,
ResultRow>();
for(EvaluationResult fold_result : results)
{
for(String evaluated_class: fold_result.evaluationResultByClass.keySet())
{
ResultRow row =
fold_result.evaluationResultByClass.get(evaluated_class);
ResultRow averageRow =
evaluationAverageResult.get(evaluated_class);
if(averageRow==null)
evaluationAverageResult.put(evaluated_class,row);
else{
averageRow.FN += row.FN;
averageRow.FP += row.FP;
averageRow.TN += row.TN;
averageRow.TP += row.TP;
}
}
}
for(String evaluated_class: evaluationAverageResult.keySet())
{
ResultRow averageRow =
evaluationAverageResult.get(evaluated_class);
FileUtil.logLine("Class: "+evaluated_class);
averageRow.print();
}
EvaluationResult er = new EvaluationResult();
er.evaluationResultByClass = evaluationAverageResult;
ConfigurationUtil.crossFoldCurrent = 0;
return er;
}
@Override
public LearnerEngine getLearnerEngine() {
return mlModel;
}
}