package edu.fudan.ml.types;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import edu.fudan.data.reader.Reader;
import edu.fudan.ml.types.alphabet.AlphabetFactory;
import edu.fudan.nlp.pipe.Pipe;
import edu.fudan.nlp.pipe.SeriesPipes;
/**
* 样本集合
*
* @author xpqiu
*
*/
public class InstanceSet extends ArrayList<Instance> {
private static final long serialVersionUID = 3449458306217680806L;
/**
* 本样本集合默认的数据类型转换管道
*/
private Pipe pipes = null;
/**
* 本样本集合对应的特征和标签索引字典管理器
*/
private AlphabetFactory factory = null;
public int numFeatures = 0;
public String name = "";
public InstanceSet(Pipe pipes) {
this.pipes = pipes;
}
public InstanceSet(Pipe pipes, AlphabetFactory factory) {
this.pipes = pipes;
this.factory = factory;
}
public InstanceSet(AlphabetFactory factory) {
this.factory = factory;
}
public InstanceSet() {
}
/**
* 分割样本集,将样本集合中样本放随机放在两个集合,大小分别为i/n,(n-i)/n
*
* @param i 第一个集合比例
* @param n 集合样本总数(相对于i)
* @return
*/
public InstanceSet[] split(int i, int n) {
return split((float) i/(float)n);
}
/**
* 分割样本集,将样本集合中样本放随机放在两个集合,大小分别为i/n,(n-i)/n
*
* @param percent 分割比例 必须在0,1之间
* @return
*/
public InstanceSet[] split(float percent) {
shuffle();
int length = this.size();
InstanceSet[] sets = new InstanceSet[2];
sets[0] = new InstanceSet(pipes, factory);
sets[1] = new InstanceSet(pipes, factory);
int idx = (int) Math.round(percent*length);
sets[0].addAll(subList(0, idx));
if(idx+1<length)
sets[1].addAll(subList(idx+1, length));
return sets;
}
public InstanceSet[] randomSplit(float percent) throws Exception {
if (percent > 1 || percent < 0)
throw new Exception("Percent should be in [0, 1]");
// shuffle();
InstanceSet[] sets = new InstanceSet[2];
sets[0] = new InstanceSet(pipes, factory);
sets[1] = new InstanceSet(pipes, factory);
int[] flag = labelFlag();
List<ArrayList<Integer>> list = listLabel(flag);
flag = randomSet(flag, list, percent);
for (int i = 0; i < flag.length; i++) {
if (flag[i] < 0)
sets[0].add(this.get(i));
else
sets[1].add(this.get(i));
}
return sets;
}
public int[] randomSet(int[] flag, List<ArrayList<Integer>> list, float percent) {
Random r = new Random();
for(ArrayList<Integer> alist : list) {
int allsize = Math.round(alist.size() * percent);
int count = 0;
while (true) {
int randomInt = r.nextInt(alist.size());
int index = alist.get(randomInt);
if (flag[index] >= 0) {
flag[index] = -1;
count++;
if (count >= allsize)
break;
}
}
}
return flag;
}
public List<ArrayList<Integer>> listLabel(int[] flag) {
List<ArrayList<Integer>> list = new ArrayList<ArrayList<Integer>>();
int classsize = classSize().size();
for (int i = 0; i < classsize; i++) {
List<Integer> ll = new ArrayList<Integer>();
list.add((ArrayList<Integer>)ll);
}
for (int i = 0; i < flag.length; i++) {
int ele = flag[i];
ArrayList<Integer> l = list.get(ele);
l.add(i);
}
return list;
}
public int[] labelFlag() {
int length = this.size();
int[] flag = new int[length];
Map<Object,Integer> map = classSize();
for (int i = 0; i < length; i++) {
Object target = this.get(i).getTarget();
int label = map.get(target);
flag[i] = label;
}
return flag;
}
public Map<Object, Integer> classSize() {
Map<Object, Integer> map = new HashMap<Object, Integer>();
int label = 0;
for (Instance ins : this) {
if (!map.containsKey(ins.getTarget())) {
map.put(ins.getTarget(), label++);
}
}
return map;
}
public InstanceSet subSet(int from,int end){
InstanceSet set = new InstanceSet();
set = new InstanceSet(pipes, factory);
set.addAll(subList(from,end));
return set;
}
/**
* 用本样本集合默认的“数据类型转换管道”通过“数据读取器”批量建立样本集合
* @param reader 数据读取器
* @throws Exception
*/
public void loadThruPipes(Reader reader) throws Exception {
// 通过迭代加入样本
while (reader.hasNext()) {
Instance inst = reader.next();
if (pipes != null)
pipes.addThruPipe(inst);
this.add(inst);
}
}
/**
* 分步骤批量处理数据,每个Pipe处理完所有数据再进行下一个Pipe
*
* @param reader
* @throws Exception
*/
public void loadThruStagePipes(Reader reader) throws Exception {
SeriesPipes p = (SeriesPipes) pipes;
// 通过迭代加入样本
Pipe p1 = p.getPipe(0);
while (reader.hasNext()) {
Instance inst = reader.next();
if(inst!=null){
if (p1 != null)
p1.addThruPipe(inst);
this.add(inst);
};
}
for (int i = 1; i < p.size(); i++)
p.getPipe(i).process(this);
}
/**
* 实验用, 为了MultiCorpus, 工程开发请忽略
*
* 分步骤批量处理数据,每个Pipe处理完所有数据再进行下一个Pipe
*
* @param reader
* @throws Exception
*/
public void loadThruStagePipesForMultiCorpus(Reader[] readers, String[] corpusNames) throws Exception {
SeriesPipes p = (SeriesPipes) pipes;
// 通过迭代加入样本
Pipe p1 = p.getPipe(0);
for(int i = 0; i < readers.length; i++) {
while (readers[i].hasNext()) {
Instance inst = readers[i].next();
inst.setClasue(corpusNames[i]);
if(inst!=null){
if (p1 != null)
p1.addThruPipe(inst);
this.add(inst);
};
}
}
for (int i = 1; i < p.size(); i++)
p.getPipe(i).process(this);
}
public void shuffle() {
Collections.shuffle(this);
}
public void sortByWeights() {
Collections.sort(this, new Comparator<Instance>() {
@Override
public int compare(Instance o1, Instance o2) {
float f1 = o1.getWeight();
float f2 = o2.getWeight();
if(f1<f2)
return 1;
else if(f1>f2)
return -1;
else
return 0;
}
});
}
public void shuffle(Random r) {
Collections.shuffle(this, r);
}
public Pipe getPipes() {
return pipes;
}
public Instance getInstance(int idx) {
if (idx < 0 || idx > this.size())
return null;
return this.get(idx);
}
public AlphabetFactory getAlphabetFactory() {
return factory;
}
// public void addAll(InstanceSet subset) {
// this.addAll(subset);
// }
public void setPipes(Pipe pipes) {
this.pipes = pipes;
}
public void setAlphabetFactory(AlphabetFactory factory) {
this.factory = factory;
}
public String toString(){
StringBuilder sb= new StringBuilder();
for(int i=0;i<size();i++){
sb.append(get(i));
sb.append("\n");
}
return sb.toString();
}
}