/*
* Copyright (c) 2017 OBiBa. All rights reserved.
*
* This program and the accompanying materials
* are made available under the terms of the GNU Public License v3.0.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.obiba.magma.math.summary;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.validation.constraints.NotNull;
import org.obiba.magma.Category;
import org.obiba.magma.Value;
import org.obiba.magma.ValueSource;
import org.obiba.magma.ValueTable;
import org.obiba.magma.Variable;
import org.obiba.magma.type.BooleanType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
/**
*
*/
public class CategoricalVariableSummary extends AbstractVariableSummary implements Serializable {
private static final long serialVersionUID = 203198842420473154L;
private static final Logger log = LoggerFactory.getLogger(CategoricalVariableSummary.class);
public static final String NULL_NAME = "N/A";
private static final String OTHER_NAME = "OTHER_VALUES";
private final org.apache.commons.math3.stat.Frequency frequencyDist = new org.apache.commons.math3.stat.Frequency();
/**
* Mode is the most frequent value
*/
private String mode = NULL_NAME;
private long n;
private boolean distinct;
private boolean empty = true;
private final Collection<Frequency> frequencies = new ArrayList<>();
private long otherFrequency;
private CategoricalVariableSummary(@NotNull Variable variable) {
super(variable);
}
@Override
public String getCacheKey(ValueTable table) {
return CategoricalVariableSummaryFactory.getCacheKey(variable, table, distinct, getOffset(), getLimit());
}
@NotNull
public Iterable<Frequency> getFrequencies() {
return ImmutableList.copyOf(frequencies);
}
public String getMode() {
return mode;
}
public long getN() {
return n;
}
public boolean isDistinct() {
return distinct;
}
public void setDistinct(boolean distinct) {
this.distinct = distinct;
}
public boolean isEmpty() {
return empty;
}
public long getOtherFrequency() {
return otherFrequency;
}
public static class Frequency implements Serializable {
private static final long serialVersionUID = -2876592652764310324L;
private final String value;
private final long freq;
private final double pct;
private final boolean missing;
public Frequency(String value, long freq, double pct, boolean missing) {
this.value = value;
this.freq = freq;
this.pct = pct;
this.missing = missing;
}
public String getValue() {
return value;
}
public long getFreq() {
return freq;
}
public double getPct() {
return pct;
}
public boolean isMissing() {
return missing;
}
}
@SuppressWarnings("ParameterHidesMemberVariable")
public static class Builder implements VariableSummaryBuilder<CategoricalVariableSummary, Builder> {
private final CategoricalVariableSummary summary;
@NotNull
private final Variable variable;
private boolean addedTable;
private boolean addedValue;
public Builder(@NotNull Variable variable) {
this.variable = variable;
summary = new CategoricalVariableSummary(variable);
}
@Override
public Builder addValue(@NotNull Value value) {
if(addedTable) {
throw new IllegalStateException("Cannot add value for variable " + summary.variable.getName() +
" because values where previously added from the whole table with addTable().");
}
add(value, categoryNames());
addedValue = true;
return this;
}
@Override
public Builder addTable(@NotNull ValueTable table, @NotNull ValueSource valueSource) {
if(addedValue) {
throw new IllegalStateException("Cannot add table for variable " + summary.variable.getName() +
" because values where previously added with addValue().");
}
add(table, valueSource);
addedTable = true;
return this;
}
private void add(@NotNull ValueTable table, @NotNull ValueSource variableValueSource) {
//noinspection ConstantConditions
Preconditions.checkArgument(table != null, "table cannot be null");
//noinspection ConstantConditions
Preconditions.checkArgument(variableValueSource != null, "variableValueSource cannot be null");
if(!variableValueSource.supportVectorSource()) return;
for(Value value : variableValueSource.asVectorSource().getValues(summary.getFilteredVariableEntities(table))) {
add(value, categoryNames());
}
}
private void add(@NotNull Value value, List<String> categoryNames) {
//noinspection ConstantConditions
Preconditions.checkArgument(value != null, "value cannot be null");
if(summary.empty) summary.empty = false;
if(value.isSequence()) {
if(value.isNull()) {
summary.frequencyDist.addValue(NULL_NAME);
} else {
for(Value v : value.asSequence().getValue()) {
add(v, categoryNames);
}
}
} else {
if(value.isNull()) {
summary.frequencyDist.addValue(NULL_NAME);
} else if(summary.distinct || categoryNames.contains(value.toString())) {
summary.frequencyDist.addValue(value.toString());
} else {
summary.frequencyDist.addValue(OTHER_NAME);
}
}
}
/**
* Returns an iterator of frequencyDist names
*/
private Iterator<String> freqNames(org.apache.commons.math3.stat.Frequency freq) {
return Iterators.transform(freq.valuesIterator(), new Function<Comparable<?>, String>() {
@Override
public String apply(Comparable<?> input) {
return input.toString();
}
});
}
/**
* Returns an iterator of category names
*/
private List<String> categoryNames() {
if(variable.getValueType().equals(BooleanType.get())) {
return ImmutableList.<String>builder() //
.add(BooleanType.get().trueValue().toString()) //
.add(BooleanType.get().falseValue().toString()).build();
}
return Lists.newArrayList(Iterables.transform(variable.getCategories(), new Function<Category, String>() {
@Override
public String apply(Category from) {
return from.getName();
}
}));
}
private Map<String, Category> getCategoriesByName() {
return Maps.uniqueIndex(variable.getCategories(), new Function<Category, String>() {
@Override
public String apply(Category input) {
return input.getName();
}
});
}
private void compute() {
log.trace("Start compute categorical {}", summary.variable.getName());
long max = 0;
Iterator<String> concat = summary.distinct //
? freqNames(summary.frequencyDist) // category names, null values and distinct values
: Iterators.concat(categoryNames().iterator(),
ImmutableList.of(NULL_NAME).iterator()); // category names and null values
// Iterate over all category names including or not distinct values.
// The loop will also determine the mode of the distribution (most frequent value)
Map<String, Category> categoriesByName = getCategoriesByName();
while(concat.hasNext()) {
String value = concat.next();
long count = summary.frequencyDist.getCount(value);
if(count > max) {
max = count;
summary.mode = value;
}
boolean notMissing = variable.getValueType().equals(BooleanType.get())
? value.equals(BooleanType.get().trueValue().toString()) ||
value.equals(BooleanType.get().falseValue().toString())
: categoriesByName.containsKey(value) && !categoriesByName.get(value).isMissing();
summary.frequencies.add(new Frequency(value, summary.frequencyDist.getCount(value),
Double.isNaN(summary.frequencyDist.getPct(value)) ? 0.0 : summary.frequencyDist.getPct(value),
!notMissing));
}
summary.otherFrequency = summary.frequencyDist.getCount(OTHER_NAME);
summary.n = summary.frequencyDist.getSumFreq();
}
public Builder distinct(boolean distinct) {
summary.setDistinct(distinct);
return this;
}
public Builder filter(Integer offset, Integer limit) {
summary.setOffset(offset);
summary.setLimit(limit);
return this;
}
@Override
@NotNull
public CategoricalVariableSummary build() {
compute();
return summary;
}
@NotNull
@Override
public Variable getVariable() {
return variable;
}
}
}