package qa.qcri.aidr.predict;
import static org.junit.Assert.assertEquals;
import java.io.IOException;
import java.util.Random;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.log4j.Logger;
import org.codehaus.jackson.JsonParseException;
import org.codehaus.jackson.map.JsonMappingException;
import org.codehaus.jackson.map.ObjectMapper;
import org.glassfish.jersey.jackson.JacksonFeature;
import org.json.JSONArray;
import org.json.JSONObject;
import org.junit.Assert;
import qa.qcri.aidr.common.code.JacksonWrapper;
import qa.qcri.aidr.predict.common.TaggerConfigurationProperty;
import qa.qcri.aidr.predict.common.TaggerConfigurator;
import qa.qcri.aidr.predict.util.ResponseWrapper;
import redis.clients.jedis.Jedis;
public class TaggerTesterHelper {
private static final Logger logger = Logger
.getLogger(TaggerTesterHelper.class);
private static final Long SLEEP_DURATION_BETWEEN_TWEETS_PUBLISH_IN_MILLIS = 200L;
private static final int TWEET_TEXT_WORD_COUNT = 30;
private static final Long MODEL_CREATION_WAIT_TIME_IN_MILLIS = 10000L;
private static final int TAGGED_TWEET_COUNT_CHECKPOINT_FOR_MODEL_CREATION = 200;
private static final int ADDITIONAL_TWEET_COUNT_TO_PUBLISH = 50;
private int nItems;
private Long crisisID;
private long userID;
private long attributeID;
private long modelFamilyID;
private WebTarget webResource;
private String jsonResponse;
private Response response;
private ObjectMapper objectMapper;
private Client client;
private boolean quiet;
public TaggerTesterHelper(Long crisisID, Long userID, Long attributeID,
Long modelFamilyID, int nItems, boolean quiet) {
this.crisisID = crisisID;
this.userID = userID;
this.attributeID = attributeID;
this.modelFamilyID = modelFamilyID;
this.nItems = nItems;
objectMapper = JacksonWrapper.getObjectMapper();
client = ClientBuilder.newBuilder().register(JacksonFeature.class)
.build();
this.quiet = quiet;
}
public void setNItems(int nItems) {
this.nItems = nItems;
}
// publish to redis queue
public void startPublishing(boolean training, LabelCode labelCode)
throws JsonParseException, JsonMappingException, IOException {
int tempItemCount;
Jedis redis = DataStore.getJedisConnection();
while (true) {
tempItemCount = nItems;
while (tempItemCount > 0) {
String tweet = generateTweet(training, labelCode, tempItemCount);
redis.publish(
TaggerConfigurator
.getInstance()
.getProperty(
TaggerConfigurationProperty.REDIS_INPUT_CHANNEL),
tweet);
tempItemCount--;
try {
Thread.sleep(SLEEP_DURATION_BETWEEN_TWEETS_PUBLISH_IN_MILLIS);
} catch (InterruptedException e) {
logger.error("Thread sleep interrupted"
+ Thread.currentThread().getName());
}
if (training && tempItemCount == 0) {
webResource = client.target(TaggerConfigurator
.getInstance().getProperty(
TaggerConfigurationProperty.TAGGER_API)
+ "/document/unlabeled/count/" + crisisID);
response = webResource.request(MediaType.APPLICATION_JSON)
.get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
Integer count = objectMapper.readValue(jsonResponse,
Integer.class);
System.out.println("Unlabeled count:" + count);
Assert.assertFalse("Unable to insert documents." , count == 0);
if (count < nItems) {
tempItemCount = ADDITIONAL_TWEET_COUNT_TO_PUBLISH;
} else {
nItems = count;
}
}
}
DataStore.close(redis);
return;
}
}
public void tagDocuments() throws JsonParseException, JsonMappingException,
IOException {
int tagCount = nItems;
int checkPoint = TAGGED_TWEET_COUNT_CHECKPOINT_FOR_MODEL_CREATION;
while (true) {
while (tagCount > 0) {
webResource = client.target(TaggerConfigurator.getInstance()
.getProperty(
TaggerConfigurationProperty.TRAINER_API_ROOT)
+ "/document/getassignabletask/"
+ TaggerTesterTest.TAGGER_TESTER_USER
+ "/"
+ crisisID
+ "/1");
response = webResource.request(MediaType.APPLICATION_JSON)
.get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
JSONArray jsonArray = new JSONArray(jsonResponse);
if (jsonArray.length() == 0) {
Assert.fail("No document to tag for crisis : " + crisisID);
}
JSONObject jsonObject = jsonArray.getJSONObject(0);
Integer documentID = (Integer) jsonObject.get("documentID");
String data = jsonObject.getString("data");
String tweetid = new JSONObject(data).getString("tweetid");
org.json.JSONObject infoJson = new org.json.JSONObject();
infoJson.put("crisisID", crisisID);
infoJson.put("documentID", documentID);
infoJson.put("aidrID", userID);
infoJson.put("attributeID", attributeID);
infoJson.put("category",
tweetid.substring(0, tweetid.indexOf("-")));
jsonObject.put("info", infoJson);
jsonArray.put(jsonObject);
webResource = client.target(TaggerConfigurator.getInstance()
.getProperty(
TaggerConfigurationProperty.TRAINER_API_ROOT)
+ "/taskanswer/save");
logger.info("saveTaskAnswer - postData : "
+ jsonArray.toString());
response = webResource
.request(MediaType.APPLICATION_JSON)
.post(Entity.json(jsonArray.toString()), Response.class);
assertEquals(204, response.getStatus());
if (!quiet) {
System.out.println("Labelled tweet : " + tweetid
+ " with label : "
+ tweetid.substring(0, tweetid.indexOf("-")));
}
tagCount--;
if (tagCount == checkPoint) {
webResource = client.target(TaggerConfigurator
.getInstance().getProperty(
TaggerConfigurationProperty.TAGGER_API)
+ "/model/modelFamily/" + modelFamilyID);
response = webResource.request(MediaType.APPLICATION_JSON)
.get();
assertEquals(200, response.getStatus());
jsonResponse = response.readEntity(String.class);
ResponseWrapper responseWrapper = objectMapper.readValue(
jsonResponse, ResponseWrapper.class);
if (responseWrapper == null
|| responseWrapper.getModelHistoryWrapper() == null
|| responseWrapper.getModelHistoryWrapper().length == 0) {
try {
Thread.sleep(MODEL_CREATION_WAIT_TIME_IN_MILLIS);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
checkPoint -= 50;
} else {
return;
}
}
}
return;
}
}
private String generateTweet(boolean isTrainingTweet, LabelCode labelCode,
int tweetIndex) {
Random random = new Random();
int tweetTextChoice = random.nextInt(2);
String tweetWordSelected = "";
StringBuffer stringBuffer = new StringBuffer();
if (labelCode == null) {
labelCode = LabelCode.values()[tweetTextChoice];
}
if (!quiet && isTrainingTweet) {
System.out.println("Sent tweet : " + labelCode.getCode() + "-"
+ tweetIndex);
}
for (int i = 0; i < TWEET_TEXT_WORD_COUNT; i++) {
int wordChoice = random.nextInt(labelCode.getTweetWords().length);
tweetWordSelected = labelCode.getTweetWords()[wordChoice];
if ("w".equals(tweetWordSelected)) {
tweetWordSelected = "w"
+ String.format("%02d", random.nextInt(100));
}
stringBuffer.append(tweetWordSelected + " ");
}
// preapare tweet
JSONObject tweetObject = new JSONObject();
JSONObject user = new JSONObject();
user.put("id", userID);
JSONObject aidr = new JSONObject();
aidr.put("crisis_code", TaggerTesterTest.TAGGER_TESTER_CRISIS_CODE);
aidr.put("doctype", "twitter");
aidr.put("crisis_name", TaggerTesterTest.TAGGER_TESTER_CRISIS_NAME);
tweetObject.put("user", user);
tweetObject.put("tweetid", labelCode.getCode() + "-" + tweetIndex);
tweetObject.put("text", stringBuffer.toString());
tweetObject.put("aidr", aidr);
return tweetObject.toString();
}
enum LabelCode {
BLACK("black", "Black", new String[] { "neutral", "night", "coal",
"ink", "coffee", "w" }), WHITE("white", "White", new String[] {
"clouds", "snow", "clear", "light", "neutral", "w" }), DOES_NOT_APPLY(
"null", "Does Not Apply", new String[] { "none" });
private String name;
private String code;
private String[] tweetWords;
private LabelCode(String code, String name, String[] tweetWords) {
this.name = name;
this.code = code;
this.tweetWords = tweetWords;
}
public String[] getTweetWords() {
return tweetWords;
}
public String getName() {
return name;
}
public String getCode() {
return code;
}
}
}