package edu.fudan.ml.classifier.hier;
import edu.fudan.ml.types.alphabet.LabelAlphabet;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* DAG结构
* 用 节点+边表示
* @author xpqiu
* @version 1.0
* Tree
* package edu.fudan.ml.types
*/
public class Tree implements Serializable {
private static final long serialVersionUID = 5846146204699950799L;
public int size=0;
private int depth=0;
int[][] treepath;
List<Integer> nodes = new ArrayList<Integer>();
TIntSet leafs = new TIntHashSet();
/**
* 父节点和对应的子节点数组
*/
HashMap<Integer,Set<Integer>> edges = new HashMap<Integer,Set<Integer>>();
/**
* 子节点->父节点
*/
HashMap<Integer,Integer> edgesInv = new HashMap<Integer,Integer>();
/**
* 层次结构
*/
HashMap<Integer,Set<Integer>> hier = new HashMap<Integer, Set<Integer>>();
public Tree(LabelAlphabet la,String sep){
Map<String, Integer> map = la.toMap();
Iterator<String> it = map.keySet().iterator();
while(it.hasNext()){
String key = it.next();
int value = map.get(key);
int idx = key.indexOf(sep, 0);
int plabel = la.lookupIndex("Root");
while(idx!=-1){
String label = key.substring(0,idx);
int clabel = la.lookupIndex(label);
addEdge(plabel, clabel);
plabel = clabel;
idx = key.indexOf(sep, idx+1);
}
if(plabel!=value)// 不能有指向自己的边
addEdge(plabel, value);
}
travel();
}
public Tree() {
// TODO Auto-generated constructor stub
}
public Integer getNode(int i) {
return nodes.get(i);
}
/**
* 文件每一行为一个边
* @param file
* @param alphabet
* @throws IOException
*/
public void loadFromFileWithEdge(String file, LabelAlphabet alphabet) throws IOException {
File f = new File(file);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
while((line=reader.readLine())!=null){
String[] tok = line.split(" ");
addEdge(alphabet.lookupIndex(tok[0]),alphabet.lookupIndex(tok[1]));
}
travel();
}
/**
* 文件每一行为一条路径
* @param file
* @param alphabet
* @throws IOException
*/
public void loadFromFileWithPath(String file, LabelAlphabet alphabet) throws IOException {
File f = new File(file);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
while((line=reader.readLine())!=null){
String[] tok = line.split(" ");
for(int i=0;i<tok.length-1;i++){
addEdge(alphabet.lookupIndex(tok[i]),alphabet.lookupIndex(tok[i+1]));
}
}
travel();
}
public void loadFromFile(String file) throws IOException{
File f = new File(file);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
while((line=reader.readLine())!=null){
String[] tok = line.split(" ");
addEdge(Integer.parseInt(tok[0]),Integer.parseInt(tok[1]));
}
travel();
}
public TIntSet getLeafs(){
return leafs;
}
/**
* 得到层次、叶子节点等信息
*/
private void travel() {
for(int i=0;i<nodes.size();i++){
int l = getLevel(nodes.get(i));
if(l>hier.size()|| hier.get(l)==null){
Set set = new HashSet<Integer>();
hier.put(l,set);
}
hier.get(l).add(i);
if(edges.get(i)==null){
leafs.add(i);
}
}
depth = hier.size();
CalcPath();
}
/**
* 得到节点的层数,根节点为0
* @param i
* @return
*/
private int getLevel(int i) {
int n=0;
Integer j=i;
while((j=edgesInv.get(j))!=null){
n++;
}
return n;
}
/**
* i -> j
* @param i 父节点
* @param j 子节点
*/
private void addEdge(int i, int j) {
if(!nodes.contains(i)){
nodes.add(i);
edges.put(i, new HashSet<Integer>());
size++;
}else if(!edges.containsKey(i)){
edges.put(i, new HashSet<Integer>());
}
if(!nodes.contains(j)){
nodes.add(j);
size++;
}
edgesInv.put(j, i);
if(!edges.get(i).contains(j)){
edges.get(i).add(j);
}
}
public static void main(String[] args) throws IOException{
String file = "D:/Datasets/wipo/e.txt";
Tree t = new Tree();
t.loadFromFile(file);
System.out.println(t.size);
System.out.println(t.hier.size());
t.dist(5, 6);
}
/**
* 由上到下存储路径
*/
public void CalcPath() {
treepath = new int[size][];
for(int i=0;i<size;i++){
TIntArrayList list= new TIntArrayList ();
list.add(i);
Integer j=i;
while((j=edgesInv.get(j))!=null){
list.add(j);
}
int s = list.size();
treepath[i] = new int[s];
for(int k=0;k<s;k++){
treepath[i][s-k-1] = list.get(k);
}
}
}
public int[] getPath(int i) {
return treepath[i];
}
public ArrayList<Integer> getAnc(Integer i) {
ArrayList<Integer> list= new ArrayList<Integer> ();
Integer j=i;
while((j=edgesInv.get(j))!=null){
list.add(j);
}
return list;
}
public int[] getAncIdx(Integer i) {
ArrayList<Integer> list = getAnc(i);
int[] idx = new int[list.size()];
for(int j=0;j<list.size();j++){
idx[j] = (int) list.get(j);
}
return idx;
}
/**
* 计算两个节点的最短路径距离
* @param i
* @param j
* @return 距离值
*/
public int dist(int i, int j) {
int[] anci = treepath[i];
int[] ancj = treepath[j];
int k=0;
for(;k<Math.min(ancj.length, anci.length);k++){
if(anci[k]!=ancj[k])
break;
}
int d = anci.length+ancj.length-2*k+1;
return d;
}
public int getDepth() {
return depth;
}
}