package de.randi2.core.unit.randomization; import static de.randi2.core.unit.randomization.RandomizationHelper.randomize; import static de.randi2.utility.IntegerIterator.upto; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.List; import org.junit.Before; import org.junit.Test; import de.randi2.model.TreatmentArm; import de.randi2.model.Trial; import de.randi2.model.TrialSubject; import de.randi2.model.randomization.Urn; import de.randi2.model.randomization.UrnDesignConfig; import de.randi2.model.randomization.UrnDesignTempData; public class UrnDesignTest { private Trial trial; private TrialSubject s; private UrnDesignConfig conf; @Before public void setUp() { trial = new Trial(); conf = new UrnDesignConfig(); trial.setRandomizationConfiguration(conf); } @Test public void testFourtySubjectAllocations() { int replacedBalls = 4; int initializeCount = 10; RandomizationHelper.addArms(trial, 20, 20); conf.setCountReplacedBalls(replacedBalls); conf.setInitializeCountBalls(initializeCount); s = new TrialSubject(); int[] countBalls = new int[2]; countBalls[0] = initializeCount; countBalls[1] = initializeCount; List<TreatmentArm> arms = new ArrayList<TreatmentArm>(trial.getTreatmentArms()); for (int i : upto(40)) { randomize(trial,s); if(s.getArm().getName().equals(arms.get(0).getName())){ countBalls[1]+= replacedBalls; countBalls[0]--; }else{ countBalls[0]+= replacedBalls; countBalls[1]--; } assertTrue(checkUrn(countBalls)); } assertEquals(40, trial.getSubjects().size()); } @Test public void test100SubjectAllocations() { int replacedBalls = 2; int initializeCount = 4; RandomizationHelper.addArms(trial, 50, 50); conf.setCountReplacedBalls(replacedBalls); conf.setInitializeCountBalls(initializeCount); s = new TrialSubject(); int[] countBalls = new int[2]; countBalls[0] = initializeCount; countBalls[1] = initializeCount; List<TreatmentArm> arms = new ArrayList<TreatmentArm>(trial.getTreatmentArms()); for (int i : upto(100)) { randomize(trial,s); if(s.getArm().getName().equals(arms.get(0).getName())){ countBalls[1]+= replacedBalls; countBalls[0]--; }else{ countBalls[0]+= replacedBalls; countBalls[1]--; } assertTrue(checkUrn(countBalls)); } assertEquals(100, trial.getSubjects().size()); } private boolean checkUrn(int[] countBalls){ String stratum = ""; if(trial.isStratifyTrialSite()) stratum = s.getTrialSite().getId() + ""; stratum += s.getStratum(); Urn urn = ((UrnDesignTempData)conf.getTempData()).getUrn(stratum); int[] count = new int[2]; count[0] = 0; count[1] = 0; List<TreatmentArm> arms = new ArrayList<TreatmentArm>(trial.getTreatmentArms()); for(TreatmentArm arm : urn.getUrn()){ if(arm.getName().equals(arms.get(0).getName())){ count[0]++; }else{ count[1]++; } } return countBalls[0]==count[0] || countBalls[1]==count[1] ; } }