package edu.fudan.ml.classifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Vector;
import edu.fudan.util.MyCollection;
/**
* KNN分类器结果记录
*
* @author lcao
*
*/
public class LinkedPredict<T> implements TPredict<T> {
/**
* 类标签
*/
private LinkedList<T> labels;
/**
* 分数
*/
private LinkedList<Float> scores;
/**
* 证据
*/
private LinkedList<T> evidences;
private int k;
public LinkedPredict(int k){
this.k = k;
labels = new LinkedList<T>();
scores = new LinkedList<Float>();
evidences = new LinkedList<T>();
}
public LinkedPredict(){
this.k = Integer.MAX_VALUE;
labels = new LinkedList<T>();
scores = new LinkedList<Float>();
evidences = new LinkedList<T>();
}
/**
* 简单可视输出
*/
public String toString(){
StringBuilder sb = new StringBuilder();
for(int i=0;i<labels.size();i++){
sb.append(labels.get(i));
sb.append(" ");
sb.append(scores.get(i));
sb.append("\n");
}
return sb.toString();
}
/**
* 增加新的标签和得分,并根据得分调整排序 *
* @param label 标签
* @param score 得分
* @return 插入位置
*/
public int add(T label,float score) {
int j = 0;
for(; j < labels.size(); j++)
if(scores.get(j) < score){
labels.add(j, label);
scores.add(j,score);
break;
}
if(j == labels.size() && labels.size() < k){
labels.add(j, label);
scores.add(j,score);
}
if(labels.size() > k){
labels.removeLast();
scores.removeLast();
}
return j;
}
public void add(T t, float score, T source) {
int j = add(t,score);
if(j>=k)
return;
evidences.add(j,source);
if(evidences.size()>k)
evidences.removeLast();
}
@Override
public T getLabel(int i) {
if(labels.size()==0)
return null;
return labels.get(i);
}
@Override
public float getScore(int i) {
return scores.get(i);
}
@Override
public void normalize() {
float base = 1;
if(scores.get(0)!=0.0f)
base = scores.get(0)/2;
float sum = 0;
for(int i=0;i<scores.size();i++){
float s = (float) Math.exp(scores.get(i)/base);
scores.set(i, s);
sum +=s;
}
for(int i=0;i<scores.size();i++){
float s = scores.get(i)/sum;
scores.set(i, s);
}
}
@Override
public int size() {
return labels.size();
}
/**
* 合并重复标签,并重新排序
* @param true 用得分; false 计数
* @return
*/
public LinkedPredict<T> mergeDuplicate(boolean useScore) {
LinkedPredict<T> pred = new LinkedPredict<T>();
for(int i = 0; i < labels.size(); i++){
T l = labels.get(i);
float score;
if(useScore)
score = scores.get(i);
else
score=1;
pred.addoradjust(l, score);
}
return pred;
}
/**
* 合并重复标签,有问题(未排序)
*/
public void mergeDuplicate() {
for(int i = 0; i < labels.size(); i++)
for(int j = i + 1; j < labels.size(); j++){
T tagi = labels.get(i);
T tagj = labels.get(j);
if(tagi.equals(tagj)){
scores.set(i, scores.get(i) + scores.get(j));
labels.remove(j);
scores.remove(j);
j--;
}
}
}
/**
* 合并
* @param pred
* @param w
*/
public void addorjust(TPredict<T> pred, float w) {
for(int i=0;i<pred.size();i++){
T l = pred.getLabel(i);
float s = pred.getScore(i);
addoradjust(l,s*w);
}
}
/**
*
* @param label
* @param f
*/
public void addoradjust(T label, float f) {
int j = 0;
for(; j < labels.size(); j++){
T tagj = labels.get(j);
if(tagj.equals(label)){
break;
}
}
if(j<labels.size()){
float ts = scores.get(j);
labels.remove(j);
scores.remove(j);
add(label,ts+f);
}else{
add(label,f);
}
}
@Override
public T[] getLabels() {
if(labels==null)
return null;
else
return (T[]) labels.toArray();
}
@Override
public void remove(int i) {
labels.remove(i);
scores.remove(i);
if(evidences.size()>i)
evidences.remove(i);
}
/**
* 确保结果个数小于等于n
* @param n
*/
public void assertSize(int n) {
while(labels.size()>n){
labels.removeLast();
scores.removeLast();
if(evidences.size()>n)
evidences.removeLast();
}
}
}