/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.activemq.artemis.tests.integration.stomp.util; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import org.apache.activemq.artemis.core.protocol.stomp.Stomp; import org.apache.activemq.artemis.tests.integration.IntegrationTestLogger; public abstract class AbstractStompClientConnection implements StompClientConnection { protected Pinger pinger; protected String version; protected String host; protected int port; protected String username; protected String passcode; protected StompFrameFactory factory; protected final SocketChannel socketChannel; protected ByteBuffer readBuffer; protected List<Byte> receiveList; protected BlockingQueue<ClientStompFrame> frameQueue = new LinkedBlockingQueue<>(); protected boolean connected = false; protected int serverPingCounter; protected ReaderThread readerThread; public AbstractStompClientConnection(String version, String host, int port) throws IOException { this.version = version; this.host = host; this.port = port; this.factory = StompFrameFactoryFactory.getFactory(version); socketChannel = SocketChannel.open(); initSocket(); } private void initSocket() throws IOException { socketChannel.configureBlocking(true); InetSocketAddress remoteAddr = new InetSocketAddress(host, port); socketChannel.connect(remoteAddr); startReaderThread(); } private void startReaderThread() { readBuffer = ByteBuffer.allocateDirect(10240); receiveList = new ArrayList<>(10240); readerThread = new ReaderThread(); readerThread.start(); } public void killReaderThread() { readerThread.stop(); } private ClientStompFrame sendFrameInternal(ClientStompFrame frame, boolean wicked) throws IOException, InterruptedException { ClientStompFrame response = null; IntegrationTestLogger.LOGGER.trace("Sending " + (wicked ? "*wicked* " : "") + "frame:\n" + frame); ByteBuffer buffer; if (wicked) { buffer = frame.toByteBufferWithExtra("\n"); } else { buffer = frame.toByteBuffer(); } while (buffer.remaining() > 0) { socketChannel.write(buffer); } //now response if (frame.needsReply()) { response = receiveFrame(); //filter out server ping while (response != null) { if (response.getCommand().equals(Stomp.Commands.STOMP)) { response = receiveFrame(); } else { break; } } } IntegrationTestLogger.LOGGER.trace("Received:\n" + response); return response; } @Override public ClientStompFrame sendFrame(ClientStompFrame frame) throws IOException, InterruptedException { return sendFrameInternal(frame, false); } @Override public ClientStompFrame sendWickedFrame(ClientStompFrame frame) throws IOException, InterruptedException { return sendFrameInternal(frame, true); } @Override public ClientStompFrame receiveFrame() throws InterruptedException { return frameQueue.poll(10, TimeUnit.SECONDS); } @Override public ClientStompFrame receiveFrame(long timeout) throws InterruptedException { return frameQueue.poll(timeout, TimeUnit.MILLISECONDS); } //put bytes to byte array. private void receiveBytes(int n) { readBuffer.rewind(); for (int i = 0; i < n; i++) { byte b = readBuffer.get(); if (b == 0) { //a new frame got. int sz = receiveList.size(); if (sz > 0) { byte[] frameBytes = new byte[sz]; for (int j = 0; j < sz; j++) { frameBytes[j] = receiveList.get(j); } ClientStompFrame frame = factory.createFrame(new String(frameBytes, StandardCharsets.UTF_8)); if (validateFrame(frame)) { frameQueue.offer(frame); receiveList.clear(); } else { receiveList.add(b); } } } else { if (b == 10 && receiveList.size() == 0) { //may be a ping incrementServerPing(); } else { receiveList.add(b); } } } //clear readbuffer readBuffer.rewind(); } protected void incrementServerPing() { serverPingCounter++; } private boolean validateFrame(ClientStompFrame f) { String h = f.getHeader(Stomp.Headers.CONTENT_LENGTH); if (h != null) { int len = Integer.valueOf(h); if (f.getBody().getBytes(StandardCharsets.UTF_8).length < len) { return false; } } return true; } protected void close() throws IOException { socketChannel.close(); } private class ReaderThread extends Thread { @Override public void run() { try { int n = socketChannel.read(readBuffer); while (n >= 0) { if (n > 0) { receiveBytes(n); } n = socketChannel.read(readBuffer); } //peer closed close(); } catch (IOException e) { try { close(); } catch (IOException e1) { //ignore } } } } @Override public ClientStompFrame connect() throws Exception { return connect(null, null); } @Override public void destroy() { try { close(); } catch (IOException e) { } finally { this.connected = false; } } @Override public ClientStompFrame connect(String username, String password) throws Exception { throw new RuntimeException("connect method not implemented!"); } @Override public boolean isConnected() { return connected && socketChannel.isConnected(); } @Override public String getVersion() { return version; } @Override public int getFrameQueueSize() { return this.frameQueue.size(); } protected class Pinger extends Thread { long pingInterval; ClientStompFrame pingFrame; volatile boolean stop = false; Pinger(long interval) { this.pingInterval = interval; pingFrame = createFrame(Stomp.Commands.STOMP); pingFrame.setBody("\n"); pingFrame.setForceOneway(); pingFrame.setPing(true); } public void startPing() { start(); } public synchronized void stopPing() { stop = true; this.notify(); } @Override public void run() { synchronized (this) { while (!stop) { try { sendFrame(pingFrame); this.wait(pingInterval); } catch (Exception e) { stop = true; e.printStackTrace(); } } } } } }