package org.infinispan.client.hotrod.impl.transport.tcp; import java.io.BufferedInputStream; import java.io.DataInputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; /** * SaslInputStream. * * @author Tristan Tarrant * @since 7.0 */ public class SaslInputStream extends InputStream { private final SaslClient saslClient; private final DataInputStream inStream; private byte[] buffer; private int bufferPtr = 0; private int bufferLength = 0; public SaslInputStream(InputStream inStream, SaslClient saslClient) { this.inStream = new DataInputStream(new BufferedInputStream(inStream)); this.saslClient = saslClient; } @Override public int read() throws IOException { if (readBuffer()) return -1; return ((int) buffer[bufferPtr++] & 0xff); } @Override public int read(byte[] b) throws IOException { return read(b, 0, b.length); } @Override public int read(byte[] b, int off, int len) throws IOException { if ( readBuffer()) return -1; if (len <= 0) { return 0; } int available = bufferLength - bufferPtr; if (len < available) available = len; if (b != null) { System.arraycopy(buffer, bufferPtr, b, off, available); } bufferPtr = bufferPtr + available; return available; } private boolean readBuffer() throws IOException { if (bufferPtr >= bufferLength) { int i = 0; while (i == 0) i = fillBuffer(); if (i == -1) return true; } return false; } @Override public long skip(long n) throws IOException { int available = bufferLength - bufferPtr; if (n > available) { n = available; } if (n < 0) { return 0; } bufferPtr += n; return n; } @Override public int available() throws IOException { return (bufferLength - bufferPtr); } @Override public void close() throws IOException { disposeSasl(); bufferPtr = 0; bufferLength = 0; inStream.close(); } @Override public boolean markSupported() { return false; } private int fillBuffer() throws IOException { byte[] saslToken; try { int length = inStream.readInt(); saslToken = new byte[length]; inStream.readFully(saslToken); buffer = saslClient.unwrap(saslToken, 0, length); } catch (EOFException e) { return -1; } catch (SaslException se) { try { disposeSasl(); } catch (SaslException e) { } throw se; } bufferPtr = 0; bufferLength = buffer.length; return bufferLength; } private void disposeSasl() throws SaslException { if (saslClient != null) { saslClient.dispose(); } } }