/*
* Copyright (c) 2014 Villu Ruusmann
*
* This file is part of Openscoring
*
* Openscoring is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Openscoring is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Openscoring. If not, see <http://www.gnu.org/licenses/>.
*/
package org.openscoring.service;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Application;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import com.google.common.collect.Maps;
import org.dmg.pmml.FieldName;
import org.glassfish.jersey.client.ClientConfig;
import org.glassfish.jersey.media.multipart.FormDataBodyPart;
import org.glassfish.jersey.media.multipart.FormDataMultiPart;
import org.glassfish.jersey.media.multipart.MultiPartFeature;
import org.glassfish.jersey.test.JerseyTest;
import org.junit.Test;
import org.openscoring.common.BatchEvaluationRequest;
import org.openscoring.common.BatchEvaluationResponse;
import org.openscoring.common.EvaluationRequest;
import org.openscoring.common.EvaluationResponse;
import org.openscoring.common.ModelResponse;
import org.openscoring.common.SimpleResponse;
import org.supercsv.prefs.CsvPreference;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class ModelResourceTest extends JerseyTest {
@Override
protected Application configure(){
Openscoring openscoring = new Openscoring();
return openscoring;
}
@Override
protected void configureClient(ClientConfig clientConfig){
clientConfig.register(MultiPartFeature.class);
// Ideally, should use the client-side ObjectMapperProvider class instead of the server-side one
clientConfig.register(ObjectMapperProvider.class);
}
@Test
public void decisionTreeIris() throws Exception {
String id = "DecisionTreeIris";
assertEquals("Iris", extractSuffix(id));
deploy(id);
download(id);
List<EvaluationRequest> records = loadRecords(id);
EvaluationRequest request = records.get(0);
EvaluationResponse response = evaluate(id, request);
List<EvaluationRequest> requests = Arrays.asList(records.get(0), invalidate(records.get(50)), records.get(100));
BatchEvaluationRequest batchRequest = new BatchEvaluationRequest();
batchRequest.setRequests(requests);
BatchEvaluationResponse batchResponse = evaluateBatch(id, batchRequest);
assertEquals(batchRequest.getId(), batchResponse.getId());
List<EvaluationResponse> responses = batchResponse.getResponses();
assertEquals(requests.size(), responses.size());
EvaluationRequest invalidRequest = requests.get(1);
EvaluationResponse invalidResponse = responses.get(1);
assertEquals(invalidRequest.getId(), invalidResponse.getId());
assertNotNull(invalidResponse.getMessage());
undeploy(id);
}
@Test
public void associationRulesShopping() throws Exception {
String id = "AssociationRulesShopping";
assertEquals("Shopping", extractSuffix(id));
deployForm(id);
List<EvaluationRequest> records = loadRecords(id);
BatchEvaluationRequest batchRequest = new BatchEvaluationRequest();
batchRequest.setRequests(records);
BatchEvaluationResponse batchResponse = evaluateBatch(id, batchRequest);
List<EvaluationRequest> aggregatedRecords = ModelResource.aggregateRequests(FieldName.create("transaction"), records);
batchRequest = new BatchEvaluationRequest("aggregate");
batchRequest.setRequests(aggregatedRecords);
batchResponse = evaluateBatch(id, batchRequest);
assertEquals(batchRequest.getId(), batchResponse.getId());
evaluateCsv(id);
evaluateCsvForm(id);
undeployForm(id);
}
private ModelResponse deploy(String id) throws IOException {
Response response;
try(InputStream is = openPMML(id)){
Entity<InputStream> entity = Entity.entity(is, MediaType.APPLICATION_XML);
response = target("model/" + id).request(MediaType.APPLICATION_JSON).put(entity);
}
assertEquals(201, response.getStatus());
return response.readEntity(ModelResponse.class);
}
private ModelResponse deployForm(String id) throws IOException {
Response response;
try(InputStream is = openPMML(id)){
FormDataMultiPart formData = new FormDataMultiPart();
formData.field("id", id);
formData.bodyPart(new FormDataBodyPart("pmml", is, MediaType.APPLICATION_XML_TYPE));
Entity<FormDataMultiPart> entity = Entity.entity(formData, MediaType.MULTIPART_FORM_DATA);
response = target("model").request(MediaType.APPLICATION_JSON).post(entity);
formData.close();
}
assertEquals(201, response.getStatus());
URI location = response.getLocation();
assertEquals("/model/" + id, location.getPath());
return response.readEntity(ModelResponse.class);
}
private Response download(String id){
Response response = target("model/" + id + "/pmml").request(MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML).get();
assertEquals(200, response.getStatus());
assertEquals(MediaType.APPLICATION_XML_TYPE.withCharset(CHARSET_UTF_8), response.getMediaType());
return response;
}
private EvaluationResponse evaluate(String id, EvaluationRequest request){
Entity<EvaluationRequest> entity = Entity.json(request);
Response response = target("model/" + id).request(MediaType.APPLICATION_JSON).post(entity);
assertEquals(200, response.getStatus());
return response.readEntity(EvaluationResponse.class);
}
private BatchEvaluationResponse evaluateBatch(String id, BatchEvaluationRequest batchRequest){
Entity<BatchEvaluationRequest> entity = Entity.json(batchRequest);
Response response = target("model/" + id + "/batch").request(MediaType.APPLICATION_JSON).post(entity);
assertEquals(200, response.getStatus());
return response.readEntity(BatchEvaluationResponse.class);
}
private Response evaluateCsv(String id) throws IOException {
Response response;
try(InputStream is = openCSV(id)){
Entity<InputStream> entity = Entity.entity(is, MediaType.TEXT_PLAIN_TYPE.withCharset(CHARSET_ISO_8859_1));
response = target("model/" + id + "/csv").queryParam("delimiterChar", "\\t").queryParam("quoteChar", "\\\"").request(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN).post(entity);
}
assertEquals(200, response.getStatus());
assertEquals(MediaType.TEXT_PLAIN_TYPE.withCharset(CHARSET_ISO_8859_1), response.getMediaType());
return response;
}
private Response evaluateCsvForm(String id) throws IOException {
Response response;
try(InputStream is = openCSV(id)){
FormDataMultiPart formData = new FormDataMultiPart();
formData.bodyPart(new FormDataBodyPart("csv", is, MediaType.TEXT_PLAIN_TYPE));
Entity<FormDataMultiPart> entity = Entity.entity(formData, MediaType.MULTIPART_FORM_DATA);
response = target("model/" + id + "/csv").request(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN).post(entity);
formData.close();
}
assertEquals(200, response.getStatus());
assertEquals(MediaType.TEXT_PLAIN_TYPE.withCharset(CHARSET_UTF_8), response.getMediaType());
return response;
}
private SimpleResponse undeploy(String id){
Response response = target("model/" + id).request(MediaType.APPLICATION_JSON).delete();
assertEquals(200, response.getStatus());
return response.readEntity(SimpleResponse.class);
}
private SimpleResponse undeployForm(String id){
Response response = target("model/" + id).request(MediaType.APPLICATION_JSON).header("X-HTTP-Method-Override", "DELETE").post(null);
assertEquals(200, response.getStatus());
return response.readEntity(SimpleResponse.class);
}
static
private EvaluationRequest invalidate(EvaluationRequest record){
Maps.EntryTransformer<String, Object, String> transformer = new Maps.EntryTransformer<String, Object, String>(){
@Override
public String transformEntry(String key, Object value){
StringBuilder sb = new StringBuilder(key);
sb.reverse();
return sb.toString();
}
};
EvaluationRequest invalidRecord = new EvaluationRequest(record.getId());
invalidRecord.setArguments(Maps.transformEntries(record.getArguments(), transformer));
return invalidRecord;
}
static
private List<EvaluationRequest> loadRecords(String id) throws Exception {
try(InputStream is = openCSV(id)){
CsvUtil.Table<EvaluationRequest> table;
try(BufferedReader reader = new BufferedReader(new InputStreamReader(is, "UTF-8"))){
table = CsvUtil.readTable(reader, CsvPreference.TAB_PREFERENCE);
}
return table.getRows();
}
}
static
private InputStream openPMML(String id){
return ModelResourceTest.class.getResourceAsStream("/pmml/" + id + ".pmml");
}
static
private InputStream openCSV(String id){
return ModelResourceTest.class.getResourceAsStream("/csv/" + extractSuffix(id) + ".csv");
}
static
private String extractSuffix(String id){
for(int i = id.length() - 1; i > -1; i--){
char c = id.charAt(i);
if(Character.isUpperCase(c)){
return id.substring(i);
}
}
throw new IllegalArgumentException();
}
private static final String CHARSET_UTF_8 = "UTF-8";
private static final String CHARSET_ISO_8859_1 = "ISO-8859-1";
}