package ruc.irm.classification;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
/**
* 分类的类别
*
* @author xiatian
*
*/
public class Variable {
/** 类别信息 */
Map<String, CategoryInfo> categoryMap = new HashMap<String, CategoryInfo>();
Map<String, Feature> features = new HashMap<String, Feature>();
/** 所有文档的数量 */
private int docCount = 0;
public void write(DataOutput out) throws IOException{
//保存文档总数
out.writeInt(docCount);
//写入类别总数
out.writeInt(categoryMap.size());
for(String category:categoryMap.keySet()){
out.writeUTF(category);
categoryMap.get(category).write(out);
}
//写入Feature总数
out.writeInt(features.size());
for(String key:features.keySet()){
out.writeUTF(key);
features.get(key).write(out);
}
}
public void readFields(DataInput in) throws IOException {
this.docCount = in.readInt();
int size = in.readInt();
categoryMap = new HashMap<String, CategoryInfo>();
for(int i=0; i<size; i++){
String category = in.readUTF();
CategoryInfo info = CategoryInfo.read(in);
categoryMap.put(category, info);
}
size = in.readInt();
features = new HashMap<String, Feature>();
for(int i=0; i<size; i++){
String word = in.readUTF();
Feature feature = Feature.read(in);
features.put(word, feature);
}
}
public static Variable read(DataInput in) throws IOException{
Variable v = new Variable();
v.readFields(in);
return v;
}
public Collection<String> getCategories(){
return categoryMap.keySet();
}
public int getFeatureCount(){
return features.size();
}
public boolean containFeature(String feature){
return features.containsKey(feature);
}
public void incDocCount(){
this.docCount++;
}
public int getDocCount(){
return this.docCount;
}
/**
* 获取置顶类别下的文档数量
* @param category
* @return
*/
public int getDocCount(String category){
return categoryMap.get(category).getDocCount();
}
/**
* 获取feature在指定类别下的文档出现数量
* @param feature
* @param category
* @return
*/
public int getDocCount(String feature, String category){
Feature f = features.get(feature);
if(f!=null){
return f.getDocCount(category);
}
return 0;
}
public void addInstance(Instance instance){
incDocCount();
CategoryInfo info = null;
if(categoryMap.containsKey(instance.getCategory())){
info = categoryMap.get(instance.getCategory());
}else{
info = new CategoryInfo();
}
info.incDocCount();
categoryMap.put(instance.getCategory(), info);
for(String word:instance.getWords()){
Feature feature = features.get(word);
if(feature==null) feature = new Feature();
feature.setName(word);
feature.incDocCount(instance.getCategory());
features.put(word, feature);
}
}
public static class CategoryInfo {
private int docCount;
public int getDocCount() {
return docCount;
}
public void incDocCount(){
this.docCount++;
}
public void setDocCount(int docCount) {
this.docCount = docCount;
}
public void write(DataOutput out) throws IOException{
out.writeInt(docCount);
}
public void readFields(DataInput in) throws IOException {
this.docCount = in.readInt();
}
public static CategoryInfo read(DataInput in) throws IOException{
CategoryInfo c = new CategoryInfo();
c.readFields(in);
return c;
}
}
}