package edu.washington.escience.myria.parallel.ipc; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import edu.washington.escience.myria.parallel.ipc.IPCMessage.StreamData; import edu.washington.escience.myria.util.AttachmentableAdapter; import edu.washington.escience.myria.util.concurrent.ClosableReentrantLock; /** * A simple InputBuffer implementation. The number of messages held in this InputBuffer can be as large as * {@link Integer.MAX_VALUE}. All the input data from different input channels are treated by bag semantic. No order is * is guaranteed. * * @param <PAYLOAD> the type of application defined data the input buffer is going to hold. * */ public abstract class BagInputBufferAdapter<PAYLOAD> extends AttachmentableAdapter implements StreamInputBuffer<PAYLOAD> { /** * logger. * */ static final Logger LOGGER = LoggerFactory.getLogger(BagInputBufferAdapter.class); /** * input channel state. * */ private class InputChannelState { /** * EOS bit. * */ private final AtomicBoolean eos = new AtomicBoolean(false); /** * EOS lock. * */ private final ReadWriteLock eosLock = new ReentrantReadWriteLock(); /** * input channel. * */ private final StreamInputChannel<PAYLOAD> inputChannel; /** * @param id input channel id. * */ InputChannelState(final StreamIOChannelID id) { inputChannel = new StreamInputChannel<PAYLOAD>(id, BagInputBufferAdapter.this); } } /** * the storage place of messages. * */ private final LinkedBlockingQueue<IPCMessage.StreamData<PAYLOAD>> storage; /** * Num of EOS. * */ private int numInputEOS; /** * Num threads waiting data by poll(timeout) or take. * */ private int numWaiting; /** * Set of input channels. * */ private final ImmutableMap<StreamIOChannelID, InputChannelState> inputChannels; /** * The processor attached. * */ private final AtomicReference<Object> processor; /** * owner. * */ private final IPCConnectionPool ownerConnectionPool; /** * Input buffer size. * */ private int size = 0; /** * Serialize buffer size access. * */ private final ClosableReentrantLock bufferSizeLock = new ClosableReentrantLock(); /** * @return buffer size lock. * */ public final ClosableReentrantLock getBufferSizeLock() { return bufferSizeLock; } /** * wait on empty. * */ private final Condition emptySize = bufferSizeLock.newCondition(); /** * @param owner the owner IPC pool. * @param remoteChannelIDs from which channels, the data will input. * */ public BagInputBufferAdapter( final IPCConnectionPool owner, final ImmutableSet<StreamIOChannelID> remoteChannelIDs) { storage = new LinkedBlockingQueue<IPCMessage.StreamData<PAYLOAD>>(); ImmutableMap.Builder<StreamIOChannelID, InputChannelState> b = ImmutableMap.builder(); for (StreamIOChannelID ecID : remoteChannelIDs) { InputChannelState ics = new InputChannelState(ecID); b.put(ecID, ics); } inputChannels = b.build(); processor = new AtomicReference<Object>(); numInputEOS = 0; numWaiting = 0; ownerConnectionPool = owner; } /** * Called before {@link #start(Object)} operations are conducted. * * @param processor {@link #start(Object)} * @throws IllegalStateException if the {@link #start(Object)} operation should not be done. * */ protected void preStart(final Object processor) throws IllegalStateException {} /** * Called after {@link #start(Object)} operations are conducted. * * @param processor {@link #start(Object)} * */ protected void postStart(final Object processor) {} @Override public final void start(final Object processor) { Preconditions.checkNotNull(processor); preStart(processor); if (!this.processor.compareAndSet(null, processor)) { throw new IllegalStateException("Already attached to a processor: " + processor); } this.getOwnerConnectionPool().registerStreamInput(this); postStart(processor); } /** * Called before {@link #stop()} operations are conducted. * * @throws IllegalStateException if the {@link #stop()} operation should not be done. * */ protected void preStop() throws IllegalStateException {} /** * Called after {@link #stop()} operations are conducted. * * */ protected void postStop() {} @Override public final void stop() { this.preStop(); this.processor.set(null); this.clear(); this.postStop(); } @Override public final int size() { try (ClosableReentrantLock l = bufferSizeLock.open()) { return this.size; } } @Override public final boolean isEmpty() { try (ClosableReentrantLock l = bufferSizeLock.open()) { return this.size == 0; } } /** * Called before {@link #clear()} is executed. * * @throws IllegalStateException if the {@link #clear()} operation should not be done. * */ protected void preClear() throws IllegalStateException {} /** * Called after {@link #clear()} is executed. * */ protected void postClear() {} @Override public final void clear() { preClear(); storage.clear(); postClear(); try (ClosableReentrantLock l = bufferSizeLock.open()) { this.size = 0; } } /** * Check if the input buffer is attached. * */ private void checkAttached() { if (!isAttached()) { throw new IllegalStateException("Not attached"); } } /** * @param e input data. * @return input state of the data coming channel. * */ private InputChannelState checkValidInputChannel(final IPCMessage.StreamData<PAYLOAD> e) { StreamIOChannelID sourceId = new StreamIOChannelID(e.getStreamID(), e.getRemoteID()); InputChannelState s = inputChannels.get(sourceId); if (s == null) { throw new IllegalArgumentException("Message received from unknown input channel" + e); } return s; } /** * @param ics input channel state * @param e the input data. * @return false if it's an EOS, true if it's not. * */ private boolean checkNotEOS(final InputChannelState ics, final IPCMessage.StreamData<PAYLOAD> e) { if (ics.eos.get()) { if (LOGGER.isDebugEnabled()) { LOGGER.debug( "Message received from an already EOS channele from remote " + e.getRemoteID() + " streamID " + e.getStreamID() + " " + ics.inputChannel); } /* temp solution, better to check if it's from a recover worker */ return false; } return true; } /** * Called before {@link #offer(StreamData)} operations are conducted. * * @param msg {@link #offer(edu.washington.escience.myria.parallel.ipc.IPCMessage.StreamData)} * @throws IllegalStateException if the * {@link #offer(edu.washington.escience.myria.parallel.ipc.IPCMessage.StreamData)} operation should not be * done. * */ protected void preOffer(final IPCMessage.StreamData<PAYLOAD> msg) throws IllegalStateException {} /** * Called after {@link #offer(StreamData)} operations are conducted. * * @param msg {@link #offer(StreamData)} * @param isSucceed if the offer operation succeeds * */ protected void postOffer(final IPCMessage.StreamData<PAYLOAD> msg, final boolean isSucceed) {} @Override public final boolean offer(final IPCMessage.StreamData<PAYLOAD> msg) { Preconditions.checkNotNull(msg); checkAttached(); InputChannelState ics = checkValidInputChannel(msg); if (!checkNotEOS(ics, msg)) { return true; } preOffer(msg); if (msg.getPayload() == null) { // EOS msg ics.eosLock.writeLock().lock(); checkNotEOS(ics, msg); ics.eos.set(true); } else { ics.eosLock.readLock().lock(); checkNotEOS(ics, msg); } boolean inserted = false; try { inserted = storage.offer(msg); if (inserted) { try (ClosableReentrantLock l = bufferSizeLock.open()) { if (msg.getPayload() == null) { this.numInputEOS += 1; } this.size += 1; if (numWaiting > 0) { if (isEOS()) { emptySize.signalAll(); } else if (!isEmpty()) { emptySize.signal(); } } } } } finally { if (msg.getPayload() == null) { ics.eosLock.writeLock().unlock(); } else { ics.eosLock.readLock().unlock(); } } this.postOffer(msg, inserted); return inserted; } /** * Called before {@link #take()} operations are conducted. * * @throws IllegalStateException if the {@link #take()} operation should not be done. * */ protected void preTake() throws IllegalStateException {} /** * Called after {@link #take()} operations are conducted. * * @param msg the result of {@link #take()}. * */ protected void postTake(final IPCMessage.StreamData<PAYLOAD> msg) {} @Override public final IPCMessage.StreamData<PAYLOAD> take() throws InterruptedException { checkAttached(); if (isEOS() && isEmpty()) { return null; } preTake(); try (ClosableReentrantLock l = bufferSizeLock.open()) { if (isEmpty() && !isEOS()) { numWaiting++; try { emptySize.await(); } finally { numWaiting--; } } if (!isEmpty()) { size -= 1; } else { return null; } } IPCMessage.StreamData<PAYLOAD> m = storage.poll(); postTake(m); return m; } /** * Called before {@link #poll(long, TimeUnit)} operations are conducted. * * @param time param of {@link #poll(long, TimeUnit)} * @param unit param of {@link #poll(long, TimeUnit)} * @throws IllegalStateException if the {@link #poll(long, TimeUnit)} operation should not be done. * */ protected void preTimeoutPoll(final long time, final TimeUnit unit) throws IllegalStateException {} /** * Called after {@link #poll(long, TimeUnit)} operations are conducted. * * @param time param of {@link #poll(long, TimeUnit)} * @param unit param of {@link #poll(long, TimeUnit)} * @param msg the result of {@link #poll(long, TimeUnit)}. * */ protected void postTimeoutPoll( final long time, final TimeUnit unit, final IPCMessage.StreamData<PAYLOAD> msg) {} @Override public final IPCMessage.StreamData<PAYLOAD> poll(final long time, final TimeUnit unit) throws InterruptedException { checkAttached(); if (isEOS() && isEmpty()) { return null; } preTimeoutPoll(time, unit); try (ClosableReentrantLock l = bufferSizeLock.open()) { if (isEmpty() && !isEOS()) { numWaiting++; try { emptySize.await(time, unit); } finally { numWaiting--; } if (isEmpty()) { return null; } } this.size -= 1; } IPCMessage.StreamData<PAYLOAD> m = this.storage.poll(); postTimeoutPoll(time, unit, m); return m; } /** * Called before {@link #poll()} operations are conducted. * * @throws IllegalStateException if the {@link #poll()} operation should not be done. * */ protected void prePoll() throws IllegalStateException {} /** * Called after {@link #poll()} operations are conducted. * * @param msg the result of {@link #poll()}. * */ protected void postPoll(final IPCMessage.StreamData<PAYLOAD> msg) {} @Override public final IPCMessage.StreamData<PAYLOAD> poll() { checkAttached(); prePoll(); try (ClosableReentrantLock l = bufferSizeLock.open()) { if (isEmpty()) { return null; } size -= 1; } IPCMessage.StreamData<PAYLOAD> m = storage.poll(); postPoll(m); return m; } @Override public final IPCMessage.StreamData<PAYLOAD> peek() { checkAttached(); return storage.peek(); } @Override public final boolean isAttached() { return processor.get() != null; } @Override public final boolean isEOS() { try (ClosableReentrantLock l = bufferSizeLock.open()) { return numInputEOS >= inputChannels.size(); } } @Override public final StreamInputChannel<PAYLOAD> getInputChannel( final StreamIOChannelID sourceChannelID) { return inputChannels.get(sourceChannelID).inputChannel; } @Override public final ImmutableSet<StreamIOChannelID> getSourceChannels() { return inputChannels.keySet(); } @Override public final IPCConnectionPool getOwnerConnectionPool() { return ownerConnectionPool; } @Override public final Object getProcessor() { return processor.get(); } }