/*
* Seldon -- open source prediction engine
* =======================================
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
**********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**********************************************************************************************
*/
package io.seldon.api.resource.service.business;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.seldon.api.APIException;
import io.seldon.api.Constants;
import io.seldon.api.logging.EventLogger;
import io.seldon.api.logging.PredictLogger;
import io.seldon.api.resource.ConsumerBean;
import io.seldon.api.resource.ErrorBean;
import io.seldon.api.resource.EventBean;
import io.seldon.api.resource.ResourceBean;
import io.seldon.api.service.ApiLoggerServer;
import io.seldon.prediction.PredictionService;
import io.seldon.prediction.PredictionServiceResult;
@Component
public class PredictionBusinessServiceImpl implements PredictionBusinessService {
private static Logger logger = Logger.getLogger(PredictionBusinessServiceImpl.class.getName());
private static final String JSON_KEY = "json";
private static final String CLIENT_KEY = "client";
private static final String TIMESTAMP_KEY = "timestamp";
public static final String PUID_KEY = "puid";
public static final String REQUEST_CUSTOM_DATA_FIELD = "data";
public static final String REPLY_CUSTOM_DATA_FIELD = "custom";
@Autowired
PredictionService predictionService;
@Autowired
PredictLogger predictLogger;
private boolean allowedKey(String key)
{
return (!(Constants.CONSUMER_KEY.equals(key) ||
Constants.CONSUMER_SECRET.equals(key) ||
Constants.OAUTH_TOKEN.equals(key) ||
CLIENT_KEY.equals(key) ||
PUID_KEY.equals(key) ||
"jsonpCallback".equals(key)));
}
private Long getTimeStamp()
{
return System.currentTimeMillis()/1000;
}
private JsonNode getValidatedJson(ConsumerBean consumer,String jsonRaw,boolean addExtraFeatures) throws JsonParseException, IOException
{
ObjectMapper mapper = new ObjectMapper();
JsonFactory factory = mapper.getFactory();
JsonParser parser = factory.createParser(jsonRaw);
JsonNode actualObj = mapper.readTree(parser);
if (addExtraFeatures) // only for events not at predict time
{
((ObjectNode) actualObj).put(CLIENT_KEY,consumer.getShort_name());
if (actualObj.get(TIMESTAMP_KEY) == null)
{
((ObjectNode) actualObj).put(TIMESTAMP_KEY,getTimeStamp());
}
else
{
JsonNode timeNode = actualObj.get(TIMESTAMP_KEY);
if (!(timeNode.isInt() || timeNode.isLong()))
{
throw new APIException(APIException.INVALID_JSON);
}
}
}
return actualObj;
}
private ResourceBean getValidatedJsonResource(ConsumerBean consumer,String jsonRaw)
{
ResourceBean responseBean;
try
{
JsonNode jsonNode = getValidatedJson(consumer, jsonRaw,true);
String json = jsonNode.toString();
EventLogger.log(json);
responseBean = new EventBean(json);
}
catch (IOException e)
{
ApiLoggerServer.log(this, e);
APIException apiEx = new APIException(APIException.INVALID_JSON);
responseBean = new ErrorBean(apiEx);
}
catch (APIException e)
{
ApiLoggerServer.log(this, e);
responseBean = new ErrorBean(e);
}
return responseBean;
}
@Override
public ResourceBean addEvent(ConsumerBean consumerBean,Map<String, String[]> parameters) {
ResourceBean responseBean;
if (parameters.containsKey(JSON_KEY))
{
String jsonRaw = parameters.get(JSON_KEY)[0];
responseBean = getValidatedJsonResource(consumerBean,jsonRaw);
}
else
{
Map<String,Object> keyVals = new HashMap<String,Object>();
for(Map.Entry<String, String[]> reqMapEntry : parameters.entrySet())
{
if (reqMapEntry.getValue().length == 1 && allowedKey(reqMapEntry.getKey()))
{
keyVals.put(reqMapEntry.getKey(), reqMapEntry.getValue()[0]);
}
}
keyVals.put(CLIENT_KEY, consumerBean.getShort_name());
if (!keyVals.containsKey(TIMESTAMP_KEY))
keyVals.put(TIMESTAMP_KEY, getTimeStamp());
else
{
try
{
Long.parseLong(((String)keyVals.get(TIMESTAMP_KEY)));
}
catch (NumberFormatException e)
{
ApiLoggerServer.log(this, e);
APIException apiEx = new APIException(APIException.INVALID_JSON);
responseBean = new ErrorBean(apiEx);
return responseBean;
}
}
ObjectMapper mapper = new ObjectMapper();
try {
String json = mapper.writeValueAsString(keyVals);
EventLogger.log(json);
responseBean = new EventBean(json);
} catch (IOException e) {
ApiLoggerServer.log(this, e);
APIException apiEx = new APIException(APIException.INVALID_JSON);
responseBean = new ErrorBean(apiEx);
}
}
return responseBean;
}
@Override
public ResourceBean addEvent(ConsumerBean consumerBean, String event) {
return getValidatedJsonResource(consumerBean,event);
}
@Override
public PredictionServiceResult predict(ConsumerBean consumer, String puid, String jsonRaw) {
try
{
logger.info("Json raw "+jsonRaw);
JsonNode jsonNode = getValidatedJson(consumer, jsonRaw, false); // used to check valid json but we don't use result
PredictionServiceResult res = predictionService.predict(consumer.getShort_name(), puid, jsonNode);
predictLogger.log(consumer.getShort_name(), jsonNode, res);
return res;
}
catch (IOException e)
{
ApiLoggerServer.log(this, e);
return new PredictionServiceResult();
}
catch (APIException e)
{
return new PredictionServiceResult();
}
}
@Override
public PredictionServiceResult predict(ConsumerBean consumerBean, Map<String, String[]> parameters) {
String puid = null;
if (parameters.containsKey(PUID_KEY))
puid = parameters.get(PUID_KEY)[0];
if (parameters.containsKey(JSON_KEY))
{
String jsonRaw = parameters.get(JSON_KEY)[0];
return predict(consumerBean, puid, jsonRaw);
}
else
{
Map<String,Object> keyVals = new HashMap<String,Object>();
for(Map.Entry<String, String[]> reqMapEntry : parameters.entrySet())
{
if (reqMapEntry.getValue().length == 1 && allowedKey(reqMapEntry.getKey()))
{
try
{
keyVals.put(reqMapEntry.getKey(), Float.parseFloat(reqMapEntry.getValue()[0]));
}
catch (Exception e)
{
keyVals.put(reqMapEntry.getKey(), reqMapEntry.getValue()[0]);
}
}
}
ObjectMapper mapper = new ObjectMapper();
try {
String jsonRaw = mapper.writeValueAsString(keyVals);
return predict(consumerBean, puid, jsonRaw);
} catch (IOException e)
{
ApiLoggerServer.log(this, e);
/*
APIException apiEx = new APIException(APIException.INVALID_JSON);
ResourceBean responseBean = new ErrorBean(apiEx);
JsonNode response = mapper.valueToTree(responseBean);
return response;
*/
return new PredictionServiceResult();
}
}
}
}