/*
* Copyright (c) 2013 Villu Ruusmann
*
* This file is part of Openscoring
*
* Openscoring is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Openscoring is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Openscoring. If not, see <http://www.gnu.org/licenses/>.
*/
package org.openscoring.service;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.security.PermitAll;
import javax.annotation.security.RolesAllowed;
import javax.inject.Inject;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.InternalServerErrorException;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.StreamingOutput;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo;
import javax.xml.bind.JAXBException;
import com.codahale.metrics.Counter;
import com.codahale.metrics.Metric;
import com.codahale.metrics.MetricFilter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import org.dmg.pmml.FieldName;
import org.glassfish.jersey.media.multipart.FormDataParam;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.HasGroupFields;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.openscoring.common.BatchEvaluationRequest;
import org.openscoring.common.BatchEvaluationResponse;
import org.openscoring.common.BatchModelResponse;
import org.openscoring.common.EvaluationRequest;
import org.openscoring.common.EvaluationResponse;
import org.openscoring.common.ModelResponse;
import org.openscoring.common.SimpleResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.supercsv.prefs.CsvPreference;
@Path("model")
@PermitAll
public class ModelResource {
@Context
private UriInfo uriInfo = null;
private ModelRegistry modelRegistry = null;
private MetricRegistry metricRegistry = null;
@Inject
public ModelResource(ModelRegistry modelRegistry, MetricRegistry metricRegistry){
this.modelRegistry = modelRegistry;
this.metricRegistry = metricRegistry;
}
@GET
@Produces(MediaType.APPLICATION_JSON)
public BatchModelResponse queryBatch(){
BatchModelResponse batchResponse = new BatchModelResponse();
List<ModelResponse> responses = new ArrayList<>();
Collection<Map.Entry<String, Model>> entries = this.modelRegistry.entries();
for(Map.Entry<String, Model> entry : entries){
ModelResponse response = createModelResponse(entry.getKey(), entry.getValue(), false);
responses.add(response);
}
Comparator<ModelResponse> comparator = new Comparator<ModelResponse>(){
@Override
public int compare(ModelResponse left, ModelResponse right){
return (left.getId()).compareToIgnoreCase(right.getId());
}
};
Collections.sort(responses, comparator);
batchResponse.setResponses(responses);
return batchResponse;
}
@GET
@Path("{id:" + ModelRegistry.ID_REGEX + "}")
@Produces(MediaType.APPLICATION_JSON)
public ModelResponse query(@PathParam("id") String id){
Model model = this.modelRegistry.get(id);
if(model == null){
throw new NotFoundException("Model \"" + id + "\" not found");
}
return createModelResponse(id, model, true);
}
@PUT
@Path("{id:" + ModelRegistry.ID_REGEX + "}")
@RolesAllowed (
value = {"admin"}
)
@Consumes({MediaType.APPLICATION_XML, MediaType.TEXT_XML})
@Produces(MediaType.APPLICATION_JSON)
public Response deploy(@PathParam("id") String id, InputStream is){
return doDeploy(id, is);
}
@POST
@RolesAllowed (
value = {"admin"}
)
@Consumes(MediaType.MULTIPART_FORM_DATA)
@Produces(MediaType.APPLICATION_JSON)
public Response deployForm(@FormDataParam("id") String id, @FormDataParam("pmml") InputStream is){
if(!ModelRegistry.validateId(id)){
throw new BadRequestException("Invalid identifier");
}
return doDeploy(id, is);
}
private Response doDeploy(String id, InputStream is){
Model model;
try {
model = this.modelRegistry.load(is);
} catch(Exception e){
logger.error("Failed to load PMML document", e);
throw new BadRequestException(e);
}
boolean success;
Model oldModel = this.modelRegistry.get(id);
if(oldModel != null){
success = this.modelRegistry.replace(id, oldModel, model);
} else
{
success = this.modelRegistry.put(id, model);
} // End if
if(!success){
throw new InternalServerErrorException("Concurrent modification");
}
ModelResponse entity = createModelResponse(id, model, true);
if(oldModel != null){
return (Response.ok().entity(entity)).build();
} else
{
UriBuilder uriBuilder = (this.uriInfo.getBaseUriBuilder()).path(ModelResource.class).path(id);
URI uri = uriBuilder.build();
return (Response.created(uri).entity(entity)).build();
}
}
@GET
@Path("{id:" + ModelRegistry.ID_REGEX + "}/pmml")
@RolesAllowed (
value = {"admin"}
)
@Produces({MediaType.APPLICATION_JSON, MediaType.TEXT_XML})
public Response download(@PathParam("id") String id){
final
Model model = this.modelRegistry.get(id, true);
if(model == null){
throw new NotFoundException("Model \"" + id + "\" not found");
}
StreamingOutput entity = new StreamingOutput(){
@Override
public void write(OutputStream os) throws IOException {
BufferedOutputStream bufferedOs = new BufferedOutputStream(os){
@Override
public void close() throws IOException {
flush();
// The closing of the underlying java.io.OutputStream is handled elsewhere
}
};
try {
ModelResource.this.modelRegistry.store(model, bufferedOs);
} catch(JAXBException je){
throw new InternalServerErrorException(je);
} finally {
bufferedOs.close();
}
}
};
return (Response.ok().entity(entity))
.type(MediaType.APPLICATION_XML_TYPE.withCharset(ModelResource.CHARSET_UTF8.name()))
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + id + ".pmml.xml") // XXX
.build();
}
@POST
@Path("{id:" + ModelRegistry.ID_REGEX + "}")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public EvaluationResponse evaluate(@PathParam("id") String id, EvaluationRequest request){
List<EvaluationRequest> requests = Collections.singletonList(request);
List<EvaluationResponse> responses = doEvaluate(id, requests, true, "evaluate");
return responses.get(0);
}
@POST
@Path("{id: " + ModelRegistry.ID_REGEX + "}/batch")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public BatchEvaluationResponse evaluateBatch(@PathParam("id") String id, BatchEvaluationRequest request){
BatchEvaluationResponse batchResponse = new BatchEvaluationResponse(request.getId());
List<EvaluationRequest> requests = request.getRequests();
List<EvaluationResponse> responses = doEvaluate(id, requests, false, "evaluate.batch");
batchResponse.setResponses(responses);
return batchResponse;
}
@POST
@Path("{id:" + ModelRegistry.ID_REGEX + "}/csv")
@Consumes(MediaType.TEXT_PLAIN)
@Produces({MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN})
public Response evaluateCsv(@PathParam("id") String id, @QueryParam("delimiterChar") String delimiterChar, @QueryParam("quoteChar") String quoteChar, @HeaderParam(HttpHeaders.CONTENT_TYPE) String contentType, InputStream is){
com.google.common.net.MediaType mediaType = com.google.common.net.MediaType.parse(contentType);
Charset charset = (mediaType.charset()).or(ModelResource.CHARSET_UTF8);
return doEvaluateCsv(id, delimiterChar, quoteChar, charset, is);
}
@POST
@Path("{id:" + ModelRegistry.ID_REGEX + "}/csv")
@Consumes(MediaType.MULTIPART_FORM_DATA)
@Produces({MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN})
public Response evaluateCsvForm(@PathParam("id") String id, @QueryParam("delimiterChar") String delimiterChar, @QueryParam("quoteChar") String quoteChar, @FormDataParam("csv") InputStream is){
Charset charset = ModelResource.CHARSET_UTF8;
return doEvaluateCsv(id, delimiterChar, quoteChar, charset, is);
}
private Response doEvaluateCsv(String id, String delimiterChar, String quoteChar, final Charset charset, InputStream is){
final
CsvPreference format;
final
CsvUtil.Table<EvaluationRequest> requestTable;
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(is, charset)){
@Override
public void close(){
// The closing of the underlying java.io.InputStream is handled elsewhere
}
};
try {
if(delimiterChar != null){
format = CsvUtil.getFormat(delimiterChar, quoteChar);
} else
{
format = CsvUtil.getFormat(reader);
}
requestTable = CsvUtil.readTable(reader, format);
} finally {
reader.close();
}
} catch(Exception e){
logger.error("Failed to load CSV document", e);
throw new BadRequestException(e);
}
List<EvaluationRequest> requests = requestTable.getRows();
List<EvaluationResponse> responses = doEvaluate(id, requests, true, "evaluate.csv");
final
CsvUtil.Table<EvaluationResponse> responseTable = new CsvUtil.Table<>();
responseTable.setId(requestTable.getId());
responseTable.setRows(responses);
StreamingOutput entity = new StreamingOutput(){
@Override
public void write(OutputStream os) throws IOException {
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(os, charset)){
@Override
public void close() throws IOException {
flush();
// The closing of the underlying java.io.OutputStream is handled elsewhere
}
};
try {
CsvUtil.writeTable(writer, format, responseTable);
} finally {
writer.close();
}
}
};
return (Response.ok().entity(entity))
.type(MediaType.TEXT_PLAIN_TYPE.withCharset(charset.name()))
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + id + ".csv") // XXX
.build();
}
@SuppressWarnings (
value = "resource"
)
private List<EvaluationResponse> doEvaluate(String id, List<EvaluationRequest> requests, boolean allOrNothing, String method){
Model model = this.modelRegistry.get(id, true);
if(model == null){
throw new NotFoundException("Model \"" + id + "\" not found");
}
List<EvaluationResponse> responses = new ArrayList<>();
Timer timer = this.metricRegistry.timer(createName(id, method));
Timer.Context context = timer.time();
try {
ModelEvaluator<?> evaluator = model.getEvaluator();
if(evaluator instanceof HasGroupFields){
HasGroupFields hasGroupFields = (HasGroupFields)evaluator;
List<InputField> groupFields = hasGroupFields.getGroupFields();
if(groupFields.size() == 1){
InputField groupField = groupFields.get(0);
requests = aggregateRequests(groupField.getName(), requests);
} else
if(groupFields.size() > 1){
throw new EvaluationException("Too many group fields");
}
}
for(EvaluationRequest request : requests){
EvaluationResponse response;
try {
response = evaluate(evaluator, request);
} catch(Exception e){
if(allOrNothing){
throw e;
}
response = new EvaluationResponse(request.getId());
response.setMessage(e.toString());
}
responses.add(response);
}
} catch(Exception e){
logger.error("Failed to evaluate", e);
throw new BadRequestException(e);
}
context.stop();
Counter counter = this.metricRegistry.counter(createName(id, "records"));
counter.inc(responses.size());
return responses;
}
@DELETE
@Path("{id:" + ModelRegistry.ID_REGEX + "}")
@RolesAllowed (
value = {"admin"}
)
@Produces(MediaType.APPLICATION_JSON)
public SimpleResponse undeploy(@PathParam("id") String id){
Model model = this.modelRegistry.get(id);
if(model == null){
throw new NotFoundException("Model \"" + id + "\" not found");
}
boolean success = this.modelRegistry.remove(id, model);
if(!success){
throw new InternalServerErrorException("Concurrent modification");
}
final
String prefix = createNamePrefix(id);
MetricFilter filter = new MetricFilter(){
@Override
public boolean matches(String name, Metric metric){
return name.startsWith(prefix);
}
};
this.metricRegistry.removeMatching(filter);
SimpleResponse response = new SimpleResponse();
return response;
}
static
protected String createName(String... strings){
return MetricRegistry.name(ModelResource.class, strings);
}
static
protected String createNamePrefix(String... strings){
return createName(strings) + ".";
}
static
protected List<EvaluationRequest> aggregateRequests(FieldName groupName, List<EvaluationRequest> requests){
Map<Object, ListMultimap<String, Object>> groupedArguments = new LinkedHashMap<>();
String key = groupName.getValue();
for(EvaluationRequest request : requests){
Map<String, ?> requestArguments = request.getArguments();
Object value = requestArguments.get(key);
if(value == null && !requestArguments.containsKey(key)){
logger.warn("Evaluation request {} does not specify a group field {}", request.getId(), key);
}
ListMultimap<String, Object> groupedArgumentMap = groupedArguments.get(value);
if(groupedArgumentMap == null){
groupedArgumentMap = ArrayListMultimap.create();
groupedArguments.put(value, groupedArgumentMap);
}
Collection<? extends Map.Entry<String, ?>> entries = requestArguments.entrySet();
for(Map.Entry<String, ?> entry : entries){
groupedArgumentMap.put(entry.getKey(), entry.getValue());
}
}
// Only continue with request modification if there is a clear need to do so
if(groupedArguments.size() == requests.size()){
return requests;
}
List<EvaluationRequest> resultRequests = new ArrayList<>();
Collection<Map.Entry<Object, ListMultimap<String, Object>>> entries = groupedArguments.entrySet();
for(Map.Entry<Object, ListMultimap<String, Object>> entry : entries){
Map<String, Object> arguments = new LinkedHashMap<>();
arguments.putAll((entry.getValue()).asMap());
// The value of the "group by" column is a single Object, not a Collection (ie. java.util.List) of Objects
arguments.put(key, entry.getKey());
EvaluationRequest resultRequest = new EvaluationRequest();
resultRequest.setArguments(arguments);
resultRequests.add(resultRequest);
}
return resultRequests;
}
static
protected EvaluationResponse evaluate(Evaluator evaluator, EvaluationRequest request){
logger.info("Received {}", request);
Map<String, ?> requestArguments = request.getArguments();
EvaluationResponse response = new EvaluationResponse(request.getId());
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<InputField> activeFields = evaluator.getActiveFields();
for(InputField activeField : activeFields){
FieldName activeName = activeField.getName();
String key = activeName.getValue();
Object value = requestArguments.get(key);
if(value == null && !requestArguments.containsKey(key)){
logger.warn("Evaluation request {} does not specify an active field {}", request.getId(), key);
}
FieldValue activeValue = activeField.prepare(value);
arguments.put(activeName, activeValue);
}
logger.debug("Evaluation request {} has prepared arguments: {}", request.getId(), arguments);
Map<FieldName, ?> result = evaluator.evaluate(arguments);
// Jackson does not support the JSON serialization of <code>null</code> map keys
result = replaceNullKey(result);
logger.debug("Evaluation response {} has result: {}", response.getId(), result);
response.setResult(EvaluatorUtil.decode(result));
logger.info("Returned {}", response);
return response;
}
static
private <V> Map<FieldName, V> replaceNullKey(Map<FieldName, V> map){
if(map.containsKey(null)){
Map<FieldName, V> result = new LinkedHashMap<>(map);
result.put(ModelResource.DEFAULT_NAME, result.remove(null));
return result;
}
return map;
}
static
private ModelResponse createModelResponse(String id, Model model, boolean expand){
ModelResponse response = new ModelResponse(id);
response.setMiningFunction(model.getMiningFunction());
response.setSummary(model.getSummary());
response.setProperties(model.getProperties());
if(expand){
response.setSchema(model.getSchema());
}
return response;
}
public static final FieldName DEFAULT_NAME = FieldName.create("_default");
private static final Charset CHARSET_UTF8 = Charset.forName("UTF-8");
private static final Logger logger = LoggerFactory.getLogger(ModelResource.class);
}