/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package liblinear;
import static liblinear.Linear.NL;
import static liblinear.Linear.atof;
import static liblinear.Linear.atoi;
import static liblinear.Linear.closeQuietly;
import static liblinear.Linear.printf;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.List;
import java.util.StringTokenizer;
import java.util.regex.Pattern;
public class Predict {
private static boolean flag_predict_probability = false;
private static final Pattern COLON = Pattern.compile(":");
/**
* <p><b>Note: The streams are NOT closed</b></p>
*/
static void doPredict( BufferedReader reader, Writer writer, Model model ) throws IOException {
int correct = 0;
int total = 0;
int nr_class = model.getNrClass();
double[] prob_estimates = null;
int n;
int nr_feature = model.getNrFeature();
if ( model.bias >= 0 )
n = nr_feature + 1;
else
n = nr_feature;
Formatter out = new Formatter(writer);
if ( flag_predict_probability ) {
if ( model.solverType != SolverType.L2_LR ) {
throw new IllegalArgumentException("probability output is only supported for logistic regression");
}
int[] labels = model.getLabels();
prob_estimates = new double[nr_class];
printf(out, "labels");
for ( int j = 0; j < nr_class; j++ )
printf(out, " %d", labels[j]);
printf(out, "\n");
}
String line = null;
while ( (line = reader.readLine()) != null ) {
List<FeatureNode> x = new ArrayList<FeatureNode>();
StringTokenizer st = new StringTokenizer(line, " \t");
String label = st.nextToken();
int target_label = atoi(label);
while ( st.hasMoreTokens() ) {
String[] split = COLON.split(st.nextToken(), 2);
if ( split == null || split.length < 2 ) exit_input_error(total + 1);
try {
int idx = atoi(split[0]);
double val = atof(split[1]);
// feature indices larger than those in training are not used
if ( idx <= nr_feature ) {
FeatureNode node = new FeatureNode(idx, val);
x.add(node);
}
}
catch ( NumberFormatException e ) {
exit_input_error(total + 1, e);
}
}
if ( model.bias >= 0 ) {
FeatureNode node = new FeatureNode(n, model.bias);
x.add(node);
}
FeatureNode[] nodes = new FeatureNode[x.size()];
nodes = x.toArray(nodes);
int predict_label;
if ( flag_predict_probability ) {
predict_label = Linear.predictProbability(model, nodes, prob_estimates);
printf(out, "%d ", predict_label);
for ( int j = 0; j < model.nr_class; j++ )
printf(out, "%g ", prob_estimates[j]);
printf(out, "\n");
} else {
predict_label = Linear.predict(model, nodes);
printf(out, "%d\n", predict_label);
}
if ( predict_label == target_label ) {
++correct;
}
++total;
}
//System.out.printf("Accuracy = %g%% (%d/%d)" + NL, (double)correct / total * 100, correct, total);
}
private static void exit_input_error( int line_num, Throwable cause ) {
throw new RuntimeException("Wrong input format at line " + line_num, cause);
}
private static void exit_input_error( int line_num ) {
throw new RuntimeException("Wrong input format at line " + line_num);
}
private static void exit_with_help() {
System.out.println("Usage: predict [options] test_file model_file output_file" + NL //
+ "options:" + NL //
+ "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0)" + NL //
);
System.exit(1);
}
public static void main( String[] argv ) throws IOException {
int i;
// parse options
for ( i = 0; i < argv.length; i++ ) {
if ( argv[i].charAt(0) != '-' ) break;
++i;
switch ( argv[i - 1].charAt(1) ) {
case 'b':
try {
flag_predict_probability = (atoi(argv[i]) != 0);
}
catch ( NumberFormatException e ) {
exit_with_help();
}
break;
default:
System.err.println("unknown option" + NL);
exit_with_help();
break;
}
}
if ( i >= argv.length || argv.length <= i + 2 ) {
exit_with_help();
}
BufferedReader reader = null;
Writer writer = null;
try {
reader = new BufferedReader(new InputStreamReader(new FileInputStream(argv[i]), Linear.FILE_CHARSET));
writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(argv[i + 2]), Linear.FILE_CHARSET));
Model model = Linear.loadModel(new File(argv[i + 1]));
doPredict(reader, writer, model);
}
finally {
closeQuietly(reader);
closeQuietly(writer);
}
}
}