/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.udf.stats;
import java.util.*;
/**
* counter for the categorical val
*/
public class CategoryCounter extends Counter {
private List<Double> binPosRate;
private List<String> categories;
private Set<String> missingValSet = new HashSet<String>();
private Map<String, Integer> categoryValIndex = new HashMap<String, Integer>();
private Map<String, Long> categoryMap = new HashMap<String, Long>();
private long missCounter;
private double unitSum = 0.0;
public CategoryCounter(List<String> missingInvalidValues, List<String> categories, List<Double> binPosRate) {
this.missingValSet.addAll(missingInvalidValues);
this.categories = categories;
this.binPosRate = binPosRate;
for(int i = 0; i < categories.size(); i++) {
categoryMap.put(categories.get(i), 0L);
categoryValIndex.put(categories.get(i), i);
}
this.missCounter = 0;
}
@Override
public void addData(String val) {
if(val == null || this.missingValSet.contains(val)) {
missCounter++;
} else {
String sVal = val.toString();
if(categoryMap.containsKey(sVal)) {
categoryMap.put(sVal, categoryMap.get(sVal) + 1);
int index = categoryValIndex.get(sVal);
this.unitSum += this.binPosRate.get(index);
} else {
missCounter++;
}
}
}
@Override
public List<Long> getCounter() {
List<Long> counters = new ArrayList<Long>();
for(int i = 0; i < categories.size(); i++) {
counters.add(categoryMap.get(categories.get(i)));
}
counters.add(missCounter);
return counters;
}
@Override
public double getUnitMean() {
long total = getTotalInstCnt();
double unitMean;
if(total == 0 || total == missCounter) {
unitMean = Double.NaN;
} else {
unitMean = this.unitSum / total;
}
return unitMean;
}
@Override
public double getMissingRate() {
long total = getTotalInstCnt();
double missingInstCnt = missCounter;
return ((total != 0) ? missingInstCnt / total : 0.0);
}
@Override
public long getTotalInstCnt() {
long total = 0;
for(Long val: categoryMap.values()) {
total += val;
}
return total + missCounter;
}
}