package edu.fudan.ml.classifier.knn;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import edu.fudan.ml.classifier.AbstractClassifier;
import edu.fudan.ml.classifier.LabelParser.Type;
import edu.fudan.ml.classifier.linear.Linear;
import edu.fudan.ml.classifier.LinkedPredict;
import edu.fudan.ml.classifier.Predict;
import edu.fudan.ml.classifier.TPredict;
import edu.fudan.ml.types.Instance;
import edu.fudan.ml.types.InstanceSet;
import edu.fudan.nlp.pipe.Pipe;
import edu.fudan.nlp.pipe.String2Dep;
import edu.fudan.nlp.similarity.ISimilarity;
import edu.fudan.util.exception.LoadModelException;
public class KNN extends AbstractClassifier{
private static final long serialVersionUID = 4459814160943364300L;
private ISimilarity sim;
public ISimilarity getSim() {
return sim;
}
public void setSim(ISimilarity sim) {
this.sim = sim;
}
private int k;
/**
* 特征转换器
*/
protected Pipe pipe;
/**
* KNN模型
*/
protected InstanceSet prototypes;
private boolean useScore = true;
/**
* 初始化
* @param instset
* @param i
* @param p
*/
public KNN(InstanceSet instset,Pipe p, ISimilarity sim, int k){
prototypes = instset;
this.pipe = p;
this.sim = sim;
this.k = k;
int count1 =0,count2=0;
int total = prototypes.size();
System.out.println("实例数量:"+total);
for(int i=0;i<total;i++){
Instance inst = prototypes.get(i);
TPredict pred = classify(inst, 1);
if(pred.getLabel(0).equals(inst.getTarget()))
count1++;
prototypes.remove(i);
TPredict pred2 = classify(inst, 1);
if(pred2.getLabel(0).equals(inst.getTarget()))
count2++;
prototypes.add(i, inst);
}
System.out.println("Leave-zero-out正确率:"+count1*1.0f/total);
System.out.println("Leave-one-out正确率:"+count2*1.0f/total);
}
public void setPipe(Pipe p) {
this.pipe = p;
}
/**
* 分类,返回标签,格式可自定义
* @param instance
* @return
* @throws Exception
*/
public TPredict classify(Instance instance, int n){
LinkedPredict<String> pred = new LinkedPredict<String>(k);
for(int i = 0; i < prototypes.size(); i++){
Instance curInst = prototypes.get(i);
// if(((String) curInst.getSource()).contains("听#per#的歌"))
// System.out.println("");
float score;
try {
score = sim.calc(instance.getData(), curInst.getData());
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
return null;
}
pred.add((String) curInst.getTarget(), score,(String) curInst.getSource());
}
//排序
LinkedPredict<String> newpred = pred.mergeDuplicate(useScore );
newpred.assertSize(n);
return newpred;
}
@Override
public TPredict classify(Instance instance, Type type, int n) {
return classify(instance, n);
}
/**
* 将分类器保存到文件
* @param file
* @throws IOException
*/
public void saveTo(String file) throws IOException {
File f = new File(file);
File path = f.getParentFile();
if(!path.exists()){
path.mkdirs();
}
ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(
new BufferedOutputStream(new FileOutputStream(file))));
out.writeObject(this);
out.close();
}
/**
* 从文件读入分类器
* @param file
* @return
* @throws LoadModelException
*/
public static KNN loadFrom(String file) throws LoadModelException{
KNN cl = null;
try {
ObjectInputStream in = new ObjectInputStream(new GZIPInputStream(
new BufferedInputStream(new FileInputStream(file))));
cl = (KNN) in.readObject();
in.close();
} catch (Exception e) {
throw new LoadModelException(e,file);
}
return cl;
}
public void setK(int k) {
this.k = k;
}
}