/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.aliyun.odps.udf;
import com.aliyun.odps.io.Writable;
/**
* <p>
* 继承 {@link Aggregator} 实现 UDAF。<br />
* UDAF (User Defined Aggregation Function) :用户自定义聚合函数,其输入输出是多对一的关系,即将多条输入记录聚合成一条输出值。
* 可以与 SQL 中的 Group By 语句联用。<br />
*
* <br />
* 实现 Java UDAF 类需要继承 Aggregator 类。<br />
* Aggregator 流程主要分为四部分,分别对应四个主要接口:
* <ul>
* <li>
* {@link #newBuffer()} 聚合中间值 buffer 的创建和初始化。
* {@link #iterate(Writable, Writable[])} 实现此方法对输入数据进行计算,聚合到中间值 buffer。其中第一个参数是 newBuffer() 产生的结果,第二个参数是数据源。
* {@link #merge(Writable, Writable)} 实现此方法将两个中间值 merge 聚合到一起。其中第一个参数是 newBuffer() 产生的结果,第二个参数是 iterate 操作完成后产生的中间结果。
* {@link #terminate(Writable)} 实现此方法将 merge 操作完成后产生的中间结果转换为 ODPS SQL 基本类型。
* <li>
* </ul>
* <br />
* 初始化流程在{@link #setup(ExecutionContext)}调用中完成,用户可重写此方法来实现一次性初始操作,例如共享资源的读取等。
* <br />
* 聚合过程的中间数据 buffer 类继承于 {@link Writable}, 除内建类型外,用户可继承 Writable 类实现自定义类。<br />
* buffer 大小不应该随数据量递增,最好不要超过 2MB,否则会造成内存占用过大。<br />
* </p>
*
* <p>
* 示例代码(求平均值):
*<pre>
*@Resolve({"double->double"})
public class AggrAvg extends Aggregator {
private static class AvgBuffer implements Writable {
private double sum = 0;
private long count = 0;
@Override
public void write(DataOutput out) throws IOException {
out.writeDouble(sum);
out.writeLong(count);
}
@Override
public void readFields(DataInput in) throws IOException {
sum = in.readDouble();
count = in.readLong();
}
}
private DoubleWritable ret = new DoubleWritable();
@Override
public Writable newBuffer() {
return new AvgBuffer();
}
@Override
public void iterate(Writable buffer, Writable[] args) throws UDFException {
DoubleWritable arg = (DoubleWritable) args[0];
AvgBuffer buf = (AvgBuffer) buffer;
if (arg != null) {
buf.count += 1;
buf.sum += arg.get();
}
}
@Override
public void merge(Writable buffer, Writable partial) throws UDFException {
AvgBuffer buf = (AvgBuffer) buffer;
AvgBuffer p = (AvgBuffer) partial;
buf.sum += p.sum;
buf.count += p.count;
}
@Override
public Writable terminate(Writable buffer) throws UDFException {
AvgBuffer buf = (AvgBuffer) buffer;
if (buf.count == 0) {
ret.set(0);
} else {
ret.set(buf.sum / buf.count);
}
return ret;
}
}
*</pre>
* </p>
*/
public abstract class Aggregator implements ContextFunction {
/**
* 初始化工作。包括一些共享资源的载入等。
* 只在初始时被调用,建议一次性的操作都写入本方法。
*
* 资源载入通过 {@link ExecutionContext} 完成
* @param ctx
*/
@Override
public void setup(ExecutionContext ctx) throws UDFException {
}
/**
* 扫尾工作
*/
@Override
public void close() throws UDFException {
}
/**
* 创建聚合Buffer
*
* @return Writable 聚合buffer
*/
abstract public Writable newBuffer();
/**
* 对输入进行计算,生成中间结果
* @param buffer
* 聚合buffer
* @param args
* SQL中调用UDAF时指定的参数
* @throws UDFException
*/
abstract public void iterate(Writable buffer, Writable[] args) throws UDFException;
/**
* 生成最终结果
*
* @param buffer
* @return Object UDAF的最终结果
* @throws UDFException
*/
abstract public Writable terminate(Writable buffer) throws UDFException;
/**
* 聚合中间结果,将 partial merge 到 buffer
*
* @param buffer
* @param partial
* @throws UDFException
*/
abstract public void merge(Writable buffer, Writable partial) throws UDFException;
}