/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.pmerienne.trident.ml.evaluation; import java.util.Objects; import storm.trident.operation.CombinerAggregator; import storm.trident.tuple.TridentTuple; import com.github.pmerienne.trident.ml.core.Instance; import com.github.pmerienne.trident.ml.evaluation.AccuracyAggregator.AccuracyState; @SuppressWarnings("unchecked") public class AccuracyAggregator<L> implements CombinerAggregator<AccuracyState<L>> { private static final long serialVersionUID = 1136784137149485843L; @Override public AccuracyState<L> init(TridentTuple tuple) { Instance<L> instance = (Instance<L>) tuple.getValue(0); L prediction = (L) tuple.getValue(1); L expected = instance.getLabel(); boolean equals = Objects.equals(expected, prediction); AccuracyState<L> state = new AccuracyState<L>(1, equals ? 0 : 1); return state; } @Override public AccuracyState<L> combine(AccuracyState<L> val1, AccuracyState<L> val2) { return new AccuracyState<L>(val1.totalCount + val2.totalCount, val1.errorCount + val2.errorCount); } @Override public AccuracyState<L> zero() { return new AccuracyState<L>(); } public static class AccuracyState<L> implements Evaluation<L> { private static final long serialVersionUID = 938679193655075913L; private final Long totalCount; private final Long errorCount; public AccuracyState() { this.totalCount = 0L; this.errorCount = 0L; } public AccuracyState(long totalCount, long errorCount) { this.totalCount = totalCount; this.errorCount = errorCount; } @Override public double getEvaluation() { return 1 - errorCount.doubleValue() / totalCount.doubleValue(); } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((errorCount == null) ? 0 : errorCount.hashCode()); result = prime * result + ((totalCount == null) ? 0 : totalCount.hashCode()); return result; } @SuppressWarnings("rawtypes") @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; AccuracyState other = (AccuracyState) obj; if (errorCount == null) { if (other.errorCount != null) return false; } else if (!errorCount.equals(other.errorCount)) return false; if (totalCount == null) { if (other.totalCount != null) return false; } else if (!totalCount.equals(other.totalCount)) return false; return true; } @Override public String toString() { return "AccuracyState [totalCount=" + totalCount + ", errorCount=" + errorCount + "]"; } } }