/** * Copyright (c) 2009, Regents of the University of Colorado All rights * reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. Redistributions in binary * form must reproduce the above copyright notice, this list of conditions and * the following disclaimer in the documentation and/or other materials provided * with the distribution. Neither the name of the University of Colorado at * Boulder nor the names of its contributors may be used to endorse or promote * products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ package clear.decode; import clear.model.AbstractModel; import clear.model.OneVsAllModel; import clear.util.tuple.JIntDoubleTuple; import com.carrotsearch.hppc.IntArrayList; import java.io.BufferedReader; import java.util.ArrayList; import java.util.Arrays; /** * One-vs-all decoder. * * @author Jinho D. Choi <br><b>Last update:</b> 11/8/2010 */ public class OneVsAllDecoder extends AbstractMultiDecoder { protected OneVsAllModel m_model; public OneVsAllDecoder(String modelFile) { m_model = new OneVsAllModel(modelFile); } public OneVsAllDecoder(BufferedReader fin) { m_model = new OneVsAllModel(fin); } public OneVsAllDecoder(OneVsAllModel model) { m_model = model; } @Override public JIntDoubleTuple predict(int[] x) { return predictAux(m_model.getScores(x)); } @Override public JIntDoubleTuple predict(IntArrayList x) { return predictAux(m_model.getScores(x)); } @Override public JIntDoubleTuple predict(JIntDoubleTuple[] x) { return predictAux(m_model.getScores(x)); } @Override public JIntDoubleTuple predict(ArrayList<JIntDoubleTuple> x) { return predictAux(m_model.getScores(x)); } private JIntDoubleTuple predictAux(double[] scores) { int[] aLabels = m_model.a_labels; JIntDoubleTuple max = new JIntDoubleTuple(aLabels[0], scores[0]); int i; for (i = 1; i < m_model.n_labels; i++) { if (scores[i] > max.d) { max.set(aLabels[i], scores[i]); } } // max.d = AbstractModel.logistic(max.d); return max; } @Override public JIntDoubleTuple[] predictAll(int[] x) { return predictAllAux(m_model.getScores(x)); } @Override public JIntDoubleTuple[] predictAll(IntArrayList x) { return predictAllAux(m_model.getScores(x)); } @Override public JIntDoubleTuple[] predictAll(ArrayList<JIntDoubleTuple> x) { return predictAllAux(m_model.getScores(x)); } private JIntDoubleTuple[] predictAllAux(double[] scores) { int[] aLabels = m_model.a_labels; JIntDoubleTuple[] aRes = new JIntDoubleTuple[m_model.n_labels]; for (int i = 0; i < m_model.n_labels; i++) { aRes[i] = new JIntDoubleTuple(aLabels[i], AbstractModel.logistic(scores[i])); } Arrays.sort(aRes); return aRes; } }