package edu.fudan.ml.classifier.hier;
import edu.fudan.ml.classifier.TPredict;
import gnu.trove.list.linked.TFloatLinkedList;
import gnu.trove.list.linked.TIntLinkedList;
/**
* 用来保存标签和相应的得分
* 标签可能为中间计算结果
*
* @author xpqiu
*
*/
public class Predict implements TPredict<Integer> {
/**
* 记录前n个结果,默认为-1,不限制个数
*/
int n=-1;
/**
* 标签得分值
*/
public TFloatLinkedList scores;
/**
* 标签
*/
public TIntLinkedList labels;
public Object other;
public Predict() {
this(1);
}
public Predict(int n) {
this.n = n;
scores = new TFloatLinkedList();
labels = new TIntLinkedList();
}
/**
* 返回插入的位置
* @param score 得分
* @param label 标签
* @return 插入位置
*/
public int add(int label,float score) {
int i = 0;
int max;
if(n==-1)
max = scores.size();
else
max = n>scores.size()?scores.size():n;
for (i = 0; i < max; i++) {
if (score > scores.get(i))
break;
}
//TODO: 没有删除多余的信息
if(n!=-1&&i>=n)
return -1;
if(i<scores.size()){
scores.insert(i,score);
labels.insert(i,label);
}else{
scores.add(score);
labels.add(label);
}
return i;
}
/**
* 获得预测结果
*
* @param i
* 位置
* @return 第i个预测结果;如果不存在,为-1
*/
public Integer getLabel(int i) {
if (i < 0 || i >= labels.size())
return -1;
return labels.get(i);
}
/**
* 获得预测结果的得分
*
* @param i 位置
* @return 第i个预测结果的得分;不存在为Double.NEGATIVE_INFINITY
*/
public float getScore(int i) {
if (i < 0 || i >=scores.size())
return Float.NEGATIVE_INFINITY;
return scores.get(i);
}
/**
* 预测结果数量
*
* @return 预测结果的数量
*/
public int size() {
return n;
}
public void normalize(){
float base = scores.get(0)/2;
float sum = 0;
for(int i=0;i<scores.size();i++){
float s = (float) Math.exp(scores.get(i)/base);
scores.set(i, s);
sum +=s;
}
for(int i=0;i<scores.size();i++){
float s = scores.get(i)/sum;
// if(s <0.001f)
// s=0;
scores.set(i, s);
}
}
@Override
public Integer[] getLabels() {
// TODO Auto-generated method stub
return null;
}
@Override
public void remove(int i) {
// TODO Auto-generated method stub
}
}