package org.nd4j.rng;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.rng.deallocator.NativePack;
import org.nd4j.rng.deallocator.NativeRandomDeallocator;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Basic NativeRandom implementation
*
* @author raver119@gmail.com
*/
@Slf4j
public abstract class NativeRandom implements Random {
protected NativeOps nativeOps;
protected DataBuffer stateBuffer;
protected Pointer statePointer;
protected long seed;
protected long amplifier;
protected long generation;
protected long numberOfElements;
protected AtomicInteger position = new AtomicInteger(0);
protected LongPointer hostPointer;
protected boolean isDestroyed = false;
protected NativeRandomDeallocator deallocator;
// special stuff for gaussian
protected double z0, z1, u0, u1;
protected boolean generated = false;
protected double mean = 0.0;
protected double stdDev = 1.0;
// hack to attach deallocator
protected NativePack pack;
public long getBufferSize() {
return numberOfElements;
}
public int getPosition() {
return position.get();
}
public long getGeneration() {
return generation;
}
public NativeRandom() {
this(System.currentTimeMillis());
}
public NativeRandom(long seed) {
this(seed, 10000000);
}
public NativeRandom(long seed, long numberOfElements) {
this.amplifier = seed;
this.generation = 1;
this.seed = seed;
this.numberOfElements = numberOfElements;
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
stateBuffer = Nd4j.getDataBufferFactory().createDouble(numberOfElements);
init();
hostPointer = new LongPointer(stateBuffer.addressPointer());
deallocator = NativeRandomDeallocator.getInstance();
pack = new NativePack(statePointer.address(), statePointer);
deallocator.trackStatePointer(pack);
}
public abstract void init();
@Override
public void setSeed(int seed) {
setSeed((long) seed);
}
@Override
public void setSeed(int[] seed) {
long sd = 0;
for (int em : seed) {
sd *= em;
}
setSeed(sd);
}
@Override
public void setSeed(long seed) {
synchronized (this) {
this.seed = seed;
this.amplifier = seed;
this.position.set(0);
nativeOps.refreshBuffer(getExtraPointers(), seed, statePointer);
}
}
@Override
public long getSeed() {
return seed;
}
@Override
public void nextBytes(byte[] bytes) {
throw new UnsupportedOperationException();
}
@Override
public int nextInt() {
int next = (int) (amplifier == seed ? nextLong() : nextLong() * amplifier + 11);
return next < 0 ? -1 * next : next;
}
@Override
public int nextInt(int to) {
int r = nextInt();
int m = to - 1;
if ((to & m) == 0) // i.e., bound is a power of 2
r = (int) ((to * (long) r) >> 31);
else {
for (int u = r; u - (r = u % to) + m < 0; u = nextInt());
}
return r;
}
@Override
public long nextLong() {
long next = 0;
synchronized (this) {
if (position.get() >= numberOfElements) {
position.set(0);
generation++;
}
next = hostPointer.get(position.getAndIncrement());
if (generation > 1)
next = next ^ generation + 11;
if (amplifier != seed)
next = next ^ amplifier + 11;
}
return next < 0 ? -1 * next : next;
}
public abstract PointerPointer getExtraPointers();
@Override
public boolean nextBoolean() {
return nextInt() % 2 == 0;
}
@Override
public float nextFloat() {
return (float) nextInt() / (float) Integer.MAX_VALUE;
}
@Override
public double nextDouble() {
return (double) nextInt() / (double) Integer.MAX_VALUE;
}
@Override
public double nextGaussian() {
double epsilon = 1e-15;
double two_pi = 2.0 * 3.14159265358979323846;
if (!generated) {
do {
u0 = nextDouble();
u1 = nextDouble();
} while (u0 <= epsilon);
z0 = Math.sqrt(-2.0 * Math.log(u0)) * Math.cos(two_pi * u1);
z1 = Math.sqrt(-2.0 * Math.log(u0)) * Math.sin(two_pi * u1);
generated = true;
return z0 * stdDev + mean;
} else {
generated = false;
return z1 * stdDev + mean;
}
}
@Override
public INDArray nextGaussian(int[] shape) {
return nextGaussian(Nd4j.order(), shape);
}
@Override
public INDArray nextGaussian(char order, int[] shape) {
INDArray array = Nd4j.createUninitialized(shape, order);
GaussianDistribution op = new GaussianDistribution(array, 0.0, 1.0);
Nd4j.getExecutioner().exec(op, this);
return array;
}
@Override
public INDArray nextDouble(int[] shape) {
return nextDouble(Nd4j.order(), shape);
}
@Override
public INDArray nextDouble(char order, int[] shape) {
INDArray array = Nd4j.createUninitialized(shape, order);
UniformDistribution op = new UniformDistribution(array, 0.0, 1.0);
Nd4j.getExecutioner().exec(op, this);
return array;
}
@Override
public INDArray nextFloat(int[] shape) {
return nextFloat(Nd4j.order(), shape);
}
@Override
public INDArray nextFloat(char order, int[] shape) {
INDArray array = Nd4j.createUninitialized(shape, order);
UniformDistribution op = new UniformDistribution(array, 0.0, 1.0);
Nd4j.getExecutioner().exec(op, this);
return array;
}
@Override
public INDArray nextInt(int[] shape) {
throw new UnsupportedOperationException();
}
@Override
public INDArray nextInt(int n, int[] shape) {
throw new UnsupportedOperationException();
}
/**
* This method returns pointer to RNG state structure.
* Please note: DefaultRandom implementation returns NULL here, making it impossible to use with RandomOps
*
* @return
*/
@Override
public Pointer getStatePointer() {
return statePointer;
}
/**
* This method returns pointer to RNG buffer
*
* @return
*/
@Override
public DataBuffer getStateBuffer() {
return stateBuffer;
}
@Override
public void reSeed() {
reSeed(System.currentTimeMillis());
}
@Override
public void reSeed(long amplifier) {
this.amplifier = amplifier;
nativeOps.reSeedBuffer(getExtraPointers(), amplifier, getStatePointer());
}
@Override
public void close() throws Exception {
/*
Do nothing here, since we use WeakReferences for actual deallocation
*/
}
}