package ContextForest;
import java.io.Serializable;
import java.util.HashSet;
import java.util.LinkedList;
public class FowlkesMallows implements Serializable {
//Fields
//data
private LinkedList<LinkedList<String>> Set1;
private LinkedList<String> Set1LS;
private HashSet<String> Set1HS;
private LinkedList<LinkedList<String>> Set2;
private LinkedList<String> Set2LS;
private HashSet<String> Set2HS;
private HashSet<String> CombinedHash;
//Adjustment options
//Summed mismatch
private boolean SummedMismatchPenalty;
private boolean FreeMismatches;
private int NumberOfFreeMatches;
private double PenaltyperMismatch;
//Exact penalty
private boolean AdjustmentPenalty;
//Matching statistics
private int Set1Only;
private int Set2Only;
private int Intersection;
private int Union;
//computation
private int[][] Matrix;
private double OriginalFowlkesMallows;
private double AdjustmentFactor;
private double B;
//sizes
private int QueryLeafCount;
private boolean IdenticalDataSets;
//Constructor
public FowlkesMallows(LinkedList<LinkedList<String>> Set1, LinkedList<LinkedList<String>> Set2){
//parameters
this.Set1 = Set1;
this.Set2 = Set2;
this.Set1LS = Set2List(Set1);
this.Set2LS = Set2List(Set2);
this.Set1HS = new HashSet<String>(Set1LS);
this.Set2HS = new HashSet<String>(Set2LS);
//build combined hash
LinkedList<String> Combined = new LinkedList<String>();
Combined.addAll(Set1LS);
Combined.addAll(Set2LS);
this.CombinedHash = new HashSet<String>(Combined);
}
// ------ Dissimilarity Processing -----------//
//Whole process
public double Compute(){
//determine elements counts (for scale factor)
ElementCounts();
//determine adjustment factor
if (AdjustmentPenalty){
AdjustmentFactor = SummedMismatchPenalty();
} else {
AdjustmentFactor = 1;
}
//retrieve original value
OriginalFowlkesMallows = OriginalFowlkesMallows();
//adjust value
B = OriginalFowlkesMallows * AdjustmentFactor;
//return value
return B;
}
//Original Fowlkes-Mallow, with clusters / k determined
public double OriginalFowlkesMallows(){
//Fill out matrix (Set1 = I Set 2 = j)
Matrix = new int[Set1.size()][Set2.size()];
for (int i = 0; i < Set1.size(); i++){
for (int j = 0; j <Set2.size(); j++){
Matrix[i][j] = CommonCounts(Set1.get(i),Set2.get(j));
}
}
// //display matrix - debugging
// for (int i = 0; i < Set1.size(); i++){
// String str = "";
// for (int j = 0; j < Set2.size(); j++){
// str = str + String.valueOf(Matrix[i][j] + " ");
// }
// System.out.println(str);
// }
//
// System.out.println("Breakpoint!");
//
//compute components
int P = 0,Q = 0,T = 0;
int sumI,sumJ = 0;
//int n = Union; //This value considers repeated elements.
int n = Intersection; //When the data sets are identical, Intersection/Union are the same.
for (int i = 0; i <Set1.size(); i++){
//re-initialize the sum of Js
sumJ = 0;
for (int j = 0; j < Set2.size(); j++){
//increment T
T = T + (Matrix[i][j]*Matrix[i][j]);
//increment sum J
sumJ = sumJ + Matrix[i][j];
}
//increment P
P = P + sumJ*sumJ;
}
for (int j = 0; j < Set2.size(); j++){
//re-initialize the sum of Is.
sumI = 0;
for (int i = 0; i <Set1.size(); i++){
sumI = sumI + Matrix[i][j];
}
//increment Q
Q = Q + sumI*sumI;
}
//final adjustments for P,Q, and T
Q = Q - n;
P = P - n;
T = T - n;
//dissimilarity
double D = (double) T / (Math.sqrt(((double) P * (double) Q)));
//NaN case - P or Q are 0, so dividing by 0.
if (Double.isNaN(D) || D < 0){
D = 0.0;
} else if (D > 1.0){
D = 1.0;
}
// //debugging
// System.out.println("Matrix:");
// for (int i = 0; i < Set1.size(); i++){
// String str = "";
// for (int j = 0; j < Set2.size(); j++){
// str = str + " " + String.valueOf(Matrix[i][j]);
// }
// System.out.println(str);
// }
// System.out.println("P: " + P + " Q: " + Q + " T: " + T);
// System.out.println("D: " + D);
return D;
}
//Return common counts. Accounts for instances of identical elements.
public int CommonCounts(LinkedList<String> A, LinkedList<String> B){
//Initialize output
int Counts = 0;
//Aggregate all candidate Strings
LinkedList<String> C = new LinkedList<String>();
C.addAll(A);
C.addAll(B);
HashSet<String> HS = new HashSet<String>(C);
//Count intersections, with number of intersecting identical elements
for (String s : HS){
//re-initialize individual counts for each String
int Count1 = 0;
int Count2 = 0;
//determine 1 counts
for (String s1: A){
if (s.equals(s1)){
Count1++;
}
}
//determine 2 counts
for (String s2: B){
if (s.equals(s2)){
Count2++;
}
}
//If an intersection, add appropriate number of intersecting elements
if (Count1 > 0 && Count2 > 0){
Counts = Counts + Math.min(Count1, Count2);
}
}
//return counts
return Counts;
}
// --- Preprocessing ------ //
//Determine number of elements intersecting, matching, etc
public void ElementCounts(){
int Intersect = 0;
int Only1 = 0;
int Only2 = 0;
for (String s : CombinedHash){
//re-initialize values
int Count1 = 0;
int Count2 = 0;
//determine 1 counts
for (String s1: Set1LS){
if (s.equals(s1)){
Count1++;
}
}
//determine 2 counts
for (String s2: Set2LS){
if (s.equals(s2)){
Count2++;
}
}
//check for intersection
if (Count1 > 0 && Count2 > 0){
if (Count1 > Count2){
Intersect = Intersect + Count2;
Only1 = Only1 + (Count1 - Count2);
} else if (Count1 < Count2){
Intersect = Intersect + Count1;
Only2 = Only2 + (Count2 - Count1);
} else {
Intersect = Intersect + Count1; //arbitrary because Count1==Count2
}
//only in Count 1
} else if (Count1 > 0 && Count2 == 0){
Only1 = Only1 + Count1;
//only in Count 2
} else if (Count1 == 0 && Count2 > 0){
Only2 = Only2 + Count2;
}
}
//update values
Set1Only = Only1;
Set2Only = Only2;
Intersection = Intersect;
Union = Only1 + Only2 + Intersect;
//Query Leaves
QueryLeafCount = Intersect + Only2;
//Identical data sets if no objects only in one set or only in another set
if (Only1 == 0 && Only2 == 0){
IdenticalDataSets = true;
} else {
IdenticalDataSets = false;
}
}
// ---- AdjustmentStep ----- //
//Summed mismatch penalty
public double SummedMismatchPenalty(){
//Initialize output
double penalty = 0.0;
//Compute mismatches, and adjust
int TotalMismatches = Set1Only + Set2Only;
//Adjust for free matches
if (FreeMismatches){
TotalMismatches = Math.max((TotalMismatches - NumberOfFreeMatches), 0);
}
//adjust accordingly
penalty = 1.0 - (double) TotalMismatches * PenaltyperMismatch;
//adjust penalty value into scale factor
if (penalty < 0){
penalty = 0;
} else if (penalty > 1){
penalty = 1;
}
//return computed penalty
return penalty;
}
//Dice Or Jaccard penalty
public double DiceOrJaccardPenalty(){
//Initialize output
double penalty = 0.0;
//determine appropriate value
if (AdjustmentPenalty){ //Dice penalty
penalty = (2.0 * (double) Intersection /
((double) Set1LS.size() + (double) Set2LS.size()));
} else { //Jaccard penalty
penalty = ((double) Intersection / (double) Union);
}
//return value
return penalty;
}
// ------ CONVERSIONS --------------//
//Convert each cluster set into a linked list
public LinkedList<String> Set2List(LinkedList<LinkedList<String>> L){
//Initialize output
LinkedList<String> Output = new LinkedList<String>();
//add items to Set
for (LinkedList<String> list : L){
Output.addAll(list);
}
return Output;
}
//------ SETTERS AND GETTERS -------//
public LinkedList<LinkedList<String>> getSet1() {
return Set1;
}
public void setSet1(LinkedList<LinkedList<String>> set1) {
Set1 = set1;
}
public LinkedList<LinkedList<String>> getSet2() {
return Set2;
}
public void setSet2(LinkedList<LinkedList<String>> set2) {
Set2 = set2;
}
public double getB() {
return B;
}
public void setB(double b) {
B = b;
}
public LinkedList<String> getSet1LS() {
return Set1LS;
}
public void setSet1LS(LinkedList<String> set1LS) {
Set1LS = set1LS;
}
public LinkedList<String> getSet2LS() {
return Set2LS;
}
public void setSet2LS(LinkedList<String> set2LS) {
Set2LS = set2LS;
}
public int[][] getMatrix() {
return Matrix;
}
public void setMatrix(int[][] matrix) {
Matrix = matrix;
}
public double getOriginalFowlkesMallows() {
return OriginalFowlkesMallows;
}
public void setOriginalFowlkesMallows(double originalFowlkesMallows) {
OriginalFowlkesMallows = originalFowlkesMallows;
}
public double getAdjustmentFactor() {
return AdjustmentFactor;
}
public void setAdjustmentFactor(double adjustmentFactor) {
AdjustmentFactor = adjustmentFactor;
}
public boolean isSummedMismatchPenalty() {
return SummedMismatchPenalty;
}
public void setSummedMismatchPenalty(boolean summedMismatchPenalty) {
SummedMismatchPenalty = summedMismatchPenalty;
}
public boolean isFreeMismatches() {
return FreeMismatches;
}
public void setFreeMismatches(boolean freeMismatches) {
FreeMismatches = freeMismatches;
}
public int getNumberOfFreeMatches() {
return NumberOfFreeMatches;
}
public void setNumberOfFreeMatches(int numberOfFreeMatches) {
NumberOfFreeMatches = numberOfFreeMatches;
}
public double getPenaltyperMismatch() {
return PenaltyperMismatch;
}
public void setPenaltyperMismatch(double penaltyperMismatch) {
PenaltyperMismatch = penaltyperMismatch;
}
public boolean isAdjustmentPenalty() {
return AdjustmentPenalty;
}
public void setAdjustmentPenalty(boolean dicePenalty) {
AdjustmentPenalty = dicePenalty;
}
public int getQueryLeaves() {
return QueryLeafCount;
}
public void setQueryLeaves(int queryLeaves) {
QueryLeafCount = queryLeaves;
}
public boolean isIdenticalDataSets() {
return IdenticalDataSets;
}
public void setIdenticalDataSets(boolean identicalDataSets) {
IdenticalDataSets = identicalDataSets;
}
}