package com.jenjinstudios.core;
import com.jenjinstudios.core.io.Message;
import com.jenjinstudios.core.io.MessageRegistry;
import com.jenjinstudios.core.io.MessageTypeException;
import com.jenjinstudios.core.xml.MessageType;
import java.io.EOFException;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.net.InetAddress;
import java.net.SocketException;
import java.security.Key;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* The {@code Connection} class is a subclass of the {@code Thread} class; it will loop indefinitely until the {@code
* shutdown} method is called, reading {@code Message} objects from a stream, and invoking the correct implementation of
* {@code ExecutableMessage} each time a new message is received.
*
* @author Caleb Brinkman
*/
public class Connection
{
private static final Logger LOGGER = Logger.getLogger(Connection.class.getName());
private static final int KEYSIZE = 512;
private final PingTracker pingTracker;
private final ExecutableMessageQueue executableMessageQueue;
private final MessageIO messageIO;
private final Thread messageReaderThread;
private String name = "Connection";
private final Map<InetAddress, Key> verifiedKeys = new HashMap<>(10);
/**
* Construct a new {@code Connection} that utilizes the specified {@code MessageIO} to read and write messages.
*
* @param streams The {@code MessageIO} containing streams used to read and write messages.
*/
protected Connection(MessageIO streams) {
this.messageIO = streams;
pingTracker = new PingTracker();
executableMessageQueue = new ExecutableMessageQueue();
messageReaderThread = new Thread(new RunnableMessageReader(this));
}
/**
* Generate a PublicKeyMessage for the given {@code PublicKey}.
*
* @param publicKey The {@code PublicKey} for which to generate a {@code Message}.
*
* @return The generated message.
*/
public static Message generatePublicKeyMessage(Key publicKey) {
Message publicKeyMessage = MessageRegistry.getInstance().createMessage("PublicKeyMessage");
publicKeyMessage.setArgument("publicKey", publicKey.getEncoded());
return publicKeyMessage;
}
/**
* Generate an RSA-512 Public-Private Key Pair.
*
* @return The generated {@code KeyPair}, or null if the KeyPair could not be created.
*/
public static KeyPair generateRSAKeyPair() {
KeyPair keyPair = null;
try
{
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(KEYSIZE);
keyPair = keyPairGenerator.generateKeyPair();
} catch (NoSuchAlgorithmException e)
{
LOGGER.log(Level.SEVERE, "Unable to create RSA key pair!", e);
}
return keyPair;
}
/**
* Start the message reader thread managed by this connection.
*/
public void start() {
messageReaderThread.start();
}
/**
* Set the RSA public/private key pair used to encrypt outgoing and decrypt incoming messages, and queue a message
* containing the public key.
*
* @param rsaKeyPair The keypair to use for encryption/decrytion.
*/
public void setRSAKeyPair(KeyPair rsaKeyPair) {
if (rsaKeyPair != null)
{
messageIO.getIn().setPrivateKey(rsaKeyPair.getPrivate());
Message message = generatePublicKeyMessage(rsaKeyPair.getPublic());
messageIO.queueOutgoingMessage(message);
}
}
/**
* Get the MessageIO containing the keys and streams used by this connection.
*
* @return The MessageIO containing the keys and streams used by this connection.
*/
public MessageIO getMessageIO() { return messageIO; }
/**
* Get the PingTracker used by this connection to track latency.
*
* @return The PingTracker used by this connection to track latency.
*/
public PingTracker getPingTracker() { return pingTracker; }
/**
* Get the {@code ExecutableMessageQueue} maintained by this connection.
*
* @return The {@code ExecutableMessageQueue} maintained by this connection.
*/
public ExecutableMessageQueue getExecutableMessageQueue() { return executableMessageQueue; }
/**
* End this connection's execution loop and close any streams.
*/
public void shutdown() {
messageIO.closeInputStream();
messageIO.closeOutputStream();
}
/**
* Get the name of this {@code Connection}.
*
* @return The name of this {@code Connection}.
*/
public String getName() { return name; }
/**
* Set the name of this {@code Connection}.
*
* @param name The name of this {@code Connection}.
*/
public void setName(String name) { this.name = name; }
/**
* Get the map of domains and verified keys for this client.
*
* @return The map of domains and verified keys for this client.
*/
@SuppressWarnings("ReturnOfCollectionOrArrayField")
public Map<InetAddress, Key> getVerifiedKeys() { return verifiedKeys; }
/**
* This class is used to continuously read {@code Message} objects from a {@code MessageInputStream}, invoke the
* appropriate {@code ExecutableMessage}, and store it so that the {@code runeDelayed} method may be called later.
*
* @author Caleb Brinkman
*/
public static class RunnableMessageReader implements Runnable
{
private static final int MAX_INVALID_MESSAGES = 10;
private static final Logger INNER_LOGGER = Logger.getLogger(RunnableMessageReader.class.getName());
private final Connection connection;
private int invalidMsgCount;
/**
* Construct a new {@code RunnableMessageReader} working for the given Connection.
*
* @param connection The {@code Connection} managing this reader.
*/
public RunnableMessageReader(Connection connection) {
this.connection = connection;
}
/**
* Generate an InvalidMessage message for the given invalid ID and message name.
*
* @param id The ID of the invalid message.
* @param name The Name of the invalid message.
*
* @return The generated InvalidMessage object.
*/
private static Message generateInvalidMessage(short id, String name) {
Message invalid = MessageRegistry.getInstance().createMessage("InvalidMessage");
invalid.setArgument("messageName", name);
invalid.setArgument("messageID", id);
return invalid;
}
@Override
public void run() {
boolean success = true;
while ((invalidMsgCount < MAX_INVALID_MESSAGES) && success)
{
try
{
Message currentMessage = connection.getMessageIO().getIn().readMessage();
executeMessage(currentMessage);
} catch (MessageTypeException e)
{
reportInvalidMessage(e);
} catch (EOFException | SocketException e)
{
INNER_LOGGER.log(Level.FINER, "Connection closed: " + connection.getName(), e);
success = false;
} catch (IOException e)
{
INNER_LOGGER.log(Level.FINE, "IOException when attempting to read from stream.", e);
success = false;
}
}
}
void executeMessage(Message message) {
ExecutableMessageFactory messageFactory = new ExecutableMessageFactory(connection);
Collection<ExecutableMessage> execs = messageFactory.getExecutableMessagesFor(message);
for (ExecutableMessage exec : execs)
{
if (exec != null)
{
processExecutableMessage(exec);
} else
{
processInvalidMessage(message);
}
}
}
private void processInvalidMessage(Message message) {
Message invalid = generateInvalidMessage(message.getID(), message.name);
connection.getMessageIO().queueOutgoingMessage(invalid);
}
private void processExecutableMessage(ExecutableMessage exec) {
exec.runImmediate();
connection.getExecutableMessageQueue().queueExecutableMessage(exec);
}
void reportInvalidMessage(MessageTypeException e) {
INNER_LOGGER.log(Level.WARNING, "Input stream reported invalid message receipt.");
Message unknown = generateInvalidMessage(e.getId(), "Unknown");
connection.getMessageIO().queueOutgoingMessage(unknown);
invalidMsgCount++;
}
}
/**
* Used to generate ExecutableMessages.
*
* @author Caleb Brinkman
*/
public static class ExecutableMessageFactory
{
private static final Logger INNER_LOGGER = Logger.getLogger(ExecutableMessageFactory.class.getName());
private static final Constructor[] EMPTY_CONSTRUCTOR_ARRAY = new Constructor[0];
private final Connection connection;
/**
* Construct an ExecutableMessageFactory for the specified connection.
*
* @param connection The connection for which this factory will produce ExecutableMessages.
*/
public ExecutableMessageFactory(Connection connection) { this.connection = connection; }
/**
* Given a {@code Connection} and a {@code Message}, create and return an appropriate {@code
* ExecutableMessage}.
*
* @param message The {@code Message} for which the {@code ExecutableMessage} is being created.
*
* @return The {@code ExecutableMessage} created for {@code connection} and {@code message}.
*/
public List<ExecutableMessage> getExecutableMessagesFor(Message message) {
List<ExecutableMessage> executableMessages = new LinkedList<>();
Collection<Constructor> execConstructors = getExecConstructors(message);
for (Constructor constructor : execConstructors)
{
if (constructor != null)
{
executableMessages.add(createExec(message, constructor));
} else
{
Object[] args = {connection.getClass().getName(), message.name};
String report = "No constructor containing Connection or {0} as first argument type found for {1}";
INNER_LOGGER.log(Level.SEVERE, report, args);
}
}
return executableMessages;
}
private Collection<Constructor> getExecConstructors(Message message) {
Collection<Constructor> constructors = new LinkedList<>();
MessageType messageType = MessageRegistry.getInstance().getMessageType(message.getID());
for (String className : messageType.getExecutables())
{
Constructor[] execConstructors = EMPTY_CONSTRUCTOR_ARRAY;
try
{
Class execClass = Class.forName(className);
execConstructors = execClass.getConstructors();
} catch (ClassNotFoundException ex)
{
INNER_LOGGER.log(Level.WARNING, "Could not find class: " + className, ex);
}
constructors.add(getAppropriateConstructor(execConstructors));
}
return constructors;
}
private ExecutableMessage createExec(Message msg, Constructor constructor) {
ExecutableMessage executableMessage = null;
try
{
executableMessage = (ExecutableMessage) constructor.newInstance(connection, msg);
} catch (InvocationTargetException | InstantiationException | IllegalAccessException e)
{
INNER_LOGGER.log(Level.SEVERE, "Constructor not correct", e);
}
return executableMessage;
}
private Constructor getAppropriateConstructor(Constructor... execConstructors) {
Constructor correctConstructor = null;
for (Constructor constructor : execConstructors)
{
Class<?> firstParam = constructor.getParameterTypes()[0];
if (firstParam.isAssignableFrom(connection.getClass()))
correctConstructor = constructor;
}
return correctConstructor;
}
}
}