package quickml.experiments; import org.javatuples.Pair; import quickml.data.instances.ClassifierInstance; import quickml.data.instances.RegressionInstance; import quickml.supervised.crossValidation.RegressionLossChecker; import quickml.supervised.crossValidation.SimpleCrossValidator; import quickml.supervised.crossValidation.data.FoldedData; import quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions.RegressionRMSELossFunction; import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest; import quickml.supervised.ensembles.randomForest.randomRegressionForest.RandomRegressionForest; import quickml.supervised.ensembles.randomForest.randomRegressionForest.RandomRegressionForestBuilder; import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; import quickml.supervised.tree.regressionTree.OptimizedRegressionForests; import quickml.supervised.tree.regressionTree.RegressionTree; import quickml.supervised.tree.regressionTree.RegressionTreeBuilder; import quickml.utlities.CSVToInstanceReader; import quickml.utlities.CSVToInstanceReaderBuilder; import quickml.utlities.selectors.NumericSelector; import java.io.Serializable; import java.util.List; import java.util.Map; /** * Created by alexanderhawk on 9/16/15. */ public class kin88nm { public static void main(String[] args) { CSVToInstanceReaderBuilder csvToInstanceReaderBuilder = new CSVToInstanceReaderBuilder().numericSelector(new NumericSelector() { @Override public boolean isNumeric(String columnName) { return true; } @Override public String cleanValue(String value) { return value; } }).delimiter(',').collumnNameForLabel("x8").hasHeader(false); CSVToInstanceReader csvToInstanceReader =csvToInstanceReaderBuilder.buildCsvReader(); try { List<RegressionInstance> allTrainingData = csvToInstanceReader.readRegressionInstancesFromCsv("uci-20070111-kin8nm.csv"); List<RegressionInstance> trainData = csvToInstanceReader.readRegressionInstancesFromCsv("/Users/alexanderhawk/msda-denoising/spearmint/data/kin8nm_train.csv"); List<RegressionInstance> valData = csvToInstanceReader.readRegressionInstancesFromCsv("/Users/alexanderhawk/msda-denoising/spearmint/data/kin8nm_test.csv"); RegressionTreeBuilder<RegressionInstance> regressionTreeBuilder = new RegressionTreeBuilder<>() .degreeOfGainRatioPenalty(1.0) .attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.5)) .maxDepth(18) .minLeafInstances(2) .minSplitFraction(0.1) .numNumericBins(10) .numSamplesPerNumericBin(20); RandomRegressionForestBuilder<RegressionInstance> regressionForestBuilder = new RandomRegressionForestBuilder<>(regressionTreeBuilder).numTrees(400); //RegressionTree regressionTree = regressionTreeBuilder.buildPredictiveModel(trainData); RandomRegressionForest randomRegressionForest = regressionForestBuilder.buildPredictiveModel(trainData); //Pair<Map<String, Serializable>, RandomRegressionForest> randomForestPair = OptimizedRegressionForests.<RegressionInstance>getOptimizedRandomForest(trainData); //RandomRegressionForest randomRegressionForest = randomForestPair.getValue1(); double loss =0; for (RegressionInstance instance: valData) { loss+=(instance.getLabel() - randomRegressionForest.predict(instance.getAttributes())) *(instance.getLabel() - randomRegressionForest.predict(instance.getAttributes())); } loss=Math.sqrt(loss/valData.size()); System.out.println("loss " + loss); SimpleCrossValidator simpleCrossValidator = new SimpleCrossValidator(regressionTreeBuilder, new RegressionLossChecker(new RegressionRMSELossFunction()), new FoldedData(allTrainingData, 8, 8)); // double loss=simpleCrossValidator.getLossForModel(); System.out.println("here"); } catch (Exception e) { e.printStackTrace(); throw new RuntimeException(e); } } }