/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.correlation;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Iterator;
import ml.shifu.guagua.GuaguaRuntimeException;
import org.apache.commons.codec.binary.Base64;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Reducer;
/**
* {@link CorrelationReducer} is used to merge all {@link CorrelationWritable}s together to compute pearson correlation
* between two variables.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public class CorrelationReducer extends Reducer<IntWritable, CorrelationWritable, IntWritable, Text> {
/**
* Output key cache to avoid new operation.
*/
private IntWritable outputKey;
/**
* Prevent too many new objects for output key.
*/
private Text outputValue;
/**
* Do initialization like ModelConfig and ColumnConfig loading.
*/
@Override
protected void setup(Context context) throws IOException, InterruptedException {
this.outputKey = new IntWritable();
this.outputValue = new Text();
}
@Override
protected void reduce(IntWritable key, Iterable<CorrelationWritable> values, Context context) throws IOException,
InterruptedException {
// build final correlation column info
CorrelationWritable finalCw = null;
Iterator<CorrelationWritable> cwIt = values.iterator();
while(cwIt.hasNext()) {
CorrelationWritable cw = cwIt.next();
if(finalCw == null) {
finalCw = cw;
} else {
finalCw.combine(cw);
}
}
this.outputKey.set(key.get());
this.outputValue.set(new String(Base64.encodeBase64(objectToBytes(finalCw)), "utf-8"));
context.write(outputKey, outputValue);
}
public byte[] objectToBytes(Writable result) {
ByteArrayOutputStream out = null;
DataOutputStream dataOut = null;
try {
out = new ByteArrayOutputStream();
dataOut = new DataOutputStream(out);
result.write(dataOut);
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
} finally {
if(dataOut != null) {
try {
dataOut.close();
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
}
}
return out.toByteArray();
}
}