/** * diqube: Distributed Query Base. * * Copyright (C) 2015 Bastian Gloeckle * * This file is part of diqube. * * diqube is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.diqube.remote.cluster; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import org.diqube.data.column.ColumnType; import org.diqube.function.IntermediaryResult; import org.diqube.function.aggregate.result.IntermediaryResultValueIterator; import org.diqube.function.aggregate.result.serialization.IntermediateResultSerialization; import org.diqube.function.aggregate.result.serialization.IntermediateResultSerializationResolver; import org.diqube.remote.cluster.thrift.RColumnType; import org.diqube.remote.cluster.thrift.RIntermediateAggregationResult; import org.diqube.remote.cluster.thrift.RIntermediateAggregationResultValue; import org.diqube.thrift.base.thrift.RValue; import org.diqube.thrift.base.util.RValueUtil; import org.diqube.util.SafeObjectInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.ImmutableSet; import com.google.common.reflect.ClassPath; import com.google.common.reflect.ClassPath.ClassInfo; /** * Util for {@link RIntermediateAggregationResult}. * * Serialization/deserialization adheres to {@link IntermediateResultSerialization}. * * @author Bastian Gloeckle */ public class RIntermediateAggregationResultUtil { private static final String ROOT_PKG = "org.diqube"; private static final Logger logger = LoggerFactory.getLogger(RIntermediateAggregationResultUtil.class); private volatile static Set<String> whitelistedSerializableClassNames = null; /** * Deserialize a {@link RIntermediateAggregationResult} to a {@link IntermediaryResult}. * * @throws IllegalArgumentException * if data cannot be deserialized. */ public static IntermediaryResult buildIntermediateAggregationResult(RIntermediateAggregationResult input) throws IllegalArgumentException { if (whitelistedSerializableClassNames == null) initialize(); ColumnType type = null; if (input.isSetInputColumnType()) { switch (input.getInputColumnType()) { case LONG: type = ColumnType.LONG; break; case DOUBLE: type = ColumnType.DOUBLE; break; default: type = ColumnType.STRING; break; } } IntermediaryResult res = new IntermediaryResult(input.getOutputColName(), type); for (RIntermediateAggregationResultValue val : input.getValues()) { if (val.isSetValue()) { res.pushValue(RValueUtil.createValue(val.getValue())); } else { byte[] serialized = val.getSerialized(); try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized)) { try (ObjectInputStream ois = new SafeObjectInputStream(bais, whitelistedSerializableClassNames)) { res.pushValue(ois.readObject()); } } catch (IOException | ClassNotFoundException e) { logger.error("Could not deserialize intermediate result", e); throw new IllegalArgumentException("Could not deserialize intermediate result", e); } } } return res; } /** * Serialize a {@link IntermediaryResult}. * * @throws IllegalArgumentException * If cannot be serialized */ public static RIntermediateAggregationResult buildRIntermediateAggregationResult(IntermediaryResult input) throws IllegalArgumentException { if (whitelistedSerializableClassNames == null) initialize(); RIntermediateAggregationResult res = new RIntermediateAggregationResult(); res.setOutputColName(input.getOutputColName()); if (input.getInputColumnType() != null) { switch (input.getInputColumnType()) { case STRING: res.setInputColumnType(RColumnType.STRING); break; case LONG: res.setInputColumnType(RColumnType.LONG); break; case DOUBLE: res.setInputColumnType(RColumnType.DOUBLE); break; } } List<RIntermediateAggregationResultValue> values = new ArrayList<>(); IntermediaryResultValueIterator it = input.createValueIterator(); while (it.hasNext()) { Object valueObject = it.next(); RIntermediateAggregationResultValue resValue = new RIntermediateAggregationResultValue(); RValue rvalue = RValueUtil.createRValue(valueObject); if (rvalue != null) { resValue.setValue(rvalue); } else { if (!whitelistedSerializableClassNames.contains(valueObject.getClass().getName())) // only a shallow check, but better than no check at all. throw new IllegalArgumentException("Class " + valueObject.getClass().getName() + " is not whitelisted."); try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { oos.writeObject(valueObject); } resValue.setSerialized(baos.toByteArray()); } catch (IOException e) { logger.error("Could not serialize intermediary result", e); throw new IllegalArgumentException("Could not serialize intermediary result", e); } } values.add(resValue); } res.setValues(values); return res; } private synchronized static void initialize() { if (whitelistedSerializableClassNames != null) return; ClassPath cp; try { cp = ClassPath.from(RIntermediateAggregationResultUtil.class.getClassLoader()); } catch (IOException e) { throw new RuntimeException("Could not initialize classpath scanning!", e); } ImmutableSet<ClassInfo> classInfos = cp.getTopLevelClassesRecursive(ROOT_PKG); whitelistedSerializableClassNames = new HashSet<>(); for (ClassInfo classInfo : classInfos) { Class<?> clazz = classInfo.load(); if (clazz.getAnnotation(IntermediateResultSerialization.class) != null) { if (!IntermediateResultSerializationResolver.class.isAssignableFrom(clazz)) { logger.warn("Class {} has {} annotation, but does not implement {}. Ignoring.", clazz.getName(), IntermediateResultSerialization.class.getSimpleName(), IntermediateResultSerializationResolver.class.getName()); continue; } try { IntermediateResultSerializationResolver resolver = (IntermediateResultSerializationResolver) clazz.newInstance(); resolver.resolve(cls -> { whitelistedSerializableClassNames.add(cls.getName()); logger.debug("Whitelisted class {} for being de-/serialized for intermediate aggregation results", cls); }); } catch (InstantiationException | IllegalAccessException e) { logger.warn("Could not instantiate {}. Ignoring.", clazz.getName(), e); } } } } }