/*
* Copyright 2004-2010 Information & Software Engineering Group (188/1)
* Institute of Software Technology and Interactive Systems
* Vienna University of Technology, Austria
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package at.tuwien.ifs.somtoolbox.apps.analysis;
import java.util.Arrays;
import org.apache.commons.math.stat.StatUtils;
import com.martiansoftware.jsap.JSAPResult;
import at.tuwien.ifs.somtoolbox.SOMToolboxException;
import at.tuwien.ifs.somtoolbox.apps.config.OptionFactory;
import at.tuwien.ifs.somtoolbox.data.InputDatum;
import at.tuwien.ifs.somtoolbox.data.SOMLibClassInformation;
import at.tuwien.ifs.somtoolbox.data.SOMLibSparseInputData;
import at.tuwien.ifs.somtoolbox.layers.metrics.L2Metric;
import at.tuwien.ifs.somtoolbox.util.ElementCounter;
import at.tuwien.ifs.somtoolbox.util.InverseComparator;
import at.tuwien.ifs.somtoolbox.util.StringUtils;
import at.tuwien.ifs.somtoolbox.util.VectorTools;
/**
* @author Rudolf Mayer
* @version $Id: FeatureDistributionAnalysis.java 3589 2010-05-21 10:42:01Z mayer $
*/
public class FeatureDistributionAnalysis {
private static final String separator = "---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------";
public static void main(String[] args) throws SOMToolboxException {
// register and parse all options
JSAPResult config = OptionFactory.parseResults(args, OptionFactory.getOptInputVectorFile(true),
OptionFactory.getOptClassInformationFile(false));
String inputVectorFileName = config.getString("inputVectorFile");
String classInfoFileName = config.getString("classInformationFile");
final int paddingLength = 9;
SOMLibSparseInputData input = new SOMLibSparseInputData(inputVectorFileName);
SOMLibClassInformation classInfo = new SOMLibClassInformation(classInfoFileName);
input.setClassInfo(classInfo);
String[] classNames = classInfo.getClassNames();
double[][] means = new double[classNames.length][];
double[][] variances = new double[classNames.length][input.dim()];
double[] classVariances = new double[input.dim()];
double[] totalVariances = new double[input.dim()];
double[] aggMeans = new double[input.dim()];
int[] occurrences = new int[input.dim()];
final L2Metric metric = new L2Metric();
input.initDistanceMatrix(metric);
final double[][] data = input.getData();
for (int i = 0; i < classNames.length; i++) {
String className = classNames[i];
double[][] calssData = input.getData(className);
means[i] = VectorTools.meanVector(calssData);
for (int j = 0; j < means[i].length; j++) {
if (means[i][j] > 0) {
occurrences[j]++;
}
}
for (int j = 0; j < variances[i].length; j++) {
variances[i][j] = StatUtils.variance(VectorTools.slice(calssData, j));
}
}
for (int i = 0; i < classVariances.length; i++) {
// System.out.println(Arrays.toString(VectorTools.slice(means, i)));
classVariances[i] = StatUtils.variance(VectorTools.slice(means, i));
aggMeans[i] = StatUtils.mean(VectorTools.slice(means, i));
totalVariances[i] = StatUtils.variance(VectorTools.slice(data, i));
}
// output
System.out.println("\n");
String[] classNameDifferences = StringUtils.getDifferences(classNames);
for (int i = 0; i < classNameDifferences.length; i++) {
String classNameDifference = classNameDifferences[i];
System.out.print(StringUtils.pad(StringUtils.formatEndMaxLengthEllipsis(classNameDifference, paddingLength,
".."), paddingLength));
System.out.print(StringUtils.pad("Var", paddingLength) + (i + 1 < classNameDifferences.length ? " | " : ""));
}
System.out.print("|| " + StringUtils.pad("Mean", paddingLength) + StringUtils.pad("Var", paddingLength)
+ StringUtils.pad("Var.Mean", paddingLength) + StringUtils.pad("#Occ", paddingLength));
System.out.println("\n" + separator);
for (int i = 0; i < means[0].length; i++) {
for (int j = 0; j < means.length; j++) {
System.out.print(StringUtils.pad(means[j][i], paddingLength));
System.out.print(StringUtils.pad(variances[j][i], paddingLength) + (j + 1 < means.length ? " | " : ""));
}
System.out.print("|| " + StringUtils.pad(aggMeans[i], paddingLength));
System.out.print(StringUtils.pad(classVariances[i], paddingLength));
System.out.print(StringUtils.pad(totalVariances[i], paddingLength));
System.out.print(StringUtils.pad(occurrences[i], paddingLength));
System.out.println();
}
System.out.println(separator + "\n");
System.out.println("Nearest neighbours");
System.out.print(StringUtils.pad("Weight-Vec", 15));
int paddingLength2 = paddingLength - 2;
for (String classNameDifference : classNameDifferences) {
System.out.print(StringUtils.pad(StringUtils.formatEndMaxLengthEllipsis(classNameDifference,
paddingLength2, ".."), paddingLength2));
System.out.print(" | ");
}
System.out.print(StringUtils.pad("Purity", paddingLength2));
System.out.print(StringUtils.pad("MapSize", paddingLength2));
System.out.println(StringUtils.pad("# Neighb", paddingLength2));
String[] differences = StringUtils.getDifferences(input.getLabels());
for (int i = 0; i < input.numVectors(); i++) {
int[] perClass = new int[classNames.length];
InputDatum inputDatum = input.getInputDatum(i);
int classIndex = classInfo.getClassIndexForInput(inputDatum.getLabel());
int classMemberCount = classInfo.getNumberOfClassMembers(classIndex) - 1;
int number = classMemberCount * 1;
final InputDatum[] nearestN = input.getNearestN(i, metric, number);
for (InputDatum neighbour : nearestN) {
perClass[classInfo.getClassIndexForInput(neighbour.getLabel())]++;
}
System.out.print(StringUtils.pad(differences[i], 15));
for (int index : perClass) {
System.out.print(StringUtils.pad(index, paddingLength2) + " | ");
}
System.out.print(StringUtils.pad(StringUtils.format(perClass[classIndex] * 100.0 / classMemberCount, 2)
+ "%", paddingLength2));
System.out.print(StringUtils.pad(classInfo.getNumberOfClassMembers(classIndex), paddingLength2));
System.out.println(number);
}
System.out.println(separator + "\n");
System.out.println("Total features: " + input.dim());
ElementCounter<Integer> counter = new ElementCounter<Integer>();
for (int d : occurrences) {
counter.incCount(d);
}
Integer[] keys = counter.keySet().toArray(new Integer[counter.size()]);
Arrays.sort(keys, new InverseComparator<Integer>());
for (Integer key : keys) {
System.out.println(key + " times: " + counter.getCount(key));
}
System.out.println(separator + "\n");
System.out.println("Co-occurence of terms with other classes ");
int paddingLength3 = paddingLength + 4;
System.out.print(StringUtils.pad("Class/Count", paddingLength3) + " | ");
for (int i = 0; i < classNames.length; i++) {
System.out.print(StringUtils.pad(i, paddingLength2));
}
System.out.print(StringUtils.pad(" | Total", paddingLength2));
System.out.print(StringUtils.pad(" | Dim", paddingLength2));
System.out.println();
System.out.println(StringUtils.repeatString((classNames.length + 2) * paddingLength2 + paddingLength3, "-"));
for (int i = 0; i < classNameDifferences.length; i++) {
String classNameDifference = classNameDifferences[i];
System.out.print(StringUtils.pad(classNameDifference, paddingLength3) + " | ");
// check terms this class uses
ElementCounter<Integer> counter2 = new ElementCounter<Integer>();
for (int j = 0; j < means[i].length; j++) {
if (means[i][j] > 0) {
// count how often these terms are used in other classes
int otherClassTerms = 0;
for (int k = 0; k < classNameDifferences.length; k++) {
if (k != i && means[k][j] > 0) {
otherClassTerms++;
}
}
counter2.incCount(otherClassTerms);
}
}
// System.out.println(counter2.keySet());
for (int j = 0; j < classNameDifferences.length; j++) {
System.out.print(StringUtils.pad(counter2.getCount(j), paddingLength2));
}
System.out.print(StringUtils.pad(" | " + counter2.totalCount(), paddingLength2));
System.out.print(StringUtils.pad(" | " + input.dim(), paddingLength2));
System.out.println();
}
}
}