package edu.fudan.ml.types.featurecluster;
import gnu.trove.map.hash.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.HashMap;
public class ClusterKmeans extends AbstractCluster {
private ArrayList<ClassData> datalist;
private HashMap<Integer, Integer> map;
private HashMap<Integer, ArrayList<Integer>> mapList;
private AbstractDistance distance;
private int feasize;
private ArrayList<Double> distanceList;
private ArrayList<Integer> idList;
private TIntObjectHashMap<String> index = null;
public ClusterKmeans(ArrayList<ClassData> datalist, AbstractDistance distance, int feasize) {
this.datalist = datalist;
this.distance = distance;
this.feasize = feasize;
map = new HashMap<Integer, Integer>();
mapList = new HashMap<Integer, ArrayList<Integer>>();
distanceList = new ArrayList<Double>();
idList = new ArrayList<Integer>();
paraInit();
setAllCount();
}
/**
* @return the map
*/
public HashMap<Integer, Integer> getMap() {
return map;
}
/**
* @param index the index to set
*/
public void setIndex(TIntObjectHashMap<String> index) {
this.index = index;
}
private void paraInit() {
for (ClassData cd : datalist) {
int key = cd.getKey();
map.put(key, key);
ArrayList<Integer> alist = new ArrayList<Integer>();
alist.add(key);
mapList.put(key, alist);
}
regular();
}
private void setAllCount() {
int allCount = 0;
for (ClassData cd : datalist)
allCount += cd.getCount();
ClassData.allCount = allCount;
}
private void regular() {
for (ClassData cd : datalist) {
regular(cd);
}
}
private void regular(ClassData cd) {
double[] label = cd.getLabel();
double sum = 0;
for (double ele : label)
sum += ele;
for (int i = 0; i < label.length; i++) {
label[i] = label[i] / sum;
}
cd.setLabel(label);
}
private void merge(int id1, int idd2) { //id1:idList idd2:dataList
int idd1 = idList.get(id1);
ClassData cd1 = datalist.get(idd1);
ClassData cd2 = datalist.get(idd2);
int count1 = cd1.getCount();
int count2 = cd2.getCount();
double ratio = (double) count1 / (double) (count1 + count2);
double[] label1 = cd1.getLabel();
double[] label2 = cd2.getLabel();
int count = updateCount(count1, count2);
double[] label = updateLabel(label1, label2, ratio);
cd1.update(label, count);
try {
mapKey(cd2.getKey(), cd1.getKey());
} catch (Exception e) {
e.printStackTrace();
}
}
private int updateCount(int count1, int count2) {
return count1 + count2;
}
private double[] updateLabel(double[] label1, double[] label2, double ratio) {
int length = label1.length;
double[] label = new double[length];
for (int i = 0; i < length; i++) {
label[i] = ratio * label1[i] + (1-ratio) * label2[i];
}
return label;
}
private void mapKey(int orikey, int key) throws Exception {
int orivalue = map.get(orikey);
int value = map.get(key);
ArrayList<Integer> oriKeyList = mapList.get(orivalue);
ArrayList<Integer> keyList = mapList.get(value);
for (Integer temp : oriKeyList) {
map.put(temp, value);
keyList.add(temp);
}
mapList.remove(orivalue);
}
private void initStack() {
for (int i = 0; i < feasize; i++) {
idList.add(i);
distanceList.add(0.0);
}
}
private void updateDistance(int id) {
ClassData cd = datalist.get(id);
for (int i = 0; i < idList.size(); i++) {
if (i == id)
continue;
int idDataList = idList.get(i);
double distemp = distanceOfTwo(cd, datalist.get(idDataList));
distanceList.set(i, distemp);
}
}
private double distanceOfTwo(ClassData cd1, ClassData cd2) {
if (index == null)
return distance.cal(cd1, cd2);
if (isSameTemplate(cd1, cd2))
return distance.cal(cd1, cd2);
else
return Double.MAX_VALUE;
}
private boolean isSameTemplate(ClassData cd1, ClassData cd2) {
String key1 = index.get(cd1.getKey());
String key2 = index.get(cd2.getKey());
return key1.charAt(0) == key2.charAt(0);
}
private int minId() {
int id = -1;
double min = Double.MAX_VALUE;
for (int i = 0; i < feasize; i++) {
double temp = distanceList.get(i);
if (temp < min) {
id = i;
min = temp;
}
if (min <= 0) {
return id;
}
}
return id;
}
public void process() {
initStack();
if (feasize >= datalist.size()) {
System.out.println("Do not need feature cluster");
return;
}
for (int i = feasize; i < datalist.size(); i++) {
updateDistance(i);
int id = minId();
merge(id, i);
if ((i + 1) % 10000 == 0)
System.out.println(i);
}
}
}