/**
*
*/
package qa.qcri.aidr.predict;
import static org.junit.Assert.assertEquals;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Properties;
import javax.ejb.EJB;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.StringUtils;
import org.apache.log4j.Logger;
import org.codehaus.jackson.map.JsonMappingException;
import org.codehaus.jackson.map.ObjectMapper;
import org.glassfish.jersey.jackson.JacksonFeature;
import org.json.JSONArray;
import org.json.JSONObject;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import qa.qcri.aidr.common.code.JacksonWrapper;
import qa.qcri.aidr.dbmanager.dto.CollectionDTO;
import qa.qcri.aidr.dbmanager.dto.CrisisTypeDTO;
import qa.qcri.aidr.dbmanager.dto.ModelFamilyDTO;
import qa.qcri.aidr.dbmanager.dto.NominalAttributeDTO;
import qa.qcri.aidr.dbmanager.dto.NominalLabelDTO;
import qa.qcri.aidr.dbmanager.dto.UsersDTO;
import qa.qcri.aidr.dbmanager.ejb.remote.facade.CollectionResourceFacade;
import qa.qcri.aidr.predict.TaggerTesterHelper.LabelCode;
import qa.qcri.aidr.predict.common.TaggerConfigurationProperty;
import qa.qcri.aidr.predict.common.TaggerConfigurator;
import qa.qcri.aidr.predict.util.ResponseWrapper;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPubSub;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
/**
* @author Kushal
*
*/
public class TaggerTesterTest {
private static Logger logger = Logger.getLogger(TaggerTesterTest.class.getName());
public static final String TAGGER_TESTER_CODE = "tagger_tester";
public static final String TAGGER_TESTER_USER = "Tagger Tester User";
public static final String TAGGER_TESTER_CRISIS_NAME = "Tagger Tester Crisis";
public static final String TAGGER_TESTER_CRISIS_CODE = "tagger_tester";
public static final String TAGGER_TESTER_NOMINAL_ATTRIBUTE_NAME = "Tagger Tester Classifier";
public static final String TAGGER_TESTER_NOMINAL_ATTRIBUTE_CODE = "tagger_tester_classifier";
public static final String TAGGER_TESTER_NOMINAL_ATTRIBUTE_DESC = "Tagger Tester Classifier Desc";
public static final String remoteEJBJNDIName = TaggerConfigurator
.getInstance().getProperty(
TaggerConfigurationProperty.REMOTE_TASK_MANAGER_JNDI_NAME);
private static TaggerConfigurator taggerConfig = (TaggerConfigurator) TaggerConfigurator.getInstance();
private Long nominalAttributeId;
private WebTarget webResource;
private String jsonResponse;
private Response response;
private ObjectMapper objectMapper;
private Client client;
private Integer crisisID;
private Long userID;
private Long modelFamilyID;
private int whiteClassifiedCount, blackClassifiedCount;
private Boolean quiet;
private int itemsToTrain;
private int itemsToTest;
private TaggerSubscriber taggerSubscriber;
@EJB
private CollectionResourceFacade crisisResourceFacade;
@Before
public void setUp() {
objectMapper = JacksonWrapper.getObjectMapper();
client = ClientBuilder.newBuilder().register(JacksonFeature.class).build();
itemsToTrain = Integer.parseInt(System.getProperty("nitems-train"));
itemsToTest = Integer.parseInt(System.getProperty("nitems-test"));
quiet = Boolean.parseBoolean(System.getProperty("quiet"));
String config = System.getProperty("config");
if(StringUtils.isNotEmpty(config)){
try (InputStream input = new FileInputStream(config);){
Properties properties = new Properties();
properties.load(input);
for (Object property : properties.keySet()) {
taggerConfig.setProperty(property.toString(), properties.get(property).toString());
}
} catch (IOException e) {
logger.error("Error in reading config properties file: " + config, e);
}
}
taggerSubscriber = new TaggerSubscriber();
}
@Test
public void testTagger() throws JsonMappingException, IOException {
/*
1. Make sure there is no data with code="tagger_tester"
in aidr-pridict database in case the tagger tester died abnormally in a previous run.
If there is data, write a warning message,
run the CLEANUP routine, and FAIL (forcing the user to run the tagger tester again)
*/
// fetch user
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/user/" + TAGGER_TESTER_USER);
response = webResource.request(MediaType.APPLICATION_JSON).get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
UsersDTO usersDTO = objectMapper.readValue(jsonResponse, UsersDTO.class);
if(usersDTO != null) {
userID = usersDTO.getUserID();
}
// fetch crisis
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API) + "/crisis/code/"
+ TAGGER_TESTER_CODE);
response = webResource.request(MediaType.APPLICATION_JSON).get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
HashMap<String,Object> result =
objectMapper.readValue(jsonResponse, HashMap.class);
if(result != null && result.get("crisisId") != null && result.get("crisisId") != new Integer(0)) {
crisisID = (Integer) result.get("crisisId");
}
// get attribute id
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/attribute/code/" + TAGGER_TESTER_NOMINAL_ATTRIBUTE_CODE);
response = webResource.request(MediaType.APPLICATION_JSON).get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
JsonParser jsonParser = new JsonParser();
JsonObject jsonObject = (JsonObject) jsonParser.parse(jsonResponse);
if( jsonObject != null && jsonObject.get("nominalAttributeID") != null ) {
nominalAttributeId = jsonObject.get("nominalAttributeID").getAsLong();
}
// fetch model family
if(crisisID != null && crisisID != 0) {
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/modelfamily/crisis/" + crisisID);
response = webResource.request(MediaType.APPLICATION_JSON).get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
ResponseWrapper responseWrapper = objectMapper.readValue(jsonResponse, ResponseWrapper.class);
if(responseWrapper != null && responseWrapper.getModelFamilies() != null && responseWrapper.getModelFamilies().length > 0) {
modelFamilyID = responseWrapper.getModelFamilies()[0].getModelFamilyId();
}
}
// If any data was found, run the cleanup code
if( userID != null || crisisID != null && crisisID != 0 || nominalAttributeId != null && nominalAttributeId != 0 || modelFamilyID != null ) {
Assert.fail("Tester data found in database. Clean up required");
}
//2. Create a test user Tagger Tester User
UsersDTO user = new UsersDTO();
user.setName(TAGGER_TESTER_USER);
user.setRole("normal");
// create user
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/user");
response = webResource.request(
MediaType.APPLICATION_JSON).post(Entity.json(user), Response.class);
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
// check for user created
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API) + "/user/"
+ TAGGER_TESTER_USER);
response = webResource.request(MediaType.APPLICATION_JSON).get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
UsersDTO userDTO = objectMapper.readValue(jsonResponse, UsersDTO.class);
if(userDTO == null || userDTO.getUserID() == null) {
Assert.fail("User not created with name : " + TAGGER_TESTER_USER);
}
userID = userDTO.getUserID();
//3. Create a collection (name="Tagger Tester Crisis", code="tagger_tester")
CollectionDTO crisis = new CollectionDTO();
crisis.setCode(TAGGER_TESTER_CRISIS_CODE);
crisis.setName(TAGGER_TESTER_CRISIS_NAME);
crisis.setCrisisTypeDTO(new CrisisTypeDTO(1100L, "Natural Hazard: Geophysical: Earthquake and/or Tsunami"));
crisis.setUsersDTO(userDTO);
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/crisis");
response = webResource.request(
MediaType.APPLICATION_JSON).post(Entity.json(crisis), Response.class);
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
assertEquals("SUCCESS", jsonResponse);
// check for crisis created
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API) + "/crisis/code/"
+ TAGGER_TESTER_CODE);
response = webResource.request(MediaType.APPLICATION_JSON).get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
HashMap<String,Object> map =
objectMapper.readValue(jsonResponse, HashMap.class);
if(map == null || map.get("crisisId") == null || map.get("crisisId") == new Integer(0)) {
Assert.fail("Crisis not created with code : " + TAGGER_TESTER_CRISIS_CODE);
}
crisisID = (Integer) map.get("crisisId");
//4. Create a classifier using the following steps:
//a. Create an attribute (name="tagger_tester_classifier")
NominalAttributeDTO attributeDTO = new NominalAttributeDTO();
attributeDTO.setUsersDTO(userDTO);
attributeDTO.setName(TAGGER_TESTER_NOMINAL_ATTRIBUTE_NAME);
attributeDTO.setCode(TAGGER_TESTER_NOMINAL_ATTRIBUTE_CODE);
attributeDTO.setDescription(TAGGER_TESTER_NOMINAL_ATTRIBUTE_DESC);
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/attribute");
response = webResource.request(
MediaType.APPLICATION_JSON).post(Entity.json(attributeDTO), Response.class);
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
attributeDTO = objectMapper.readValue(jsonResponse, NominalAttributeDTO.class);
if(attributeDTO == null || attributeDTO.getNominalAttributeId() == null) {
Assert.fail("NominalAttribute not created with code : " + TAGGER_TESTER_NOMINAL_ATTRIBUTE_CODE);
} else {
nominalAttributeId = attributeDTO.getNominalAttributeId();
}
//b. Create three labels white, black, null
createNominalLabel(LabelCode.WHITE, attributeDTO);
createNominalLabel(LabelCode.BLACK, attributeDTO);
createNominalLabel(LabelCode.DOES_NOT_APPLY, attributeDTO);
//5. Create a ModelFamily
ModelFamilyDTO modelFamilyDTO = new ModelFamilyDTO();
CollectionDTO crisisDTO = new CollectionDTO();
crisisDTO.setCrisisID(new Long (crisisID.intValue()));
modelFamilyDTO.setCrisisDTO(crisisDTO);
modelFamilyDTO.setNominalAttributeDTO(attributeDTO);
modelFamilyDTO.setIsActive(true);
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/modelfamily");
response = webResource.request(
MediaType.APPLICATION_JSON).post(Entity.json(modelFamilyDTO), Response.class);
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
JsonParser parser = new JsonParser();
JsonObject json = (JsonObject) parser.parse(jsonResponse);
modelFamilyID = json.get("entityID").getAsLong();
if(modelFamilyID == -1) {
Assert.fail("Failed to create ModelFamily with crisis : " + crisisID + " attribute : " + attributeDTO.getNominalAttributeId());
}
// 6. Push training data to redis queue
crisisDTO.setCode(TAGGER_TESTER_CRISIS_CODE);
crisisDTO.setName(TAGGER_TESTER_CRISIS_NAME);
final TaggerTesterHelper helper = new TaggerTesterHelper(new Long(crisisID), userID, attributeDTO.getNominalAttributeId(), modelFamilyID, itemsToTrain, quiet);
helper.startPublishing(true, null); // NULL IS THERE FOR Labelcode as we need to push both white and black docs
// 7. tag training data set : human tagging
helper.tagDocuments();
final Jedis subscriberJedis = DataStore.getJedisConnection();
final String outputChannel = TaggerConfigurator.getInstance().getProperty(TaggerConfigurationProperty.REDIS_OUTPUT_CHANNEL_PREFIX) + "." + TAGGER_TESTER_CRISIS_CODE;
new Thread(new Runnable() {
@Override
public void run() {
try {
logger.info("Subscribing to Redis channel "+ outputChannel);
subscriberJedis.subscribe(taggerSubscriber, outputChannel);
} catch (Exception e) {
logger.error("Subscribing to Redis channel " + outputChannel + " failed.", e);
Assert.fail("Failed to subscribe to channel : " + outputChannel);
}
}
}).start();
//8. push white items - unlabeled
helper.setNItems(itemsToTest);
helper.startPublishing(false, LabelCode.WHITE);
if(whiteClassifiedCount < (int)(itemsToTest*(80/100.0f))) {
Assert.fail("Failed to tagged documents with label : white : " + whiteClassifiedCount);
}
// push black items -unlabeled
helper.setNItems(itemsToTest);
helper.startPublishing(false, LabelCode.BLACK);
if(blackClassifiedCount < (int)(itemsToTest*(80/100.0f))) {
Assert.fail("Failed to tagged documents with label : black : " + blackClassifiedCount);
}
}
@After
public void tearDown() {
// delete complete data related to modelFamily, model, modelNominalLabel
if(modelFamilyID != null) {
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/modelfamily/" + modelFamilyID);
response = webResource.request(MediaType.APPLICATION_JSON).delete();
}
// delete nominal attribute data : nominalAttribute, nominalLabel and documentNominalLabel
if(nominalAttributeId != null) {
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/attribute/" + nominalAttributeId);
response = webResource.request(MediaType.APPLICATION_JSON).delete();
}
if(crisisID != null && userID != null) {
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/document/delete/" + crisisID + "/" + userID);
response = webResource.request(MediaType.APPLICATION_JSON).delete();
}
if(crisisID != null) {
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/crisis/" + crisisID);
response = webResource.request(MediaType.APPLICATION_JSON).delete();
}
if(userID != null) {
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/user/" + userID);
response = webResource.request(MediaType.APPLICATION_JSON).delete();
}
}
private void createNominalLabel(LabelCode labelCode, NominalAttributeDTO attributeDTO) {
Long nominalLabelID = 0L;
try {
NominalLabelDTO nominalLabelDTO = new NominalLabelDTO();
nominalLabelDTO.setNominalAttributeDTO(attributeDTO);
nominalLabelDTO.setName(labelCode.getName());
nominalLabelDTO.setNominalLabelCode(labelCode.getCode());
nominalLabelDTO.setDescription(labelCode.getName());
nominalLabelDTO.setSequence(101);
webResource = client.target(taggerConfig.getProperty(TaggerConfigurationProperty.TAGGER_API)+"/label");
response = webResource.request(
MediaType.APPLICATION_JSON).post(Entity.json(nominalLabelDTO), Response.class);
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
nominalLabelDTO = objectMapper.readValue(jsonResponse, NominalLabelDTO.class);
if(nominalLabelDTO == null || nominalLabelDTO.getNominalLabelId() == null) {
Assert.fail("NominalLabel not created with code : " + labelCode.getCode());
} else {
nominalLabelID = nominalLabelDTO.getNominalLabelId();
}
} catch (Exception e) {
// TODO Auto-generated catch block
logger.error(e.getMessage());
Assert.fail("Failed to create NominalLabel :" + e.getMessage());
}
}
class TaggerSubscriber extends JedisPubSub {
@Override
public void onMessage(String channel, String message) {
// TODO Auto-generated method stub
JSONObject jsonObject = new JSONObject(message);
String tweetid = jsonObject.getString("tweetid");
JSONObject aidrJson = jsonObject.getJSONObject("aidr");
JSONArray jsonArray = aidrJson.getJSONArray("nominal_labels");
JSONObject jsonObject2 = jsonArray.getJSONObject(0);
String code = jsonObject2.getString("label_code");
if(!quiet) {
System.out.println("Received tweet : "+ tweetid + " with label : " + code);
}
if(LabelCode.WHITE.getCode().equals(code)) {
whiteClassifiedCount++;
} else if(LabelCode.BLACK.getCode().equals(code)) {
blackClassifiedCount++;
} else {
logger.equals("Irrelevant lable code for tester : " + code);
}
}
@Override
public void onPMessage(String pattern, String channel, String message) {
// TODO Auto-generated method stub
}
@Override
public void onSubscribe(String channel, int subscribedChannels) {
// TODO Auto-generated method stub
}
@Override
public void onUnsubscribe(String channel, int subscribedChannels) {
// TODO Auto-generated method stub
}
@Override
public void onPUnsubscribe(String pattern, int subscribedChannels) {
// TODO Auto-generated method stub
}
@Override
public void onPSubscribe(String pattern, int subscribedChannels) {
// TODO Auto-generated method stub
}
}
}