/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.aliyun.odps.io; import java.io.DataInput; import java.io.IOException; import java.util.Comparator; import java.util.HashMap; import com.aliyun.odps.utils.ReflectionUtils; /** * WritableComparator 提供对 {@link WritableComparable} 对象的通用比较函数. * * <p> * 通用比较函数使用对象的自然顺序进行排序,如果用户需要自定义对象顺序,则需要重载 * {@link #compare(WritableComparable, WritableComparable)}方法。 * * <p> * 如果对排序性能敏感,可以重载 {@link #compare(byte[], int, int, byte[], int, int)} 方法。 */ @SuppressWarnings("rawtypes") public class WritableComparator implements RawComparator { private static HashMap<Class, WritableComparator> comparators = new HashMap<Class, WritableComparator>(); // registry /** * 此静态方法用于获取为类型 c 注册的 WritableComparator 实现. * * <p> * 此方法返回通过 {@link #define(Class, WritableComparator)} 方法为类型 c 注册的 * WritableComparator 实现,如果没有注册,则返回 WritableComparator 这一通用实现。 * * @param c * 待比较的 {@link WritableComparable} 类型 * @return WritableComparator 实现 * @see JobConf#getOutputKeyComparator() */ public static synchronized WritableComparator get( Class<? extends WritableComparable> c) { WritableComparator comparator = comparators.get(c); if (comparator == null) { comparator = new WritableComparator(c, true); } return comparator; } /** * 此静态方法用于为指定类型注册更高效的 WritableComparator 实现,否则默认使用本实现. * * <p> * 注意:只能注册线程安全的比较器,注册的对象可能在多线程场景中使用 * * @param c * 待比较的 {@link WritableComparable} 类型 * @param comparator * 更高效的 WritableComparator 实现 * @see #get(Class) */ public static synchronized void define(Class c, WritableComparator comparator) { comparators.put(c, comparator); } private final Class<? extends WritableComparable> keyClass; private final WritableComparable key1; private final WritableComparable key2; private final DataInputBuffer buffer; /** * 构造函数,传入待比较的 {@link WritableComparable} 类型(通常是 {@link Mapper} 输出的 Key类型). * * @param keyClass * 待比较的 {@link WritableComparable} 类型 */ protected WritableComparator(Class<? extends WritableComparable> keyClass) { this(keyClass, false); } /** * 构造函数,传入待比较的 Key 类型和是否创建 Key 对象 * * @param keyClass * @param createInstances */ protected WritableComparator(Class<? extends WritableComparable> keyClass, boolean createInstances) { this.keyClass = keyClass; if (createInstances) { key1 = newKey(); key2 = newKey(); buffer = new DataInputBuffer(); } else { key1 = key2 = null; buffer = null; } } /** * 返回待比较的 {@link WritableComparable} 实现类 * * @return {@link WritableComparable} 实现类 */ public Class<? extends WritableComparable> getKeyClass() { return keyClass; } /** * 新建一个新的 {@link WritableComparable} 对象. * * @return 新创建的 {@link WritableComparable} 对象 */ public WritableComparable newKey() { return ReflectionUtils.newInstance(keyClass, null); } /** * 本方法是 {@link RawComparator} 的低效实现,如果能提供此方法的高效实现,请重载此方法. * * <p> * 本方法的默认实现是:先通过 {@link Writable#readFields(DataInput)} 将二进制表示反序列化为 * {@link WritableComparable} 对象,然后调用 * {@link #compare(WritableComparable, WritableComparable)} 进行比较。 */ @Override public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) { try { buffer.reset(b1, s1, l1); // parse key1 key1.readFields(buffer); buffer.reset(b2, s2, l2); // parse key2 key2.readFields(buffer); } catch (IOException e) { throw new RuntimeException(e); } return compare(key1, key2); // compare them } /** * 比较两个 {@link WritableComparable} 对象. * * <p> * 按自然顺序比较两个 {@link WritableComparable} 对象,直接调用 * {@link Comparable#compareTo(Object)}. * * @param a * 左边 {@link WritableComparable} 对象 * @param b * 右边 {@link WritableComparable} 对象 * @return a.compareTo(b) */ @SuppressWarnings("unchecked") public int compare(WritableComparable a, WritableComparable b) { return a.compareTo(b); } /** * 重载 {@link Comparator#compare(Object, Object)} 方法,定义为 final,不允许子类重载. * * <p> * 本比较函数直接调用 {@link #compare(WritableComparable, WritableComparable)}. */ @Override public final int compare(Object a, Object b) { return compare((WritableComparable) a, (WritableComparable) b); } /** * 逐字节比较两组二进制数据. * * @param b1 * @param s1 * @param l1 * @param b2 * @param s2 * @param l2 * @return * @see RawComparator#compare(byte[], int, int, byte[], int, int) */ public static int compareBytes(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) { int end1 = s1 + l1; int end2 = s2 + l2; for (int i = s1, j = s2; i < end1 && j < end2; i++, j++) { int a = (b1[i] & 0xff); int b = (b2[j] & 0xff); if (a != b) { return a - b; } } return l1 - l2; } /** * 计算二进制数据的哈希值. * * @param bytes * byte数组 * @param length * 数据长度 * @return 哈希值 */ public static int hashBytes(byte[] bytes, int length) { int hash = 1; for (int i = 0; i < length; i++) { hash = (31 * hash) + (int) bytes[i]; } return hash; } /** * 从 byte 数组读取无符号 short. * * @param bytes * 字节数组 * @param start * 起始位置,读取内容bytes[start, start+1] * @return */ public static int readUnsignedShort(byte[] bytes, int start) { return (((bytes[start] & 0xff) << 8) + ((bytes[start + 1] & 0xff))); } /** * 从 byte 数组读取 int. * * @param bytes * 字节数组 * @param start * 起始位置,读取内容bytes[start, start+3] * @return */ public static int readInt(byte[] bytes, int start) { return (((bytes[start] & 0xff) << 24) + ((bytes[start + 1] & 0xff) << 16) + ((bytes[start + 2] & 0xff) << 8) + ((bytes[start + 3] & 0xff))); } /** * 从 byte 数组读取 float. * * @param bytes * 字节数组 * @param start * 起始位置,读取内容bytes[start, start+3] * @return */ public static float readFloat(byte[] bytes, int start) { return Float.intBitsToFloat(readInt(bytes, start)); } /** * 从 byte 数组读取 long. * * @param bytes * 字节数组 * @param start * 起始位置,读取内容bytes[start, start+7] * @return */ public static long readLong(byte[] bytes, int start) { return ((long) (readInt(bytes, start)) << 32) + (readInt(bytes, start + 4) & 0xFFFFFFFFL); } /** * 从 byte 数组读取double. * * @param bytes * 字节数组 * @param start * 起始位置,读取内容bytes[start, start+7] * @return */ public static double readDouble(byte[] bytes, int start) { return Double.longBitsToDouble(readLong(bytes, start)); } /** * 从byte数组读取压缩编码过的 long. * * @param bytes * 字节数组 * @param start * 起始位置,读取内容长度依 long 值的压缩编码特征而定 * @return * @throws IOException */ public static long readVLong(byte[] bytes, int start) throws IOException { int len = bytes[start]; if (len >= -112) { return len; } boolean isNegative = (len < -120); len = isNegative ? -(len + 120) : -(len + 112); if (start + 1 + len > bytes.length) { throw new IOException( "Not enough number of bytes for a zero-compressed integer"); } long i = 0; for (int idx = 0; idx < len; idx++) { i = i << 8; i = i | (bytes[start + 1 + idx] & 0xFF); } return (isNegative ? (i ^ -1L) : i); } /** * 从byte数组读取压缩编码过的 int. * * @param bytes * 字节数组 * @param start * 起始位置,读取内容长度依 int 值的压缩编码特征而定 * @return * @throws IOException */ public static int readVInt(byte[] bytes, int start) throws IOException { return (int) readVLong(bytes, start); } }