/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.drill.exec.rpc.control; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import io.netty.buffer.ByteBuf; import io.netty.buffer.DrillBuf; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.util.internal.ThreadLocalRandom; import java.util.Arrays; import java.util.Random; import org.apache.drill.BaseTestQuery; import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint; import org.apache.drill.exec.proto.UserBitShared.QueryId; import org.apache.drill.exec.rpc.UserRpcException; import org.apache.drill.exec.rpc.control.ControlTunnel.CustomFuture; import org.apache.drill.exec.rpc.control.ControlTunnel.CustomTunnel; import org.apache.drill.exec.rpc.control.Controller.CustomMessageHandler; import org.apache.drill.exec.rpc.control.Controller.CustomResponse; import org.apache.drill.exec.server.DrillbitContext; import org.junit.Test; public class TestCustomTunnel extends BaseTestQuery { private final QueryId expectedId = QueryId .newBuilder() .setPart1(ThreadLocalRandom.current().nextLong()) .setPart2(ThreadLocalRandom.current().nextLong()) .build(); private final ByteBuf buf1; private final byte[] expected; public TestCustomTunnel() { buf1 = UnpooledByteBufAllocator.DEFAULT.buffer(1024); Random r = new Random(); this.expected = new byte[1024]; r.nextBytes(expected); buf1.writeBytes(expected); } @Test public void ensureRoundTrip() throws Exception { final DrillbitContext context = getDrillbitContext(); final TestCustomMessageHandler handler = new TestCustomMessageHandler(context.getEndpoint(), false); context.getController().registerCustomHandler(1001, handler, DrillbitEndpoint.PARSER); final ControlTunnel loopbackTunnel = context.getController().getTunnel(context.getEndpoint()); final CustomTunnel<DrillbitEndpoint, QueryId> tunnel = loopbackTunnel.getCustomTunnel(1001, DrillbitEndpoint.class, QueryId.PARSER); CustomFuture<QueryId> future = tunnel.send(context.getEndpoint()); assertEquals(expectedId, future.get()); } @Test public void ensureRoundTripBytes() throws Exception { final DrillbitContext context = getDrillbitContext(); final TestCustomMessageHandler handler = new TestCustomMessageHandler(context.getEndpoint(), true); context.getController().registerCustomHandler(1002, handler, DrillbitEndpoint.PARSER); final ControlTunnel loopbackTunnel = context.getController().getTunnel(context.getEndpoint()); final CustomTunnel<DrillbitEndpoint, QueryId> tunnel = loopbackTunnel.getCustomTunnel(1002, DrillbitEndpoint.class, QueryId.PARSER); buf1.retain(); CustomFuture<QueryId> future = tunnel.send(context.getEndpoint(), buf1); assertEquals(expectedId, future.get()); byte[] actual = new byte[1024]; future.getBuffer().getBytes(0, actual); future.getBuffer().release(); assertTrue(Arrays.equals(expected, actual)); } private class TestCustomMessageHandler implements CustomMessageHandler<DrillbitEndpoint, QueryId> { private DrillbitEndpoint expectedValue; private final boolean returnBytes; public TestCustomMessageHandler(DrillbitEndpoint expectedValue, boolean returnBytes) { super(); this.expectedValue = expectedValue; this.returnBytes = returnBytes; } @Override public CustomResponse<QueryId> onMessage(DrillbitEndpoint pBody, DrillBuf dBody) throws UserRpcException { if (!expectedValue.equals(pBody)) { throw new UserRpcException(expectedValue, "Invalid expected downstream value.", new IllegalStateException()); } if (returnBytes) { byte[] actual = new byte[1024]; dBody.getBytes(0, actual); if (!Arrays.equals(expected, actual)) { throw new UserRpcException(expectedValue, "Invalid expected downstream value.", new IllegalStateException()); } } return new CustomResponse<QueryId>() { @Override public QueryId getMessage() { return expectedId; } @Override public ByteBuf[] getBodies() { if (returnBytes) { buf1.retain(); return new ByteBuf[] { buf1 }; } else { return null; } } }; } } @Test public void ensureRoundTripJackson() throws Exception { final DrillbitContext context = getDrillbitContext(); final MesgA mesgA = new MesgA(); mesgA.fieldA = "123"; mesgA.fieldB = "okra"; final TestCustomMessageHandlerJackson handler = new TestCustomMessageHandlerJackson(mesgA); context.getController().registerCustomHandler(1003, handler, new ControlTunnel.JacksonSerDe<MesgA>(MesgA.class), new ControlTunnel.JacksonSerDe<MesgB>(MesgB.class)); final ControlTunnel loopbackTunnel = context.getController().getTunnel(context.getEndpoint()); final CustomTunnel<MesgA, MesgB> tunnel = loopbackTunnel.getCustomTunnel( 1003, new ControlTunnel.JacksonSerDe<MesgA>(MesgA.class), new ControlTunnel.JacksonSerDe<MesgB>(MesgB.class)); CustomFuture<MesgB> future = tunnel.send(mesgA); assertEquals(expectedB, future.get()); } private MesgB expectedB = new MesgB().set("hello", "bye", "friend"); public static class MesgA { public String fieldA; public String fieldB; @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((fieldA == null) ? 0 : fieldA.hashCode()); result = prime * result + ((fieldB == null) ? 0 : fieldB.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (obj == null) { return false; } if (getClass() != obj.getClass()) { return false; } MesgA other = (MesgA) obj; if (fieldA == null) { if (other.fieldA != null) { return false; } } else if (!fieldA.equals(other.fieldA)) { return false; } if (fieldB == null) { if (other.fieldB != null) { return false; } } else if (!fieldB.equals(other.fieldB)) { return false; } return true; } } public static class MesgB { public String fieldA; public String fieldB; public String fieldC; public MesgB set(String a, String b, String c) { fieldA = a; fieldB = b; fieldC = c; return this; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((fieldA == null) ? 0 : fieldA.hashCode()); result = prime * result + ((fieldB == null) ? 0 : fieldB.hashCode()); result = prime * result + ((fieldC == null) ? 0 : fieldC.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (obj == null) { return false; } if (getClass() != obj.getClass()) { return false; } MesgB other = (MesgB) obj; if (fieldA == null) { if (other.fieldA != null) { return false; } } else if (!fieldA.equals(other.fieldA)) { return false; } if (fieldB == null) { if (other.fieldB != null) { return false; } } else if (!fieldB.equals(other.fieldB)) { return false; } if (fieldC == null) { if (other.fieldC != null) { return false; } } else if (!fieldC.equals(other.fieldC)) { return false; } return true; } } private class TestCustomMessageHandlerJackson implements CustomMessageHandler<MesgA, MesgB> { private MesgA expectedValue; public TestCustomMessageHandlerJackson(MesgA expectedValue) { super(); this.expectedValue = expectedValue; } @Override public CustomResponse<MesgB> onMessage(MesgA pBody, DrillBuf dBody) throws UserRpcException { if (!expectedValue.equals(pBody)) { throw new UserRpcException(DrillbitEndpoint.getDefaultInstance(), "Invalid expected downstream value.", new IllegalStateException()); } return new CustomResponse<MesgB>() { @Override public MesgB getMessage() { return expectedB; } @Override public ByteBuf[] getBodies() { return null; } }; } } }