package org.multiverse.stms.gamma.benchmarks;
import org.benchy.BenchmarkDriver;
import org.benchy.TestCaseResult;
import org.multiverse.TestThread;
import org.multiverse.api.Txn;
import org.multiverse.api.TxnExecutor;
import org.multiverse.api.callables.TxnDoubleCallable;
import org.multiverse.api.callables.TxnVoidCallable;
import org.multiverse.api.references.TxnDouble;
import org.multiverse.stms.gamma.GammaStm;
import java.util.Random;
import static org.multiverse.TestUtils.joinAll;
import static org.multiverse.TestUtils.startAll;
public class AccountDriver extends BenchmarkDriver {
public static final int WARMUP_PHASE = 1;
public static final int TEST_PHASE = 2;
public static final int SHUTDOWN_PHASE = 3;
private GammaStm stm;
private int accountCount;
private Bank bank;
private int threadCount;
private BenchmarkThread[] threads;
private int readFrequency;
private int writeFrequency;
@Override
public void setUp() {
stm = new GammaStm();
bank = new Bank(accountCount);
threads = new BenchmarkThread[threadCount];
for (int k = 0; k < threads.length; k++) {
//threads[k] = new BenchmarkThread(k, );
}
}
@Override
public void run(TestCaseResult testCaseResult) {
startAll(threads);
joinAll(threads);
}
@Override
public void processResults(TestCaseResult testCaseResult) {
//To change body of implemented methods use File | Settings | File Templates.
}
class BenchmarkThread extends TestThread {
final private int id;
final private int nb;
final private int max;
final private int readThreads;
final private int writeThreads;
int transferCount;
int readCount;
int writeCount;
final private Random random;
volatile private int phase;
private int steps;
BenchmarkThread(int id, int nb, int max, int readThreads, int writeThreads) {
phase = WARMUP_PHASE;
steps = 0;
this.id = id;
this.nb = nb;
this.max = max;
this.readThreads = readThreads;
this.writeThreads = writeThreads;
transferCount = readCount = writeCount = 0;
random = new Random();
}
public void setPhase(int phase) {
this.phase = phase;
}
public int getSteps() {
return steps;
}
public void doRun() {
while (phase == WARMUP_PHASE) {
step(WARMUP_PHASE);
}
while (phase == TEST_PHASE) {
step(TEST_PHASE);
steps++;
}
}
protected void step(int phase) {
if (id < readThreads) {
// Compute total of all accounts (read-all transaction)
bank.computeTotal();
if (phase == TEST_PHASE)
readCount++;
} else if (id < readThreads + writeThreads) {
// Add 0% interest (write-all transaction)
bank.addInterest(0);
if (phase == TEST_PHASE)
writeCount++;
} else {
int i = random.nextInt(100);
if (i < readFrequency) {
// Compute total of all accounts (read-all transaction)
bank.computeTotal();
if (phase == TEST_PHASE)
readCount++;
} else if (i < readFrequency + writeFrequency) {
// Add 0% interest (write-all transaction)
bank.addInterest(0);
if (phase == TEST_PHASE)
writeCount++;
} else {
int amount = random.nextInt(max) + 1;
Account src;
Account dst;
if (s_disjoint && nb <= bank.accounts.length) {
src = bank.accounts[random.nextInt(bank.accounts.length / nb) * nb + id];
dst = bank.accounts[random.nextInt(bank.accounts.length / nb) * nb + id];
} else {
src = bank.accounts[random.nextInt(bank.accounts.length)];
dst = bank.accounts[random.nextInt(bank.accounts.length)];
}
try {
bank.transfer(src, dst, amount);
if (phase == TEST_PHASE)
transferCount++;
} catch (OverdraftException e) {
System.err.println("Overdraft: " + e.getMessage());
}
}
}
}
public String getStats() {
return "T=" + transferCount + ", R=" + readCount + ", W=" + writeCount;
}
}
static volatile boolean s_disjoint = false;
static volatile boolean s_yield = false;
class Bank {
private final Account[] accounts;
private final TxnExecutor addInterrestBlock = stm.newTxnFactoryBuilder().newTxnExecutor();
private final TxnExecutor computeTotalBlock = stm.newTxnFactoryBuilder().newTxnExecutor();
private final TxnExecutor transferBlock = stm.newTxnFactoryBuilder().newTxnExecutor();
public Bank(int accountCount) {
accounts = new Account[accountCount];
for (int k = 0; k < accounts.length; k++) {
accounts[k] = new Account("user-" + k, 0);
}
}
public void addInterest(final float rate) {
addInterrestBlock.execute(new TxnVoidCallable() {
@Override
public void call(Txn tx) throws Exception {
for (Account a : accounts) {
a.deposit(a.getBalance() * rate);
if (s_yield)
Thread.yield();
}
}
});
}
public double computeTotal() {
return computeTotalBlock.execute(new TxnDoubleCallable() {
@Override
public double call(Txn tx) throws Exception {
double total = 0.0;
for (Account a : accounts) {
total += a.getBalance();
if (s_yield)
Thread.yield();
}
return total;
}
});
}
public void transfer(final Account src, final Account dst, final float amount) throws OverdraftException {
transferBlock.execute(new TxnVoidCallable() {
@Override
public void call(Txn tx) throws Exception {
dst.deposit(amount);
if (s_yield)
Thread.yield();
src.withdraw(amount);
}
});
}
}
public class Account {
private final String name;
private final TxnDouble balance;
public Account(String name, double balance) {
this.balance = stm.getDefaultRefFactory().newTxnDouble(balance);
this.name = name;
}
public String getName() {
return name;
}
public double getBalance() {
return balance.get();
}
public void deposit(double amount) {
balance.incrementAndGet(amount);
}
public void withdraw(double amount) throws OverdraftException {
if (balance.get() < amount)
throw new OverdraftException("Cannot withdraw $" + amount + " from $" + balance.get());
balance.incrementAndGet(-amount);
}
}
public class OverdraftException extends Exception {
public OverdraftException(String reason) {
super(reason);
}
}
}