/*
* Copyright 2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ratpack.exec.util.internal;
import io.netty.util.concurrent.ScheduledFuture;
import ratpack.exec.Downstream;
import ratpack.exec.Promise;
import ratpack.exec.Upstream;
import ratpack.exec.internal.Continuation;
import ratpack.exec.internal.DefaultExecution;
import ratpack.exec.util.ReadWriteAccess;
import java.time.Duration;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
public class DefaultReadWriteAccess implements ReadWriteAccess {
private final Queue<Access<?>> queue = new ConcurrentLinkedQueue<>();
private final AtomicBoolean draining = new AtomicBoolean();
private final AtomicInteger activeReaders = new AtomicInteger();
private final Duration defaultTimeout;
private AtomicReference<Access<?>> pendingWriteRef = new AtomicReference<>();
public DefaultReadWriteAccess(Duration defaultTimeout) {
if (defaultTimeout.isNegative()) {
throw new IllegalArgumentException("defaultTimeout must not be negative");
}
this.defaultTimeout = defaultTimeout;
}
@Override
public Duration getDefaultTimeout() {
return defaultTimeout;
}
@Override
public <T> Promise<T> read(Promise<T> promise) {
return promise.transform(up -> down -> new Access<T>(true, up, defaultTimeout, down));
}
@Override
public <T> Promise<T> read(Promise<T> promise, Duration timeout) {
return promise.transform(up -> down -> new Access<T>(true, up, timeout, down));
}
@Override
public <T> Promise<T> write(Promise<T> promise) {
return promise.transform(up -> down -> new Access<T>(false, up, defaultTimeout, down));
}
@Override
public <T> Promise<T> write(Promise<T> promise, Duration timeout) {
return promise.transform(up -> down -> new Access<T>(false, up, timeout, down));
}
private class Access<T> {
private final boolean read;
private final Upstream<? extends T> upstream;
private final Duration timeout;
private final Downstream<? super T> downstream;
private final DefaultExecution execution;
private boolean fired;
private Continuation continuation;
private ScheduledFuture<?> timeoutFuture;
private Access(boolean read, Upstream<? extends T> upstream, Duration timeout, Downstream<? super T> downstream) {
if (timeout.isNegative()) {
throw new IllegalArgumentException("Timeout value must not be negative");
}
this.read = read;
this.upstream = upstream;
this.timeout = timeout;
this.downstream = downstream;
this.execution = DefaultExecution.get();
execution.delimit(e -> {
relinquish(false);
if (fire()) {
downstream.error(e);
}
}, continuation -> {
if (!timeout.isZero()) {
timeoutFuture = execution.getEventLoop().schedule(this::timeout, timeout.toMillis(), TimeUnit.MILLISECONDS);
}
this.continuation = continuation;
queue.add(this);
drain();
});
}
private boolean fire() {
//noinspection SimplifiableIfStatement
if (fired) {
return false;
} else {
return fired = true;
}
}
private void timeout() {
if (fire()) {
continuation.resume(() -> {
drain();
downstream.error(new TimeoutException("Could not acquire " + (read ? "read" : "write") + " access within " + timeout));
});
}
}
private void access() {
if (read) {
activeReaders.incrementAndGet();
}
if (fire()) {
if (timeoutFuture != null) {
timeoutFuture.cancel(false);
}
continuation.resume(() ->
upstream.connect(new Downstream<T>() {
@Override
public void success(T value) {
relinquish(false);
downstream.success(value);
}
@Override
public void error(Throwable throwable) {
relinquish(false);
downstream.error(throwable);
}
@Override
public void complete() {
relinquish(false);
downstream.complete();
}
})
);
} else {
relinquish(true);
}
}
private void relinquish(boolean didTimeout) {
if (read) {
if (activeReaders.decrementAndGet() == 0) {
Access<?> pendingWrite = pendingWriteRef.getAndSet(null);
if (pendingWrite != null) {
pendingWrite.access();
return;
}
}
} else {
draining.set(false);
}
drain();
}
}
private void drain() {
if (draining.compareAndSet(false, true)) {
Access<?> access = queue.poll();
while (access != null) {
if (access.read) {
access.access();
access = queue.poll();
} else {
if (activeReaders.get() == 0) {
access.access();
} else {
pendingWriteRef.set(access);
if (activeReaders.get() == 0) {
access = pendingWriteRef.getAndSet(null);
if (access != null) {
access.access();
}
}
}
return;
}
}
draining.set(false);
if (!queue.isEmpty()) {
drain();
}
}
}
}