package edu.fudan.data.reader;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import edu.fudan.ml.types.Instance;
import edu.fudan.util.UnicodeReader;
/**
* 读入序列标记的数据,目标值(若存在)在最后一列。列数必须一致
* 为了数据处理方便,内部的行列和文件里的行列翻转
* 格式为
* x1 y1
* x2 y2
*
* x3 y3
* x4 y4
* 每行数据用\t隔开, * 不同样本以空行分开
* 输出数据格式为:
* data:ArrayList<ArrayList<String>>
* target: ArrayList<String>
* @author xpqiu
*
*/
public class SequenceReader extends Reader {
BufferedReader reader;
Instance cur;
/**
* 默认包含目标值
*/
private boolean hasTarget = true;;
static final char delimiter = '\t';
/**
* 当前行号
*/
int lineNo=0;
/**
* 构造函数
* @param file 文件名
* @param hasTarget 是否包含目标值
*/
public SequenceReader(String file,boolean hasTarget) {
this(file, hasTarget,"UTF-8");
}
public SequenceReader(String file,boolean hasTarget, String charsetName) {
this.hasTarget = hasTarget;
try {
reader = new BufferedReader(new UnicodeReader(
new FileInputStream(file), charsetName));
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
public SequenceReader(InputStream is) {
reader = new BufferedReader(new UnicodeReader(
is,null));
}
public boolean hasNext() {
cur = readSequence();
return (cur != null);
}
public Instance next() {
return cur;
}
private Instance readSequence() {
cur = null;
try {
ArrayList<ArrayList<String>> seq = new ArrayList<ArrayList<String>>();
ArrayList<String> first = new ArrayList(); //至少有一列元素
seq.add(first);
ArrayList<String> labels = null;
if(hasTarget){
labels = new ArrayList<String>();
}
String content = null;
while ((content = reader.readLine()) != null) {
lineNo++;
// content = content.trim();
if (content.matches("^$")){
if(first.size()>0) //第一列个数>0
break;
else
continue;
}
int colsnum = 0;
int start = 0;
int next =0;
while ((next = content.indexOf(delimiter, next)) != -1) {
if(next==start){ //防止字符和分隔符相同
next++;
continue;
}
ensure(colsnum,seq);
seq.get(colsnum).add(content.substring(start,next));
next++;
colsnum++;
start = next;
}
//处理最后一列
if(hasTarget){
if(start<2){
System.out.println("数据格式错误,只有一列,请检查!");
System.out.println("第"+lineNo+"行");
continue;
}
labels.add(content.substring(start));
}else{
ensure(colsnum,seq);
seq.get(colsnum).add(content.substring(start));
}
}
if (first.size() > 0){
cur = new Instance(seq, labels);
}
seq = null;
labels = null;
} catch (IOException e) {
e.printStackTrace();
}
return cur;
}
private void ensure(int colsnum, ArrayList<ArrayList<String>> seq) {
while(colsnum>=seq.size()){
seq.add(new ArrayList<String>());
}
}
public static void main(String[] args) {
SequenceReader sr = new SequenceReader("example-data/sequence/train.txt",true);
// SequenceReader sr = new SequenceReader("example-data/sequence/test0.txt",false);
Instance inst = null;
int count = 0;
while (sr.hasNext()) {
inst = sr.next();
System.out.print(".");
inst = null;
count++;
}
System.out.println(count);
}
}