package test.r2.filter; import com.linkedin.r2.filter.FilterChain; import com.linkedin.r2.filter.FilterChains; import com.linkedin.r2.filter.NextFilter; import com.linkedin.r2.filter.message.stream.StreamFilter; import com.linkedin.r2.message.RequestContext; import com.linkedin.r2.message.stream.StreamRequest; import com.linkedin.r2.message.stream.StreamResponse; import com.linkedin.r2.testutils.filter.FilterUtil; import com.linkedin.r2.testutils.filter.StreamCountFilter; import org.easymock.EasyMock; import org.testng.Assert; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import java.util.HashMap; import java.util.Map; /** * @auther Zhenkai Zhu */ public class StreamFilterTest { private StreamFilter _filter; private StreamCountFilter _beforeFilter; private StreamCountFilter _afterFilter; private FilterChain _fc; @BeforeMethod public void setUp() throws Exception { _filter = new StreamFilter() {}; _beforeFilter = new StreamCountFilter(); _afterFilter = new StreamCountFilter(); _fc = FilterChains.createStreamChain(_beforeFilter, _filter, _afterFilter); } @Test public void testStreamRequestCallsNextFilter() { fireStreamRequest(_fc); Assert.assertEquals(1, _afterFilter.getStreamReqCount()); } @Test public void testStreamResponseCallsNextFilter() { fireStreamResponse(_fc); Assert.assertEquals(1, _beforeFilter.getStreamResCount()); } @Test public void testStreamErrorCallsNextFilter() { fireStreamError(_fc); Assert.assertEquals(1, _beforeFilter.getStreamErrCount()); } private void fireStreamRequest(FilterChain fc) { fc.onStreamRequest(EasyMock.createMock(StreamRequest.class), createRequestContext(), createWireAttributes() ); } private void fireStreamResponse(FilterChain fc) { fc.onStreamResponse(EasyMock.createMock(StreamResponse.class), createRequestContext(), createWireAttributes() ); } private void fireStreamError(FilterChain fc) { fc.onStreamError(new Exception(), createRequestContext(), createWireAttributes() ); } private Map<String, String> createWireAttributes() { return new HashMap<String, String>(); } private RequestContext createRequestContext() { return new RequestContext(); } }