/*
* JBoss, Home of Professional Open Source.
* Copyright 2012, Red Hat, Inc., and individual contributors
* as indicated by the @author tags. See the copyright.txt file in the
* distribution for a full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.jboss.naming.remote.common.ejb;
import org.jboss.ejb.client.ClusterAffinity;
import org.jboss.ejb.client.SessionID;
import org.jboss.ejb.client.remoting.PackedInteger;
import org.jboss.logging.Logger;
import org.jboss.marshalling.Marshalling;
import org.jboss.marshalling.SimpleDataInput;
import org.jboss.remoting3.Channel;
import org.jboss.remoting3.CloseHandler;
import org.jboss.remoting3.Endpoint;
import org.jboss.remoting3.MessageInputStream;
import org.jboss.remoting3.OpenListener;
import org.jboss.remoting3.Registration;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.channels.AcceptingChannel;
import org.xnio.channels.ConnectedStreamChannel;
import javax.ejb.NoSuchEJBException;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.Future;
/**
* @author Jaikiran Pai
*/
public class DummyEJBServer {
private static final Logger logger = Logger.getLogger(DummyEJBServer.class);
private static final String[] supportedMarshallerTypes = new String[]{"river", "java-serial"};
private static final String CLUSTER_NAME = "dummy-cluster";
private AcceptingChannel<? extends ConnectedStreamChannel> server;
private Map<EJBModuleIdentifier, Map<String, Object>> registeredEJBs = new ConcurrentHashMap<EJBModuleIdentifier, Map<String, Object>>();
private final Collection<Channel> openChannels = new CopyOnWriteArraySet<Channel>();
private final Endpoint endpoint;
private volatile Registration ejbChannelRegistration;
public DummyEJBServer(final Endpoint endpoint) {
this.endpoint = endpoint;
}
public synchronized void start() throws IOException {
if (this.ejbChannelRegistration != null) {
throw new IllegalStateException(this.getClass().getSimpleName() + " is already started");
}
this.ejbChannelRegistration = this.registerEJBServer();
}
public synchronized void stop() throws IOException {
if (this.ejbChannelRegistration == null) {
throw new IllegalStateException(this.getClass().getSimpleName() + " is not started");
}
this.ejbChannelRegistration.close();
}
private Registration registerEJBServer() throws IOException {
logger.info("Registering EJB server to endpoint " + endpoint);
return endpoint.registerService("jboss.ejb", new OpenListener() {
@Override
public void channelOpened(Channel channel) {
logger.info("Channel opened " + channel);
channel.addCloseHandler(new CloseHandler<Channel>() {
@Override
public void handleClose(Channel closed, IOException exception) {
logger.info("Bye " + closed);
}
});
try {
this.sendVersionMessage(channel);
} catch (IOException e) {
logger.error("Could not send version message to channel " + channel + " Closing the channel");
IoUtils.safeClose(channel);
}
Channel.Receiver handler = new VersionReceiver();
channel.receiveMessage(handler);
}
@Override
public void registrationTerminated() {
logger.info("Registration terminated for open listener");
}
private void sendVersionMessage(final Channel channel) throws IOException {
final DataOutputStream outputStream = new DataOutputStream(channel.writeMessage());
// write the version
outputStream.write(0x01);
// write the marshaller type count
PackedInteger.writePackedInteger(outputStream, supportedMarshallerTypes.length);
// write the marshaller types
for (int i = 0; i < supportedMarshallerTypes.length; i++) {
outputStream.writeUTF(supportedMarshallerTypes[i]);
}
outputStream.flush();
outputStream.close();
}
}, OptionMap.EMPTY);
}
public String getClusterName() {
return this.CLUSTER_NAME;
}
class Version1Receiver implements Channel.Receiver {
private final DummyProtocolHandler dummyProtocolHandler;
Version1Receiver(final String marshallingType) {
this.dummyProtocolHandler = new DummyProtocolHandler(marshallingType);
}
@Override
public void handleError(Channel channel, IOException error) {
//To change body of implemented methods use File | Settings | File Templates.
}
@Override
public void handleEnd(Channel channel) {
//To change body of implemented methods use File | Settings | File Templates.
}
@Override
public void handleMessage(Channel channel, MessageInputStream messageInputStream) {
final DataInputStream inputStream = new DataInputStream(messageInputStream);
try {
final byte header = inputStream.readByte();
logger.info("Received message with header 0x" + Integer.toHexString(header));
switch (header) {
case 0x03:
final MethodInvocationRequest methodInvocationRequest = this.dummyProtocolHandler.readMethodInvocationRequest(inputStream, this.getClass().getClassLoader());
Object methodInvocationResult = null;
try {
methodInvocationResult = DummyEJBServer.this.handleMethodInvocationRequest(channel, methodInvocationRequest, dummyProtocolHandler);
} catch (NoSuchEJBException nsee) {
final DataOutputStream outputStream = new DataOutputStream(channel.writeMessage());
try {
this.dummyProtocolHandler.writeNoSuchEJBFailureMessage(outputStream, methodInvocationRequest.getInvocationId(), methodInvocationRequest.getAppName(),
methodInvocationRequest.getModuleName(), methodInvocationRequest.getDistinctName(), methodInvocationRequest.getBeanName(),
methodInvocationRequest.getViewClassName());
} finally {
outputStream.close();
}
return;
} catch (Exception e) {
final DataOutputStream outputStream = new DataOutputStream(channel.writeMessage());
try {
this.dummyProtocolHandler.writeException(outputStream, methodInvocationRequest.getInvocationId(), e, methodInvocationRequest.getAttachments());
} finally {
outputStream.close();
}
return;
}
logger.info("Method invocation result on server " + methodInvocationResult);
// write the method invocation result
final DataOutputStream outputStream = new DataOutputStream(channel.writeMessage());
try {
this.dummyProtocolHandler.writeMethodInvocationResponse(outputStream, methodInvocationRequest.getInvocationId(), methodInvocationResult, methodInvocationRequest.getAttachments());
} finally {
outputStream.close();
}
break;
case 0x01:
// session open request
try {
this.handleSessionOpenRequest(channel, messageInputStream);
} catch (Exception e) {
// TODO: Let the client know of this exception
throw new RuntimeException(e);
}
break;
default:
logger.warn("Not supported message header 0x" + Integer.toHexString(header) + " received by " + this);
return;
}
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
// receive next message
channel.receiveMessage(this);
IoUtils.safeClose(inputStream);
}
}
private void handleSessionOpenRequest(Channel channel, MessageInputStream messageInputStream) throws IOException {
if (messageInputStream == null) {
throw new IllegalArgumentException("Cannot read from null message inputstream");
}
final DataInputStream dataInputStream = new DataInputStream(messageInputStream);
// read invocation id
final short invocationId = dataInputStream.readShort();
final String appName = dataInputStream.readUTF();
final String moduleName = dataInputStream.readUTF();
final String distinctName = dataInputStream.readUTF();
final String beanName = dataInputStream.readUTF();
final EJBModuleIdentifier ejbModuleIdentifier = new EJBModuleIdentifier(appName, moduleName, distinctName);
final Map<String, Object> ejbs = DummyEJBServer.this.registeredEJBs.get(ejbModuleIdentifier);
if (ejbs == null || ejbs.get(beanName) == null) {
final DataOutputStream outputStream = new DataOutputStream(channel.writeMessage());
try {
this.dummyProtocolHandler.writeNoSuchEJBFailureMessage(outputStream, invocationId, appName, moduleName, distinctName, beanName, null);
} finally {
outputStream.close();
}
return;
}
final UUID uuid = UUID.randomUUID();
ByteBuffer bb = ByteBuffer.wrap(new byte[16]);
bb.putLong(uuid.getMostSignificantBits());
bb.putLong(uuid.getLeastSignificantBits());
final SessionID sessionID = SessionID.createSessionID(bb.array());
final DataOutputStream outputStream = new DataOutputStream(channel.writeMessage());
try {
final ClusterAffinity hardAffinity = new ClusterAffinity(DummyEJBServer.this.CLUSTER_NAME);
this.dummyProtocolHandler.writeSessionId(outputStream, invocationId, sessionID, hardAffinity);
} finally {
outputStream.close();
}
}
}
public void register(final String appName, final String moduleName, final String distinctName, final String beanName, final Object instance) {
final EJBModuleIdentifier moduleID = new EJBModuleIdentifier(appName, moduleName, distinctName);
Map<String, Object> ejbs = this.registeredEJBs.get(moduleID);
if (ejbs == null) {
ejbs = new HashMap<String, Object>();
this.registeredEJBs.put(moduleID, ejbs);
}
ejbs.put(beanName, instance);
try {
this.sendNewModuleReportToClients(new EJBModuleIdentifier[]{moduleID}, true);
} catch (IOException e) {
logger.warn("Could not send EJB module availability message to clients, for module " + moduleID, e);
}
}
public void unregister(final String appName, final String moduleName, final String distinctName, final String beanName) {
this.unregister(appName, moduleName, distinctName, beanName, true);
}
public void unregister(final String appName, final String moduleName, final String distinctName, final String beanName, final boolean notifyClients) {
final EJBModuleIdentifier moduleID = new EJBModuleIdentifier(appName, moduleName, distinctName);
Map<String, Object> ejbs = this.registeredEJBs.get(moduleID);
if (ejbs != null) {
ejbs.remove(beanName);
}
if (notifyClients) {
try {
this.sendNewModuleReportToClients(new EJBModuleIdentifier[]{moduleID}, false);
} catch (IOException e) {
logger.warn("Could not send EJB module un-availability message to clients, for module " + moduleID, e);
}
}
}
private void sendNewModuleReportToClients(final EJBModuleIdentifier[] modules, final boolean availabilityReport) throws IOException {
if (modules == null) {
return;
}
if (this.openChannels.isEmpty()) {
logger.debug("No open channels to send EJB module availability");
}
for (final Channel channel : this.openChannels) {
final DataOutputStream dataOutputStream = new DataOutputStream(channel.writeMessage());
try {
if (availabilityReport) {
this.writeModuleAvailability(dataOutputStream, modules);
} else {
this.writeModuleUnAvailability(dataOutputStream, modules);
}
} catch (IOException e) {
logger.warn("Could not send module availability message to client", e);
} finally {
dataOutputStream.close();
}
}
}
private void writeModuleAvailability(final DataOutput output, final EJBModuleIdentifier[] ejbModuleIdentifiers) throws IOException {
if (output == null) {
throw new IllegalArgumentException("Cannot write to null output");
}
if (ejbModuleIdentifiers == null) {
throw new IllegalArgumentException("EJB module identifiers cannot be null");
}
// write the header
output.write(0x08);
this.writeModuleReport(output, ejbModuleIdentifiers);
}
private void writeModuleUnAvailability(final DataOutput output, final EJBModuleIdentifier[] ejbModuleIdentifiers) throws IOException {
if (output == null) {
throw new IllegalArgumentException("Cannot write to null output");
}
if (ejbModuleIdentifiers == null) {
throw new IllegalArgumentException("EJB module identifiers cannot be null");
}
// write the header
output.write(0x09);
this.writeModuleReport(output, ejbModuleIdentifiers);
}
private void writeModuleReport(final DataOutput output, final EJBModuleIdentifier[] modules) throws IOException {
// write the count
PackedInteger.writePackedInteger(output, modules.length);
// write the module identifiers
for (int i = 0; i < modules.length; i++) {
// write the app name
final String appName = modules[i].getAppName();
if (appName == null) {
// write out a empty string
output.writeUTF("");
} else {
output.writeUTF(appName);
}
// write the module name
output.writeUTF(modules[i].getModuleName());
// write the distinct name
final String distinctName = modules[i].getDistinctName();
if (distinctName == null) {
// write out an empty string
output.writeUTF("");
} else {
output.writeUTF(distinctName);
}
}
}
private Object handleMethodInvocationRequest(final Channel channel, final MethodInvocationRequest methodInvocationRequest, final DummyProtocolHandler dummyProtocolHandler) throws InvocationTargetException, IllegalAccessException, IOException {
final EJBModuleIdentifier ejbModuleIdentifier = new EJBModuleIdentifier(methodInvocationRequest.getAppName(), methodInvocationRequest.getModuleName(), methodInvocationRequest.getDistinctName());
final Map<String, Object> ejbs = this.registeredEJBs.get(ejbModuleIdentifier);
final Object beanInstance = ejbs.get(methodInvocationRequest.getBeanName());
if (beanInstance == null) {
throw new NoSuchEJBException(methodInvocationRequest.getBeanName() + " EJB not available");
}
Method method = null;
try {
method = this.getRequiredMethod(beanInstance.getClass(), methodInvocationRequest.getMethodName(), methodInvocationRequest.getParamTypes());
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
// check if this is an async method
if (this.isAsyncMethod(method)) {
final DataOutputStream output = new DataOutputStream(channel.writeMessage());
try {
// send a notification to the client that this is an async method
dummyProtocolHandler.writeAsyncMethodNotification(output, methodInvocationRequest.getInvocationId());
} finally {
output.close();
}
}
// invoke on the method
return method.invoke(beanInstance, methodInvocationRequest.getParams());
}
private Method getRequiredMethod(final Class<?> klass, final String methodName, final String[] paramTypes) throws NoSuchMethodException {
final Class<?>[] types = new Class<?>[paramTypes.length];
for (int i = 0; i < paramTypes.length; i++) {
try {
types[i] = Class.forName(paramTypes[i], false, klass.getClassLoader());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
return klass.getMethod(methodName, types);
}
private boolean isAsyncMethod(final Method method) {
// just check for return type and assume it to be async if it returns Future
return method.getReturnType().equals(Future.class);
}
class VersionReceiver implements Channel.Receiver {
@Override
public void handleError(Channel channel, IOException error) {
try {
channel.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
throw new RuntimeException("NYI: .handleError");
}
@Override
public void handleEnd(Channel channel) {
try {
channel.close();
} catch (IOException e) {
// ignore
}
}
@Override
public void handleMessage(Channel channel, MessageInputStream message) {
final SimpleDataInput input = new SimpleDataInput(Marshalling.createByteInput(message));
try {
final byte version = input.readByte();
final String clientMarshallingType = input.readUTF();
input.close();
switch (version) {
case 0x01:
final Version1Receiver receiver = new Version1Receiver(clientMarshallingType);
DummyEJBServer.this.openChannels.add(channel);
channel.receiveMessage(receiver);
// send module availability report to clients
final Collection<EJBModuleIdentifier> availableModules = DummyEJBServer.this.registeredEJBs.keySet();
DummyEJBServer.this.sendNewModuleReportToClients(availableModules.toArray(new EJBModuleIdentifier[availableModules.size()]), true);
break;
default:
logger.info("Received unsupported version 0x" + Integer.toHexString(version) + " from client, on channel " + channel);
channel.close();
break;
}
} catch (IOException e) {
logger.error("Exception on channel " + channel, e);
try {
logger.info("Shutting down channel " + channel);
channel.writeShutdown();
} catch (IOException e1) {
// ignore
if (logger.isTraceEnabled()) {
logger.trace("Ignoring exception that occurred during channel shutdown", e1);
}
}
}
}
}
private class EJBModuleIdentifier {
private final String appName;
private final String moduleName;
private final String distinctName;
EJBModuleIdentifier(final String appname, final String moduleName, final String distinctName) {
this.appName = appname;
this.moduleName = moduleName;
this.distinctName = distinctName;
}
String getAppName() {
return this.appName;
}
String getModuleName() {
return this.moduleName;
}
String getDistinctName() {
return this.distinctName;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EJBModuleIdentifier that = (EJBModuleIdentifier) o;
if (appName != null ? !appName.equals(that.appName) : that.appName != null) return false;
if (distinctName != null ? !distinctName.equals(that.distinctName) : that.distinctName != null)
return false;
if (!moduleName.equals(that.moduleName)) return false;
return true;
}
@Override
public int hashCode() {
int result = appName != null ? appName.hashCode() : 0;
result = 31 * result + moduleName.hashCode();
result = 31 * result + (distinctName != null ? distinctName.hashCode() : 0);
return result;
}
}
}