/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.security; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.ByteArrayInputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.FilterInputStream; import java.io.FilterOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.regex.Pattern; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; import javax.security.auth.callback.PasswordCallback; import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.auth.kerberos.KerberosPrincipal; import javax.security.sasl.RealmCallback; import javax.security.sasl.RealmChoiceCallback; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; import javax.security.sasl.SaslClient; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.GlobPattern; import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper; import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper; import org.apache.hadoop.ipc.RPC.RpcKind; import org.apache.hadoop.ipc.RemoteException; import org.apache.hadoop.ipc.RpcConstants; import org.apache.hadoop.ipc.Server.AuthProtocol; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslAuth; import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcSaslProto.SaslState; import org.apache.hadoop.security.SaslRpcServer.AuthMethod; import org.apache.hadoop.security.authentication.util.KerberosName; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenInfo; import org.apache.hadoop.security.token.TokenSelector; import org.apache.hadoop.util.ProtoUtil; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; /** * A utility class that encapsulates SASL logic for RPC client */ @InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) @InterfaceStability.Evolving public class SaslRpcClient { public static final Log LOG = LogFactory.getLog(SaslRpcClient.class); private final UserGroupInformation ugi; private final Class<?> protocol; private final InetSocketAddress serverAddr; private final Configuration conf; private SaslClient saslClient; private AuthMethod authMethod; private static final RpcRequestHeaderProto saslHeader = ProtoUtil .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER, OperationProto.RPC_FINAL_PACKET, AuthProtocol.SASL.callId, RpcConstants.INVALID_RETRY_COUNT, RpcConstants.DUMMY_CLIENT_ID); private static final RpcSaslProto negotiateRequest = RpcSaslProto.newBuilder().setState(SaslState.NEGOTIATE).build(); /** * Create a SaslRpcClient that can be used by a RPC client to negotiate * SASL authentication with a RPC server * @param ugi - connecting user * @param protocol - RPC protocol * @param serverAddr - InetSocketAddress of remote server * @param conf - Configuration */ public SaslRpcClient(UserGroupInformation ugi, Class<?> protocol, InetSocketAddress serverAddr, Configuration conf) { this.ugi = ugi; this.protocol = protocol; this.serverAddr = serverAddr; this.conf = conf; } @VisibleForTesting @InterfaceAudience.Private public Object getNegotiatedProperty(String key) { return (saslClient != null) ? saslClient.getNegotiatedProperty(key) : null; } // the RPC Client has an inelegant way of handling expiration of TGTs // acquired via a keytab. any connection failure causes a relogin, so // the Client needs to know what authMethod was being attempted if an // exception occurs. the SASL prep for a kerberos connection should // ideally relogin if necessary instead of exposing this detail to the // Client @InterfaceAudience.Private public AuthMethod getAuthMethod() { return authMethod; } /** * Instantiate a sasl client for the first supported auth type in the * given list. The auth type must be defined, enabled, and the user * must possess the required credentials, else the next auth is tried. * * @param authTypes to attempt in the given order * @return SaslAuth of instantiated client * @throws AccessControlException - client doesn't support any of the auths * @throws IOException - misc errors */ private SaslAuth selectSaslClient(List<SaslAuth> authTypes) throws SaslException, AccessControlException, IOException { SaslAuth selectedAuthType = null; boolean switchToSimple = false; for (SaslAuth authType : authTypes) { if (!isValidAuthType(authType)) { continue; // don't know what it is, try next } AuthMethod authMethod = AuthMethod.valueOf(authType.getMethod()); if (authMethod == AuthMethod.SIMPLE) { switchToSimple = true; } else { saslClient = createSaslClient(authType); if (saslClient == null) { // client lacks credentials, try next continue; } } selectedAuthType = authType; break; } if (saslClient == null && !switchToSimple) { List<String> serverAuthMethods = new ArrayList<String>(); for (SaslAuth authType : authTypes) { serverAuthMethods.add(authType.getMethod()); } throw new AccessControlException( "Client cannot authenticate via:" + serverAuthMethods); } if (LOG.isDebugEnabled()) { LOG.debug("Use " + selectedAuthType.getMethod() + " authentication for protocol " + protocol.getSimpleName()); } return selectedAuthType; } private boolean isValidAuthType(SaslAuth authType) { AuthMethod authMethod; try { authMethod = AuthMethod.valueOf(authType.getMethod()); } catch (IllegalArgumentException iae) { // unknown auth authMethod = null; } // do we know what it is? is it using our mechanism? return authMethod != null && authMethod.getMechanismName().equals(authType.getMechanism()); } /** * Try to create a SaslClient for an authentication type. May return * null if the type isn't supported or the client lacks the required * credentials. * * @param authType - the requested authentication method * @return SaslClient for the authType or null * @throws SaslException - error instantiating client * @throws IOException - misc errors */ private SaslClient createSaslClient(SaslAuth authType) throws SaslException, IOException { String saslUser = null; // SASL requires the client and server to use the same proto and serverId // if necessary, auth types below will verify they are valid final String saslProtocol = authType.getProtocol(); final String saslServerName = authType.getServerId(); Map<String, String> saslProperties = SaslRpcServer.SASL_PROPS; CallbackHandler saslCallback = null; final AuthMethod method = AuthMethod.valueOf(authType.getMethod()); switch (method) { case TOKEN: { Token<?> token = getServerToken(authType); if (token == null) { return null; // tokens aren't supported or user doesn't have one } saslCallback = new SaslClientCallbackHandler(token); break; } case KERBEROS: { if (ugi.getRealAuthenticationMethod().getAuthMethod() != AuthMethod.KERBEROS) { return null; // client isn't using kerberos } String serverPrincipal = getServerPrincipal(authType); if (serverPrincipal == null) { return null; // protocol doesn't use kerberos } if (LOG.isDebugEnabled()) { LOG.debug("RPC Server's Kerberos principal name for protocol=" + protocol.getCanonicalName() + " is " + serverPrincipal); } break; } default: throw new IOException("Unknown authentication method " + method); } String mechanism = method.getMechanismName(); if (LOG.isDebugEnabled()) { LOG.debug("Creating SASL " + mechanism + "(" + method + ") " + " client to authenticate to service at " + saslServerName); } return Sasl.createSaslClient( new String[] { mechanism }, saslUser, saslProtocol, saslServerName, saslProperties, saslCallback); } /** * Try to locate the required token for the server. * * @param authType of the SASL client * @return Token<?> for server, or null if no token available * @throws IOException - token selector cannot be instantiated */ private Token<?> getServerToken(SaslAuth authType) throws IOException { TokenInfo tokenInfo = SecurityUtil.getTokenInfo(protocol, conf); LOG.debug("Get token info proto:"+protocol+" info:"+tokenInfo); if (tokenInfo == null) { // protocol has no support for tokens return null; } TokenSelector<?> tokenSelector = null; try { tokenSelector = tokenInfo.value().newInstance(); } catch (InstantiationException e) { throw new IOException(e.toString()); } catch (IllegalAccessException e) { throw new IOException(e.toString()); } return tokenSelector.selectToken( SecurityUtil.buildTokenService(serverAddr), ugi.getTokens()); } /** * Get the remote server's principal. The value will be obtained from * the config and cross-checked against the server's advertised principal. * * @param authType of the SASL client * @return String of the server's principal * @throws IOException - error determining configured principal */ @VisibleForTesting String getServerPrincipal(SaslAuth authType) throws IOException { KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(protocol, conf); LOG.debug("Get kerberos info proto:"+protocol+" info:"+krbInfo); if (krbInfo == null) { // protocol has no support for kerberos return null; } String serverKey = krbInfo.serverPrincipal(); if (serverKey == null) { throw new IllegalArgumentException( "Can't obtain server Kerberos config key from protocol=" + protocol.getCanonicalName()); } // construct server advertised principal for comparision String serverPrincipal = new KerberosPrincipal( authType.getProtocol() + "/" + authType.getServerId()).getName(); boolean isPrincipalValid = false; // use the pattern if defined String serverKeyPattern = conf.get(serverKey + ".pattern"); if (serverKeyPattern != null && !serverKeyPattern.isEmpty()) { Pattern pattern = GlobPattern.compile(serverKeyPattern); isPrincipalValid = pattern.matcher(serverPrincipal).matches(); } else { // check that the server advertised principal matches our conf String confPrincipal = SecurityUtil.getServerPrincipal( conf.get(serverKey), serverAddr.getAddress()); if (confPrincipal == null || confPrincipal.isEmpty()) { throw new IllegalArgumentException( "Failed to specify server's Kerberos principal name"); } KerberosName name = new KerberosName(confPrincipal); if (name.getHostName() == null) { throw new IllegalArgumentException( "Kerberos principal name does NOT have the expected hostname part: " + confPrincipal); } isPrincipalValid = serverPrincipal.equals(confPrincipal); } if (!isPrincipalValid) { throw new IllegalArgumentException( "Server has invalid Kerberos principal: " + serverPrincipal); } return serverPrincipal; } /** * Do client side SASL authentication with server via the given InputStream * and OutputStream * * @param inS * InputStream to use * @param outS * OutputStream to use * @return AuthMethod used to negotiate the connection * @throws IOException */ public AuthMethod saslConnect(InputStream inS, OutputStream outS) throws IOException { DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS)); DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream( outS)); // redefined if/when a SASL negotiation starts, can be queried if the // negotiation fails authMethod = AuthMethod.SIMPLE; sendSaslMessage(outStream, negotiateRequest); // loop until sasl is complete or a rpc error occurs boolean done = false; do { int totalLen = inStream.readInt(); RpcResponseMessageWrapper responseWrapper = new RpcResponseMessageWrapper(); responseWrapper.readFields(inStream); RpcResponseHeaderProto header = responseWrapper.getMessageHeader(); switch (header.getStatus()) { case ERROR: // might get a RPC error during case FATAL: throw new RemoteException(header.getExceptionClassName(), header.getErrorMsg()); default: break; } if (totalLen != responseWrapper.getLength()) { throw new SaslException("Received malformed response length"); } if (header.getCallId() != AuthProtocol.SASL.callId) { throw new SaslException("Non-SASL response during negotiation"); } RpcSaslProto saslMessage = RpcSaslProto.parseFrom(responseWrapper.getMessageBytes()); if (LOG.isDebugEnabled()) { LOG.debug("Received SASL message "+saslMessage); } // handle sasl negotiation process RpcSaslProto.Builder response = null; switch (saslMessage.getState()) { case NEGOTIATE: { // create a compatible SASL client, throws if no supported auths SaslAuth saslAuthType = selectSaslClient(saslMessage.getAuthsList()); // define auth being attempted, caller can query if connect fails authMethod = AuthMethod.valueOf(saslAuthType.getMethod()); byte[] responseToken = null; if (authMethod == AuthMethod.SIMPLE) { // switching to SIMPLE done = true; // not going to wait for success ack } else { byte[] challengeToken = null; if (saslAuthType.hasChallenge()) { // server provided the first challenge challengeToken = saslAuthType.getChallenge().toByteArray(); saslAuthType = SaslAuth.newBuilder(saslAuthType).clearChallenge().build(); } else if (saslClient.hasInitialResponse()) { challengeToken = new byte[0]; } responseToken = (challengeToken != null) ? saslClient.evaluateChallenge(challengeToken) : new byte[0]; } response = createSaslReply(SaslState.INITIATE, responseToken); response.addAuths(saslAuthType); break; } case CHALLENGE: { if (saslClient == null) { // should probably instantiate a client to allow a server to // demand a specific negotiation throw new SaslException("Server sent unsolicited challenge"); } byte[] responseToken = saslEvaluateToken(saslMessage, false); response = createSaslReply(SaslState.RESPONSE, responseToken); break; } case SUCCESS: { // simple server sends immediate success to a SASL client for // switch to simple if (saslClient == null) { authMethod = AuthMethod.SIMPLE; } else { saslEvaluateToken(saslMessage, true); } done = true; break; } default: { throw new SaslException( "RPC client doesn't support SASL " + saslMessage.getState()); } } if (response != null) { sendSaslMessage(outStream, response.build()); } } while (!done); return authMethod; } private void sendSaslMessage(DataOutputStream out, RpcSaslProto message) throws IOException { if (LOG.isDebugEnabled()) { LOG.debug("Sending sasl message "+message); } RpcRequestMessageWrapper request = new RpcRequestMessageWrapper(saslHeader, message); out.writeInt(request.getLength()); request.write(out); out.flush(); } /** * Evaluate the server provided challenge. The server must send a token * if it's not done. If the server is done, the challenge token is * optional because not all mechanisms send a final token for the client to * update its internal state. The client must also be done after * evaluating the optional token to ensure a malicious server doesn't * prematurely end the negotiation with a phony success. * * @param saslResponse - client response to challenge * @param serverIsDone - server negotiation state * @throws SaslException - any problems with negotiation */ private byte[] saslEvaluateToken(RpcSaslProto saslResponse, boolean serverIsDone) throws SaslException { byte[] saslToken = null; if (saslResponse.hasToken()) { saslToken = saslResponse.getToken().toByteArray(); saslToken = saslClient.evaluateChallenge(saslToken); } else if (!serverIsDone) { // the server may only omit a token when it's done throw new SaslException("Server challenge contains no token"); } if (serverIsDone) { // server tried to report success before our client completed if (!saslClient.isComplete()) { throw new SaslException("Client is out of sync with server"); } // a client cannot generate a response to a success message if (saslToken != null) { throw new SaslException("Client generated spurious response"); } } return saslToken; } private RpcSaslProto.Builder createSaslReply(SaslState state, byte[] responseToken) { RpcSaslProto.Builder response = RpcSaslProto.newBuilder(); response.setState(state); if (responseToken != null) { response.setToken(ByteString.copyFrom(responseToken)); } return response; } private boolean useWrap() { // getNegotiatedProperty throws if client isn't complete String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP); // SASL wrapping is only used if the connection has a QOP, and // the value is not auth. ex. auth-int & auth-priv return qop != null && !"auth".equalsIgnoreCase(qop); } /** * Get SASL wrapped InputStream if SASL QoP requires unwrapping, * otherwise return original stream. Can be called only after * saslConnect() has been called. * * @param in - InputStream used to make the connection * @return InputStream that may be using SASL unwrap * @throws IOException */ public InputStream getInputStream(InputStream in) throws IOException { if (useWrap()) { in = new WrappedInputStream(in); } return in; } /** * Get SASL wrapped OutputStream if SASL QoP requires wrapping, * otherwise return original stream. Can be called only after * saslConnect() has been called. * * @param in - InputStream used to make the connection * @return InputStream that may be using SASL unwrap * @throws IOException */ public OutputStream getOutputStream(OutputStream out) throws IOException { if (useWrap()) { // the client and server negotiate a maximum buffer size that can be // wrapped String maxBuf = (String)saslClient.getNegotiatedProperty(Sasl.RAW_SEND_SIZE); out = new BufferedOutputStream(new WrappedOutputStream(out), Integer.parseInt(maxBuf)); } return out; } // ideally this should be folded into the RPC decoding loop but it's // currently split across Client and SaslRpcClient... class WrappedInputStream extends FilterInputStream { private ByteBuffer unwrappedRpcBuffer = ByteBuffer.allocate(0); public WrappedInputStream(InputStream in) throws IOException { super(in); } @Override public int read() throws IOException { byte[] b = new byte[1]; int n = read(b, 0, 1); return (n != -1) ? b[0] : -1; } @Override public int read(byte b[]) throws IOException { return read(b, 0, b.length); } @Override public int read(byte[] buf, int off, int len) throws IOException { synchronized(unwrappedRpcBuffer) { // fill the buffer with the next RPC message if (unwrappedRpcBuffer.remaining() == 0) { readNextRpcPacket(); } // satisfy as much of the request as possible int readLen = Math.min(len, unwrappedRpcBuffer.remaining()); unwrappedRpcBuffer.get(buf, off, readLen); return readLen; } } // all messages must be RPC SASL wrapped, else an exception is thrown private void readNextRpcPacket() throws IOException { LOG.debug("reading next wrapped RPC packet"); DataInputStream dis = new DataInputStream(in); int rpcLen = dis.readInt(); byte[] rpcBuf = new byte[rpcLen]; dis.readFully(rpcBuf); // decode the RPC header ByteArrayInputStream bis = new ByteArrayInputStream(rpcBuf); RpcResponseHeaderProto.Builder headerBuilder = RpcResponseHeaderProto.newBuilder(); headerBuilder.mergeDelimitedFrom(bis); boolean isWrapped = false; // Must be SASL wrapped, verify and decode. if (headerBuilder.getCallId() == AuthProtocol.SASL.callId) { RpcSaslProto.Builder saslMessage = RpcSaslProto.newBuilder(); saslMessage.mergeDelimitedFrom(bis); if (saslMessage.getState() == SaslState.WRAP) { isWrapped = true; byte[] token = saslMessage.getToken().toByteArray(); if (LOG.isDebugEnabled()) { LOG.debug("unwrapping token of length:" + token.length); } token = saslClient.unwrap(token, 0, token.length); unwrappedRpcBuffer = ByteBuffer.wrap(token); } } if (!isWrapped) { throw new SaslException("Server sent non-wrapped response"); } } } class WrappedOutputStream extends FilterOutputStream { public WrappedOutputStream(OutputStream out) throws IOException { super(out); } @Override public void write(byte[] buf, int off, int len) throws IOException { if (LOG.isDebugEnabled()) { LOG.debug("wrapping token of length:" + len); } buf = saslClient.wrap(buf, off, len); RpcSaslProto saslMessage = RpcSaslProto.newBuilder() .setState(SaslState.WRAP) .setToken(ByteString.copyFrom(buf, 0, buf.length)) .build(); RpcRequestMessageWrapper request = new RpcRequestMessageWrapper(saslHeader, saslMessage); DataOutputStream dob = new DataOutputStream(out); dob.writeInt(request.getLength()); request.write(dob); } } /** Release resources used by wrapped saslClient */ public void dispose() throws SaslException { if (saslClient != null) { saslClient.dispose(); saslClient = null; } } private static class SaslClientCallbackHandler implements CallbackHandler { private final String userName; private final char[] userPassword; public SaslClientCallbackHandler(Token<? extends TokenIdentifier> token) { this.userName = SaslRpcServer.encodeIdentifier(token.getIdentifier()); this.userPassword = SaslRpcServer.encodePassword(token.getPassword()); } @Override public void handle(Callback[] callbacks) throws UnsupportedCallbackException { NameCallback nc = null; PasswordCallback pc = null; RealmCallback rc = null; for (Callback callback : callbacks) { if (callback instanceof RealmChoiceCallback) { continue; } else if (callback instanceof NameCallback) { nc = (NameCallback) callback; } else if (callback instanceof PasswordCallback) { pc = (PasswordCallback) callback; } else if (callback instanceof RealmCallback) { rc = (RealmCallback) callback; } else { throw new UnsupportedCallbackException(callback, "Unrecognized SASL client callback"); } } if (nc != null) { if (LOG.isDebugEnabled()) LOG.debug("SASL client callback: setting username: " + userName); nc.setName(userName); } if (pc != null) { if (LOG.isDebugEnabled()) LOG.debug("SASL client callback: setting userPassword"); pc.setPassword(userPassword); } if (rc != null) { if (LOG.isDebugEnabled()) LOG.debug("SASL client callback: setting realm: " + rc.getDefaultText()); rc.setText(rc.getDefaultText()); } } } }