/* * JBoss, Home of Professional Open Source. * Copyright 2014 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.undertow.conduits; import org.xnio.channels.StreamSourceChannel; import org.xnio.conduits.AbstractStreamSinkConduit; import org.xnio.conduits.StreamSinkConduit; import java.io.IOException; import java.io.InterruptedIOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.concurrent.TimeUnit; import io.undertow.util.WorkerUtils; /** * Class that implements the token bucket algorithm. * <p> * Allows send speed to be throttled * <p> * Note that throttling is applied after an initial write, so if a big write is performed initially * it may be a while before it can write again. * * @author Stuart Douglas */ public class RateLimitingStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> { private final long time; private final int bytes; private boolean writesResumed = false; private int byteCount = 0; private long startTime = 0; private long nextSendTime = 0; private boolean scheduled = false; /** * @param next The next conduit * @param bytes The number of bytes that are allowed per time frame * @param time The time frame * @param timeUnit The time unit */ public RateLimitingStreamSinkConduit(StreamSinkConduit next, int bytes, long time, TimeUnit timeUnit) { super(next); writesResumed = next.isWriteResumed(); this.time = timeUnit.toMillis(time); this.bytes = bytes; } @Override public int write(ByteBuffer src) throws IOException { if (!canSend()) { return 0; } int bytes = this.bytes - this.byteCount; int old = src.limit(); if (src.remaining() > bytes) { src.limit(src.position() + bytes); } try { int written = super.write(src); handleWritten(written); return written; } finally { src.limit(old); } } @Override public long transferFrom(FileChannel src, long position, long count) throws IOException { if (!canSend()) { return 0; } int bytes = this.bytes - this.byteCount; long written = super.transferFrom(src, position, Math.min(count, bytes)); handleWritten(written); return written; } @Override public long transferFrom(StreamSourceChannel source, long count, ByteBuffer throughBuffer) throws IOException { if (!canSend()) { return 0; } int bytes = this.bytes - this.byteCount; long written = super.transferFrom(source, Math.min(count, bytes), throughBuffer); handleWritten(written); return written; } @Override public long write(ByteBuffer[] srcs, int offs, int len) throws IOException { if (!canSend()) { return 0; } int old = 0; int adjPos = -1; long rem = bytes - byteCount; for (int i = offs; i < offs + len; ++i) { ByteBuffer buf = srcs[i]; rem -= buf.remaining(); if (rem < 0) { adjPos = i; old = buf.limit(); buf.limit((int) (buf.limit() + rem)); break; } } try { long written; if (adjPos == -1) { written = super.write(srcs, offs, len); } else { written = super.write(srcs, offs, adjPos - offs + 1); } handleWritten(written); return written; } finally { if (adjPos != -1) { ByteBuffer buf = srcs[adjPos]; buf.limit(old); } } } @Override public int writeFinal(ByteBuffer src) throws IOException { if (!canSend()) { return 0; } int bytes = this.bytes - this.byteCount; int old = src.limit(); if (src.remaining() > bytes) { src.limit(src.position() + bytes); } try { int written = super.writeFinal(src); handleWritten(written); return written; } finally { src.limit(old); } } @Override public long writeFinal(ByteBuffer[] srcs, int offs, int len) throws IOException { if (!canSend()) { return 0; } int old = 0; int adjPos = -1; long rem = bytes - byteCount; for (int i = offs; i < offs + len; ++i) { ByteBuffer buf = srcs[i]; rem -= buf.remaining(); if (rem < 0) { adjPos = i; old = buf.limit(); buf.limit((int) (buf.limit() + rem)); break; } } try { long written; if (adjPos == -1) { written = super.writeFinal(srcs, offs, len); } else { written = super.writeFinal(srcs, offs, adjPos - offs + 1); } handleWritten(written); return written; } finally { if (adjPos != -1) { ByteBuffer buf = srcs[adjPos]; buf.limit(old); } } } @Override public void resumeWrites() { writesResumed = true; if (canSend()) { super.resumeWrites(); } } @Override public void suspendWrites() { writesResumed = false; super.suspendWrites(); } @Override public void wakeupWrites() { writesResumed = true; if (canSend()) { super.wakeupWrites(); } } @Override public boolean isWriteResumed() { return writesResumed; } @Override public void awaitWritable() throws IOException { long toGo = nextSendTime - System.currentTimeMillis(); if (toGo > 0) { try { Thread.sleep(toGo); } catch (InterruptedException e) { throw new InterruptedIOException(); } } super.awaitWritable(); } @Override public void awaitWritable(long time, TimeUnit timeUnit) throws IOException { long toGo = nextSendTime - System.currentTimeMillis(); if (toGo > 0) { try { Thread.sleep(Math.min(toGo, timeUnit.toMillis(time))); } catch (InterruptedException e) { throw new InterruptedIOException(); } return; } super.awaitWritable(time, timeUnit); } private boolean canSend() { if (byteCount < bytes) { return true; } if (System.currentTimeMillis() > nextSendTime) { byteCount = 0; startTime = 0; nextSendTime = 0; return true; } if (writesResumed) { handleWritesResumedWhenBlocked(); } return false; } private void handleWritten(long written) { if (written == 0) { return; } byteCount += written; if (byteCount < bytes) { //we are still allowed to send if (startTime == 0) { startTime = System.currentTimeMillis(); nextSendTime = System.currentTimeMillis() + time; } } else { //we have gone over, we need to wait till we are allowed to send again if (startTime == 0) { startTime = System.currentTimeMillis(); } nextSendTime = startTime + time; if (writesResumed) { handleWritesResumedWhenBlocked(); } } } private void handleWritesResumedWhenBlocked() { if (scheduled) { return; } scheduled = true; next.suspendWrites(); long millis = nextSendTime - System.currentTimeMillis(); WorkerUtils.executeAfter(getWriteThread(), new Runnable() { @Override public void run() { scheduled = false; if (writesResumed) { next.wakeupWrites(); } } }, millis, TimeUnit.MILLISECONDS); } }