package org.shanbo.feluca.vectors;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;
import java.util.Map.Entry;
import org.apache.commons.lang3.text.StrBuilder;
import org.shanbo.feluca.data2.HashPartitioner;
import org.shanbo.feluca.data2.Vector;
import org.shanbo.feluca.data2.DataSetInfo.Statistic;
import org.shanbo.feluca.data2.util.BytesUtil;
import org.shanbo.feluca.data2.util.NumericTokenizer;
import org.shanbo.feluca.data2.util.NumericTokenizer.FeatureWeight;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
/**
* label/tag for classification
* <p></p>
* libsvm format-> int fid1:weight1 fid2:weight2 fid3:weight3....
* @author lgn
*
*/
public class LabelVector extends GeneralVector{
public static class LabelStatistic extends Statistic{
HashMap<Integer, int[]> labelInfoBag = new HashMap<Integer, int[]>();
@Override
public void statAsOne() {
LabelVector lv = (LabelVector)this.current;
int[] labelInfo = labelInfoBag.get(lv.getLabel());
if (labelInfo == null){
labelInfoBag.put(lv.getLabel(), new int[]{labelInfoBag.size(), 1});
}else{
labelInfo[1] += 1;
}
}
@Override
public void statOnFeature(int index) {
}
@Override
public Properties getStatInfo() {
Properties p = new Properties();
p.put(CLASSES, this.labelInfoBag.size());
StringBuilder sb = new StringBuilder();
for(Entry<Integer, int[]> entry : labelInfoBag.entrySet()){
sb.append(String.format("%d:%d:%d ", entry.getKey(), entry.getValue()[0], entry.getValue()[1]));
}
p.put(LABEL_INFO, sb.toString());
return p;
}
}
public LabelVector(){
this.inputType = VectorType.LABEL_FID_WEIGHT;
this.outputType = VectorType.LABEL_FID_WEIGHT;
// this.head = new byte[4];
// this.fids = new TIntArrayList();
// this.weights = new TFloatArrayList();
}
public int getLabel(){
return BytesUtil.getInt(getHeader());
}
@Override
public boolean parseLine(String line) {
if (fids == null){ //don't know
fids = new TIntArrayList(1024);
weights = new TFloatArrayList(1024);
head = new byte[4];
}else{
fids.resetQuick();
weights.resetQuick();
}
NumericTokenizer nt = new NumericTokenizer();
nt.load(line);
BytesUtil.int2Byte(nt.nextNumber().intValue(), head);
while(nt.hasNext()){
FeatureWeight nextKeyWeight = nt.nextKeyWeight();
fids.add(nextKeyWeight.getId());
weights.add(nextKeyWeight.getWeight());
}
if (fids.size() == 0){
return false;
}
return true;
}
@Deprecated
public List<Vector> divideByFeature(HashPartitioner partitioner) {
List<Vector> vectors = new ArrayList<Vector>(partitioner.getMaxShards());
List<StrBuilder> lines = new ArrayList<StrBuilder>(partitioner.getMaxShards());
for(int i = 0 ; i < partitioner.getMaxShards(); i++){
vectors.add(new LabelVector());
lines.add(new StrBuilder().append(getLabel())); //label
}
for(int i = 0 ; i < getSize(); i++){
int shardId = partitioner.decideShard(getFId(i));
lines.get(shardId).append(String.format(" %d:%.4f", fids.getQuick(i), weights.getQuick(i) ));
}
for(int i = 0; i < vectors.size(); i++){
vectors.get(i).parseLine(lines.get(i).toString());
}
return vectors;
}
@Override
public void swallow(Vector v) {
if (v == null)
return;
if (fids == null){
fids = new TIntArrayList();
weights = new TFloatArrayList();
head = ((GeneralVector)v).getHeader();
}
for(int i = 0 ; i < v.getSize(); i++){
fids.add(v.getFId(i));
weights.add(((GeneralVector)v).getWeight(i));
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder(this.getLabel() + "");
for(int i = 0 ; i < fids.size(); i++){
sb.append(String.format(" %d:%.4f", fids.getQuick(i), weights.getQuick(i) ));
}
return sb.toString();
}
@Override
public List<Statistic> getStat() {
List<Statistic> stats = new ArrayList<Statistic>();
stats.add(new LabelStatistic());
stats.add(new BasicStatistic());
return stats;
}
}