package quickml.supervised.dataProcessing.instanceTranformer;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import quickml.data.instances.InstanceWithAttributesMap;
import java.util.*;
/**
* Created by chrisreeves on 10/14/15.
*/
public class CommonCoocurrenceProductFeatureAppender<I extends InstanceWithAttributesMap> implements ProductFeatureAppender<I>{
int minObservationsOfRawAttribute;
int minOverlap;
boolean approximateOverlap;
boolean allowCategoricalProductFeatures;
boolean allowNumericProductFeatures;
private boolean ignoreAttributesCommonToAllInsances = false;
public CommonCoocurrenceProductFeatureAppender setIgnoreAttributesCommonToAllInsances(boolean ignoreAttributesCommonToAllInsances) {
this.ignoreAttributesCommonToAllInsances = ignoreAttributesCommonToAllInsances;
return this;
}
public CommonCoocurrenceProductFeatureAppender setMinObservationsOfRawAttribute(int minObservationsOfRawAttribute) {
this.minObservationsOfRawAttribute = minObservationsOfRawAttribute;
return this;
}
public CommonCoocurrenceProductFeatureAppender setMinOverlap(int minOverlap) {
this.minOverlap = minOverlap;
return this;
}
public CommonCoocurrenceProductFeatureAppender setApproximateOverlap(boolean approximateOverlap) {
this.approximateOverlap = approximateOverlap;
return this;
}
public CommonCoocurrenceProductFeatureAppender setAllowCategoricalProductFeatures(boolean allowCategoricalProductFeatures) {
this.allowCategoricalProductFeatures = allowCategoricalProductFeatures;
return this;
}
public CommonCoocurrenceProductFeatureAppender setAllowNumericProductFeatures(boolean allowNumericProductFeatures) {
this.allowNumericProductFeatures = allowNumericProductFeatures;
return this;
}
@Override
public List<I> addProductAttributes(List<I> trainingData) {
Set<String> pairableAttributes = getPairableAttributes(trainingData, minObservationsOfRawAttribute, allowCategoricalProductFeatures, allowNumericProductFeatures);
Map<String, List<Integer>> invertedIndex = buildInvertedIndexOfAttributesToInstances(trainingData, pairableAttributes);
List<String> orderedPairableKeys = Lists.newArrayList(pairableAttributes);
Collections.sort(orderedPairableKeys);
for (int i = 0; i < orderedPairableKeys.size(); i++) {
for (int j = i + 1; j < orderedPairableKeys.size(); j++) {
String attribute1 = orderedPairableKeys.get(i);
String attribute2 = orderedPairableKeys.get(j);
if (attribute1.split("--")[0].equals(attribute2.split("--")[0])) {
continue; //skip self association of same attribute with different vals
}
List<Integer> instances1 = invertedIndex.get(attribute1);
List<Integer> instances2 = invertedIndex.get(attribute2);
if (enoughOverlap(instances1, instances2, minOverlap, trainingData.size(), approximateOverlap, ignoreAttributesCommonToAllInsances)) {
appendCrossAttributeToCommonInstances(trainingData, attribute1, attribute2, instances1, instances2);
}
}
}
return trainingData;
}
private static <I extends InstanceWithAttributesMap> Map<String, List<Integer>> buildInvertedIndexOfAttributesToInstances(List<I> trainingData, Set<String> pairableAttributes) {
Map<String, List<Integer>> invertedIndex = new HashMap<>();
for (int i = 0; i < trainingData.size(); i++) {
I instance = trainingData.get(i);
for (String key : instance.getAttributes().keySet()) {
if (pairableAttributes.contains(key) && ((Double)instance.getAttributes().get(key)).doubleValue() != 0.0) {
if (!invertedIndex.containsKey(key)) {
invertedIndex.put(key, new ArrayList<Integer>());
}
List<Integer> instancesIn = invertedIndex.get(key);
instancesIn.add(i);
invertedIndex.put(key, instancesIn);
}
}
}
return invertedIndex;
}
private static <I extends InstanceWithAttributesMap> Set<String> getPairableAttributes(List<I> trainingData, int minObservationsOfRawAttribute, boolean allowCategoricalProductFeatures, boolean allowNumericProductFeatures) {
Map<String, Integer> attributeCounts = new HashMap<>();
Set<String> numericAttributes = Sets.newHashSet();
for (I instance : trainingData) {
for (String key : instance.getAttributes().keySet()) {
if (!attributeCounts.containsKey(key)) {
attributeCounts.put(key, 0);
}
if (!instance.getAttributes().get(key).equals(1.0) && !instance.getAttributes().get(key).equals(0.0) ) {
numericAttributes.add(key);
}
attributeCounts.put(key, attributeCounts.get(key) + 1);
}
}
Set<String> pairableAttributes = Sets.newHashSet();
for (Map.Entry<String, Integer> entry : attributeCounts.entrySet()) {
if (entry.getValue() > minObservationsOfRawAttribute) {
if (allowNumericProductFeatures && numericAttributes.contains(entry.getKey())) {
pairableAttributes.add(entry.getKey());
}
if (allowCategoricalProductFeatures && !numericAttributes.contains(entry.getKey())) {
pairableAttributes.add(entry.getKey());
}
}
}
return pairableAttributes;
}
private static Set<String> identifyPairableAttributes(int minObservationsOfRawAttribute, Map<String, Integer> attributeCounts) {
Set<String> pairableAttributes = Sets.newHashSet();
for (Map.Entry<String, Integer> entry : attributeCounts.entrySet()) {
if (entry.getValue() > minObservationsOfRawAttribute) {
pairableAttributes.add(entry.getKey());
}
}
return pairableAttributes;
}
private static boolean enoughOverlap(List<Integer> instances1, List<Integer> instances2, int minOverlap, int numInstances, boolean approximateOverlap, boolean ignoreAttributesCommonToAllInsances) {
int overlap = 0;
if (ignoreAttributesCommonToAllInsances && (instances1.size() == numInstances || instances2.size() == numInstances)) {
return false;
}
if (approximateOverlap && instances1.size()> numInstances/4 || instances2.size() > numInstances/4) {
int larger = Math.max(instances1.size(), instances2.size());
int lesser = Math.min(instances1.size(), instances2.size());
overlap = larger/numInstances * lesser;
}
else {
int index1 = 0, index2 = 0;
while (index1 < instances1.size() && index2 < instances2.size()) {
if (instances1.get(index1).intValue() == instances2.get(index2).intValue()) {
overlap++;
index1++;
index2++;
if (overlap >= minOverlap) {
return true;
}
} else if (instances1.get(index1).intValue() < instances2.get(index2).intValue()) {
index1++;
} else {
index2++;
}
int remainingOverlap = minOverlap - overlap;
if (remainingOverlap > instances2.size() -index2 || remainingOverlap > instances1.size() -index1 ) {
return false;
}
}
}
return overlap >= minOverlap;
}
private static <I extends InstanceWithAttributesMap> List<I> appendCrossAttributeToCommonInstances(List<I> trainingData, String attribute1, String attribute2, List<Integer> instances1, List<Integer> instances2) {
int index1 = 0, index2 = 0;
String newAttribute = attribute1 + "-" + attribute2;
while (index1<instances1.size() && index2 < instances2.size()){
if (instances1.get(index1).intValue() == instances2.get(index2).intValue()) {
I instance = trainingData.get(instances1.get(index1).intValue());
double val1 = (Double) instance.getAttributes().get(attribute1);
double val2 = (Double) instance.getAttributes().get(attribute2);
instance.getAttributes().put(newAttribute, val1 * val2);
index1++;
index2++;
}
else if (instances1.get(index1).intValue() < instances2.get(index2).intValue()) {
index1++;
}
else {
index2++;
}
}
return trainingData;
}
}