/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.pmerienne.trident.ml.clustering; import java.util.List; import com.github.pmerienne.trident.ml.clustering.Clusterer; import com.github.pmerienne.trident.ml.core.Instance; import com.github.pmerienne.trident.ml.testing.data.DatasetUtils; public abstract class ClustererTest { private final static Integer FOLD_NB = 10; /** * Cross validation with 10 folds * * @param * @param <F> * @param clusterer * @param samples * @return */ protected double eval(Clusterer clusterer, List<Instance<Integer>> samples) { double randIndex = 0.0; for (int i = 0; i < FOLD_NB; i++) { List<Instance<Integer>> training = DatasetUtils.getTrainingFolds(i, FOLD_NB, samples); List<Instance<Integer>> eval = DatasetUtils.getEvalFold(i, FOLD_NB, samples); randIndex += this.eval(clusterer, training, eval); } return randIndex / FOLD_NB; } protected double eval(Clusterer clusterer, List<Instance<Integer>> training, List<Instance<Integer>> eval) { clusterer.reset(); // Train for (Instance<Integer> sample : training) { clusterer.update(sample.features); } RandEvaluator randEvaluator = new RandEvaluator(); double randIndex = randEvaluator.evaluate(clusterer, eval); return randIndex; } }