/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.rs.security.jose.jwe;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.logging.Logger;
import javax.crypto.Cipher;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.common.util.Base64UrlUtility;
public class JweOutputStream extends FilterOutputStream {
protected static final Logger LOG = LogUtils.getL7dLogger(JweOutputStream.class);
private Cipher encryptingCipher;
private int blockSize;
private AuthenticationTagProducer authTagProducer;
private byte[] lastRawDataChunk;
private byte[] lastEncryptedDataChunk;
private boolean flushed;
public JweOutputStream(OutputStream out,
Cipher encryptingCipher,
AuthenticationTagProducer authTagProducer) {
super(out);
this.encryptingCipher = encryptingCipher;
this.blockSize = encryptingCipher.getBlockSize();
this.authTagProducer = authTagProducer;
}
@Override
public void write(int value) throws IOException {
byte[] bytes = ByteBuffer.allocate(Integer.SIZE / 8).putInt(value).array();
write(bytes, 0, bytes.length);
}
@Override
public void write(byte b[], int off, int len) throws IOException {
if (lastRawDataChunk != null) {
int remaining = blockSize - lastRawDataChunk.length;
int lenToCopy = remaining < len ? remaining : len;
lastRawDataChunk = newArray(lastRawDataChunk, 0, lastRawDataChunk.length, b, off, lenToCopy);
off = off + lenToCopy;
len -= lenToCopy;
if (lastRawDataChunk.length < blockSize) {
return;
} else {
encryptAndWrite(lastRawDataChunk, 0, lastRawDataChunk.length);
lastRawDataChunk = null;
}
}
int offset = 0;
int chunkSize = blockSize > len ? blockSize : blockSize * (len / blockSize);
for (; offset + chunkSize <= len; offset += chunkSize, off += chunkSize) {
encryptAndWrite(b, off, chunkSize);
}
if (offset < len) {
lastRawDataChunk = newArray(b, off, len - offset);
}
}
private void encryptAndWrite(byte[] chunk, int off, int len) throws IOException {
byte[] encrypted = encryptingCipher.update(chunk, off, len);
if (authTagProducer != null) {
authTagProducer.update(encrypted, 0, encrypted.length);
}
encodeAndWrite(encrypted, 0, encrypted.length, false);
}
private void encodeAndWrite(byte[] encryptedChunk, int off, int len, boolean finalWrite) throws IOException {
byte[] theChunk = lastEncryptedDataChunk;
int lenToEncode = len;
if (theChunk != null) {
theChunk = newArray(theChunk, 0, theChunk.length, encryptedChunk, off, len);
lenToEncode = theChunk.length;
off = 0;
} else {
theChunk = encryptedChunk;
}
int rem = finalWrite ? 0 : lenToEncode % 3;
Base64UrlUtility.encodeAndStream(theChunk, off, lenToEncode - rem, out);
out.flush();
if (rem > 0) {
lastEncryptedDataChunk = newArray(theChunk, lenToEncode - rem, rem);
} else {
lastEncryptedDataChunk = null;
}
}
public void finalFlush() throws IOException {
if (flushed) {
return;
}
try {
byte[] finalBytes = lastRawDataChunk == null
? encryptingCipher.doFinal()
: encryptingCipher.doFinal(lastRawDataChunk, 0, lastRawDataChunk.length);
final int authTagLengthBits = 128;
if (authTagProducer != null) {
authTagProducer.update(finalBytes, 0, finalBytes.length);
encodeAndWrite(finalBytes, 0, finalBytes.length, true);
} else {
encodeAndWrite(finalBytes, 0, finalBytes.length - authTagLengthBits / 8, true);
}
out.write(new byte[]{'.'});
if (authTagProducer == null) {
encodeAndWrite(finalBytes, finalBytes.length - authTagLengthBits / 8, authTagLengthBits / 8, true);
} else {
byte[] authTag = authTagProducer.getTag();
encodeAndWrite(authTag, 0, authTagLengthBits / 8, true);
}
} catch (Exception ex) {
LOG.warning("Content encryption failure");
throw new JweException(JweException.Error.CONTENT_ENCRYPTION_FAILURE, ex);
}
flushed = true;
}
private byte[] newArray(byte[] src, int srcPos, int srcLen) {
byte[] buf = new byte[srcLen];
System.arraycopy(src, srcPos, buf, 0, srcLen);
return buf;
}
private byte[] newArray(byte[] src, int srcPos, int srcLen, byte[] src2, int srcPos2, int srcLen2) {
byte[] buf = new byte[srcLen + srcLen2];
System.arraycopy(src, srcPos, buf, 0, srcLen);
System.arraycopy(src2, srcPos2, buf, srcLen, srcLen2);
return buf;
}
}