package edu.uncc.cs.watsonsim.scripts;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.sql.SQLException;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.http.NameValuePair;
import org.apache.http.client.fluent.Form;
import org.apache.http.client.fluent.Request;
import org.eclipse.jgit.lib.Repository;
import org.eclipse.jgit.storage.file.FileRepositoryBuilder;
import edu.uncc.cs.watsonsim.Answer;
import edu.uncc.cs.watsonsim.DBQuestionSource;
import edu.uncc.cs.watsonsim.DefaultPipeline;
import edu.uncc.cs.watsonsim.Environment;
import edu.uncc.cs.watsonsim.StringUtils;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import com.google.common.util.concurrent.AtomicDouble;
/**
* @author Sean Gallagher
* @author Matt Gibson
*/
public class ParallelStats {
/**
* @param args the command line arguments
* @throws Exception
*/
public static void main(String[] args) throws Exception {
BasicConfigurator.configure();
Logger.getRootLogger().setLevel(Level.WARN);
Logger log = Logger.getLogger(ParallelStats.class);
//String mode = System.console().readLine("Train or test [test]:");
System.out.print("Train, test, minitrain or minitest [minitest]: ");
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String mode = br.readLine();
String sql;
if (mode.equals("test")) {
sql = String.format("ORDER BY permute LIMIT %d OFFSET %d", 2000, 0);
} else if (mode.equals("train")) {
sql = String.format("ORDER BY permute LIMIT %d OFFSET %d", 10000, 2000);
} else if (mode.equals("minitrain")) {
sql = String.format("ORDER BY permute LIMIT %d OFFSET %d", 1000, 0);
} else {
sql = String.format("ORDER BY permute LIMIT %d OFFSET %d", 1000, 2000);
}
System.out.print("Describe the setup: ");
String description = br.readLine();
try {
new StatsGenerator(description + ": " + mode, sql).run();
} catch (SQLException e) {
e.printStackTrace();
log.error("Database missing, invalid, or out of date. Check that you "
+ "have the latest version.", e);
}
System.out.println("Done.");
}
}
/**
* This private class runs all the kinds of statistics in the background.
* <p>
* It measures:<p>
* 1. Overall (top) accuracy<p>
* 2. Top-3 accuracy<p>
* 3. Mean Reciprocal Rank (MRR), aka mean inverse rank.
* It is only calculated on questions where the correct answer was one
* of the candidate answers. Thus, Scorers and the Learner should use
* MRR as a guide, looking to approach 1.0. <p>
* 4. Availability, aka binary recall.
* This is more of an issue with the Searchers, which should strive for
* high binary recall. Precision eventually comes in to play too but is
* not calculated because the intention is for scorers to improve it
* instead of filtering it out early in Searchers. Still, it comes into
* play. <p>
* 5. A histogram of accuracy by confidence. In theory, it should be more
* accurate when it is more confident. That has not yet panned out. <p>
* 6. Miscellaneous facts like the Git commit, for later reference.<p>
* <p>
* It also prints out a number each time it finishes a question, simply to
* relieve some of the boredom of watching it calculate. Expect to see: 0 1 2
* 3 ...
*
* There is only one method to call, which is basically just a procedure. But
* internally there are several private functions to aid organization.
*
* @author Phani Rahul
* @author Sean Gallagher
*/
class StatsGenerator {
private final String dataset;
private final DBQuestionSource questionsource;
private AtomicInteger available = new AtomicInteger(0);
private AtomicDouble total_inverse_rank = new AtomicDouble(0);
private AtomicInteger total_questions = new AtomicInteger(0);
private AtomicInteger total_correct = new AtomicInteger(0);
private int total_answers = 0;
private double runtime;
private long run_start;
private final Logger log = Logger.getLogger(getClass());
/**
* Generate statistics on a specific set of questions
*
* To understand the query, see {@link DBQuestionSource}.
* @param dataset What to name the result when it is posted online.
* @param question_query The SQL filters for the questions.
* @throws IOException
* @throws Exception
*/
public StatsGenerator(String dataset, String question_query) throws SQLException{
this.dataset = dataset;
questionsource = new DBQuestionSource(new Environment(), question_query);
this.run_start = System.currentTimeMillis();
}
/** Run statistics, then upload to the server */
public void run() {
final long start_time = System.nanoTime();
BasicConfigurator.configure();
Logger.getRootLogger().setLevel(Level.ERROR);
System.out.println("Performing train/test session\n"
+ " #=top o=top3 .=recall ' '=missing");
ConcurrentHashMap<Long, DefaultPipeline> pipes =
new ConcurrentHashMap<>();
int[] all_ranks = questionsource.parallelStream().mapToInt(q -> {
long tid = Thread.currentThread().getId();
DefaultPipeline pipe = pipes.computeIfAbsent(tid, (i) -> new DefaultPipeline());
List<Answer> answers;
try{
answers = pipe.ask(q, message -> {});
} catch (Exception e) {
log.fatal(e, e);
return 99;
}
int tq = total_questions.incrementAndGet();
if (tq % 50 == 0) {
System.out.println(
String.format(
"[%d]: %d (%.02f%%) accurate",
total_questions.get(),
total_correct.get(),
total_correct.get() * 100.0 / total_questions.get()));
}
int correct_rank = 99;
if (answers.size() == 0) {
System.out.print('!');
return 99;
}
for (int rank=0; rank<answers.size(); rank++) {
Answer candidate = answers.get(rank);
if (candidate.scores.get("CORRECT") > 0.99) {
total_inverse_rank.addAndGet(1 / ((double)rank + 1));
available.incrementAndGet();
if (rank < 100) correct_rank = rank;
break;
}
}
if (correct_rank == 0) {
total_correct.incrementAndGet();
System.out.print('#');
} else if (correct_rank < 3) {
System.out.print('o');
} else if (correct_rank < 99) {
System.out.print('.');
} else {
System.out.print(' ');
}
total_answers += answers.size();
//System.out.println("Q: " + text.question + "\n" +
// "A[Guessed: " + top_answer.getScore() + "]: " + top_answer.getTitle() + "\n" +
// "A[Actual:" + correct_answer_score + "]: " + text.answer);
return correct_rank;
}).mapToObj(x -> {int[] xs = new int[100]; xs[x] = 1; return xs;}).reduce(new int[100], StatsGenerator::add);
// Only count the rank of questions that were actually there
// This is not atomic but by now only one is running
total_inverse_rank.set(total_inverse_rank.doubleValue() / available.doubleValue());
// Finish the timing
runtime = System.nanoTime() - start_time;
runtime /= 1e9;
report(all_ranks);
}
private static int[] add(int[] a, int[] b) {
int[] c = new int[a.length];
for (int i=0; i<a.length; i++)
c[i] = a[i] + b[i];
return c;
}
/** Send Statistics to the server */
private void report(int[] correct) {
// At worst, give an empty branch and commit
String branch = "", commit = "";
if (System.getenv("TRAVIS_BRANCH") != null) {
// Use CI information if possible.
branch = System.getenv("TRAVIS_BRANCH");
commit = System.getenv("TRAVIS_COMMIT");
} else {
// Otherwise take a stab at it ourselves.
try {
Repository repo = new FileRepositoryBuilder()
.readEnvironment()
.findGitDir()
.build();
commit = repo
.resolve("HEAD")
.abbreviate(10)
.name();
if (commit == null) {
commit = "";
log.warn("Problem finding git repository.\n"
+ "Resulting stats will be missing information.");
}
branch = repo.getBranch();
} catch (IOException ex) {
// Well at least we tried.
}
}
// Generate report
List<NameValuePair> response = Form.form()
.add("run[branch]", branch)
.add("run[commit_hash]", commit.substring(0, 10))
.add("run[dataset]", dataset)
.add("run[top]", String.valueOf(correct[0]))
.add("run[top3]", String.valueOf(correct[0] + correct[1] + correct[2]))
.add("run[available]", String.valueOf(available))
.add("run[rank]", String.valueOf(total_inverse_rank))
.add("run[total_questions]", String.valueOf(questionsource.size()))
.add("run[total_answers]", String.valueOf(total_answers))
.add("run[confidence_histogram]", StringUtils.join(new int[100], " "))
.add("run[confidence_correct_histogram]", StringUtils.join(new int[100], " "))
.add("run[runtime]", String.valueOf(runtime))
.build();
try {
Request.Post("http://watsonsim.herokuapp.com/runs.json").bodyForm(response).execute();
} catch (IOException e) {
log.warn("Error uploading stats. Ignoring. "
+ "Details follow.", e);
}
log.info(correct[0] + " of " + questionsource.size() + " correct");
log.info(available + " of " + questionsource.size() + " could have been");
log.info("Mean Inverse Rank " + total_inverse_rank);
}
}