/* * Copyright [2013-2017] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.correlation; import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceStability; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Counter; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.JobContext; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.OutputCommitter; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.RecordWriter; import org.apache.hadoop.mapreduce.StatusReporter; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.TaskAttemptID; import org.apache.hadoop.util.ReflectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Copy from MultithreadedMapper to do some customization. Merge mapper output results and then write to reducer. * * @author Zhang David (pengzhang@paypal.com) */ @InterfaceAudience.Public @InterfaceStability.Stable public class CorrelationMultithreadedMapper extends Mapper<LongWritable, Text, IntWritable, CorrelationWritable> { private final static Logger LOG = LoggerFactory.getLogger(CorrelationMultithreadedMapper.class); public static String NUM_THREADS = "mapreduce.mapper.multithreadedmapper.threads"; public static String MAP_CLASS = "mapreduce.mapper.multithreadedmapper.mapclass"; private Class<? extends Mapper<LongWritable, Text, IntWritable, CorrelationWritable>> mapClass; private Context outer; private List<MapRunner> runners; /** * Only correlation map to save memory, if set it into CorrelationMapper, then threads * memory will be used. it is * a very big memory consider > 3000 variables. Use static to make it be easy to be accessed in CorrelationMapper. * This is ugly and should be changed in the future. */ static Map<Integer, CorrelationWritable> finalCorrelationMap = new HashMap<Integer, CorrelationWritable>();; /** * Output key cache to avoid new operation. */ private IntWritable outputKey; /** * Column config list to initialize {@link #finalCorrelationMap} */ private List<ColumnConfig> columnConfigList; /** * The number of threads in the thread pool that will run the map function. * * @param job * the job * @return the number of threads */ public static int getNumberOfThreads(JobContext job) { return job.getConfiguration().getInt(NUM_THREADS, 10); } /** * Set the number of threads in the pool for running maps. * * @param job * the job to modify * @param threads * the new number of threads */ public static void setNumberOfThreads(Job job, int threads) { job.getConfiguration().setInt(NUM_THREADS, threads); } /** * Get the application's mapper class. * * @param <K1> * the map's input key type * @param <V1> * the map's input value type * @param <K2> * the map's output key type * @param <V2> * the map's output value type * @param job * the job * @return the mapper class to run */ @SuppressWarnings("unchecked") public static <K1, V1, K2, V2> Class<Mapper<K1, V1, K2, V2>> getMapperClass(JobContext job) { return (Class<Mapper<K1, V1, K2, V2>>) job.getConfiguration().getClass(MAP_CLASS, Mapper.class); } /** * Set the application's mapper class. * * @param <K1> * the map input key type * @param <V1> * the map input value type * @param <K2> * the map output key type * @param <V2> * the map output value type * @param job * the job to modify * @param cls * the class to use as the mapper */ public static <K1, V1, K2, V2> void setMapperClass(Job job, Class<? extends Mapper<K1, V1, K2, V2>> cls) { if(CorrelationMultithreadedMapper.class.isAssignableFrom(cls)) { throw new IllegalArgumentException("Can't have recursive " + "MultithreadedMapper instances."); } job.getConfiguration().setClass(MAP_CLASS, cls, Mapper.class); } private void loadConfigFiles(final Context context) { try { SourceType sourceType = SourceType.valueOf(context.getConfiguration().get( Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.columnConfigList = CommonUtils.loadColumnConfigList( context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } /** * Run the application's maps using a thread pool. */ @Override public void run(Context context) throws IOException, InterruptedException { outer = context; loadConfigFiles(context); // initialize each cw instance to make it easy to be synchronized in CorrelationMapper for(int i = 0; i < this.columnConfigList.size(); i++) { CorrelationMultithreadedMapper.finalCorrelationMap.put(this.columnConfigList.get(i).getColumnNum(), new CorrelationWritable()); } int numberOfThreads = getNumberOfThreads(context); mapClass = getMapperClass(context); if(LOG.isDebugEnabled()) { LOG.debug("Configuring multithread runner to use " + numberOfThreads + " threads"); } runners = new ArrayList<MapRunner>(numberOfThreads); for(int i = 0; i < numberOfThreads; ++i) { MapRunner thread = new MapRunner(context); thread.start(); runners.add(i, thread); } for(int i = 0; i < numberOfThreads; ++i) { MapRunner thread = runners.get(i); thread.join(); Throwable th = thread.throwable; if(th != null) { if(th instanceof IOException) { throw (IOException) th; } else if(th instanceof InterruptedException) { throw (InterruptedException) th; } else { throw new RuntimeException(th); } } } outputKey = new IntWritable(); // after all sub mapper completed, finalCorrelationMap includes global results and send them to reducer. // send to reducer with only one merged copy no matter how many threads for(Entry<Integer, CorrelationWritable> entry: finalCorrelationMap.entrySet()) { outputKey.set(entry.getKey()); context.write(outputKey, entry.getValue()); } } private class SubMapRecordReader extends RecordReader<LongWritable, Text> { private LongWritable key; private Text value; private Configuration conf; @Override public void close() throws IOException { } @Override public float getProgress() throws IOException, InterruptedException { return 0; } @Override public void initialize(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException { conf = context.getConfiguration(); } @Override public boolean nextKeyValue() throws IOException, InterruptedException { synchronized(outer) { if(!outer.nextKeyValue()) { return false; } key = ReflectionUtils.copy(outer.getConfiguration(), outer.getCurrentKey(), key); value = ReflectionUtils.copy(conf, outer.getCurrentValue(), value); return true; } } public LongWritable getCurrentKey() { return key; } @Override public Text getCurrentValue() { return value; } } private class SubMapRecordWriter extends RecordWriter<IntWritable, CorrelationWritable> { @Override public void close(TaskAttemptContext context) throws IOException, InterruptedException { } @Override public void write(IntWritable key, CorrelationWritable value) throws IOException, InterruptedException { } } private class SubMapStatusReporter extends StatusReporter { @Override public Counter getCounter(Enum<?> name) { return outer.getCounter(name); } @Override public Counter getCounter(String group, String name) { return outer.getCounter(group, name); } @Override public void progress() { outer.progress(); } @Override public void setStatus(String status) { outer.setStatus(status); } public float getProgress() { try { Method method = outer.getClass().getDeclaredMethod("getProgress", new Class[] {}); if(method != null) { return (Float) (method.invoke(outer, new Object[] {})); } } catch (Throwable e) { return 0f; } return 0f; } } private class MapRunner extends Thread { private Mapper<LongWritable, Text, IntWritable, CorrelationWritable> mapper; private Context subcontext; private Throwable throwable; private RecordReader<LongWritable, Text> reader = new SubMapRecordReader(); MapRunner(Context context) throws IOException, InterruptedException { mapper = ReflectionUtils.newInstance(mapClass, context.getConfiguration()); subcontext = createSubContext(context); reader.initialize(context.getInputSplit(), context); } private Context createSubContext(Context context) { boolean isHadoop2 = false; Class<?> mapContextImplClazz = null; try { mapContextImplClazz = Class.forName("org.apache.hadoop.mapreduce.task.MapContextImpl"); isHadoop2 = true; } catch (ClassNotFoundException e) { isHadoop2 = false; } if(mapContextImplClazz == null) { isHadoop2 = false; } try { if(isHadoop2) { return createSubContextForHadoop2(context, mapContextImplClazz); } else { return createSubContextForHadoop1(context); } } catch (Throwable t) { throw new RuntimeException(t); } } @SuppressWarnings("unchecked") private Context createSubContextForHadoop2(Context context, Class<?> mapContextImplClazz) throws NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException, ClassNotFoundException { Constructor<?> constructor = mapContextImplClazz.getDeclaredConstructor(Configuration.class, TaskAttemptID.class, RecordReader.class, RecordWriter.class, OutputCommitter.class, StatusReporter.class, InputSplit.class); constructor.setAccessible(true); Object mapContext = constructor.newInstance(outer.getConfiguration(), outer.getTaskAttemptID(), reader, new SubMapRecordWriter(), context.getOutputCommitter(), new SubMapStatusReporter(), outer.getInputSplit()); Class<?> wrappedMapperClazz = Class.forName("org.apache.hadoop.mapreduce.lib.map.WrappedMapper"); Object wrappedMapper = wrappedMapperClazz.newInstance(); Method method = wrappedMapperClazz.getDeclaredMethod("getMapContext", Class.forName("org.apache.hadoop.mapreduce.MapContext")); return (Context) (method.invoke(wrappedMapper, mapContext)); } @SuppressWarnings("unchecked") private Context createSubContextForHadoop1(Context context) throws NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException { Constructor<?> constructor = Context.class.getDeclaredConstructor(Mapper.class, Configuration.class, TaskAttemptID.class, RecordReader.class, RecordWriter.class, OutputCommitter.class, StatusReporter.class, InputSplit.class); constructor.setAccessible(true); return (Context) constructor.newInstance(mapper, outer.getConfiguration(), outer.getTaskAttemptID(), reader, new SubMapRecordWriter(), context.getOutputCommitter(), new SubMapStatusReporter(), outer.getInputSplit()); } @SuppressWarnings("unused") public Throwable getThrowable() { return throwable; } @Override public void run() { try { mapper.run(subcontext); } catch (Throwable ie) { throwable = ie; } finally { try { reader.close(); } catch (IOException ignore) { } } } } }