/*
* JASA Java Auction Simulator API
* Copyright (C) 2013 Steve Phelps
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of
* the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU General Public License for more details.
*/
package net.sourceforge.jabm.learning;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import net.sourceforge.jabm.learning.EpsilonGreedyActionSelector;
import net.sourceforge.jabm.learning.QLearner;
import net.sourceforge.jabm.test.PRNGTestSeeds;
import net.sourceforge.jabm.util.SummaryStats;
import org.apache.log4j.Logger;
import cern.jet.random.engine.MersenneTwister64;
import cern.jet.random.engine.RandomEngine;
public class QLearnerTest extends TestCase {
QLearner learner1;
EpsilonGreedyActionSelector actionSelector;
double score;
RandomEngine prng;
static final double EPSILON = 0.05;
static final double LEARNING_RATE = 0.8;
static final double DISCOUNT_RATE = 0.9;
static final int NUM_ACTIONS = 10;
static final int CORRECT_ACTION = 2;
static final int NUM_TRIALS = 20000;
static Logger logger = Logger.getLogger(QLearnerTest.class);
public QLearnerTest(String name) {
super(name);
}
public void setUp() {
prng = new MersenneTwister64(PRNGTestSeeds.UNIT_TEST_SEED);
learner1 = new QLearner(1, NUM_ACTIONS, LEARNING_RATE,
DISCOUNT_RATE, prng);
actionSelector = new EpsilonGreedyActionSelector(prng);
actionSelector.setPrng(prng);
actionSelector.setEpsilon(EPSILON);
learner1.setActionSelector(actionSelector);
score = 0;
}
public void testBestAction() {
((EpsilonGreedyActionSelector) learner1.getActionSelector()).setEpsilon(0.0);
System.out.println("testBestAction()");
SummaryStats stats = new SummaryStats("action");
int correctActions = 0;
for (int i = 0; i < NUM_TRIALS; i++) {
int action = learner1.act();
stats.newData(action);
if (action == CORRECT_ACTION) {
learner1.newState(1.0, 0);
correctActions++;
} else {
learner1.newState(0.0, 0);
}
}
logger.info("final state of learner1 = " + learner1);
logger.info("learner1 score = " + score(correctActions) + "%");
assertTrue(learner1.bestAction(0) == CORRECT_ACTION);
System.out.println(stats);
}
public void testMinimumScore() {
logger.info("testMinimumScore()");
SummaryStats stats = new SummaryStats("action");
int correctActions = 0;
int bestActionChosen = 0;
for (int i = 0; i < NUM_TRIALS; i++) {
int action = learner1.act();
stats.newData(action);
assertTrue(action == learner1.getLastActionChosen());
int bestAction = learner1.bestAction(0);
if (bestAction == action) {
bestActionChosen++;
}
if (action == CORRECT_ACTION) {
learner1.newState(1.0, 0);
correctActions++;
} else {
learner1.newState(0.0, 0);
}
}
logger.info("final state of learner1 = " + learner1);
double score = score(correctActions);
double bestActionPercent = score(bestActionChosen);
logger.info("learner1 score = " + score + "%");
logger.info(stats);
logger.info("chose best action " + bestActionPercent + "% of the time.");
assertTrue(score > 80);
assertTrue(1 - (bestActionPercent / 100) <= EPSILON);
}
public void testStates() {
System.out.println("testStates()");
int[] correctChoices = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
int correctActions = 0;
learner1.setStatesAndActions(correctChoices.length, NUM_ACTIONS);
int s = 3;
for (int i = 0; i < NUM_TRIALS; i++) {
int action = learner1.act();
double reward = 0;
if (action == correctChoices[s]) {
reward = 1.0;
correctActions++;
if (++s > 9) {
s = 0;
}
}
learner1.newState(reward, s);
assertTrue(learner1.getState() == s);
}
score = score(correctActions);
System.out.println("score = " + score + "%");
System.out.println("learner1 = " + learner1);
assertTrue(score >= 70);
}
public void testReset() {
RandomEngine prng = new MersenneTwister64(PRNGTestSeeds.UNIT_TEST_SEED);
System.out.println("testReset()");
System.out.println("virgin learner1 = " + learner1);
learner1.setPrng(prng);
actionSelector.setPrng(prng);
learner1.reset();
testStates();
double score1 = score;
System.out.println("score1 = " + score1);
prng = new MersenneTwister64(PRNGTestSeeds.UNIT_TEST_SEED);
learner1.setPrng(prng);
actionSelector.setPrng(prng);
learner1.reset();
System.out.println("reseted learner1 = " + learner1);
testStates();
double score2 = score;
System.out.println("score2 = " + score2);
assertTrue(score1 == score2);
}
public double score(int numCorrect) {
return ((double) numCorrect / (double) NUM_TRIALS) * 100;
}
public static void main(String[] args) {
junit.textui.TestRunner.run(suite());
}
public static Test suite() {
return new TestSuite(QLearnerTest.class);
}
}