package com.ausregistry.jtoolkit2.session; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.net.ConnectException; import java.net.InetAddress; import java.net.SocketTimeoutException; import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLSocket; import org.xml.sax.SAXException; import com.ausregistry.jtoolkit2.ErrorPkg; import com.ausregistry.jtoolkit2.Timer; import com.ausregistry.jtoolkit2.se.CLTRID; import com.ausregistry.jtoolkit2.se.Command; import com.ausregistry.jtoolkit2.se.CommandType; import com.ausregistry.jtoolkit2.se.Greeting; import com.ausregistry.jtoolkit2.se.LoginCommand; import com.ausregistry.jtoolkit2.se.LogoutCommand; import com.ausregistry.jtoolkit2.se.PollRequestCommand; import com.ausregistry.jtoolkit2.se.Response; import com.ausregistry.jtoolkit2.se.Result; import com.ausregistry.jtoolkit2.se.ResultCode; import com.ausregistry.jtoolkit2.xml.ParsingException; import com.ausregistry.jtoolkit2.xml.XMLDocument; import com.ausregistry.jtoolkit2.xml.XMLParser; /** * <p> * RFC5734 specifies a transport mapping for EPP over TCP; this class implements that mapping. That specification * requires that the session is layered over TLS (Transport Layer Security). This class complies with RFC5734 in * implementing the Session interface. It also implements its own StatsManager since it is tightly coupled with the * Session implementation. * </p> * * <p> * Uses the debug, support and user level loggers. * </p> */ public class TLSSession implements Session, StatsManager { private static final String[] TYPE_INTERVAL_ARR = new String[] {"<<type>>", "<<interval>>" }; private static final String[] TIME_COUNT_ARR = new String[] {"<<time>>", "<<count>>" }; private static final int BUF_SIZE = 4096; private static String pollXML; private static CommandType pollCmdType; static { PollRequestCommand poll = new PollRequestCommand(); try { pollXML = poll.toXML(); } catch (SAXException saxe) { } pollCmdType = poll.getCommandType(); } private java.io.DataInputStream in; private java.io.DataOutputStream out; private TLSContext ctx; private XMLParser parser; private SSLSocket socket; private boolean inUse; private boolean isOpen; private boolean isInvalid; private CommandCounter commandCounter; private ResultCounter resultCounter; private long totalTime; private Map<CommandType, Long> commandTimeMap; private long mruTime; private long acquireTimeout; private InetAddress inaddr; private int port; private int soTimeout; private String username; private String password; private String newPW; private String eppVersion; private String language; private String[] objURIs, extURIs; private Greeting greeting; private String pname; private Logger debugLogger; private Logger supportLogger; private Logger userLogger; { pname = TLSSession.class.getPackage().getName(); debugLogger = Logger.getLogger(pname + ".debug"); supportLogger = Logger.getLogger(pname + ".support"); userLogger = Logger.getLogger(pname + ".user"); parser = new XMLParser(); resultCounter = new ResultCounter(); commandTimeMap = new HashMap<CommandType, Long>(); totalTime = 0L; isInvalid = true; greeting = null; } protected TLSSession() { } protected TLSSession(SessionProperties props) throws SessionConfigurationException { configure(props); } /** * Configure the session as described in the Session interface. * * @throws SessionConfigurationException * Possible causes: * <ul> * <li>KeyStoreNotFoundException</li> * <li>KeyStoreTypeException</li> * <li>NoSuchAlgorithmException</li> * <li>UnrecoverableKeyException</li> * <li>CertificateException</li> * <li>KeyStoreReadException</li> * <li>UnknownHostException</li> * <li>FileNotFoundException</li> * </ul> */ @Override public void configure(SessionProperties properties) throws SessionConfigurationException { this.port = properties.getPort(); this.username = properties.getClientID(); this.password = properties.getClientPW(); this.newPW = null; this.eppVersion = properties.getVersion(); this.language = properties.getLanguage(); this.objURIs = properties.getObjURIs(); this.extURIs = properties.getExtURIs(); commandCounter = new CommandCounter(properties.getCommandLimitInterval()); this.acquireTimeout = properties.getAcquireTimeout(); this.soTimeout = properties.getSocketTimeout(); try { inaddr = InetAddress.getByName(properties.getHostname()); if (ctx == null) { ctx = new TLSContext(properties.getKeyStoreFilename(), properties.getKeyStorePassphrase(), properties.getTrustStoreFilename(), properties.getTrustStorePassphrase(), properties.getKeyStoreType(), properties.getSSLAlgorithm()); } } catch (Exception e) { throw new SessionConfigurationException(e); } } @Override public boolean isInvalid() { return isInvalid; } @Override public boolean isOpen() { return isOpen; } /** * * An EPP session is opened by first establishing a connection using the server location information and * authentication sources provided in <code>configure</code>, then issuing a login command with further * authentication data and options provided in <code>configure</code>. Service information provided immediately upon * connection establishment may affect options provided in the login command. * * @throws SessionOpenException * The getCause() method should be invoked on the exception thrown. The cause may be one of: * <dl> * <dt>SSLHandshakeException</dt> * <dd>The SSL handshake failed. The reason is described in the exception message, and is also recorded * in the user logs.</dd> * <dt>IOException</dt> * <dd>See the log record published to the user logs, or check the exception message.</dd> * <dt>GreetingException</dt> * <dd>The service element received from the server upon connection establishment was not a valid EPP * greeting.</dd> * <dt>LoginException</dt> * <dd>The login command failed to establish an EPP session. The cause, available via getCause(), * describes the specific reason for failure.</dd> * </dl> */ @Override public void open() throws SessionOpenException { isInvalid = true; try { openSocket(); processGreeting(); login(); isOpen = true; isInvalid = false; } catch (SSLHandshakeException e) { userLogger.severe(ErrorPkg.getMessage("TLSContext.createSocket.0", "<<java.home>>", System.getProperty("java.home", "java.home"))); userLogger.severe(e.getMessage()); throw new SessionOpenException(e); } catch (ConnectException e) { final String errorMessage = ErrorPkg.getMessage("net.socket.open.fail", new String[] {"<<port>>", "<<host>>" }, new String[] {String.valueOf(port), inaddr.getHostAddress() }); userLogger.severe(errorMessage); userLogger.severe(e.getMessage()); throw new SessionOpenException(errorMessage); } catch (IOException e) { userLogger.severe(ErrorPkg.getMessage("net.socket.open.fail", new String[] {"<<port>>", "<<host>>" }, new String[] {String.valueOf(port), inaddr.getHostAddress() })); userLogger.severe(e.getMessage()); throw new SessionOpenException(ErrorPkg.getMessage("net.socket.open.fail", new String[] {"<<port>>", "<<host>>" }, new String[] {String.valueOf(port), inaddr.getHostAddress() })); } catch (SessionLimitExceededException e) { e.printStackTrace(); throw new SessionOpenException(e); } catch (GreetingException e) { e.printStackTrace(); throw new SessionOpenException(e); } catch (LoginException e) { e.printStackTrace(); throw new SessionOpenException(e); } } @Override public void changePassword(String newPassword) throws SessionOpenException { newPW = newPassword; open(); close(); } private void openSocket() throws SSLHandshakeException, IOException { socket = ctx.createSocket(inaddr.getHostAddress(), port, soTimeout); in = new DataInputStream(new BufferedInputStream(socket.getInputStream())); out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream(), BUF_SIZE)); } @Override public Greeting getGreeting() { return greeting; } private void processGreeting() throws SessionLimitExceededException, GreetingException { try { if (greeting == null) { greeting = new Greeting(); } greeting.fromXML(readToDocument()); } catch (IOException pe) { throw new SessionLimitExceededException(); } catch (ParsingException pe) { pe.printStackTrace(); throw new GreetingException(pe); } } private void login() throws LoginException, IOException { LoginCommand login = new LoginCommand(username, password, newPW, eppVersion, language, objURIs, extURIs); CLTRID.setClID(username); long startTime = Timer.now(); try { write(login.toXML(), login.getCommandType()); } catch (SAXException saxe) { throw new LoginException(saxe); } mruTime = Timer.now(); Response response = new Response(); try { response.fromXML(readToDocument()); long responseTime = Timer.msDiff(startTime); recordResponseTime(login.getCommandType(), responseTime); } catch (NumberFormatException nfe) { // These two exceptions usually indicate a fatal error. throw new LoginException(nfe); } catch (ParsingException pe) { throw new LoginException(pe); } Result result = response.getResults()[0]; assert result != null; String msg = result.getResultMessage(); switch (result.getResultCode()) { case ResultCode.SUCCESS: return; case ResultCode.CMD_USE_ERR: userLogger.warning(ErrorPkg.getMessage("epp.login.fail.loggedin")); throw new LoginException(msg); case ResultCode.AUTHENT_ERROR_CLOSING: case ResultCode.AUTHENT_ERR: String cn = ctx.getCertificateCommonName(); if (cn.equals(username)) { userLogger.severe(ErrorPkg.getMessage("epp.login.fail.auth.pw", new String[] {"<<clID>>", "<<pw>>" }, new String[] {username, password })); throw new UserPassMismatchException(msg); } else { userLogger.severe(ErrorPkg.getMessage("epp.login.fail.auth.match", new String[] {"<<clID>>", "<<cn>>" }, new String[] {username, cn })); throw new CertificateUserMismatchException(username, cn); } case ResultCode.UNIMPL_OBJ_SVC: if (result.getResultValue() != null) { msg = result.getResultValue().item(0).getTextContent(); } userLogger.severe(ErrorPkg.getMessage("epp.login.fail.unimpl.objsvc", "<<uri>>", msg)); throw new LoginException(msg); case ResultCode.PARAM_VAL_SYNTAX_ERR: throw new LoginException(new ParameterSyntaxException(result.getResultMessage())); case ResultCode.SESS_LIM_EXCEEDED_CLOSING: throw new SessionLimitExceededException(); case ResultCode.CMD_FAILED: raiseLoginException(result.hasResultExtReasons(), result.getResultExtReason(0)); case ResultCode.CMD_FAILED_CLOSING: raiseLoginException(result.hasResultExtReasons(), result.getResultExtReason(0)); default: // do nothing } } private void raiseLoginException(boolean hasValues, String msg) throws LoginException { if (hasValues) { throw new LoginException(new CommandFailedException(msg)); } else { throw new LoginException(new CommandFailedException()); } } @Override public void close() { inUse = true; isOpen = false; try { logout(); } catch (LogoutException le) { debugLogger.info(le.getMessage()); debugLogger.info(ErrorPkg.getMessage("net.event.socket_closed")); } closeSocket(); inUse = false; } private void logout() throws LogoutException { LogoutCommand logout = new LogoutCommand(); Response response = new Response(); try { long startTime = Timer.now(); try { write(logout.toXML(), logout.getCommandType()); } catch (SAXException saxe) { debugLogger.warning(saxe.getMessage()); throw new LogoutException(saxe); } response.fromXML(readToDocument()); long responseTime = Timer.msDiff(startTime); recordResponseTime(logout.getCommandType(), responseTime); Result[] results = response.getResults(); switch (results[0].getResultCode()) { case ResultCode.SUCCESS: return; case ResultCode.CMD_FAILED: default: throw new LogoutException(new CommandFailedException(results[0].getResultMessage())); } } catch (IOException ioe) { throw new LogoutException(ioe); } catch (ParsingException pe) { throw new LogoutException(pe); } } private void closeSocket() { try { in.close(); } catch (IOException ioe) { userLogger.warning(ioe.getMessage()); } try { out.close(); } catch (IOException ioe) { userLogger.warning(ioe.getMessage()); } try { socket.close(); } catch (IOException ioe) { userLogger.warning(ioe.getMessage()); } } /** * Receive data from the peer. This method is unsynchronised; the caller MUST provide synchronisation against other * calls to read. * * @return the details from the socket * @throws IOException Signals that an I/O exception has occurred. */ @Override public String read() throws IOException { try { int n = readSize(); debugLogger.finer("PDU size: " + n); String data = readData(n); supportLogger.info(data); return data; } catch (SocketTimeoutException ste) { userLogger.severe(ste.getMessage()); userLogger.severe(ErrorPkg.getMessage("epp.session.read.timeout")); throw ste; } catch (IOException ioe) { userLogger.severe(ioe.getMessage()); throw ioe; } } @Override public void read(Response response) throws IOException, ParsingException { response.fromXML(readToDocument()); } @Override public XMLDocument readToDocument() throws IOException, ParsingException { String xml = read(); assert parser != null; return parser.parse(xml); } /** * Send data to peer. This method is unsynchronised; the caller MUST provide synchronisation against other calls to * <code>write(String)</code>. * * @param xml * the XML to be sent to the EPP Server * @throws IOException * Signals that an I/O exception has occurred. */ @Override public void write(String xml) throws IOException { doWrite(xml); mruTime = Timer.now(); } @Override public void write(Command command) throws IOException, ParsingException { try { write(command.toXML()); } catch (SAXException saxe) { throw new ParsingException(saxe); } } private void doWrite(String xml) throws IOException { if (out == null) { throw new UninitialisedSessionException(); } try { final byte[] xmlBytes = xml.getBytes(); writeSize(xmlBytes.length); writeData(xmlBytes); } catch (IOException ioe) { isInvalid = true; throw ioe; } } private void write(String xml, CommandType cmdType) throws IOException { debugLogger.finer("writing command " + cmdType.toString()); doWrite(xml); incCommandCounter(cmdType); } /** * Send a poll command to the EPP server in order to prevent the session timing out. This operation does not affect * the most-recently-used statistic. */ @Override public void keepAlive() throws IOException { try { acquire(); write(pollXML, pollCmdType); read(); // not interested in response. release(); } catch (TimeoutException te) { userLogger.info(te.getMessage()); } catch (InterruptedException ie) { userLogger.info(ie.getMessage()); } } private int readSize() throws IOException { return in.readInt() - 4; } private void writeSize(int size) throws IOException { out.writeInt(size + 4); } private String readData(int length) throws IOException { byte[] inputBuffer = new byte[length]; in.readFully(inputBuffer, 0, length); return new String(inputBuffer); } private void writeData(final byte[] xml) throws IOException { out.write(xml); out.flush(); } @Override public void incCommandCounter(CommandType type) { commandCounter.increment(type); } @Override public void incResultCounter(int resultCode) { resultCounter.increment(resultCode); } @Override public void recordResponseTime(CommandType type, long responseTime) { totalTime += responseTime; if (!commandTimeMap.containsKey(type)) { commandTimeMap.put(type, responseTime); debugLogger.info(ErrorPkg.getMessage("epp.server.response_time.new_cmd", TYPE_INTERVAL_ARR, new String[] { type.getCommandName(), String.valueOf(responseTime) })); } else { commandTimeMap.put(type, commandTimeMap.get(type) + responseTime); debugLogger.info(ErrorPkg.getMessage("epp.server.response_time.previous_cmd", TYPE_INTERVAL_ARR, new String[] {type.getCommandName(), String.valueOf(responseTime) })); } } @Override public long getAverageResponseTime() { long totalCount = commandCounter.getTotal(); if (totalCount == 0L) { return 0L; } debugLogger.info(ErrorPkg.getMessage("epp.server.response_time.avg", TIME_COUNT_ARR, new String[] {String.valueOf(totalTime), String.valueOf(totalCount) })); return totalTime / totalCount; } @Override public long getAverageResponseTime(CommandType type) { if (!commandTimeMap.containsKey(type)) { return 0L; } long cmdCount = commandCounter.getCount(type); if (cmdCount == 0L) { return 0L; } return commandTimeMap.get(type) / cmdCount; } @Override public long getCommandCount() { return commandCounter.getTotal(); } @Override public long getCommandCount(CommandType type) { return commandCounter.getCount(type); } @Override public int getRecentCommandCount() { return commandCounter.getExactRecentTotal(); } @Override public int getRecentCommandCount(CommandType type) { return commandCounter.getRecentCount(type); } @Override public long getResultCodeCount(int resultCode) { return resultCounter.getValue(resultCode); } /** * Get the length of time (in milliseconds) since the most recent use (mru) of the session. The session is * considered to be used when the write method is invoked. */ @Override public long getMruInterval() { return Timer.msDiff(mruTime); } @Override public StatsManager getStatsManager() { return this; } @Override public boolean isAvailable() { return (isOpen && !inUse); } @Override public void acquire() throws InterruptedException, TimeoutException { synchronized (this) { while (inUse) { wait(acquireTimeout); if (inUse) { throw new TimeoutException(ErrorPkg.getMessage("epp.session.acquire.timeout", "<<timeout>>", String.valueOf(acquireTimeout))); } } inUse = true; } } @Override public void release() { synchronized (this) { inUse = false; notify(); } } }