package ch.ge.ve.commons.streamutils;
/*-
* #%L
* Common crypto utilities
* %%
* Copyright (C) 2015 - 2016 République et Canton de Genève
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
* #L%
*/
import java.io.*;
import java.util.List;
/**
* This class provides for safer object deserialization, by limiting length of input, number and type of objects read.
*/
public class SafeObjectReader {
private SafeObjectReader() {
// utility class, do not allow to instantiate it
}
/**
* This method should be used to replace unsafe calls to ObjectInputStream.readObject() built into Java.
* It checks that only allowed classes are read, that the number of objects and bytes read stay within the given parameters.
* Also, it casts the read object to the expected type.
*
* @param expectedType Class of the expected object
* @param safeClasses The list of Classes allowed to be read (on top of primitive arrays, numbers and Strings which are always considered safe)
* @param maxObjects The maximum number of objects allowed to be read
* @param maxBytes The maximum number of bytes allowed to be read
* @param in The InputStream containing an object from an untrusted source
* @param <T> The type the object will be cast to.
* @return the object read from the stream, cast to the type parameter
* @throws IOException
* @throws ClassNotFoundException
*/
@SuppressWarnings("unchecked")
public static <T> T safeReadObject(final Class<? extends T> expectedType, final List<Class<?>> safeClasses, final long maxObjects, final long maxBytes, InputStream in) throws IOException, ClassNotFoundException {
// Create a FilterInputStream that checks the length of the input as it is being read.
InputStream fis = new LimitedLengthFilterInputStream(in, maxBytes);
ObjectInputStream ois = new SafeObjectInputStream(fis, maxObjects, expectedType, safeClasses);
return (T) ois.readObject();
}
/**
* This class limits the allowed length for reading objects
*/
private static class LimitedLengthFilterInputStream extends FilterInputStream {
private final long maxBytes;
private long length;
public LimitedLengthFilterInputStream(InputStream in, long maxBytes) {
super(in);
this.maxBytes = maxBytes;
length = 0;
}
@Override
public int read() throws IOException {
int val = super.read();
if (val != -1) {
length++;
checkLength();
}
return val;
}
@Override
public int read(byte[] b, int off, int length) throws IOException {
int val = super.read(b, off, length);
if (val > 0) {
this.length += val;
checkLength();
}
return val;
}
private void checkLength() {
if (length > maxBytes) {
throw new SafeObjectSecurityRuntimeException("Security violation: attempt to deserialize too many bytes from stream. Limit is " + maxBytes);
}
}
}
/**
* This specialized ObjectInputStream prevents too many objects from being unserialized, as well as filtering the
* types of objects allowed.
*/
private static class SafeObjectInputStream<T> extends ObjectInputStream {
private final long maxObjects;
private final Class<? extends T> type;
private final List<Class<?>> safeClasses;
boolean shouldResolveObjects;
private int objectCount;
public SafeObjectInputStream(InputStream fis, long maxObjects, Class<? extends T> type, List<Class<?>> safeClasses) throws IOException {
super(fis);
this.maxObjects = maxObjects;
this.type = type;
this.safeClasses = safeClasses;
shouldResolveObjects = enableResolveObject(true);
objectCount = 0;
}
@Override
protected Object resolveObject(Object obj) throws IOException {
if (objectCount++ > maxObjects) {
throw new SafeObjectSecurityRuntimeException("Security violation: attempt to deserialize too many objects from stream. Limit is " + maxObjects);
}
return super.resolveObject(obj);
}
@Override
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
Class<?> clazz = super.resolveClass(desc);
if (isSafeClass(clazz)) {
return clazz;
} else {
throw new SafeObjectSecurityRuntimeException("Security violation: attempt to deserialize unauthorized " + clazz);
}
}
private boolean isSafeClass(Class<?> clazz) {
return clazz.isArray() ||
clazz.equals(type) ||
clazz.equals(String.class) ||
Number.class.isAssignableFrom(clazz) ||
safeClasses.contains(clazz);
}
}
}