package org.shanbo.feluca.data2; import gnu.trove.list.array.TFloatArrayList; import gnu.trove.list.array.TIntArrayList; import java.io.IOException; import java.util.ArrayList; import java.util.List; import org.apache.commons.lang3.text.StrBuilder; import org.msgpack.packer.Packer; import org.msgpack.unpacker.Unpacker; import org.shanbo.feluca.data2.DataSetInfo.Statistic; import org.shanbo.feluca.data2.util.NumericTokenizer; import org.shanbo.feluca.data2.util.NumericTokenizer.FeatureWeight; public abstract class Vector { public enum VectorType{ LABEL_FID_WEIGHT, VID_FID_WEIGHT, NUMBER_FID_WEIGHT, } protected TIntArrayList fids; protected VectorType inputType; //use only for convert or build protected VectorType outputType; // public abstract List<Statistic> getStat(); public abstract void pack(Packer packer) throws IOException; public abstract void unpack(Unpacker unpacker) throws IOException; public abstract boolean parseLine(String line); protected abstract boolean readObject(Object... values); public abstract String toString(); public abstract List<Vector> divideByFeature(HashPartitioner partitioner); public abstract void swallow(Vector v); public abstract int getSpaceCost(); public VectorType getOutVectorType(){ return this.outputType; } public void setOutputType(VectorType outputType){ this.outputType = outputType; } public int getSize(){ return fids.size(); } @Deprecated public long getLongHeader(){ return 0; } @Deprecated public int getIntHeader(){ return 0; } public int getFId(int idx){ return fids.getQuick(idx); } @Deprecated public float getWeight(int idx){ return 0; } @Deprecated public byte[] getBytesPayload(int idx){ return new byte[]{}; } @Deprecated public int getIntPayload(int idx){ return 0; } @Deprecated public long getLongPayload(int idx){ return 0l; } @Deprecated public float getFloatPayload(int idx) { return 0.0f; } public static Vector create(VectorType vt){ Vector v; if (vt == VectorType.LABEL_FID_WEIGHT){ v = new LWVector(); }else { v = new VIDVector(); } return v; } public static Vector create(VectorType vt, String line){ Vector v = create(vt); v.parseLine(line); return v; } public static Vector create(VectorType vt, Unpacker unpacker) throws IOException{ Vector v = create(vt); v.unpack(unpacker); return v; } public static Vector create(VectorType vt, Object... values) throws IOException{ Vector v = create(vt); v.readObject(values); return v; } public double getDoublePayload(int idx) { return 0.0; } public static class LWVector extends Vector{ int label ; TFloatArrayList weights; public LWVector(){ this.inputType = VectorType.LABEL_FID_WEIGHT; this.outputType = VectorType.LABEL_FID_WEIGHT; } @Override public void pack(Packer packer) throws IOException { packer.write(label); packer.write(this.fids.toArray()); packer.write(this.weights.toArray()); } @Override public void unpack(Unpacker unpacker) throws IOException { this.label = unpacker.readInt(); this.fids = new TIntArrayList(unpacker.read(int[].class)); this.weights = new TFloatArrayList(unpacker.read(float[].class)); } public int getIntHeader(){ return label; } public float getWeight(int idx){ return weights.getQuick(idx); } @Override public boolean parseLine(String line) { if (fids == null){ //don't know fids = new TIntArrayList(1024); weights = new TFloatArrayList(1024); }else{ fids.resetQuick(); weights.resetQuick(); } NumericTokenizer nt = new NumericTokenizer(); nt.load(line); this.label = (Integer)(nt.nextNumber()); while(nt.hasNext()){ // long kv = nt.nextKeyValuePair(); FeatureWeight nextKeyWeight = nt.nextKeyWeight(); // int fid = nextKeyWeight.getId(); // float weight = NumericTokenizer.extractWeight(kv); fids.add(nextKeyWeight.getId()); weights.add(nextKeyWeight.getWeight()); } if (fids.size() == 0){ return false; } return true; } @Override public String toString() { StringBuilder sb = new StringBuilder(label + ""); for(int i = 0 ; i < fids.size(); i++){ sb.append(String.format(" %d:%.4f", fids.getQuick(i), weights.getQuick(i) )); } return sb.toString(); } @Override public int getSpaceCost() { return 4 + (fids.size() << 3 ) ; //label + id[] * 4 + weight[] * 4 } @Override protected boolean readObject(Object... values) { label = (Integer)values[0]; fids = (TIntArrayList)values[1]; weights = (TFloatArrayList)values[2]; return true; } public List<Vector> divideByFeature(HashPartitioner partitioner) { List<Vector> vectors = new ArrayList<Vector>(partitioner.getMaxShards()); List<StrBuilder> lines = new ArrayList<StrBuilder>(partitioner.getMaxShards()); for(int i = 0 ; i < partitioner.getMaxShards(); i++){ vectors.add(create(getOutVectorType())); lines.add(new StrBuilder().append(getIntHeader())); //label } for(int i = 0 ; i < getSize(); i++){ int shardId = partitioner.decideShard(getFId(i)); lines.get(shardId).append(String.format(" %d:%.4f", fids.getQuick(i), weights.getQuick(i) )); } for(int i = 0; i < vectors.size(); i++){ vectors.get(i).parseLine(lines.get(i).toString()); } return vectors; } @Override public void swallow(Vector v) { if (v == null) return; if (fids == null){ fids = new TIntArrayList(); weights = new TFloatArrayList(); label = v.getIntHeader(); } for(int i = 0 ; i < v.getSize(); i++){ fids.add(v.getFId(i)); weights.add(v.getWeight(i)); } } @Override public List<Statistic> getStat() { // TODO Auto-generated method stub return null; } } public static class VIDVector extends LWVector{ } }