/* * Seldon -- open source prediction engine * ======================================= * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * ********************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ********************************************************************************************** */ package io.seldon.vw; import java.util.ArrayList; import java.util.List; import java.util.Map; import org.apache.log4j.Logger; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import com.fasterxml.jackson.databind.JsonNode; import io.seldon.api.rpc.ClassificationReply; import io.seldon.api.rpc.ClassificationRequest; import io.seldon.clustering.recommender.RecommendationContext.OptionsHolder; import io.seldon.prediction.PredictionAlgorithm; import io.seldon.prediction.PredictionResult; import io.seldon.prediction.PredictionServiceResult; import io.seldon.vw.VwFeatureExtractor.Namespace; import io.seldon.vw.VwModelManager.VwModel; @Component public class VwClassifier implements PredictionAlgorithm { private static Logger logger = Logger.getLogger(VwClassifier.class.getName()); VwModelManager modelManager; VwFeatureExtractor featureExtractor; @Autowired public VwClassifier(VwModelManager modelManager,VwFeatureExtractor featureExtractor) { this.modelManager = modelManager; this.featureExtractor = featureExtractor; } private double sigmoid(double val) { return 1.0/(1.0 + Math.exp(-val)); } private List<PredictionResult> normalise(List<PredictionResult> scores) { double sum = 0.0; for(PredictionResult p : scores) sum = sum + p.confidence; for(PredictionResult p : scores) p.confidence = p.confidence / sum; return scores; } @Override public PredictionServiceResult predictFromJSON(String client, JsonNode jsonNode,OptionsHolder options) { VwModel model = modelManager.getClientStore(client,options); if (model == null) { logger.warn("No model found for client"+client); return null; } else { List<Namespace> namespaces = featureExtractor.extract(jsonNode); List<PredictionResult> predictions = new ArrayList<PredictionResult>(); for(int i=0;i<model.oaa;i++) { float score = 0; for(Namespace n : namespaces) { for(Map.Entry<String, Float> e : n.features.entrySet()) { Integer index = model.hasher.getFeatureHash(i+1, n.name, e.getKey()); Float weight = model.weights.get(index); if (weight != null) score = score + (e.getValue() * weight); } } Integer constantIndex = model.hasher.getConstantHash(i+1); Float weight = model.weights.get(constantIndex); if (weight != null) score = score + weight; String classId = ""+(i+1); if (model.classIdMap.containsKey(i+1)) classId = model.classIdMap.get(i+1); predictions.add(new PredictionResult((double)score, classId, sigmoid(score))); } //aribrary decision point at 0.0 for binary classification if (model.oaa == 1 && predictions.get(0).prediction < 0) if (model.classIdMap.containsKey(-1)) predictions.get(0).predictedClass = model.classIdMap.get(-1); else predictions.get(0).predictedClass = "-1"; return new PredictionServiceResult(null,normalise(predictions),null); } } @Override public ClassificationReply predictFromProto(String client, ClassificationRequest request, OptionsHolder options) { // TODO Auto-generated method stub return null; } }