package test.r2.filter; import com.linkedin.common.callback.Callback; import com.linkedin.data.ByteString; import com.linkedin.r2.filter.FilterChain; import com.linkedin.r2.filter.FilterChains; import com.linkedin.r2.filter.NextFilter; import com.linkedin.r2.filter.message.rest.RestFilter; import com.linkedin.r2.filter.message.stream.StreamFilter; import com.linkedin.r2.filter.message.stream.StreamFilterAdapters; import com.linkedin.r2.message.RequestContext; import com.linkedin.r2.message.rest.RestException; import com.linkedin.r2.message.rest.RestRequest; import com.linkedin.r2.message.rest.RestRequestBuilder; import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.rest.RestResponseBuilder; import com.linkedin.r2.message.stream.StreamException; import com.linkedin.r2.message.stream.StreamRequest; import com.linkedin.r2.message.stream.StreamRequestBuilder; import com.linkedin.r2.message.stream.StreamResponse; import com.linkedin.r2.message.stream.StreamResponseBuilder; import com.linkedin.r2.message.stream.entitystream.ByteStringWriter; import com.linkedin.r2.message.stream.entitystream.EntityStreams; import com.linkedin.r2.message.stream.entitystream.FullEntityReader; import com.linkedin.r2.testutils.filter.FilterUtil; import org.testng.Assert; import org.testng.annotations.Test; import java.net.URI; import java.util.Map; /** * @author Zhenkai Zhu */ public class TestStreamFilterAdapters { private static final URI SIMPLE_URI = URI.create("simple_uri"); private CaptureFilter _beforeFilter; private CaptureFilter _afterFilter; @Test public void testRequestFilterAdapterPassThrough() { FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestRequest(RestRequest req, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onRequest(req, requestContext, wireAttrs); } }); fc.onStreamRequest(simpleStreamRequest("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); StreamRequest capturedReq = _afterFilter.getRequest(); Assert.assertEquals(capturedReq.getURI(), SIMPLE_URI); capturedReq.getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("shouldn't have error"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "12345"); } })); } @Test public void testRequestFilterAdapterChangeRequest() { FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestRequest(RestRequest req, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onRequest(req.builder(). setEntity(req.getEntity().asString("UTF8").replace('1', '0').getBytes()).build(), requestContext, wireAttrs); } }); fc.onStreamRequest(simpleStreamRequest("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); StreamRequest capturedReq = _afterFilter.getRequest(); Assert.assertEquals(capturedReq.getURI(), SIMPLE_URI); capturedReq.getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("shouldn't have error"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "02345"); } })); } @Test public void testRequestFilterAdapterCallsOnResponse() { FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestRequest(RestRequest req, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onResponse(simpleRestResponse(req.getEntity().asString("UTF8")), requestContext, wireAttrs); } }); fc.onStreamRequest(simpleStreamRequest("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); StreamResponse capturedReq = _beforeFilter.getResponse(); capturedReq.getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("shouldn't have error"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "12345"); } })); } @Test public void testRequestFilterAdapterCallsOnError() { final Exception runTimeException = new RuntimeException(); FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestRequest(RestRequest req, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onError(runTimeException, requestContext, wireAttrs); } }); fc.onStreamRequest(simpleStreamRequest("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); Throwable ex = _beforeFilter.getThrowable(); Assert.assertSame(ex, runTimeException); fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestRequest(RestRequest req, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onError(simpleRestException(req.getEntity().asString("UTF8")), requestContext, wireAttrs); } }); fc.onStreamRequest(simpleStreamRequest("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); ex = _beforeFilter.getThrowable(); Assert.assertTrue(ex instanceof StreamException); StreamResponse errorResponse = ((StreamException) ex).getResponse(); errorResponse.getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("should not happen"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "12345"); } })); } @Test public void testResponseFilterAdapterPassThrough() { FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestResponse(RestResponse res, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onResponse(res, requestContext, wireAttrs); } @Override public void onRestError(Throwable ex, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onError(ex, requestContext, wireAttrs); } }); fc.onStreamResponse(simpleStreamResponse("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); StreamResponse capturedResponse = _beforeFilter.getResponse(); capturedResponse.getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("should not happen"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "12345"); } })); fc.onStreamError(simpleStreamException("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); Throwable capturedEx = _beforeFilter.getThrowable(); Assert.assertTrue(capturedEx instanceof StreamException); ((StreamException) capturedEx).getResponse().getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("should not happen"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "12345"); } })); } @Test public void testResponseFilterAdapterChangeResponse() { FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestResponse(RestResponse res, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onResponse(res.builder().setEntity(res.getEntity().asString("UTF8").replace('1', '0').getBytes()).build(), requestContext, wireAttrs); } @Override public void onRestError(Throwable ex, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { } }); fc.onStreamResponse(simpleStreamResponse("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); StreamResponse capturedResponse = _beforeFilter.getResponse(); capturedResponse.getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("should not happen"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "02345"); } })); } @Test public void testResponseFilterAdapterChangeError() { FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestResponse(RestResponse res, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { } @Override public void onRestError(Throwable ex, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { if (ex instanceof RestException) { RestResponse res = ((RestException) ex).getResponse(); String newEntityStr = res.getEntity().asString("UTF8").replace('1', '0'); nextFilter.onError(new RestException( (res.builder().setEntity(newEntityStr.getBytes()).build())), requestContext, wireAttrs); } else { nextFilter.onError(new IllegalStateException(), requestContext, wireAttrs); } } }); fc.onStreamError(simpleStreamException("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); Throwable capturedEx = _beforeFilter.getThrowable(); Assert.assertTrue(capturedEx instanceof StreamException); ((StreamException) capturedEx).getResponse().getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("should not happen"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "02345"); } })); fc.onStreamError(new IllegalArgumentException(), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); capturedEx = _beforeFilter.getThrowable(); Assert.assertTrue(capturedEx instanceof IllegalStateException); } @Test public void testResponseFilterAdapterCallsOnErrorInOnResponse() { FilterChain fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestResponse(RestResponse res, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onError(simpleRestException(res.getEntity().asString("UTF8")), requestContext, wireAttrs); } @Override public void onRestError(Throwable ex, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { } }); fc.onStreamResponse(simpleStreamResponse("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); Throwable capturedEx = _beforeFilter.getThrowable(); Assert.assertTrue(capturedEx instanceof StreamException); ((StreamException) capturedEx).getResponse().getEntityStream().setReader(new FullEntityReader(new Callback<ByteString>() { @Override public void onError(Throwable e) { Assert.fail("should not happen"); } @Override public void onSuccess(ByteString result) { Assert.assertEquals(result.asString("UTF8"), "12345"); } })); fc = adaptAndCreateFilterChain(new RestFilter() { @Override public void onRestResponse(RestResponse res, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { nextFilter.onError(new IllegalStateException(), requestContext, wireAttrs); } @Override public void onRestError(Throwable ex, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<RestRequest, RestResponse> nextFilter) { } }); fc.onStreamResponse(simpleStreamResponse("12345"), FilterUtil.emptyRequestContext(), FilterUtil.emptyWireAttrs()); capturedEx = _beforeFilter.getThrowable(); Assert.assertTrue(capturedEx instanceof IllegalStateException); } private FilterChain adaptAndCreateFilterChain(RestFilter filter) { _beforeFilter = new CaptureFilter(); _afterFilter = new CaptureFilter(); return FilterChains.createStreamChain(_beforeFilter, StreamFilterAdapters.adaptRestFilter(filter), _afterFilter); } private static class CaptureFilter implements StreamFilter { private StreamRequest _req = null; private StreamResponse _res = null; private Throwable _ex = null; @Override public void onStreamRequest(StreamRequest req, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<StreamRequest, StreamResponse> nextFilter) { _req = req; nextFilter.onRequest(req, requestContext, wireAttrs); } @Override public void onStreamResponse(StreamResponse res, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<StreamRequest, StreamResponse> nextFilter) { _res = res; nextFilter.onResponse(res, requestContext, wireAttrs); } @Override public void onStreamError(Throwable ex, RequestContext requestContext, Map<String, String> wireAttrs, NextFilter<StreamRequest, StreamResponse> nextFilter) { _ex = ex; nextFilter.onError(ex, requestContext, wireAttrs); } public StreamRequest getRequest() { return _req; } public StreamResponse getResponse() { return _res; } public Throwable getThrowable() { return _ex; } } private static StreamRequest simpleStreamRequest(String str) { return new StreamRequestBuilder(SIMPLE_URI) .build(EntityStreams.newEntityStream(new ByteStringWriter(ByteString.copy(str.getBytes())))); } private static StreamResponse simpleStreamResponse(String str) { return new StreamResponseBuilder() .build(EntityStreams.newEntityStream(new ByteStringWriter(ByteString.copy(str.getBytes())))); } private static StreamException simpleStreamException(String str) { return new StreamException(simpleStreamResponse(str)); } private static RestRequest simpleRestRequest(String str) { return new RestRequestBuilder(SIMPLE_URI).setEntity(str.getBytes()) .build(); } private static RestResponse simpleRestResponse(String str) { return new RestResponseBuilder().setEntity(str.getBytes()) .build(); } private static RestException simpleRestException(String str) { return new RestException(simpleRestResponse(str)); } }