package edu.fudan.ml.classifier.hier;
import edu.fudan.ml.types.Instance;
import edu.fudan.ml.types.InstanceSet;
import edu.fudan.ml.types.alphabet.LabelAlphabet;
import edu.fudan.ml.types.sv.HashSparseVector;
import edu.fudan.ml.types.sv.ISparseVector;
/**
* 计算类中心
* @author xpqiu
*
*/
public class Mean {
public static HashSparseVector[] mean (InstanceSet trainingList,Tree tree)
{
LabelAlphabet alphabet = trainingList.getAlphabetFactory().DefaultLabelAlphabet();
int numLabels = alphabet.size();
HashSparseVector[] means = new HashSparseVector[numLabels];
int[] classNum = new int[numLabels];
for(int i=0;i<numLabels;i++){
means[i]=new HashSparseVector();
}
for (int ii = 0; ii < trainingList.size(); ii++){
Instance inst = trainingList.getInstance(ii);
ISparseVector fv = (ISparseVector) inst.getData ();
int target = (Integer) inst.getTarget();
if(tree!=null){
int[] anc = tree.getPath(target);
for(int j=0;j<anc.length;j++){
means[anc[j]].plus(fv);
classNum[anc[j]]+=1;
}
}else{
means[target].plus(fv);
classNum[target]+=1;
}
}
for(int i=0;i<numLabels;i++){
if(classNum[i]>0)
means[i].scaleDivide(classNum[i]);
}
return means;
}
}