/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MultiLabelOutput.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier; import java.util.Arrays; import mulan.core.ArgumentNullException; /** * Class representing the output of a {@link MultiLabelLearner}. * This can be a bipartition of labels into <code>true</code> and <code>false</code>, * a ranking of labels, or an array of confidence values for each label. * * @author Grigorios Tsoumakas */ public class MultiLabelOutput { /** a bipartition of the labels into relevant and irrelevant */ private boolean[] bipartition; /** the rank of each label, ranging from 1 to array length */ private int[] ranking; /** the probability of each label being positive */ private double[] confidences; /** * Creates a new instance of {@link MultiLabelOutput}. * @param bipartition bipartition of labels * @throws ArgumentNullException if bipartitions is null. */ public MultiLabelOutput(boolean[] bipartition) { if (bipartition == null) { throw new ArgumentNullException("bipartitions"); } this.bipartition = Arrays.copyOf(bipartition, bipartition.length); } /** * Creates a new instance of {@link MultiLabelOutput}. * @param ranking ranking of labels * @throws ArgumentNullException if ranking is null */ public MultiLabelOutput(int[] ranking) { if (ranking == null) { throw new ArgumentNullException("ranking"); } this.ranking = Arrays.copyOf(ranking, ranking.length); } /** * Creates a new instance of {@link MultiLabelOutput}. It creates a ranking * based on the probabilities and a bipartition based on a threshold for the probabilities. * * @param probabilities score of each label * @param threshold threshold to output bipartition based on probabilities * @throws ArgumentNullException if probabilities is null */ public MultiLabelOutput(double[] probabilities, double threshold) { if (probabilities == null) { throw new ArgumentNullException("probabilities"); } confidences = probabilities; ranking = ranksFromValues(probabilities); bipartition = new boolean[probabilities.length]; for (int i = 0; i < probabilities.length; i++) { if (probabilities[i] >= threshold) { bipartition[i] = true; } } } /** * Creates a new instance of {@link MultiLabelOutput}. It creates a ranking * based on the probabilities. * * @param probabilities score of each label * @throws ArgumentNullException if probabilities is null */ public MultiLabelOutput(double[] probabilities) { if (probabilities == null) { throw new ArgumentNullException("probabilities"); } confidences = probabilities; ranking = ranksFromValues(probabilities); } /** * Creates a new instance of {@link MultiLabelOutput}. * @param bipartition bipartition of labels * @param someConfidences values of labels * @throws ArgumentNullException if either of the input parameters is null or * their dimensions do not match */ public MultiLabelOutput(boolean[] bipartition, double[] someConfidences) { this(bipartition); if (someConfidences == null) { throw new ArgumentNullException("someConfidences"); } if (bipartition.length != someConfidences.length) { this.bipartition = null; throw new IllegalArgumentException("The dimensions of the bipartition " + " and confidences arrays do not match."); } confidences = Arrays.copyOf(someConfidences, someConfidences.length); ranking = ranksFromValues(someConfidences); } /** * Gets bipartition of labels. * @return the bipartition */ public boolean[] getBipartition() { return bipartition; } /** * Determines whether the {@link MultiLabelOutput} has bipartition of labels. * @return <code>true</code> if has bipartition; otherwise <code>false</code> */ public boolean hasBipartition() { return (bipartition != null); } /** * Gets ranking of labels. * @return the ranking */ public int[] getRanking() { return ranking; } /** * Determines whether the {@link MultiLabelOutput} has ranking of labels. * @return <code>true</code> if has ranking; otherwise <code>false</code> */ public boolean hasRanking() { return (ranking != null); } /** * Gets confidences of labels. * @return the confidences */ public double[] getConfidences() { return confidences; } /** * Determines whether the {@link MultiLabelOutput} has confidences of labels. * @return <code>true</code> if has confidences; otherwise <code>false</code> */ public boolean hasConfidences() { return (confidences != null); } /** * Creates a ranking form specified values/confidences. * * @param values the values/confidences to be converted to ranking * @return the ranking of given values/confidences */ public static int[] ranksFromValues(double[] values) { int[] temp = weka.core.Utils.stableSort(values); int[] ranks = new int[values.length]; for (int i = 0; i < values.length; i++) { ranks[temp[i]] = values.length - i; } return ranks; } /** * Tests if two MultiLabelOutput objects are equal * * @param mlo a MultiLabelOutput object * @return true if the given object represents a MultiLabelOutput equivalent to this MultiLabelOutput, false otherwise */ @Override public boolean equals(Object mlo) { if (mlo == this) { return true; } if (!(mlo instanceof MultiLabelOutput)) { return false; } //check bipartitions if (bipartition == null) { if (((MultiLabelOutput) mlo).bipartition != null) { return false; } } if (bipartition != null) { if (((MultiLabelOutput) mlo).bipartition == null) { return false; } else { for (int i = 0; i < bipartition.length; i++) { if (bipartition[i] != ((MultiLabelOutput) mlo).bipartition[i]) { return false; } } } } //check rankings if (ranking == null) { if (((MultiLabelOutput) mlo).ranking != null) { return false; } } if (ranking != null) { if (((MultiLabelOutput) mlo).ranking == null) { return false; } else { for (int i = 0; i < ranking.length; i++) { if (ranking[i] != ((MultiLabelOutput) mlo).ranking[i]) { return false; } } } } //check confidences if (confidences == null) { if (((MultiLabelOutput) mlo).confidences != null) { return false; } } if (confidences != null) { if (((MultiLabelOutput) mlo).confidences == null) { return false; } else { double[] conf = ((MultiLabelOutput) mlo).getConfidences(); for (int i = 0; i < confidences.length; i++) { if (!weka.core.Utils.eq(confidences[i], conf[i])) { return false; } } } } return true; } @Override public String toString() { StringBuilder sb = new StringBuilder(); if (bipartition != null) { sb.append("Bipartion: " + Arrays.toString(bipartition) + " "); } if (confidences != null) { sb.append("Confidences: " + Arrays.toString(confidences) + " "); } if (ranking != null) { sb.append("Ranking: " + Arrays.toString(ranking)); } return sb.toString(); } @Override public int hashCode() { int hash = 7; hash = 89 * hash + Arrays.hashCode(this.bipartition); hash = 89 * hash + Arrays.hashCode(this.ranking); hash = 89 * hash + Arrays.hashCode(this.confidences); return hash; } }