package io.termd.core.ssh.netty;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.sshd.common.Factory;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.NamedResource;
import org.apache.sshd.common.PropertyResolverUtils;
import org.apache.sshd.common.Service;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.closeable.AbstractCloseable;
import org.apache.sshd.server.ServerAuthenticationManager;
import org.apache.sshd.server.ServerFactoryManager;
import org.apache.sshd.server.auth.UserAuth;
import org.apache.sshd.server.auth.UserAuthNoneFactory;
import org.apache.sshd.server.session.ServerSession;
import org.apache.sshd.server.session.ServerSessionHolder;
/**
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
public class AsyncUserAuthService extends AbstractCloseable implements Service, ServerSessionHolder {
private final ServerSession serverSession;
private List<NamedFactory<UserAuth>> userAuthFactories;
private List<List<String>> authMethods;
private String authUserName;
private String authMethod;
private String authService;
private UserAuth currentAuth;
private AsyncAuth async;
private int maxAuthRequests;
private int nbAuthRequests;
public AsyncUserAuthService(Session s) throws SshException {
ValidateUtils.checkTrue(s instanceof ServerSession, "Server side service used on client side");
if (s.isAuthenticated()) {
throw new SshException("Session already authenticated");
}
serverSession = (ServerSession) s;
maxAuthRequests = PropertyResolverUtils.getIntProperty(s, ServerAuthenticationManager.MAX_AUTH_REQUESTS, ServerAuthenticationManager.DEFAULT_MAX_AUTH_REQUESTS);
List<NamedFactory<UserAuth>> factories = ValidateUtils.checkNotNullAndNotEmpty(
serverSession.getUserAuthFactories(), "No user auth factories for %s", s);
userAuthFactories = new ArrayList<>(factories);
// Get authentication methods
authMethods = new ArrayList<>();
String mths = PropertyResolverUtils.getString(s, ServerFactoryManager.AUTH_METHODS);
if (GenericUtils.isEmpty(mths)) {
for (NamedFactory<UserAuth> uaf : factories) {
authMethods.add(new ArrayList<>(Collections.singletonList(uaf.getName())));
}
} else {
if (log.isDebugEnabled()) {
log.debug("ServerUserAuthService({}) using configured methods={}", s, mths);
}
for (String mthl : mths.split("\\s")) {
authMethods.add(new ArrayList<>(Arrays.asList(GenericUtils.split(mthl, ','))));
}
}
// Verify all required methods are supported
for (List<String> l : authMethods) {
for (String m : l) {
NamedFactory<UserAuth> factory = NamedResource.Utils.findByName(m, String.CASE_INSENSITIVE_ORDER, userAuthFactories);
if (factory == null) {
throw new SshException("Configured method is not supported: " + m);
}
}
}
if (log.isDebugEnabled()) {
log.debug("ServerUserAuthService({}) authorized authentication methods: {}",
s, NamedResource.Utils.getNames(userAuthFactories));
}
}
@Override
public void start() {
// do nothing
}
@Override
public ServerSession getSession() {
return getServerSession();
}
@Override
public ServerSession getServerSession() {
return serverSession;
}
@Override
public void process(int cmd, Buffer buffer) throws Exception {
Boolean authed = Boolean.FALSE;
ServerSession session = getServerSession();
if (cmd == SshConstants.SSH_MSG_USERAUTH_REQUEST) {
if (currentAuth != null) {
try {
currentAuth.destroy();
} finally {
currentAuth = null;
}
}
String username = buffer.getString();
String service = buffer.getString();
String method = buffer.getString();
if (log.isDebugEnabled()) {
log.debug("process({}) Received SSH_MSG_USERAUTH_REQUEST user={}, service={}, method={}",
session, username, service, method);
}
if (this.authUserName == null || this.authService == null) {
this.authUserName = username;
this.authService = service;
} else if (this.authUserName.equals(username) && this.authService.equals(service)) {
if (nbAuthRequests++ > maxAuthRequests) {
session.disconnect(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR, "Too many authentication failures: " + nbAuthRequests);
return;
}
} else {
session.disconnect(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR,
"Change of username or service is not allowed (" + this.authUserName + ", " + this.authService + ") -> ("
+ username + ", " + service + ")");
return;
}
// TODO: verify that the service is supported
this.authMethod = method;
if (log.isDebugEnabled()) {
log.debug("process({}) Authenticating user '{}' with service '{}' and method '{}' (attempt {} / {})",
session, username, service, method, nbAuthRequests, maxAuthRequests);
}
Factory<UserAuth> factory = NamedResource.Utils.findByName(method, String.CASE_INSENSITIVE_ORDER, userAuthFactories);
if (factory != null) {
currentAuth = ValidateUtils.checkNotNull(factory.create(), "No authenticator created for method=%s", method);
try {
authed = currentAuth.auth(session, username, service, buffer);
} catch (Exception e) {
if (asyncAuth(cmd, buffer, e)) {
return;
}
if (log.isDebugEnabled()) {
log.debug("process({}) Failed ({}) to authenticate using factory method={}: {}",
session, e.getClass().getSimpleName(), method, e.getMessage());
}
if (log.isTraceEnabled()) {
log.trace("process(" + session + ") factory authentication=" + method + " failure details", e);
}
}
} else {
if (log.isDebugEnabled()) {
log.debug("process({}) no authentication factory for method={}", session, method);
}
}
} else {
if (this.currentAuth == null) {
// This should not happen
throw new IllegalStateException("No current authentication mechanism for cmd=" + SshConstants.getCommandMessageName(cmd));
}
if (log.isDebugEnabled()) {
log.debug("process({}) Received authentication message={} for mechanism={}",
session, SshConstants.getCommandMessageName(cmd), currentAuth.getName());
}
buffer.rpos(buffer.rpos() - 1);
try {
authed = currentAuth.next(buffer);
} catch (Exception e) {
if (asyncAuth(cmd, buffer, e)) {
return;
}
// Continue
if (log.isDebugEnabled()) {
log.debug("process({}) Failed ({}) to authenticate using current method={}: {}",
session, e.getClass().getSimpleName(), currentAuth.getName(), e.getMessage());
}
if (log.isTraceEnabled()) {
log.trace("process(" + session + ") current authentiaction=" + currentAuth.getName() + " failure details", e);
}
}
}
if (authed == null) {
handleAuthenticationInProgress(cmd, buffer);
} else if (authed.booleanValue()) {
handleAuthenticationSuccess(cmd, buffer);
} else {
handleAuthenticationFailure(cmd, buffer);
}
}
private boolean asyncAuth(int cmd, Buffer buffer, Exception e) {
if (e instanceof AsyncAuth) {
async = (AsyncAuth) e;
async.setListener(authenticated -> {
async = null;
try {
if (authenticated) {
handleAuthenticationSuccess(cmd, buffer);
} else {
handleAuthenticationFailure(cmd, buffer);
}
} catch (Exception e1) {
// HANDLE THIS BETTER
e1.printStackTrace();
}
});
return true;
} else {
return false;
}
}
protected void handleAuthenticationInProgress(int cmd, Buffer buffer) throws Exception {
String username = (currentAuth == null) ? null : currentAuth.getUsername();
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationInProgress({}@{}) {}",
username, getServerSession(), SshConstants.getCommandMessageName(cmd));
}
}
protected void handleAuthenticationSuccess(int cmd, Buffer buffer) throws Exception {
String username = ValidateUtils.checkNotNull(currentAuth, "No current auth").getUsername();
ServerSession session = getServerSession();
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationSuccess({}@{}) {}",
username, session, SshConstants.getCommandMessageName(cmd));
}
boolean success = false;
for (List<String> l : authMethods) {
if ((GenericUtils.size(l) > 0) && l.get(0).equals(authMethod)) {
l.remove(0);
success |= l.isEmpty();
}
}
if (success) {
Integer maxSessionCount = PropertyResolverUtils.getInteger(session, ServerFactoryManager.MAX_CONCURRENT_SESSIONS);
if (maxSessionCount != null) {
int currentSessionCount = session.getActiveSessionCountForUser(username);
if (currentSessionCount >= maxSessionCount) {
session.disconnect(SshConstants.SSH2_DISCONNECT_TOO_MANY_CONNECTIONS,
"Too many concurrent connections (" + currentSessionCount + ") - max. allowed: " + maxSessionCount);
return;
}
}
/*
* TODO check if we can send the banner sooner. According to RFC-4252 section 5.4:
*
* The SSH server may send an SSH_MSG_USERAUTH_BANNER message at any
* time after this authentication protocol starts and before
* authentication is successful. This message contains text to be
* displayed to the client user before authentication is attempted.
*/
String welcomeBanner = PropertyResolverUtils.getString(session, ServerFactoryManager.WELCOME_BANNER);
if (GenericUtils.length(welcomeBanner) > 0) {
String lang = PropertyResolverUtils.getStringProperty(session,
ServerFactoryManager.WELCOME_BANNER_LANGUAGE,
ServerFactoryManager.DEFAULT_WELCOME_BANNER_LANGUAGE);
buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_BANNER,
welcomeBanner.length() + GenericUtils.length(lang) + Long.SIZE);
buffer.putString(welcomeBanner);
buffer.putString(lang);
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationSuccess({}@{}) send banner (length={}, lang={})",
username, session, welcomeBanner.length(), lang);
}
session.writePacket(buffer);
}
buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_SUCCESS, Byte.SIZE);
session.writePacket(buffer);
session.setUsername(username);
session.setAuthenticated();
session.startService(authService);
session.resetIdleTimeout();
log.info("Session {}@{} authenticated", username, session.getIoSession().getRemoteAddress());
} else {
StringBuilder sb = new StringBuilder();
for (List<String> l : authMethods) {
if (GenericUtils.size(l) > 0) {
if (sb.length() > 0) {
sb.append(",");
}
sb.append(l.get(0));
}
}
String remaining = sb.toString();
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationSuccess({}@{}) remaining methods={}", username, session, remaining);
}
buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, remaining.length() + Byte.SIZE);
buffer.putString(remaining);
buffer.putBoolean(true); // partial success ...
session.writePacket(buffer);
}
try {
currentAuth.destroy();
} finally {
currentAuth = null;
}
}
protected void handleAuthenticationFailure(int cmd, Buffer buffer) throws Exception {
String username = (currentAuth == null) ? null : currentAuth.getUsername();
ServerSession session = getServerSession();
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationFailure({}@{}) {}",
username, session, SshConstants.getCommandMessageName(cmd));
}
StringBuilder sb = new StringBuilder((authMethods.size() + 1) * Byte.SIZE);
for (List<String> l : authMethods) {
if (GenericUtils.size(l) > 0) {
String m = l.get(0);
if (!UserAuthNoneFactory.NAME.equals(m)) {
if (sb.length() > 0) {
sb.append(",");
}
sb.append(m);
}
}
}
String remaining = sb.toString();
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationFailure({}@{}) remaining methods: {}", username, session, remaining);
}
buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, remaining.length() + Byte.SIZE);
buffer.putString(remaining);
buffer.putBoolean(false); // no partial success ...
session.writePacket(buffer);
if (currentAuth != null) {
try {
currentAuth.destroy();
} finally {
currentAuth = null;
}
}
}
public ServerFactoryManager getFactoryManager() {
return serverSession.getFactoryManager();
}
}