/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.websockets.jsr.test.stress;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import io.undertow.Handlers;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
import io.undertow.servlet.api.ServletContainer;
import io.undertow.servlet.test.util.TestClassIntrospector;
import io.undertow.servlet.test.util.TestResourceLoader;
import io.undertow.testutils.DefaultServer;
import io.undertow.testutils.HttpOneOnly;
import io.undertow.websockets.jsr.ServerWebSocketContainer;
import io.undertow.websockets.jsr.WebSocketDeploymentInfo;
/**
* @author <a href="mailto:nmaurer@redhat.com">Norman Maurer</a>
*/
@RunWith(DefaultServer.class)
@HttpOneOnly
public class WebsocketStressTestCase {
public static final int NUM_THREADS = 100;
public static final int NUM_REQUESTS = 1000;
private static ServerWebSocketContainer deployment;
private static WebSocketContainer defaultContainer = ContainerProvider.getWebSocketContainer();
static ExecutorService executor;
@BeforeClass
public static void setup() throws Exception {
executor = Executors.newFixedThreadPool(NUM_THREADS);
final ServletContainer container = ServletContainer.Factory.newInstance();
DeploymentInfo builder = new DeploymentInfo()
.setClassLoader(WebsocketStressTestCase.class.getClassLoader())
.setContextPath("/ws")
.setResourceManager(new TestResourceLoader(WebsocketStressTestCase.class))
.setClassIntrospecter(TestClassIntrospector.INSTANCE)
.addServletContextAttribute(WebSocketDeploymentInfo.ATTRIBUTE_NAME,
new WebSocketDeploymentInfo()
.setBuffers(DefaultServer.getBufferPool())
.setWorker(DefaultServer.getWorker())
.addEndpoint(StressEndpoint.class)
.addListener(new WebSocketDeploymentInfo.ContainerReadyListener() {
@Override
public void ready(ServerWebSocketContainer container) {
deployment = container;
}
})
)
.setDeploymentName("servletContext.war");
DeploymentManager manager = container.addDeployment(builder);
manager.deploy();
DefaultServer.setRootHandler(Handlers.path().addPrefixPath("/ws", manager.start()));
}
@AfterClass
public static void after() {
StressEndpoint.MESSAGES.clear();
deployment = null;
executor.shutdownNow();
executor = null;
}
@Test
public void webSocketStringStressTestCase() throws Exception {
List<CountDownLatch> latches = new ArrayList<>();
for (int i = 0; i < NUM_THREADS; ++i) {
final CountDownLatch latch = new CountDownLatch(1);
latches.add(latch);
final Session session = deployment.connectToServer(new Endpoint() {
@Override
public void onOpen(Session session, EndpointConfig config) {
}
@Override
public void onClose(Session session, CloseReason closeReason) {
latch.countDown();
}
@Override
public void onError(Session session, Throwable thr) {
latch.countDown();
}
}, null, new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/stress"));
final int thread = i;
executor.submit(new Runnable() {
@Override
public void run() {
try {
executor.submit(new SendRunnable(session, thread, executor));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
}
for (CountDownLatch future : latches) {
future.await(40, TimeUnit.SECONDS);
}
for (int t = 0; t < NUM_THREADS; ++t) {
for (int i = 0; i < NUM_REQUESTS; ++i) {
String msg = "t-" + t + "-m-" + i;
Assert.assertTrue(msg, StressEndpoint.MESSAGES.remove(msg));
}
}
Assert.assertEquals(0, StressEndpoint.MESSAGES.size());
}
@Test
public void websocketFragmentationStressTestCase() throws Exception {
final ByteArrayOutputStream out = new ByteArrayOutputStream();
final CountDownLatch done = new CountDownLatch(1);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 10000; ++i) {
sb.append("message ");
sb.append(i);
}
String toSend = sb.toString();
final Session session = defaultContainer.connectToServer(new Endpoint() {
@Override
public void onOpen(Session session, EndpointConfig config) {
session.addMessageHandler(new MessageHandler.Partial<byte[]>() {
@Override
public void onMessage(byte[] bytes, boolean b) {
try {
out.write(bytes);
} catch (IOException e) {
e.printStackTrace();
done.countDown();
}
if (b) {
done.countDown();
}
}
});
}
@Override
public void onClose(Session session, CloseReason closeReason) {
done.countDown();
}
@Override
public void onError(Session session, Throwable thr) {
thr.printStackTrace();
done.countDown();
}
}, null, new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/stress"));
OutputStream stream = session.getBasicRemote().getSendStream();
for (int i = 0; i < toSend.length(); ++i) {
stream.write(toSend.charAt(i));
stream.flush();
}
stream.close();
done.await(40, TimeUnit.SECONDS);
Assert.assertEquals(toSend, new String(out.toByteArray()));
}
private static class SendRunnable implements Runnable {
private final Session session;
private final int thread;
private final AtomicInteger count = new AtomicInteger();
private final ExecutorService executor;
SendRunnable(Session session, int thread, ExecutorService executor) {
this.session = session;
this.thread = thread;
this.executor = executor;
}
@Override
public void run() {
session.getAsyncRemote().sendText("t-" + thread + "-m-" + count.get(), new SendHandler() {
@Override
public void onResult(SendResult result) {
if (!result.isOK()) {
try {
result.getException().printStackTrace();
session.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
if (count.incrementAndGet() != NUM_REQUESTS) {
executor.submit(SendRunnable.this);
} else {
executor.submit(new Runnable() {
@Override
public void run() {
session.getAsyncRemote().sendText("close");
}
});
}
}
});
}
}
}