/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.security; import java.io.DataInputStream; import java.io.EOFException; import java.io.InputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceStability; /** * A SaslInputStream is composed of an InputStream and a SaslServer (or * SaslClient) so that read() methods return data that are read in from the * underlying InputStream but have been additionally processed by the SaslServer * (or SaslClient) object. The SaslServer (or SaslClient) object must be fully * initialized before being used by a SaslInputStream. */ @InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) @InterfaceStability.Evolving public class SaslInputStream extends InputStream implements ReadableByteChannel { public static final Log LOG = LogFactory.getLog(SaslInputStream.class); private final DataInputStream inStream; /** Should we wrap the communication channel? */ private final boolean useWrap; /* * data read from the underlying input stream before being processed by SASL */ private byte[] saslToken; private final SaslClient saslClient; private final SaslServer saslServer; private byte[] lengthBuf = new byte[4]; /* * buffer holding data that have been processed by SASL, but have not been * read out */ private byte[] obuffer; // position of the next "new" byte private int ostart = 0; // position of the last "new" byte private int ofinish = 0; // whether or not this stream is open private boolean isOpen = true; private static int unsignedBytesToInt(byte[] buf) { if (buf.length != 4) { throw new IllegalArgumentException( "Cannot handle byte array other than 4 bytes"); } int result = 0; for (int i = 0; i < 4; i++) { result <<= 8; result |= ((int) buf[i] & 0xff); } return result; } /** * Read more data and get them processed <br> * Entry condition: ostart = ofinish <br> * Exit condition: ostart <= ofinish <br> * * return (ofinish-ostart) (we have this many bytes for you), 0 (no data now, * but could have more later), or -1 (absolutely no more data) */ private int readMoreData() throws IOException { try { inStream.readFully(lengthBuf); int length = unsignedBytesToInt(lengthBuf); if (LOG.isDebugEnabled()) LOG.debug("Actual length is " + length); saslToken = new byte[length]; inStream.readFully(saslToken); } catch (EOFException e) { return -1; } try { if (saslServer != null) { // using saslServer obuffer = saslServer.unwrap(saslToken, 0, saslToken.length); } else { // using saslClient obuffer = saslClient.unwrap(saslToken, 0, saslToken.length); } } catch (SaslException se) { try { disposeSasl(); } catch (SaslException ignored) { } throw se; } ostart = 0; if (obuffer == null) ofinish = 0; else ofinish = obuffer.length; return ofinish; } /** * Disposes of any system resources or security-sensitive information Sasl * might be using. * * @exception SaslException * if a SASL error occurs. */ private void disposeSasl() throws SaslException { if (saslClient != null) { saslClient.dispose(); } if (saslServer != null) { saslServer.dispose(); } } /** * Constructs a SASLInputStream from an InputStream and a SaslServer <br> * Note: if the specified InputStream or SaslServer is null, a * NullPointerException may be thrown later when they are used. * * @param inStream * the InputStream to be processed * @param saslServer * an initialized SaslServer object */ public SaslInputStream(InputStream inStream, SaslServer saslServer) { this.inStream = new DataInputStream(inStream); this.saslServer = saslServer; this.saslClient = null; String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); this.useWrap = qop != null && !"auth".equalsIgnoreCase(qop); } /** * Constructs a SASLInputStream from an InputStream and a SaslClient <br> * Note: if the specified InputStream or SaslClient is null, a * NullPointerException may be thrown later when they are used. * * @param inStream * the InputStream to be processed * @param saslClient * an initialized SaslClient object */ public SaslInputStream(InputStream inStream, SaslClient saslClient) { this.inStream = new DataInputStream(inStream); this.saslServer = null; this.saslClient = saslClient; String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP); this.useWrap = qop != null && !"auth".equalsIgnoreCase(qop); } /** * Reads the next byte of data from this input stream. The value byte is * returned as an <code>int</code> in the range <code>0</code> to * <code>255</code>. If no byte is available because the end of the stream has * been reached, the value <code>-1</code> is returned. This method blocks * until input data is available, the end of the stream is detected, or an * exception is thrown. * <p> * * @return the next byte of data, or <code>-1</code> if the end of the stream * is reached. * @exception IOException * if an I/O error occurs. */ @Override public int read() throws IOException { if (!useWrap) { return inStream.read(); } if (ostart >= ofinish) { // we loop for new data as we are blocking int i = 0; while (i == 0) i = readMoreData(); if (i == -1) return -1; } return ((int) obuffer[ostart++] & 0xff); } /** * Reads up to <code>b.length</code> bytes of data from this input stream into * an array of bytes. * <p> * The <code>read</code> method of <code>InputStream</code> calls the * <code>read</code> method of three arguments with the arguments * <code>b</code>, <code>0</code>, and <code>b.length</code>. * * @param b * the buffer into which the data is read. * @return the total number of bytes read into the buffer, or <code>-1</code> * is there is no more data because the end of the stream has been * reached. * @exception IOException * if an I/O error occurs. */ @Override public int read(byte[] b) throws IOException { return read(b, 0, b.length); } /** * Reads up to <code>len</code> bytes of data from this input stream into an * array of bytes. This method blocks until some input is available. If the * first argument is <code>null,</code> up to <code>len</code> bytes are read * and discarded. * * @param b * the buffer into which the data is read. * @param off * the start offset of the data. * @param len * the maximum number of bytes read. * @return the total number of bytes read into the buffer, or <code>-1</code> * if there is no more data because the end of the stream has been * reached. * @exception IOException * if an I/O error occurs. */ @Override public int read(byte[] b, int off, int len) throws IOException { if (!useWrap) { return inStream.read(b, off, len); } if (ostart >= ofinish) { // we loop for new data as we are blocking int i = 0; while (i == 0) i = readMoreData(); if (i == -1) return -1; } if (len <= 0) { return 0; } int available = ofinish - ostart; if (len < available) available = len; if (b != null) { System.arraycopy(obuffer, ostart, b, off, available); } ostart = ostart + available; return available; } /** * Skips <code>n</code> bytes of input from the bytes that can be read from * this input stream without blocking. * * <p> * Fewer bytes than requested might be skipped. The actual number of bytes * skipped is equal to <code>n</code> or the result of a call to * {@link #available() <code>available</code>}, whichever is smaller. If * <code>n</code> is less than zero, no bytes are skipped. * * <p> * The actual number of bytes skipped is returned. * * @param n * the number of bytes to be skipped. * @return the actual number of bytes skipped. * @exception IOException * if an I/O error occurs. */ @Override public long skip(long n) throws IOException { if (!useWrap) { return inStream.skip(n); } int available = ofinish - ostart; if (n > available) { n = available; } if (n < 0) { return 0; } ostart += n; return n; } /** * Returns the number of bytes that can be read from this input stream without * blocking. The <code>available</code> method of <code>InputStream</code> * returns <code>0</code>. This method <B>should</B> be overridden by * subclasses. * * @return the number of bytes that can be read from this input stream without * blocking. * @exception IOException * if an I/O error occurs. */ @Override public int available() throws IOException { if (!useWrap) { return inStream.available(); } return (ofinish - ostart); } /** * Closes this input stream and releases any system resources associated with * the stream. * <p> * The <code>close</code> method of <code>SASLInputStream</code> calls the * <code>close</code> method of its underlying input stream. * * @exception IOException * if an I/O error occurs. */ @Override public void close() throws IOException { disposeSasl(); ostart = 0; ofinish = 0; inStream.close(); isOpen = false; } /** * Tests if this input stream supports the <code>mark</code> and * <code>reset</code> methods, which it does not. * * @return <code>false</code>, since this class does not support the * <code>mark</code> and <code>reset</code> methods. */ @Override public boolean markSupported() { return false; } @Override public boolean isOpen() { return isOpen; } @Override public int read(ByteBuffer dst) throws IOException { int bytesRead = 0; if (dst.hasArray()) { bytesRead = read(dst.array(), dst.arrayOffset() + dst.position(), dst.remaining()); if (bytesRead > -1) { dst.position(dst.position() + bytesRead); } } else { byte[] buf = new byte[dst.remaining()]; bytesRead = read(buf); if (bytesRead > -1) { dst.put(buf, 0, bytesRead); } } return bytesRead; } }