/* * Copyright (C)2009 - SSHJ Contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.schmizz.sshj.userauth; import net.schmizz.concurrent.Promise; import net.schmizz.sshj.AbstractService; import net.schmizz.sshj.Service; import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.userauth.method.AuthMethod; import java.util.Arrays; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.concurrent.TimeUnit; /** * {@link UserAuth} implementation. */ public class UserAuthImpl extends AbstractService implements UserAuth { private final Promise<Boolean, UserAuthException> authenticated; // Externally available private volatile String banner = ""; private volatile boolean partialSuccess = false; private volatile List<String> allowedMethods = new LinkedList<String>(); // Internal state private volatile AuthMethod currentMethod; private volatile Service nextService; // The next service layer to set on the transport once we're authenticated. public UserAuthImpl(Transport trans) { super("ssh-userauth", trans); authenticated = new Promise<Boolean, UserAuthException>("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory()); } @Override public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs) throws UserAuthException, TransportException { final boolean outcome; authenticated.lock(); try { super.request(); // Request "ssh-userauth" service (if not already active) currentMethod = method; this.nextService = nextService; currentMethod.init(makeAuthParams(username, nextService)); authenticated.clear(); log.debug("Trying `{}` auth...", method.getName()); currentMethod.request(); outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS); if (outcome) { log.debug("`{}` auth successful", method.getName()); } else { log.debug("`{}` auth failed", method.getName()); } } finally { // Clear the internal state. currentMethod = null; this.nextService = null; authenticated.unlock(); } return outcome; } @Override public String getBanner() { return banner; } @Override public boolean hadPartialSuccess() { return partialSuccess; } @Override public Iterable<String> getAllowedMethods() { return Collections.unmodifiableList(allowedMethods); } @Override public void handle(Message msg, SSHPacket buf) throws SSHException { if (!msg.in(50, 80)) // ssh-userauth packets have message numbers between 50-80 throw new TransportException(DisconnectReason.PROTOCOL_ERROR); authenticated.lock(); try { switch (msg) { case USERAUTH_BANNER: banner = buf.readString(); break; case USERAUTH_SUCCESS: // In order to prevent race conditions, we immediately set the authenticated flag on the transport // And change the service before delivering the authenticated promise. // Should fix https://github.com/hierynomus/sshj/issues/237 trans.setAuthenticated(); // So it can put delayed compression into force if applicable trans.setService(nextService); // We aren't in charge anymore, next service is authenticated.deliver(true); break; case USERAUTH_FAILURE: allowedMethods = Arrays.asList(buf.readString().split(",")); partialSuccess |= buf.readBoolean(); if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) { currentMethod.request(); } else { authenticated.deliver(false); } break; default: log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg); try { currentMethod.handle(msg, buf); } catch (UserAuthException e) { authenticated.deliverError(e); } break; } } finally { authenticated.unlock(); } } @Override public void notifyError(SSHException error) { super.notifyError(error); authenticated.deliverError(error); } private AuthParams makeAuthParams(final String username, final Service nextService) { return new AuthParams() { @Override public String getNextServiceName() { return nextService.getName(); } @Override public Transport getTransport() { return trans; } @Override public String getUsername() { return username; } }; } }