/* * This file is part of ELKI: * Environment for Developing KDD-Applications Supported by Index-Structures * * Copyright (C) 2017 * ELKI Development Team * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package de.lmu.ifi.dbs.elki.application; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm; import de.lmu.ifi.dbs.elki.algorithm.classification.Classifier; import de.lmu.ifi.dbs.elki.data.ClassLabel; import de.lmu.ifi.dbs.elki.data.type.TypeUtil; import de.lmu.ifi.dbs.elki.database.AbstractDatabase; import de.lmu.ifi.dbs.elki.database.Database; import de.lmu.ifi.dbs.elki.database.StaticArrayDatabase; import de.lmu.ifi.dbs.elki.database.relation.Relation; import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection; import de.lmu.ifi.dbs.elki.datasource.FileBasedDatabaseConnection; import de.lmu.ifi.dbs.elki.datasource.MultipleObjectsBundleDatabaseConnection; import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; import de.lmu.ifi.dbs.elki.evaluation.classification.ConfusionMatrix; import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.AbstractHoldout; import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.Holdout; import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.StratifiedCrossValidation; import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.TrainingAndTestSet; import de.lmu.ifi.dbs.elki.index.IndexFactory; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.logging.statistics.Duration; import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectListParameter; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter; /** * Evaluate a classifier. * * TODO: split into application and task. * * TODO: add support for predefined test and training pairs! * * @author Erich Schubert * @since 0.7.0 * * @param <O> Object type */ public class ClassifierHoldoutEvaluationTask<O> extends AbstractApplication { /** * Class logger. */ private static final Logging LOG = Logging.getLogger(ClassifierHoldoutEvaluationTask.class); /** * Holds the database connection to get the initial data from. */ protected DatabaseConnection databaseConnection = null; /** * Indexes to add. */ protected Collection<IndexFactory<?, ?>> indexFactories; /** * Classifier to evaluate. */ protected Classifier<O> algorithm; /** * Holds the holdout. */ protected Holdout holdout; /** * Constructor. * * @param databaseConnection Data source * @param indexFactories Data indexes * @param algorithm Classification algorithm * @param holdout Evaluation holdout */ public ClassifierHoldoutEvaluationTask(DatabaseConnection databaseConnection, Collection<IndexFactory<?, ?>> indexFactories, Classifier<O> algorithm, Holdout holdout) { this.databaseConnection = databaseConnection; this.indexFactories = indexFactories; this.algorithm = algorithm; this.holdout = holdout; } @Override public void run() { Duration ptime = LOG.newDuration("evaluation.time.load").begin(); MultipleObjectsBundle allData = databaseConnection.loadData(); holdout.initialize(allData); LOG.statistics(ptime.end()); Duration time = LOG.newDuration("evaluation.time.total").begin(); ArrayList<ClassLabel> labels = holdout.getLabels(); int[][] confusion = new int[labels.size()][labels.size()]; for(int p = 0; p < holdout.numberOfPartitions(); p++) { TrainingAndTestSet partition = holdout.nextPartitioning(); // Load the data set into a database structure (for indexing) Duration dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".init.time").begin(); Database db = new StaticArrayDatabase(new MultipleObjectsBundleDatabaseConnection(partition.getTraining()), indexFactories); db.initialize(); LOG.statistics(dur.end()); // Train the classifier dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".train.time").begin(); Relation<ClassLabel> lrel = db.getRelation(TypeUtil.CLASSLABEL); algorithm.buildClassifier(db, lrel); LOG.statistics(dur.end()); // Evaluate the test set dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".evaluation.time").begin(); // FIXME: this part is still a big hack, unfortunately! MultipleObjectsBundle test = partition.getTest(); int lcol = AbstractHoldout.findClassLabelColumn(test); int tcol = (lcol == 0) ? 1 : 0; for(int i = 0, l = test.dataLength(); i < l; ++i) { @SuppressWarnings("unchecked") O obj = (O) test.data(i, tcol); ClassLabel truelbl = (ClassLabel) test.data(i, lcol); ClassLabel predlbl = algorithm.classify(obj); int pred = Collections.binarySearch(labels, predlbl); int real = Collections.binarySearch(labels, truelbl); confusion[pred][real]++; } LOG.statistics(dur.end()); } LOG.statistics(time.end()); ConfusionMatrix m = new ConfusionMatrix(labels, confusion); LOG.statistics(m.toString()); } /** * Parameterization class. * * @author Erich Schubert * * @apiviz.exclude */ public static class Parameterizer<O> extends AbstractApplication.Parameterizer { /** * Parameter to specify the holdout for evaluation, must extend * {@link de.lmu.ifi.dbs.elki.evaluation.classification.holdout.Holdout}. * <p> * Key: {@code -classifier.holdout} * </p> * <p> * Default value: {@link StratifiedCrossValidation} * </p> */ public static final OptionID HOLDOUT_ID = new OptionID("evaluation.holdout", "Holdout class used in evaluation."); /** * Holds the database connection to get the initial data from. */ protected DatabaseConnection databaseConnection = null; /** * Indexes to add. */ protected Collection<IndexFactory<?, ?>> indexFactories; /** * Classifier to evaluate. */ protected Classifier<O> algorithm; /** * Holds the holdout. */ protected Holdout holdout; @Override protected void makeOptions(Parameterization config) { super.makeOptions(config); // Get database connection. final ObjectParameter<DatabaseConnection> dbcP = new ObjectParameter<>(AbstractDatabase.Parameterizer.DATABASE_CONNECTION_ID, DatabaseConnection.class, FileBasedDatabaseConnection.class); if(config.grab(dbcP)) { databaseConnection = dbcP.instantiateClass(config); } // Get indexes. final ObjectListParameter<IndexFactory<?, ?>> indexFactoryP = new ObjectListParameter<>(AbstractDatabase.Parameterizer.INDEX_ID, IndexFactory.class, true); if(config.grab(indexFactoryP)) { indexFactories = indexFactoryP.instantiateClasses(config); } ObjectParameter<Classifier<O>> algorithmP = new ObjectParameter<>(AbstractAlgorithm.ALGORITHM_ID, Classifier.class); if(config.grab(algorithmP)) { algorithm = algorithmP.instantiateClass(config); } ObjectParameter<Holdout> holdoutP = new ObjectParameter<>(HOLDOUT_ID, Holdout.class, StratifiedCrossValidation.class); if(config.grab(holdoutP)) { holdout = holdoutP.instantiateClass(config); } } @Override protected ClassifierHoldoutEvaluationTask<O> makeInstance() { return new ClassifierHoldoutEvaluationTask<O>(databaseConnection, indexFactories, algorithm, holdout); } } /** * Runs the classifier evaluation task accordingly to the specified * parameters. * * @param args parameter list according to description */ public static void main(String[] args) { runCLIApplication(ClassifierHoldoutEvaluationTask.class, args); } }