package rainbownlp.analyzer.evaluation.regression;
import java.util.ArrayList;
import java.util.List;
import rainbownlp.analyzer.evaluation.ICrossfoldValidator;
import rainbownlp.machinelearning.LearnerEngine;
import rainbownlp.machinelearning.MLExample;
import rainbownlp.util.ConfigurationUtil;
public class RegressionCrossValidation implements ICrossfoldValidator {
LearnerEngine mlModel;
public RegressionCrossValidation(LearnerEngine learningEngine)
{
mlModel = learningEngine;
}
public RegressionEvaluationResult crossValidation(List<MLExample> examples, int folds) throws Exception
{
int foldCount = examples.size()/folds;
ArrayList<RegressionEvaluationResult> results =
new ArrayList<RegressionEvaluationResult>();
for(int foldIndex = 0;foldIndex<folds;foldIndex++)
{
ConfigurationUtil.crossFoldCurrent = foldIndex+1;
int start_index = foldIndex*foldCount;
int end_index = (foldIndex+1)*foldCount;
if(end_index>=examples.size())
end_index = examples.size();
// HibernateUtil.startTransaction();
List<MLExample> train_set = new ArrayList<MLExample>();
for(int i=0;i<start_index;i++)
train_set.add(examples.get(i).clone());
for(int i=end_index;i<examples.size();i++)
train_set.add(examples.get(i).clone());
mlModel.train(train_set);
train_set = null;
System.gc();
List<MLExample> test_set = new ArrayList<MLExample>();
for(int i=start_index;i<end_index;i++)
test_set.add(examples.get(i).clone());
mlModel.test(test_set);
// HibernateUtil.endTransaction();
results.add(RegressionEvaluator.getEvaluationResult(test_set));
}
RegressionEvaluationResult averageRes = new RegressionEvaluationResult();
for(RegressionEvaluationResult fold_result : results)
{
averageRes.add(fold_result);
}
ConfigurationUtil.crossFoldCurrent = 0;
return averageRes;
}
@Override
public LearnerEngine getLearnerEngine() {
return mlModel;
}
}