/* * The Alluxio Open Foundation licenses this work under the Apache License, version 2.0 * (the "License"). You may not use this work except in compliance with the License, which is * available at www.apache.org/licenses/LICENSE-2.0 * * This software is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, * either express or implied, as more fully set forth in the License. * * See the NOTICE file distributed with this work for information regarding copyright ownership. */ package alluxio.security.authentication; import alluxio.Configuration; import alluxio.PropertyKey; import alluxio.exception.status.UnauthenticatedException; import alluxio.security.LoginUser; import alluxio.security.User; import org.apache.thrift.transport.TSaslClientTransport; import org.apache.thrift.transport.TSaslServerTransport; import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportFactory; import java.net.InetSocketAddress; import java.security.Security; import java.util.HashMap; import java.util.Set; import javax.annotation.concurrent.ThreadSafe; import javax.security.auth.Subject; import javax.security.sasl.SaslException; /** * If authentication type is {@link AuthType#SIMPLE} or {@link AuthType#CUSTOM}, this is the * default transport provider which uses SASL transport. */ @ThreadSafe public final class PlainSaslTransportProvider implements TransportProvider { static { Security.addProvider(new PlainSaslServerProvider()); } /** Timeout for socket in ms. */ private int mSocketTimeoutMs; /** * Constructor for transport provider with {@link AuthType#SIMPLE} or {@link AuthType#CUSTOM}. */ public PlainSaslTransportProvider() { mSocketTimeoutMs = Configuration.getInt(PropertyKey.SECURITY_AUTHENTICATION_SOCKET_TIMEOUT_MS); } @Override public TTransport getClientTransport(InetSocketAddress serverAddress) throws UnauthenticatedException { String username = LoginUser.get().getName(); String password = "noPassword"; return getClientTransport(username, password, serverAddress); } @Override public TTransport getClientTransport(Subject subject, InetSocketAddress serverAddress) throws UnauthenticatedException { String username = null; String password = "noPassword"; if (subject != null) { Set<User> user = subject.getPrincipals(User.class); if (user != null && !user.isEmpty()) { username = user.iterator().next().getName(); } } if (username == null || username.isEmpty()) { username = LoginUser.get().getName(); } return getClientTransport(username, password, serverAddress); } // TODO(binfan): make this private and use whitebox to access this method in test /** * Gets a PLAIN mechanism transport for client side. * * @param username User Name of PlainClient * @param password Password of PlainClient * @param serverAddress Address of the server * @return Wrapped transport with PLAIN mechanism */ public TTransport getClientTransport(String username, String password, InetSocketAddress serverAddress) throws UnauthenticatedException { TTransport wrappedTransport = TransportProviderUtils.createThriftSocket(serverAddress, mSocketTimeoutMs); try { return new TSaslClientTransport(PlainSaslServerProvider.MECHANISM, null, null, null, new HashMap<String, String>(), new PlainSaslClientCallbackHandler(username, password), wrappedTransport); } catch (SaslException e) { throw new UnauthenticatedException(e.getMessage(), e); } } @Override public TTransportFactory getServerTransportFactory() throws SaslException { return getServerTransportFactory(new Runnable() { @Override public void run() {} }); } @Override public TTransportFactory getServerTransportFactory(Runnable runnable) throws SaslException { AuthType authType = Configuration.getEnum(PropertyKey.SECURITY_AUTHENTICATION_TYPE, AuthType.class); TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory(); AuthenticationProvider provider = AuthenticationProvider.Factory.create(authType); saslFactory .addServerDefinition(PlainSaslServerProvider.MECHANISM, null, null, new HashMap<String, String>(), new PlainSaslServerCallbackHandler(provider, runnable)); return saslFactory; } }