package edu.fudan.ml.types.featurecluster;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.TreeMap;
public class ClusterFix extends AbstractCluster {
private ArrayList<ClassData> datalist;
private HashMap<Integer, Integer> map;
private HashMap<Integer, ArrayList<Integer>> mapList;
private AbstractDistance distance;
private int feasize;
private ArrayList<ArrayList<Double>> distanceList;
private ArrayList<Integer> idList;
private TreeMap<Double, Set<String>> sortmap;
public ClusterFix(ArrayList<ClassData> datalist, AbstractDistance distance, int feasize) {
this.datalist = datalist;
this.distance = distance;
this.feasize = feasize;
map = new HashMap<Integer, Integer>();
mapList = new HashMap<Integer, ArrayList<Integer>>();
distanceList = new ArrayList<ArrayList<Double>>();
idList = new ArrayList<Integer>();
sortmap = new TreeMap<Double, Set<String>>();
paraInit();
setAllCount();
}
/**
* @return the map
*/
public HashMap<Integer, Integer> getMap() {
return map;
}
private void paraInit() {
for (ClassData cd : datalist) {
int key = cd.getKey();
map.put(key, key);
ArrayList<Integer> alist = new ArrayList<Integer>();
alist.add(key);
mapList.put(key, alist);
}
regular();
}
private void setAllCount() {
int allCount = 0;
for (ClassData cd : datalist)
allCount += cd.getCount();
ClassData.allCount = allCount;
}
private void regular() {
for (ClassData cd : datalist) {
regular(cd);
}
}
private void regular(ClassData cd) {
double[] label = cd.getLabel();
double sum = 0;
for (double ele : label)
sum += ele;
for (int i = 0; i < label.length; i++) {
label[i] = label[i] / sum;
}
cd.setLabel(label);
}
private void merge(int id1, int id2) {
int idd1 = idList.get(id1);
int idd2 = idList.get(id2);
ClassData cd1 = datalist.get(idd1);
ClassData cd2 = datalist.get(idd2);
int count1 = cd1.getCount();
int count2 = cd2.getCount();
double ratio = (double) count1 / (double) (count1 + count2);
double[] label1 = cd1.getLabel();
double[] label2 = cd2.getLabel();
int count = updateCount(count1, count2);
double[] label = updateLabel(label1, label2, ratio);
cd1.update(label, count);
try {
mapKey(cd2.getKey(), cd1.getKey());
} catch (Exception e) {
e.printStackTrace();
}
}
private int updateCount(int count1, int count2) {
return count1 + count2;
}
private double[] updateLabel(double[] label1, double[] label2, double ratio) {
int length = label1.length;
double[] label = new double[length];
for (int i = 0; i < length; i++) {
label[i] = ratio * label1[i] + (1-ratio) * label2[i];
}
return label;
}
private void mapKey(int orikey, int key) throws Exception {
int orivalue = map.get(orikey);
int value = map.get(key);
ArrayList<Integer> oriKeyList = mapList.get(orivalue);
ArrayList<Integer> keyList = mapList.get(value);
for (Integer temp : oriKeyList) {
map.put(temp, value);
keyList.add(temp);
}
mapList.remove(orivalue);
}
private void addClassData(int id, int idd) {
idList.set(id, idd);
}
private void initDistanceAll() {
for (int i = 0; i < feasize; i++) {
ArrayList<Double> disId = new ArrayList<Double>();
idList.add(i);
for (int j = 0; j < feasize; j++) {
double distemp = distance.cal(datalist.get(i), datalist.get(j));
disId.add(distemp);
}
distanceList.add(disId);
}
}
private void updateDistance(int id) { //id of idList
int idd = idList.get(id);
ClassData cd = datalist.get(idd);
for (int i = 0; i < idList.size(); i++) {
if (i == id)
continue;
int idDataList = idList.get(i);
double distemp = distance.cal(cd, datalist.get(idDataList));
if (i < id) {
double oridata = getDistance(i, id);
updateSortMap(i, id, oridata, distemp);
}
else {
double oridata = getDistance(id, i);
updateSortMap(id, i, oridata, distemp);
}
setDistance(id, i, distemp);
setDistance(i, id, distemp);
}
}
private void setDistance(int i, int j, double value) {
ArrayList<Double> disId = distanceList.get(i);
disId.set(j, value);
}
private int[] minIdMap() {
double key = sortmap.firstKey();
Set<String> set = sortmap.get(key);
Iterator<String> it = set.iterator();
int[] id = string2Id(it.next());
// System.out.println(key + " " + id[0] + " " + id[1]);
return id;
}
public int[] minId() {
int[] id = new int[]{0, 0};
double min = Double.MAX_VALUE;
for (int i = 0; i < feasize; i++) {
for (int j = 0; j < feasize; j++) {
if (j == i)
continue;
else {
double temp = getDistance(i, j);
if (temp < min) {
min = temp;
id[0] = i;
id[1] = j;
}
}
if (min <= 0) {
// System.out.println("min: " + min);
return id;
}
}
}
// System.out.println("min: " + min);
return id;
}
private void initSortMap() {
for (int i = 0; i < feasize - 1; i++) {
for (int j = i + 1; j < feasize; j++) {
double temp = getDistance(i, j);
Set<String> hashset;
if (sortmap.containsKey(temp))
hashset = sortmap.get(temp);
else
hashset = new HashSet<String>();
// System.out.println(temp + " " + i + " " + j);
// System.out.println(getDistance(i, j));
hashset.add(id2String(i, j));
sortmap.put(temp, hashset);
}
}
System.out.println("Map init size: " + sortmap.size());
}
private void updateSortMap(int a, int b, double oridata, double updatedata) {
deleteSortMap(a, b, oridata);
addSortMap(a, b, updatedata);
}
private void deleteSortMap(int a, int b, double oridata) {
Set<String> set = sortmap.get(oridata);
if (set == null)
return;
if (set.size() == 1)
sortmap.remove(oridata);
else
set.remove(id2String(a, b));
}
private void addSortMap(int a, int b, double updatedata) {
if (sortmap.containsKey(updatedata)) {
Set<String> set = sortmap.get(updatedata);
set.add(id2String(a, b));
}
else {
Set<String> set = new HashSet<String>();
set.add(id2String(a, b));
sortmap.put(updatedata, set);
}
}
private double getDistance(int i, int j) {
ArrayList<Double> disId = distanceList.get(i);
return disId.get(j);
}
private String id2String(int a, int b) {
String s = a + "$" + b;
return s;
}
private int[] string2Id(String s) {
String[] sid = s.split("\\$");
int[] id = new int[2];
id[0] = Integer.parseInt(sid[0]);
id[1] = Integer.parseInt(sid[1]);
return id;
}
public void process() {
initDistanceAll();
initSortMap();
System.out.println("Finish distance & sortmap init");
if (feasize >= datalist.size()) {
System.out.println("Do not need feature cluster");
return;
}
for (int i = feasize; i < datalist.size(); i++) {
int[] id = minIdMap();
merge(id[0], id[1]);
addClassData(id[1], i);
updateDistance(id[0]);
updateDistance(id[1]);
if ((i + 1) % 10000 == 0) {
System.out.println(i);
System.out.println(sortmap.size());
}
}
}
}