/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.server.session; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.apache.sshd.common.Factory; import org.apache.sshd.common.NamedFactory; import org.apache.sshd.common.NamedResource; import org.apache.sshd.common.PropertyResolverUtils; import org.apache.sshd.common.Service; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; import org.apache.sshd.common.config.keys.KeyRandomArt; import org.apache.sshd.common.io.IoWriteFuture; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.NumberUtils; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.common.util.closeable.AbstractCloseable; import org.apache.sshd.common.util.io.IoUtils; import org.apache.sshd.server.ServerAuthenticationManager; import org.apache.sshd.server.ServerFactoryManager; import org.apache.sshd.server.auth.UserAuth; import org.apache.sshd.server.auth.UserAuthNoneFactory; import org.apache.sshd.server.auth.WelcomeBannerPhase; /** * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ public class ServerUserAuthService extends AbstractCloseable implements Service, ServerSessionHolder { private final ServerSession serverSession; private final AtomicBoolean welcomeSent = new AtomicBoolean(false); private final WelcomeBannerPhase welcomePhase; private List<NamedFactory<UserAuth>> userAuthFactories; private List<List<String>> authMethods; private String authUserName; private String authMethod; private String authService; private UserAuth currentAuth; private int maxAuthRequests; private int nbAuthRequests; public ServerUserAuthService(Session s) throws IOException { serverSession = ValidateUtils.checkInstanceOf(s, ServerSession.class, "Server side service used on client side: %s", s); if (s.isAuthenticated()) { throw new SshException("Session already authenticated"); } Object phase = PropertyResolverUtils.getObject(s, ServerAuthenticationManager.WELCOME_BANNER_PHASE); phase = PropertyResolverUtils.toEnum(WelcomeBannerPhase.class, phase, true, WelcomeBannerPhase.VALUES); welcomePhase = (phase == null) ? ServerAuthenticationManager.DEFAULT_BANNER_PHASE : (WelcomeBannerPhase) phase; maxAuthRequests = s.getIntProperty(ServerAuthenticationManager.MAX_AUTH_REQUESTS, ServerAuthenticationManager.DEFAULT_MAX_AUTH_REQUESTS); List<NamedFactory<UserAuth>> factories = ValidateUtils.checkNotNullAndNotEmpty( serverSession.getUserAuthFactories(), "No user auth factories for %s", s); userAuthFactories = new ArrayList<>(factories); // Get authentication methods authMethods = new ArrayList<>(); String mths = s.getString(ServerAuthenticationManager.AUTH_METHODS); if (GenericUtils.isEmpty(mths)) { for (NamedFactory<UserAuth> uaf : factories) { authMethods.add(new ArrayList<>(Collections.singletonList(uaf.getName()))); } } else { if (log.isDebugEnabled()) { log.debug("ServerUserAuthService({}) using configured methods={}", s, mths); } for (String mthl : mths.split("\\s")) { authMethods.add(new ArrayList<>(Arrays.asList(GenericUtils.split(mthl, ',')))); } } // Verify all required methods are supported for (List<String> l : authMethods) { for (String m : l) { NamedFactory<UserAuth> factory = NamedResource.findByName(m, String.CASE_INSENSITIVE_ORDER, userAuthFactories); if (factory == null) { throw new SshException("Configured method is not supported: " + m); } } } if (log.isDebugEnabled()) { log.debug("ServerUserAuthService({}) authorized authentication methods: {}", s, NamedResource.getNames(userAuthFactories)); } } public WelcomeBannerPhase getWelcomePhase() { return welcomePhase; } @Override public void start() { // do nothing } @Override public ServerSession getSession() { return getServerSession(); } @Override public ServerSession getServerSession() { return serverSession; } @Override public void process(int cmd, Buffer buffer) throws Exception { Boolean authed = Boolean.FALSE; ServerSession session = getServerSession(); if (cmd == SshConstants.SSH_MSG_USERAUTH_REQUEST) { if (WelcomeBannerPhase.FIRST_REQUEST.equals(getWelcomePhase())) { sendWelcomeBanner(session); } if (currentAuth != null) { try { currentAuth.destroy(); } finally { currentAuth = null; } } String username = buffer.getString(); String service = buffer.getString(); String method = buffer.getString(); if (log.isDebugEnabled()) { log.debug("process({}) Received SSH_MSG_USERAUTH_REQUEST user={}, service={}, method={}", session, username, service, method); } if (this.authUserName == null || this.authService == null) { this.authUserName = username; this.authService = service; } else if (this.authUserName.equals(username) && this.authService.equals(service)) { if (nbAuthRequests++ > maxAuthRequests) { session.disconnect(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR, "Too many authentication failures: " + nbAuthRequests); return; } } else { session.disconnect(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR, "Change of username or service is not allowed (" + this.authUserName + ", " + this.authService + ") -> (" + username + ", " + service + ")"); return; } // TODO: verify that the service is supported this.authMethod = method; if (log.isDebugEnabled()) { log.debug("process({}) Authenticating user '{}' with service '{}' and method '{}' (attempt {} / {})", session, username, service, method, nbAuthRequests, maxAuthRequests); } Factory<UserAuth> factory = NamedResource.findByName(method, String.CASE_INSENSITIVE_ORDER, userAuthFactories); if (factory != null) { currentAuth = ValidateUtils.checkNotNull(factory.create(), "No authenticator created for method=%s", method); try { authed = currentAuth.auth(session, username, service, buffer); } catch (Exception e) { if (log.isDebugEnabled()) { log.debug("process({}) Failed ({}) to authenticate using factory method={}: {}", session, e.getClass().getSimpleName(), method, e.getMessage()); } if (log.isTraceEnabled()) { log.trace("process(" + session + ") factory authentication=" + method + " failure details", e); } } } else { if (log.isDebugEnabled()) { log.debug("process({}) no authentication factory for method={}", session, method); } } } else { if (WelcomeBannerPhase.FIRST_AUTHCMD.equals(getWelcomePhase())) { sendWelcomeBanner(session); } if (this.currentAuth == null) { // This should not happen throw new IllegalStateException("No current authentication mechanism for cmd=" + SshConstants.getCommandMessageName(cmd)); } if (log.isDebugEnabled()) { log.debug("process({}) Received authentication message={} for mechanism={}", session, SshConstants.getCommandMessageName(cmd), currentAuth.getName()); } buffer.rpos(buffer.rpos() - 1); try { authed = currentAuth.next(buffer); } catch (Exception e) { // Continue if (log.isDebugEnabled()) { log.debug("process({}) Failed ({}) to authenticate using current method={}: {}", session, e.getClass().getSimpleName(), currentAuth.getName(), e.getMessage()); } if (log.isTraceEnabled()) { log.trace("process(" + session + ") current authentication=" + currentAuth.getName() + " failure details", e); } } } if (authed == null) { handleAuthenticationInProgress(cmd, buffer); } else if (authed) { handleAuthenticationSuccess(cmd, buffer); } else { handleAuthenticationFailure(cmd, buffer); } } protected void handleAuthenticationInProgress(int cmd, Buffer buffer) throws Exception { String username = (currentAuth == null) ? null : currentAuth.getUsername(); if (log.isDebugEnabled()) { log.debug("handleAuthenticationInProgress({}@{}) {}", username, getServerSession(), SshConstants.getCommandMessageName(cmd)); } } protected void handleAuthenticationSuccess(int cmd, Buffer buffer) throws Exception { String username = Objects.requireNonNull(currentAuth, "No current auth").getUsername(); ServerSession session = getServerSession(); if (log.isDebugEnabled()) { log.debug("handleAuthenticationSuccess({}@{}) {}", username, session, SshConstants.getCommandMessageName(cmd)); } boolean success = false; for (List<String> l : authMethods) { if ((GenericUtils.size(l) > 0) && l.get(0).equals(authMethod)) { l.remove(0); success |= l.isEmpty(); } } if (success) { Integer maxSessionCount = session.getInteger(ServerFactoryManager.MAX_CONCURRENT_SESSIONS); if (maxSessionCount != null) { int currentSessionCount = session.getActiveSessionCountForUser(username); if (currentSessionCount >= maxSessionCount) { session.disconnect(SshConstants.SSH2_DISCONNECT_TOO_MANY_CONNECTIONS, "Too many concurrent connections (" + currentSessionCount + ") - max. allowed: " + maxSessionCount); return; } } if (WelcomeBannerPhase.POST_SUCCESS.equals(getWelcomePhase())) { sendWelcomeBanner(session); } buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_SUCCESS, Byte.SIZE); session.writePacket(buffer); session.setUsername(username); session.setAuthenticated(); session.startService(authService); session.resetIdleTimeout(); log.info("Session {}@{} authenticated", username, session.getIoSession().getRemoteAddress()); } else { String remaining = authMethods.stream() .filter(GenericUtils::isNotEmpty) .map(l -> l.get(0)) .collect(Collectors.joining(",")); if (log.isDebugEnabled()) { log.debug("handleAuthenticationSuccess({}@{}) remaining methods={}", username, session, remaining); } buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, remaining.length() + Byte.SIZE); buffer.putString(remaining); buffer.putBoolean(true); // partial success ... session.writePacket(buffer); } try { currentAuth.destroy(); } finally { currentAuth = null; } } protected void handleAuthenticationFailure(int cmd, Buffer buffer) throws Exception { ServerSession session = getServerSession(); if (WelcomeBannerPhase.FIRST_FAILURE.equals(getWelcomePhase())) { sendWelcomeBanner(session); } String username = (currentAuth == null) ? null : currentAuth.getUsername(); if (log.isDebugEnabled()) { log.debug("handleAuthenticationFailure({}@{}) {}", username, session, SshConstants.getCommandMessageName(cmd)); } StringBuilder sb = new StringBuilder((authMethods.size() + 1) * Byte.SIZE); for (List<String> l : authMethods) { if (GenericUtils.size(l) > 0) { String m = l.get(0); if (!UserAuthNoneFactory.NAME.equals(m)) { if (sb.length() > 0) { sb.append(","); } sb.append(m); } } } String remaining = sb.toString(); if (log.isDebugEnabled()) { log.debug("handleAuthenticationFailure({}@{}) remaining methods: {}", username, session, remaining); } buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, remaining.length() + Byte.SIZE); buffer.putString(remaining); buffer.putBoolean(false); // no partial success ... session.writePacket(buffer); if (currentAuth != null) { try { currentAuth.destroy(); } finally { currentAuth = null; } } } /** * Sends the welcome banner (if any configured) and if not already invoked * * @param session The {@link ServerSession} to send the welcome banner to * @return The sent welcome banner {@link IoWriteFuture} - {@code null} if none sent * @throws IOException If failed to send the banner */ public IoWriteFuture sendWelcomeBanner(ServerSession session) throws IOException { if (welcomeSent.getAndSet(true)) { if (log.isDebugEnabled()) { log.debug("sendWelcomeBanner({}) already sent", session); } return null; } String welcomeBanner = resolveWelcomeBanner(session); if (GenericUtils.isEmpty(welcomeBanner)) { return null; } String lang = PropertyResolverUtils.getStringProperty(session, ServerAuthenticationManager.WELCOME_BANNER_LANGUAGE, ServerAuthenticationManager.DEFAULT_WELCOME_BANNER_LANGUAGE); Buffer buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_BANNER, welcomeBanner.length() + GenericUtils.length(lang) + Long.SIZE); buffer.putString(welcomeBanner); buffer.putString(lang); if (log.isDebugEnabled()) { log.debug("sendWelcomeBanner({}) send banner (length={}, lang={})", session, welcomeBanner.length(), lang); } return session.writePacket(buffer); } protected String resolveWelcomeBanner(ServerSession session) throws IOException { Object bannerValue = session.getObject(ServerAuthenticationManager.WELCOME_BANNER); if (bannerValue == null) { return null; } if (bannerValue instanceof CharSequence) { String message = bannerValue.toString(); if (GenericUtils.isEmpty(message)) { return null; } if (ServerAuthenticationManager.AUTO_WELCOME_BANNER_VALUE.equalsIgnoreCase(message)) { try { return KeyRandomArt.combine(' ', session.getKeyPairProvider()); } catch (Exception e) { if (e instanceof IOException) { throw (IOException) e; } throw new IOException(e); } } if (!message.contains("://")) { return message; } try { bannerValue = new URI(message); } catch (URISyntaxException e) { log.error("resolveWelcomeBanner({}) bad path URI {}: {}", session, message, e.getMessage()); throw new MalformedURLException(e.getClass().getSimpleName() + " - bad URI (" + message + "): " + e.getMessage()); } if (message.startsWith("file:/")) { bannerValue = Paths.get((URI) bannerValue); } } if (bannerValue instanceof File) { bannerValue = ((File) bannerValue).toPath(); } if (bannerValue instanceof Path) { Path path = (Path) bannerValue; if ((!Files.exists(path)) || (Files.size(path) <= 0L)) { if (log.isDebugEnabled()) { log.debug("resolveWelcomeBanner({}) file is empty/does not exist", session, path); } return null; } bannerValue = path.toUri(); } if (bannerValue instanceof URI) { bannerValue = ((URI) bannerValue).toURL(); } if (bannerValue instanceof URL) { Charset cs = PropertyResolverUtils.getCharset(session, ServerAuthenticationManager.WELCOME_BANNER_CHARSET, Charset.defaultCharset()); return loadWelcomeBanner(session, (URL) bannerValue, cs); } return bannerValue.toString(); } protected String loadWelcomeBanner(ServerSession session, URL url, Charset cs) throws IOException { try (InputStream stream = url.openStream()) { byte[] bytes = IoUtils.toByteArray(stream); return NumberUtils.isEmpty(bytes) ? "" : new String(bytes, cs); } } public ServerFactoryManager getFactoryManager() { return serverSession.getFactoryManager(); } }