/** * Copyright (c) 2010, Regents of the University of Colorado All rights * reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. Redistributions in binary * form must reproduce the above copyright notice, this list of conditions and * the following disclaimer in the documentation and/or other materials provided * with the distribution. Neither the name of the University of Colorado at * Boulder nor the names of its contributors may be used to endorse or promote * products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ package clear.model; import clear.train.kernel.AbstractKernel; import clear.util.IOUtil; import clear.util.tuple.JIntDoubleTuple; import com.carrotsearch.hppc.IntArrayList; import java.io.BufferedReader; import java.io.PrintStream; import java.util.ArrayList; import java.util.Arrays; /** * One-vs-all model. * * @author Jinho D. Choi <b>Last update:</b> 11/5/2010 */ public class OneVsAllModel extends AbstractMultiModel { public OneVsAllModel(AbstractKernel kernel) { super(kernel); } public OneVsAllModel(String modelFile) { super(modelFile); } public OneVsAllModel(BufferedReader fin) { super(fin); } public OneVsAllModel(int nLabels, int nFeatures, int[] aLabels, double[] dWeights) { super(nLabels, nFeatures, aLabels, dWeights); } @Override public void init(AbstractKernel kernel) { n_labels = kernel.L; n_features = kernel.D; a_labels = kernel.a_labels; d_weights = new double[n_labels * n_features]; } @Override public void load(String modelFile) { try { BufferedReader fin = IOUtil.createBufferedFileReader(modelFile); loadAux(fin); fin.close(); } catch (Exception e) { e.printStackTrace(); } } @Override public void load(BufferedReader fin) { try { loadAux(fin); } catch (Exception e) { e.printStackTrace(); } } public void loadAux(BufferedReader fin) throws Exception { n_labels = Integer.parseInt(fin.readLine()); n_features = Integer.parseInt(fin.readLine()); a_labels = new int[n_labels]; d_weights = new double[n_labels * n_features]; readLabels(fin); readWeights(fin); } @Override public void save(String modelFile) { try { PrintStream fout = IOUtil.createPrintFileStream(modelFile); saveAux(fout); fout.flush(); fout.close(); } catch (Exception e) { e.printStackTrace(); } } @Override public void save(PrintStream fout) { try { saveAux(fout); } catch (Exception e) { e.printStackTrace(); } } private void saveAux(PrintStream fout) throws Exception { fout.println(n_labels); fout.println(n_features); printLabels(fout); printWeights(fout); } private int getBeginIndex(int label, int index) { return index * n_labels + label; } @Override public void copyWeight(int label, double[] weight) { int i; for (i = 0; i < n_features; i++) { d_weights[getBeginIndex(label, i)] = weight[i]; } } @Override public double[] getScores(int[] x) { double[] scores = Arrays.copyOf(d_weights, n_labels); int i, idx, label; for (i = 0; i < x.length; i++) { for (label = 0; label < n_labels; label++) { if ((idx = getBeginIndex(label, x[i])) < d_weights.length) { scores[label] += d_weights[idx]; } } } return scores; } @Override public double[] getScores(IntArrayList x) { double[] scores = Arrays.copyOf(d_weights, n_labels); int i, idx, label; for (i = 0; i < x.size(); i++) { for (label = 0; label < n_labels; label++) { if ((idx = getBeginIndex(label, x.get(i))) < d_weights.length) { scores[label] += d_weights[idx]; } } } return scores; } public double[] getScores(JIntDoubleTuple[] x) { double[] scores = Arrays.copyOf(d_weights, n_labels); int idx, label; for (JIntDoubleTuple tup : x) { for (label = 0; label < n_labels; label++) { if ((idx = getBeginIndex(label, tup.i)) < d_weights.length) { scores[label] += (d_weights[idx] * tup.d); } } } for (label = 0; label < n_labels; label++) { scores[label] = scores[label]; } return scores; } public double[] getScores(ArrayList<JIntDoubleTuple> x) { double[] scores = Arrays.copyOf(d_weights, n_labels); int idx, label; for (JIntDoubleTuple tup : x) { for (label = 0; label < n_labels; label++) { if ((idx = getBeginIndex(label, tup.i)) < d_weights.length) { scores[label] += (d_weights[idx] * tup.d); } } } for (label = 0; label < n_labels; label++) { scores[label] = scores[label]; } return scores; } }