/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.aliyun.odps.mapred.unittest; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import com.aliyun.odps.data.Record; import com.aliyun.odps.mapred.Mapper; import com.aliyun.odps.mapred.Reducer; import com.aliyun.odps.mapred.conf.JobConf; import com.aliyun.odps.mapred.local.MapOutputBuffer; import com.aliyun.odps.pipeline.Pipeline; import com.aliyun.odps.counter.Counters; /** * Mapper/Reducer的输出. */ public class TaskOutput extends MapOutputBuffer { public TaskOutput(JobConf conf, int reduceNum) { super(conf, reduceNum); } public TaskOutput(JobConf job, Pipeline pipeline, String taskId, int reduceCopyNum) { super(job, pipeline, taskId, reduceCopyNum); } private Counters counters; private Map<Integer, List<KeyValue<Record, Record>>> outputKeyValues = new LinkedHashMap<Integer, List<KeyValue<Record, Record>>>(); private Map<String, List<Record>> outputs = new HashMap<String, List<Record>>(); /** * 返回 {@link Mapper} 最终输出到 {@link Reducer} 的键值对列表. * * @return {@link Mapper} 输出到 {@link Reducer} 的键值对列表 */ public List<KeyValue<Record, Record>> getOutputKeyValues() { List<KeyValue<Record, Record>> list = new ArrayList<KeyValue<Record, Record>>(); for (Map.Entry<Integer, List<KeyValue<Record, Record>>> entry : outputKeyValues .entrySet()) { list.addAll(entry.getValue()); } return list; } /** * 返回 {@link Mapper} 输出到给定 reduceId 的键值对列表. * * @param reduceId * {@link Reducer} 序号,从 0 开始计数。 * @return 输出到给定 reduceId 的键值对列表 */ public List<KeyValue<Record, Record>> getOutputKeyValues(int reduceId) { return outputKeyValues.get(reduceId); } void setOutputKeyValues(int reduceId, List<KeyValue<Record, Record>> outputKeyValues) { this.outputKeyValues.put(reduceId, outputKeyValues); } /** * 获取默认输出结果,返回 List<{@link Record}>. * * @return 默认输出结果 */ public List<Record> getOutputRecords() { return getOutputRecords("__default__"); } /** * 获取给定标签的输出结果,返回 List<{@link Record}>. * * @param label * 输出标签 * @return 给定标签的输出结果 */ public List<Record> getOutputRecords(String label) { return getOutputRecords(label, true); } /** * 获取默认输出结果,返回 List<{@link Record}>. * * @return 默认输出结果 */ public List<Record> getOutputRecords(boolean sort) { return getOutputRecords("__default__", sort); } /** * 获取给定标签的输出结果,返回 List<{@link Record}>. * * @param label * 输出标签 * @return 给定标签的输出结果 */ public List<Record> getOutputRecords(String label, boolean sort) { List<Record> records = outputs.get(label); if (records == null) { records = new ArrayList<Record>(); } List<Record> sortRecords = new ArrayList<Record>(records); LocalRecordComparator comparator = new LocalRecordComparator(); if (sort) { Collections.sort(sortRecords, comparator); } return sortRecords; } /** * 获取作业 {@link Counters}. * * @return 作业 {@link Counters}. */ public Counters getCounters() { return counters; } void setCounters(Counters counters) { this.counters = counters; } void setOutputRecords(String label, List<Record> records) { outputs.put(label, records); } @Override public void add(Record key, Record value) { int partition = getPartition(key); this.add(key, value, partition); } @Override public void add(Record key, Record value, int partition) { super.add(key, value, partition); KeyValue<Record, Record> kv = new KeyValue<Record, Record>(key.clone(), value.clone()); List<KeyValue<Record, Record>> kvs = this.getOutputKeyValues(partition); if (kvs == null) { kvs = new ArrayList<KeyValue<Record, Record>>(); outputKeyValues.put(partition, kvs); } kvs.add(kv); } @Override public void add(Record record, String label) { List<Record> records = outputs.get(label); if (records == null) { records = new ArrayList<Record>(); outputs.put(label, records); } records.add(record.clone()); } @Override public long getTotalRecordCount() { long numOutputRecords = 0; for (List<Record> records : outputs.values()) { numOutputRecords += records.size(); } return numOutputRecords + super.getTotalRecordCount(); } }