package edu.fudan.ml.types.featurecluster;
import java.util.ArrayList;
import java.util.HashMap;
public class ClusterOri extends AbstractCluster {
private ArrayList<ClassData> datalist;
private HashMap<Integer, Integer> map;
private HashMap<Integer, ArrayList<Integer>> mapList;
private ArrayList<Boolean> flag;
private AbstractDistance distance;
private ArrayList<ArrayList<Double>> distanceList;
private int feasize;
public ClusterOri(ArrayList<ClassData> datalist, AbstractDistance distance, int feasize) {
this.datalist = datalist;
this.distance = distance;
this.feasize = feasize;
distanceList = new ArrayList<ArrayList<Double>>();
flag = new ArrayList<Boolean>();
map = new HashMap<Integer, Integer>();
mapList = new HashMap<Integer, ArrayList<Integer>>();
paraInit();
setAllCount();
}
/**
* @return the map
*/
public HashMap<Integer, Integer> getMap() {
return map;
}
private void paraInit() {
for (ClassData cd : datalist) {
int key = cd.getKey();
map.put(key, key);
ArrayList<Integer> alist = new ArrayList<Integer>();
alist.add(key);
mapList.put(key, alist);
flag.add(true);
}
regular();
}
private void setAllCount() {
int allCount = 0;
for (ClassData cd : datalist)
allCount += cd.getCount();
ClassData.allCount = allCount;
}
private void regular() {
for (ClassData cd : datalist) {
regular(cd);
}
}
private void regular(ClassData cd) {
double[] label = cd.getLabel();
double sum = 0;
for (double ele : label)
sum += ele;
for (int i = 0; i < label.length; i++) {
label[i] = label[i] / sum;
}
cd.setLabel(label);
}
private void merge(int id1, int id2) {
ClassData cd1 = datalist.get(id1);
ClassData cd2 = datalist.get(id2);
int count1 = cd1.getCount();
int count2 = cd2.getCount();
double ratio = (double) count1 / (double) (count1 + count2);
double[] label1 = cd1.getLabel();
double[] label2 = cd2.getLabel();
int count = updateCount(count1, count2);
double[] label = updateLabel(label1, label2, ratio);
cd1.update(label, count);
try {
mapKey(cd2.getKey(), cd1.getKey());
} catch (Exception e) {
e.printStackTrace();
}
deleteClassData(id2);
}
private int updateCount(int count1, int count2) {
return count1 + count2;
}
private double[] updateLabel(double[] label1, double[] label2, double ratio) {
int length = label1.length;
double[] label = new double[length];
for (int i = 0; i < length; i++) {
label[i] = ratio * label1[i] + (1-ratio) * label2[i];
}
return label;
}
private void mapKey(int orikey, int key) throws Exception {
int orivalue = map.get(orikey);
int value = map.get(key);
ArrayList<Integer> oriKeyList = mapList.get(orivalue);
ArrayList<Integer> keyList = mapList.get(value);
for (Integer temp : oriKeyList) {
map.put(temp, value);
keyList.add(temp);
}
mapList.remove(orivalue);
}
private void deleteClassData(int id) {
flag.set(id, false);
}
private void initDistanceAll() {
for (int i = 0; i < datalist.size(); i++) {
ArrayList<Double> disId = new ArrayList<Double>();
for (int j = 0; j < datalist.size(); j++) {
double distemp = distance.cal(datalist.get(i), datalist.get(j));
disId.add(distemp);
}
distanceList.add(disId);
}
}
private void updateDistance(int id) {
ClassData cd = datalist.get(id);
ArrayList<Double> disId = distanceList.get(id);
for (int i = 0; i < datalist.size(); i++) {
if (!flag.get(i))
continue;
else {
double distemp = distance.cal(cd, datalist.get(i));
disId.set(i, distemp);
}
}
}
private int[] minId() {
int len1 = distanceList.size();
int len2 = distanceList.get(0).size();
int[] id = new int[]{0, 0};
double min = Double.MAX_VALUE;
for (int i = 0; i < len1; i++) {
for (int j = 0; j < len2; j++) {
if (!(flag.get(i)) || !(flag.get(j)) || j == i)
continue;
else {
double temp = getDistance(i, j);
if (temp < min) {
min = temp;
id[0] = i;
id[1] = j;
}
}
}
}
return id;
}
private double getDistance(int i, int j) {
ArrayList<Double> disId = distanceList.get(i);
return disId.get(j);
}
public void process() {
initDistanceAll();
System.out.println("Finish distance init");
int cycle = getFeaSize(feasize);
for (int i = 0; i < cycle; i++) {
int[] id = minId();
System.out.println(id[0] + " " + id[1]);
merge(id[0], id[1]);
updateDistance(id[0]);
System.out.println(i);
}
}
private int getFeaSize(int feasize) {
return datalist.size() - feasize;
}
}