package edu.fudan.ml.classifier; import java.util.Arrays; /** * 用来输出带得分的预测结果 * @author xpqiu * */ public class Predict<T> implements TPredict<T> { /** * 标签数组 */ public T[] labels; /** * 得分数组 */ public float[] scores; /** * 保存个数 */ int n; /** * 缺省只保存一个最大值 */ public Predict(){ this(1); } /** * 保存前n个最大值 * @param n */ public Predict(int n){ this.n = n; labels = (T[]) new Object[n]; scores = new float[n]; Arrays.fill(scores, Float.NEGATIVE_INFINITY); } public int size() { return labels.length; } /** * 返回得分最高的标签 * @return */ public T getLabel() { return labels[0]; } /** * 返回得分第i高的标签 */ public T getLabel(int i) { return labels[i]; } public float getScore(int i) { return scores[i]; } /** * 设置位置i上的标签和得分 * @param i * @param label2 * @param d */ public void set(int i, T label2, float d) { labels[i] = label2; scores[i] = d; } /** * 增加新的标签和得分,并根据得分调整排序 * * @param label 标签 * @param score 得分 * @return 插入位置 */ public int add(T label,float score) { int i = 0; int ret = i; if (n != 0) { for (i = 0; i < n; i++) { if (score > scores[i]) break; } if (i != n || n < scores.length) { for (int k = n - 2; k >= i; k--) { scores[k + 1] = scores[k]; labels[k + 1] = labels[k]; } ret = i; }else if (n < scores.length) { }else { ret = -1; } } if(ret!=-1){ scores[i] = score; labels[i] = label; } if (n < scores.length) n++; return ret; } /** * 将得分归一化到[0,1]区间 */ public void normalize() { float base = 1; if(scores[0]!=0.0f) base = scores[0]/2; float sum = 0; for(int i=0;i<scores.length;i++){ float s = (float) Math.exp(scores[i]/base); scores[i] = s; sum +=s; } for(int i=0;i<scores.length;i++){ float s = scores[i]/sum; scores[i] = s; } } /** * 简单可视输出 */ public String toString(){ StringBuilder sb = new StringBuilder(); for(int i=0;i<labels.length;i++){ sb.append(labels[i]); sb.append(" "); sb.append(scores[i]); sb.append("\n"); } return sb.toString(); } /** * 取得所有返回结果 * @return */ public T[] getLabels() { return labels; } @Override public void remove(int i) { // TODO Auto-generated method stub System.err.println("没有实现"); } }