/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.explore;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.chombo.mr.FeatureField;
import org.chombo.util.Pair;
import org.chombo.util.Triplet;
/**
* Processes mutual info score
* @author pranab
*
*/
public class MutualInformationScore {
private List<FeatureMutualInfo> featureClassMutualInfoList = new ArrayList<FeatureMutualInfo>();
private List<FeaturePairMutualInfo> featurePairMutualInfoList = new ArrayList<FeaturePairMutualInfo>();
private List<FeaturePairMutualInfo> featurePairClassMutualInfoList = new ArrayList<FeaturePairMutualInfo>();
private List<FeaturePairEntropy> featurePairClassEntropyList = new ArrayList<FeaturePairEntropy>();
/**
* @author pranab
*
*/
public static class FeatureMutualInfo extends Pair<Integer, Double> implements Comparable<FeatureMutualInfo> {
public FeatureMutualInfo( int featureOrdinal, double mutualInfo) {
super( featureOrdinal, mutualInfo);
}
@Override
public int compareTo(FeatureMutualInfo that) {
return that.getRight().compareTo(this.getRight());
}
}
/**
* @author pranab
*
*/
public static class FeaturePairMutualInfo extends Triplet<Integer, Integer, Double> {
public FeaturePairMutualInfo(int firstFeatureOrdinal, int secondFeatureOrdinal, double mutualInfo) {
super(firstFeatureOrdinal, secondFeatureOrdinal, mutualInfo);
}
}
/**
* @author pranab
*
*/
public static class FeaturePairEntropy extends Triplet<Integer, Integer, Double> {
public FeaturePairEntropy(int firstFeatureOrdinal, int secondFeatureOrdinal, double entropy) {
super(firstFeatureOrdinal, secondFeatureOrdinal, entropy);
}
}
/**
* @param featureOrdinal
* @param mutualInfo
*/
public void addFeatureClassMutualInfo(int featureOrdinal, double mutualInfo) {
FeatureMutualInfo featureClassMutualInfo = new FeatureMutualInfo( featureOrdinal, mutualInfo);
featureClassMutualInfoList.add(featureClassMutualInfo);
}
/**
*
*/
public void sortFeatureMutualInfo() {
Collections.sort(featureClassMutualInfoList);
}
/**
* Mutual lInformation Maximization (MIM)
* @return
*/
public List<FeatureMutualInfo> getMutualInfoMaximizerScore() {
sortFeatureMutualInfo();
return featureClassMutualInfoList;
}
/**
* @param featureOrdinal
* @param mutualInfo
*/
public void addFeaturePairMutualInfo(int firstFeatureOrdinal, int secondFeatureOrdinal, double mutualInfo) {
FeaturePairMutualInfo featurepairMutualInfo = new FeaturePairMutualInfo( firstFeatureOrdinal, secondFeatureOrdinal, mutualInfo);
featurePairMutualInfoList.add(featurepairMutualInfo);
}
/**
* Mutual Information Feature Selection (MIFS)
* @return
*/
public List<FeatureMutualInfo> getMutualInfoFeatureSelectionScore(double redunacyFactor) {
List<FeatureMutualInfo> mutualInfoFeatureSelection = new ArrayList<FeatureMutualInfo>();
Set<Integer> selectedFeatures = new HashSet<Integer>();
while (selectedFeatures.size() < featureClassMutualInfoList.size() ) {
double maxScore = Double.NEGATIVE_INFINITY;
int selectedFeature = 0;
//all features
for (FeatureMutualInfo muInfo : featureClassMutualInfoList) {
int feature = muInfo.getLeft();
if (selectedFeatures.contains(feature)) {
continue;
}
//all feature pair mutual info
double sum = 0;
for (FeaturePairMutualInfo otherMuInfo : featurePairMutualInfoList) {
//pair with feature already selected
if ( otherMuInfo.getLeft() == feature && selectedFeatures.contains(otherMuInfo.getCenter())) {
sum += otherMuInfo.getRight();
} else if ( otherMuInfo.getCenter() == feature && selectedFeatures.contains(otherMuInfo.getLeft())) {
sum += otherMuInfo.getRight();
}
}
double score = muInfo.getRight() - redunacyFactor * sum;
if (score > maxScore) {
maxScore = score;
selectedFeature = feature;
}
}
//add the feature with max score
FeatureMutualInfo featureClassMutualInfo = new FeatureMutualInfo( selectedFeature, maxScore);
mutualInfoFeatureSelection.add(featureClassMutualInfo);
selectedFeatures.add(selectedFeature);
}
return mutualInfoFeatureSelection;
}
/**
* @param featureOrdinal
* @param mutualInfo
*/
public void addFeaturePairClassMutualInfo(int firstFeatureOrdinal, int secondFeatureOrdinal, double mutualInfo) {
FeaturePairMutualInfo featurePairMutualInfo = new FeaturePairMutualInfo( firstFeatureOrdinal, secondFeatureOrdinal, mutualInfo);
featurePairClassMutualInfoList.add(featurePairMutualInfo);
}
/**
* @param featureOrdinal
* @param mutualInfo
*/
public void addFeaturePairClassEntropy(int firstFeatureOrdinal, int secondFeatureOrdinal, double entropy) {
FeaturePairEntropy featurePairEntropy = new FeaturePairEntropy( firstFeatureOrdinal, secondFeatureOrdinal, entropy);
featurePairClassEntropyList.add(featurePairEntropy);
}
/**
* Joint Mutual Info (JMI)
* @return
*/
public List<FeatureMutualInfo> getJointMutualInfoScore() {
return getJointMutualInfoScoreHelper(true);
}
/**
* Double Input Symetrical Relevance (DISR)
* @return
*/
public List<FeatureMutualInfo> getDoubleInputSymmetricalRelevanceScore() {
return getJointMutualInfoScoreHelper(false);
}
/**
* Joint Mutual Info (JMI)
* @param featureFields
* @return
*/
private List<FeatureMutualInfo> getJointMutualInfoScoreHelper(boolean joinMutInfo ) {
List<FeatureMutualInfo> featureJointMutualInfoList = new ArrayList<FeatureMutualInfo>();
Set<Integer> selectedFeatures = new HashSet<Integer>();
//boot strap selected feature set with one based on max relevancy
FeatureMutualInfo mostRelevantFeature = getMutualInfoMaximizerScore().get(0);
FeatureMutualInfo featureClassMutualInfo = new FeatureMutualInfo( mostRelevantFeature.getLeft(), mostRelevantFeature.getRight());
featureJointMutualInfoList.add(featureClassMutualInfo);
selectedFeatures.add(mostRelevantFeature.getLeft());
//select features
while (selectedFeatures.size() < featureClassMutualInfoList.size() ) {
double maxScore = Double.NEGATIVE_INFINITY;
int selectedFeature = 0;
//all features
for (FeatureMutualInfo featureMuInfo : featureClassMutualInfoList ) {
int feature = featureMuInfo.getLeft();
if (selectedFeatures.contains(feature)) {
continue;
}
double sum = 0;
for (FeaturePairMutualInfo featurePairMuInfo : featurePairClassMutualInfoList) {
//pair with feature already selected
if ( featurePairMuInfo.getLeft() == feature && selectedFeatures.contains(featurePairMuInfo.getCenter()) ||
featurePairMuInfo.getCenter() == feature && selectedFeatures.contains(featurePairMuInfo.getLeft()) ) {
if (joinMutInfo) {
sum += featurePairMuInfo.getRight();
} else {
FeaturePairEntropy featurePairEntropy = getFeaturePairClassEntropy(featurePairMuInfo.getLeft(), featurePairMuInfo.getCenter());
sum += featurePairMuInfo.getRight() / featurePairEntropy.getRight() ;
}
}
}
double score = sum ;
if (score > maxScore) {
maxScore = score;
selectedFeature = feature;
}
}
//add the feature with max score
featureClassMutualInfo = new FeatureMutualInfo( selectedFeature, maxScore);
featureJointMutualInfoList.add(featureClassMutualInfo);
selectedFeatures.add(selectedFeature);
}
return featureJointMutualInfoList;
}
/**
* @param featureOne
* @param featureTwo
* @return
*/
private FeaturePairEntropy getFeaturePairClassEntropy(int featureOne, int featureTwo) {
FeaturePairEntropy featurePairEntropy = null;
for (FeaturePairEntropy thisFeaturePairEntropy : featurePairClassEntropyList) {
if (thisFeaturePairEntropy.getLeft() == featureOne && thisFeaturePairEntropy.getCenter() == featureTwo ||
thisFeaturePairEntropy.getLeft() == featureTwo && thisFeaturePairEntropy.getCenter() == featureOne ) {
featurePairEntropy = thisFeaturePairEntropy;
break;
}
}
return featurePairEntropy;
}
/**
* Min redundancy Max Relevance (MRMR)
* @return
*/
public List<FeatureMutualInfo> getMinRedundancyMaxrelevanceScore( ) {
List<FeatureMutualInfo> minRedundancyMaxrelevance = new ArrayList<FeatureMutualInfo>();
Set<Integer> selectedFeatures = new HashSet<Integer>();
while (selectedFeatures.size() < featureClassMutualInfoList.size() ) {
double maxScore = Double.NEGATIVE_INFINITY;
int selectedFeature = 0;
for (FeatureMutualInfo featureMuInfo : featureClassMutualInfoList) {
int feature = featureMuInfo.getLeft();
if (selectedFeatures.contains(feature)) {
continue;
}
double feMuInfo = featureMuInfo.getRight();
double sum = 0;
for (FeaturePairMutualInfo featurePairMuInfo : featurePairMutualInfoList) {
//pair with feature already selected
if ( featurePairMuInfo.getLeft() == feature && selectedFeatures.contains(featurePairMuInfo.getCenter())) {
sum += featurePairMuInfo.getRight();
} else if ( featurePairMuInfo.getCenter() == feature && selectedFeatures.contains(featurePairMuInfo.getLeft())) {
sum += featurePairMuInfo.getRight();
}
}
double score = selectedFeatures.size() > 0 ? feMuInfo - sum / selectedFeatures.size() : feMuInfo ;
if (score > maxScore) {
maxScore = score;
selectedFeature = feature;
}
}
//add the feature with max score
FeatureMutualInfo featureClassMutualInfo = new FeatureMutualInfo( selectedFeature, maxScore);
minRedundancyMaxrelevance.add(featureClassMutualInfo);
selectedFeatures.add(selectedFeature);
}
return minRedundancyMaxrelevance;
}
}