/*
* Copyright 2010 Proofpoint, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.airlift.http.server;
import com.google.common.base.Preconditions;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Locale;
class TimingFilter
implements Filter
{
public static final String FIRST_BYTE_TIME = TimingFilter.class.getName() + ".FIRST_BYTE_TIME";
@Override
public void init(FilterConfig filterConfig)
throws ServletException
{
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
throws IOException, ServletException
{
TimedResponse response = new TimedResponse((HttpServletResponse) servletResponse);
try {
chain.doFilter(servletRequest, response);
}
finally {
Long firstByteTime = response.getFirstByteTime();
if (firstByteTime != null) {
servletRequest.setAttribute(FIRST_BYTE_TIME, firstByteTime);
}
}
}
@Override
public void destroy()
{
}
private static class TimedResponse extends HttpServletResponseWrapper
{
private TimedServletOutputStream outputStream;
private TimedPrintWriter printWriter;
private TimedResponse(HttpServletResponse response)
{
super(response);
}
@Override
public ServletOutputStream getOutputStream()
throws IOException
{
Preconditions.checkState(printWriter == null, "getWriter() has already been called");
if (outputStream == null) {
outputStream = new TimedServletOutputStream(super.getOutputStream());
}
return outputStream;
}
@Override
public PrintWriter getWriter()
throws IOException
{
Preconditions.checkState(outputStream == null, "getOutputStream() has already been called");
if (printWriter == null) {
printWriter = new TimedPrintWriter(super.getWriter());
}
return printWriter;
}
public Long getFirstByteTime()
{
if (outputStream != null) {
return outputStream.getFirstByteTime();
}
if (printWriter != null) {
return printWriter.getFirstByteTime();
}
return null;
}
}
private static class TimedServletOutputStream extends ServletOutputStream
{
private final ServletOutputStream delegate;
private Long firstByteTime;
private TimedServletOutputStream(ServletOutputStream delegate)
{
this.delegate = delegate;
}
public Long getFirstByteTime()
{
return firstByteTime;
}
private void recordFirstByteTime()
{
if (firstByteTime == null) {
firstByteTime = System.currentTimeMillis();
}
}
@Override
public void write(int b)
throws IOException
{
recordFirstByteTime();
delegate.write(b);
}
@Override
public void write(byte[] b)
throws IOException
{
recordFirstByteTime();
delegate.write(b);
}
@Override
public void print(String s)
throws IOException
{
recordFirstByteTime();
delegate.print(s);
}
@Override
public void write(byte[] b, int off, int len)
throws IOException
{
recordFirstByteTime();
delegate.write(b, off, len);
}
@Override
public void print(boolean b)
throws IOException
{
recordFirstByteTime();
delegate.print(b);
}
@Override
public void print(char c)
throws IOException
{
recordFirstByteTime();
delegate.print(c);
}
@Override
public void print(int i)
throws IOException
{
recordFirstByteTime();
delegate.print(i);
}
@Override
public void print(long l)
throws IOException
{
recordFirstByteTime();
delegate.print(l);
}
@Override
public void print(float f)
throws IOException
{
recordFirstByteTime();
delegate.print(f);
}
@Override
public void print(double d)
throws IOException
{
recordFirstByteTime();
delegate.print(d);
}
@Override
public void println()
throws IOException
{
recordFirstByteTime();
delegate.println();
}
@Override
public void println(String s)
throws IOException
{
recordFirstByteTime();
delegate.println(s);
}
@Override
public void println(boolean b)
throws IOException
{
recordFirstByteTime();
delegate.println(b);
}
@Override
public void println(char c)
throws IOException
{
recordFirstByteTime();
delegate.println(c);
}
@Override
public void println(int i)
throws IOException
{
recordFirstByteTime();
delegate.println(i);
}
@Override
public void println(long l)
throws IOException
{
recordFirstByteTime();
delegate.println(l);
}
@Override
public void println(float f)
throws IOException
{
recordFirstByteTime();
delegate.println(f);
}
@Override
public void println(double d)
throws IOException
{
recordFirstByteTime();
delegate.println(d);
}
@Override
public void flush()
throws IOException
{
delegate.flush();
}
@Override
public void close()
throws IOException
{
delegate.close();
}
@Override
public boolean isReady()
{
return delegate.isReady();
}
@Override
public void setWriteListener(WriteListener writeListener)
{
delegate.setWriteListener(writeListener);
}
}
private static class TimedPrintWriter extends PrintWriter
{
private Long firstByteTime;
private TimedPrintWriter(PrintWriter delegate)
{
super(delegate);
}
public Long getFirstByteTime()
{
return firstByteTime;
}
private void recordFirstByteTime()
{
if (firstByteTime == null) {
firstByteTime = System.currentTimeMillis();
}
}
@Override
public void write(int c)
{
recordFirstByteTime();
super.write(c);
}
@Override
public void write(char[] buf, int off, int len)
{
recordFirstByteTime();
super.write(buf, off, len);
}
@Override
public void write(char[] buf)
{
recordFirstByteTime();
super.write(buf);
}
@Override
public void write(String s, int off, int len)
{
recordFirstByteTime();
super.write(s, off, len);
}
@Override
public void write(String s)
{
recordFirstByteTime();
super.write(s);
}
@Override
public void print(boolean b)
{
recordFirstByteTime();
super.print(b);
}
@Override
public void print(char c)
{
recordFirstByteTime();
super.print(c);
}
@Override
public void print(int i)
{
recordFirstByteTime();
super.print(i);
}
@Override
public void print(long l)
{
recordFirstByteTime();
super.print(l);
}
@Override
public void print(float f)
{
recordFirstByteTime();
super.print(f);
}
@Override
public void print(double d)
{
recordFirstByteTime();
super.print(d);
}
@Override
public void print(char[] s)
{
recordFirstByteTime();
super.print(s);
}
@Override
public void print(String s)
{
recordFirstByteTime();
super.print(s);
}
@Override
public void print(Object obj)
{
recordFirstByteTime();
super.print(obj);
}
@Override
public void println()
{
recordFirstByteTime();
super.println();
}
@Override
public void println(boolean x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(char x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(int x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(long x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(float x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(double x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(char[] x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(String x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public void println(Object x)
{
recordFirstByteTime();
super.println(x);
}
@Override
public PrintWriter printf(String format, Object... args)
{
recordFirstByteTime();
super.printf(format, args);
return this;
}
@Override
public PrintWriter printf(Locale l, String format, Object... args)
{
recordFirstByteTime();
super.printf(l, format, args);
return this;
}
@Override
public PrintWriter format(String format, Object... args)
{
recordFirstByteTime();
super.format(format, args);
return this;
}
@Override
public PrintWriter format(Locale l, String format, Object... args)
{
recordFirstByteTime();
super.format(l, format, args);
return this;
}
@Override
public PrintWriter append(CharSequence csq)
{
recordFirstByteTime();
super.append(csq);
return this;
}
@Override
public PrintWriter append(CharSequence csq, int start, int end)
{
recordFirstByteTime();
super.append(csq, start, end);
return this;
}
@Override
public PrintWriter append(char c)
{
recordFirstByteTime();
super.append(c);
return this;
}
}
}