/* * Copyright 2016 The Simple File Server Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.sfs.block; import com.google.common.base.Optional; import io.vertx.core.logging.Logger; import org.sfs.SfsVertx; import org.sfs.rx.Defer; import org.sfs.rx.Holder2; import org.sfs.rx.ObservableFuture; import org.sfs.rx.RxHelper; import rx.Observable; import rx.Subscriber; import rx.functions.Func0; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static com.google.common.base.Optional.fromNullable; import static com.google.common.base.Preconditions.checkState; import static com.google.common.math.LongMath.checkedAdd; import static io.vertx.core.logging.LoggerFactory.getLogger; import static java.lang.Boolean.FALSE; import static java.lang.Boolean.TRUE; import static java.lang.Long.MAX_VALUE; import static java.lang.Math.floor; import static java.lang.Math.random; import static java.lang.System.currentTimeMillis; import static org.sfs.math.Rounding.up; import static rx.Observable.just; import static rx.Observable.using; public class RangeLock { private static final Logger LOGGER = getLogger(RangeLock.class); private final int blockSize; private final Object mutex = new Object(); private final List<LockedRange> readLocks = new ArrayList<>(); private final List<LockedRange> writeLocks = new ArrayList<>(); private AtomicInteger lockCount = new AtomicInteger(0); public RangeLock(int blockSize) { this.blockSize = blockSize; } public static <R> Observable<R> lockedObservable(SfsVertx vertx, Func0<Optional<Lock>> lockFactory, Func0<? extends Observable<R>> observableFactory, long lockWaitTimeoutMs) { ObservableFuture<R> handler = RxHelper.observableFuture(); long now = currentTimeMillis(); AtomicBoolean unSubscribed = new AtomicBoolean(false); lockedObservable0(vertx, now, lockFactory, observableFactory, lockWaitTimeoutMs, handler, unSubscribed); return handler .doOnUnsubscribe(() -> unSubscribed.set(true)); } private static <R> void lockedObservable0(SfsVertx vertx, long startTimeMs, Func0<Optional<Lock>> lockFactory, Func0<? extends Observable<R>> observableFactory, long lockWaitTimeoutMs, ObservableFuture<R> handler, AtomicBoolean unSubscribed) { using( lockFactory::call, oLock -> { if (oLock.isPresent()) { Lock lock = oLock.get(); ObservableFuture<Holder2<Boolean, R>> innerHandler = RxHelper.observableFuture(); Defer.aVoid() .flatMap(aVoid -> observableFactory.call()) .subscribe(new Subscriber<R>() { R value; @Override public void onCompleted() { lock.unlock(); innerHandler.complete(new Holder2<>(TRUE, value)); } @Override public void onError(Throwable e) { lock.unlock(); innerHandler.fail(e); } @Override public void onNext(R r) { value = r; } }); return innerHandler; } else { return just(new Holder2<>(FALSE, (R) null)); } }, oLock -> { if (oLock.isPresent()) { oLock.get().unlock(); } }) .subscribe(new Subscriber<Holder2<Boolean, R>>() { Holder2<Boolean, ? extends R> statusHolder; @Override public void onCompleted() { if (TRUE.equals(statusHolder.value0())) { handler.complete(statusHolder.value1()); } else { if (currentTimeMillis() - startTimeMs >= lockWaitTimeoutMs) { TimedOutException timedOutException = new TimedOutException(); handler.fail(timedOutException); } else { if (!unSubscribed.get()) { vertx.setTimer((long) (floor(random() * 100) + 1), event -> lockedObservable0(vertx, startTimeMs, lockFactory, observableFactory, lockWaitTimeoutMs, handler, unSubscribed)); } } } } @Override public void onError(Throwable e) { handler.fail(e); } @Override public void onNext(Holder2<Boolean, R> status) { statusHolder = status; } }); } public int getLockCount() { return lockCount.get(); } public Optional<Lock> tryWriteLock(long position, long length) { return fromNullable(lockWrite0(position, length)); } public Optional<Lock> tryReadLock(long position, long length) { return fromNullable(lockRead0(position, length)); } protected Lock lockWrite0(long position, long length) { Range range = new Range(position, computeLast(position, length)); synchronized (mutex) { if (isWriteConflict(range) || isReadConflict(range)) { return null; } LockedRange lockedRange = new LockedRange(range); addWriteLockedRange(lockedRange); lockCount.incrementAndGet(); return new Lock() { @Override void unlock0() { synchronized (mutex) { removeWriteLockedRange(lockedRange); lockCount.decrementAndGet(); } } }; } } protected Lock lockRead0(long position, long length) { Range range = new Range(position, computeLast(position, length)); synchronized (mutex) { if (isWriteConflict(range)) { return null; } LockedRange lockedRange = new LockedRange(range); addReadLockedRange(lockedRange); lockCount.incrementAndGet(); return new Lock() { @Override void unlock0() { synchronized (mutex) { removeReadLockedRange(lockedRange); lockCount.decrementAndGet(); } } }; } } private void removeReadLockedRange(LockedRange lockedRange) { checkState(readLocks.remove(lockedRange)); } private void removeWriteLockedRange(LockedRange lockedRange) { checkState(writeLocks.remove(lockedRange)); } private void addWriteLockedRange(LockedRange lockedRange) { checkState(writeLocks.add(lockedRange)); } private void addReadLockedRange(LockedRange lockedRange) { checkState(readLocks.add(lockedRange)); } protected boolean isReadConflict(Range range) { for (LockedRange lockedRange : readLocks) { if (lockedRange.intersects(range)) { return true; } } return false; } protected boolean isWriteConflict(Range range) { for (LockedRange lockedRange : writeLocks) { if (lockedRange.intersects(range)) { return true; } } return false; } protected long computeLast(long first, long length) { long last; try { last = checkedAdd(first, up(length, blockSize)); } catch (ArithmeticException e) { last = MAX_VALUE; } return last - 1; } public static class LockedRange { private final Range range; public LockedRange(Range range) { this.range = range; } public boolean adjacent(Range other) { return range.adjacent(other); } public Range merge(Range other) { return range.merge(other); } public boolean encloses(Range other) { return range.encloses(other); } public Range[] remove(Range toRemove) { return range.remove(toRemove); } public long getLast() { return range.getLast(); } public long getFirst() { return range.getFirst(); } public boolean intersects(Range other) { return range.intersects(other); } public boolean isEmpty() { return range.isEmpty(); } public long getBlockCount() { return range.getBlockCount(); } public boolean encloses(long first, long last) { return range.encloses(first, last); } public Range[] remove(long first, long last) { return range.remove(first, last); } @Override public String toString() { return "LockedRange{" + "range=" + range + '}'; } } public abstract static class Lock { private final AtomicBoolean unlocked = new AtomicBoolean(false); public Lock() { } public boolean unlock() { if (unlocked.compareAndSet(false, true)) { unlock0(); return true; } return false; } abstract void unlock0(); } public static class TimedOutException extends RuntimeException { } }