/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
package com.facebook.crypto.streams;
import com.facebook.crypto.mac.NativeMac;
import java.io.IOException;
import java.io.InputStream;
public class NativeMacLayeredInputStream extends InputStream {
private final NativeMac mMac;
private final TailInputStream mInputDelegate;
private boolean mMacChecked = false;
private static final String MAC_DOES_NOT_MATCH = "Mac does not match";
/**
* Creates a new input stream to read from.
*
* @param mac The object used to compute the mac.
* @param inputDelegate The stream to read the data from.
*/
public NativeMacLayeredInputStream(NativeMac mac, InputStream inputDelegate) {
mMac = mac;
mInputDelegate = new TailInputStream(inputDelegate, mac.getMacLength());
}
@Override
public int available() throws IOException {
return mInputDelegate.available();
}
@Override
public void close() throws IOException {
try {
ensureMacValid();
} finally {
mInputDelegate.close();
}
}
@Override
public void mark(int readlimit) {
throw new UnsupportedOperationException();
}
@Override
public boolean markSupported() {
return false;
}
@Override
public int read() throws IOException {
byte[] buffer = new byte[1];
int read = read(buffer, 0, 1);
while (read == 0) {
read = read(buffer, 0, 1);
}
if (read == -1) {
return -1;
} else {
return buffer[0] & 0xFF;
}
}
@Override
public int read(byte[] buffer) throws IOException {
return read(buffer, 0, buffer.length);
}
@Override
public int read(byte[] buffer, int offset, int length) throws IOException {
int read = mInputDelegate.read(buffer, offset, length);
if (read == -1) {
ensureMacValid();
return -1;
}
if (read > 0) {
mMac.update(buffer, offset, read);
}
return read;
}
private void ensureMacValid() throws IOException {
if (mMacChecked) {
return;
}
mMacChecked = true;
try {
byte[] mac = mMac.doFinal();
if (!constantTimeEquals(mInputDelegate.getTail(), mac)) {
throw new IOException(MAC_DOES_NOT_MATCH);
}
} finally {
mMac.destroy();
}
}
@Override
public synchronized void reset() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long skip(long byteCount) throws IOException {
throw new UnsupportedOperationException();
}
private boolean constantTimeEquals(byte[] a, byte[] b) {
if (a.length != b.length) {
return false;
}
int compare = 0;
for (int i = 0; i < a.length; ++i) {
compare |= a[i] ^ b[i];
}
if (compare == 0) {
return true;
}
return false;
}
}