package com.github.elazarl.multireducers; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.io.serializer.Deserializer; import org.apache.hadoop.io.serializer.Serialization; import org.apache.hadoop.io.serializer.SerializationFactory; import org.apache.hadoop.io.serializer.Serializer; import org.apache.hadoop.util.ReflectionUtils; import java.io.*; import java.util.Collection; /** * MultiSerializer serializes PerMapperOutputValue and PerMapperOutputKey */ public class MultiSerializer implements Serialization<PerInternalMapper>, Configurable { @Override public boolean accept(Class c) { return c.equals(PerMapperOutputKey.class) || c.equals(PerMapperOutputValue.class); } @Override public Serializer<PerInternalMapper> getSerializer(Class c) { Class[] serClasses = (c.equals(PerMapperOutputKey.class)) ? mapperOutputKeyClasses : mapperOutputValueClasses; return new PerInternalMapperSerializer(factory, serClasses); } @Override public Deserializer<PerInternalMapper> getDeserializer(final Class c) { Class[] serClasses = (c.equals(PerMapperOutputKey.class)) ? mapperOutputKeyClasses : mapperOutputValueClasses; return new PerInternalMapperDeserializer(factory, serClasses, c, conf); } @Override public void setConf(Configuration conf) { this.conf = conf; mapperOutputKeyClasses = conf.getClasses(MultiReducer.INPUT_KEY_CLASSES); mapperOutputValueClasses = conf.getClasses(MultiReducer.INPUT_VALUE_CLASSES); factory = new SerializationFactory(removeMyClass(conf)); } private Configuration removeMyClass(Configuration conf) { Configuration withoutMe = new Configuration(conf); Collection<String> classes = withoutMe.getStringCollection(CommonConfigurationKeys.IO_SERIALIZATIONS_KEY); classes.remove(MultiSerializer.class.getName()); withoutMe.setStrings(CommonConfigurationKeys.IO_SERIALIZATIONS_KEY, classes.toArray(new String[classes.size()])); return withoutMe; } @Override public Configuration getConf() { return conf; } private Configuration conf; Class[] mapperOutputKeyClasses; Class[] mapperOutputValueClasses; SerializationFactory factory; private static class PerInternalMapperSerializer implements Serializer<PerInternalMapper> { private final Serializer[] serializers; private OutputStream nopClose; private DataOutputStream dataOut; private Serializer[] serializerUsed; public PerInternalMapperSerializer(SerializationFactory factory, Class[] serClasses) { this.serializers = new Serializer[serClasses.length]; for (int i = 0; i < serializers.length; i++) { serializers[i] = factory.getSerializer(serClasses[i]); } serializerUsed = new Serializer[serializers.length]; } @Override public void open(OutputStream out) throws IOException { this.nopClose = new NopCloseOutputStream(out); if (out instanceof DataOutputStream) { dataOut = (DataOutputStream) out; } else { dataOut = new DataOutputStream(out); } } @SuppressWarnings("unchecked") @Override public void serialize(PerInternalMapper perInternalMapper) throws IOException { WritableUtils.writeVInt(dataOut, perInternalMapper.targetReducer); getSerializer(perInternalMapper.targetReducer).serialize(perInternalMapper.data); } private Serializer getSerializer(int i) throws IOException { if (serializerUsed[i] == null) { serializerUsed[i] = serializers[i]; serializerUsed[i].open(nopClose); } return serializerUsed[i]; } @Override public void close() throws IOException { for (Serializer serializer : serializerUsed) { if (serializer != null) { serializer.close(); } } dataOut.close(); } } private static class PerInternalMapperDeserializer implements Deserializer<PerInternalMapper> { private final Deserializer[] deserializers; private final Class[] serClasses; private Class c; private Configuration conf; InputStream nopClose; DataInputStream dataIn; private Deserializer[] deserializersUsed; public PerInternalMapperDeserializer(SerializationFactory factory, Class[] serClasses, Class c, Configuration conf) { this.serClasses = serClasses; this.deserializers = new Deserializer[serClasses.length]; for (int i = 0; i < deserializers.length; i++) { deserializers[i] = factory.getDeserializer(serClasses[i]); } this.c = c; this.conf = conf; deserializersUsed = new Deserializer[deserializers.length]; } @Override public void open(InputStream in) throws IOException { nopClose = new NopCloseInputStream(in); if (in instanceof DataInputStream) { dataIn = (DataInputStream) in; } else { dataIn = new DataInputStream(in); } } @SuppressWarnings("unchecked") @Override public PerInternalMapper deserialize(PerInternalMapper m) throws IOException { if (m == null) { m = (PerInternalMapper) ReflectionUtils.newInstance(c, conf); } m.targetReducer = WritableUtils.readVInt(dataIn); if (!serClasses[m.targetReducer].isInstance(m.data)) { // TODO: cache perInternalMapper per targetReducer m.data = ReflectionUtils.newInstance(serClasses[m.targetReducer], conf); } m.data = getDesiralizer(m.targetReducer).deserialize(m.data); return m; } @Override public void close() throws IOException { for (Deserializer deserializer : deserializersUsed) { if (deserializer != null) { deserializer.close(); } } dataIn.close(); } private Deserializer getDesiralizer(int i) throws IOException { if (deserializersUsed[i] == null) { deserializersUsed[i] = deserializers[i]; deserializersUsed[i].open(nopClose); } return deserializersUsed[i]; } } }