/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.smile.tools;
import hivemall.utils.collections.IntArrayList;
import hivemall.utils.lang.Counter;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
@SuppressWarnings("deprecation")
@Description(name = "rf_ensemble",
value = "_FUNC_(int y) - Returns emsebled prediction results of Random Forest classifiers")
public final class RandomForestEnsembleUDAF extends UDAF {
public static class RandomForestPredictUDAFEvaluator implements UDAFEvaluator {
private Counter<Integer> partial;
@Override
public void init() {
this.partial = null;
}
public boolean iterate(Integer k) {
if (k == null) {
return true;
}
if (partial == null) {
this.partial = new Counter<Integer>();
}
partial.increment(k);
return true;
}
/*
* https://cwiki.apache.org/confluence/display/Hive/GenericUDAFCaseStudy#GenericUDAFCaseStudy-terminatePartial
*/
public Map<Integer, Integer> terminatePartial() {
if (partial == null) {
return null;
}
if (partial.size() == 0) {
return null;
} else {
return partial.getMap(); // CAN NOT return Counter here
}
}
public boolean merge(Map<Integer, Integer> o) {
if (o == null) {
return true;
}
if (partial == null) {
this.partial = new Counter<Integer>();
}
partial.addAll(o);
return true;
}
public Result terminate() {
if (partial == null) {
return null;
}
if (partial.size() == 0) {
return null;
}
return new Result(partial);
}
}
public static final class Result {
@SuppressWarnings("unused")
private Integer label;
@SuppressWarnings("unused")
private Double probability;
@SuppressWarnings("unused")
private List<Double> probabilities;
Result(Counter<Integer> partial) {
final Map<Integer, Integer> counts = partial.getMap();
int size = counts.size();
assert (size > 0) : size;
IntArrayList keyList = new IntArrayList(size);
long totalCnt = 0L;
Integer maxKey = null;
int maxCnt = Integer.MIN_VALUE;
for (Map.Entry<Integer, Integer> e : counts.entrySet()) {
Integer key = e.getKey();
keyList.add(key);
int cnt = e.getValue().intValue();
totalCnt += cnt;
if (cnt >= maxCnt) {
maxCnt = cnt;
maxKey = key;
}
}
int[] keyArray = keyList.toArray();
Arrays.sort(keyArray);
int last = keyArray[keyArray.length - 1];
double totalCnt_d = (double) totalCnt;
final Double[] probabilities = new Double[Math.max(2, last + 1)];
for (int i = 0, len = probabilities.length; i < len; i++) {
final Integer cnt = counts.get(Integer.valueOf(i));
if (cnt == null) {
probabilities[i] = Double.valueOf(0d);
} else {
probabilities[i] = Double.valueOf(cnt.intValue() / totalCnt_d);
}
}
this.label = maxKey;
this.probability = Double.valueOf(maxCnt / totalCnt_d);
this.probabilities = Arrays.asList(probabilities);
}
}
}