/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.common.typeutils; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import static org.junit.Assert.*; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.util.TestLogger; import org.junit.Assert; import org.junit.Test; /** * Abstract test base for comparators. * * @param <T> */ public abstract class ComparatorTestBase<T> extends TestLogger { // Same as in the NormalizedKeySorter private static final int DEFAULT_MAX_NORMALIZED_KEY_LEN = 8; protected abstract TypeComparator<T> createComparator(boolean ascending); protected abstract TypeSerializer<T> createSerializer(); /** * Returns the sorted data set. * <p> * Note: every element needs to be *strictly greater* than the previous element. * * @return sorted test data set */ protected abstract T[] getSortedTestData(); // -------------------------------- test duplication ------------------------------------------ @Test public void testDuplicate() { try { TypeComparator<T> comparator = getComparator(true); TypeComparator<T> clone = comparator.duplicate(); T[] data = getSortedData(); comparator.setReference(data[0]); clone.setReference(data[1]); assertTrue("Comparator duplication does not work: Altering the reference in a duplicated comparator alters the original comparator's reference.", comparator.equalToReference(data[0]) && clone.equalToReference(data[1])); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); Assert.fail(e.getMessage()); } } // --------------------------------- equality tests ------------------------------------------- @Test public void testEquality() { testEquals(true); testEquals(false); } protected void testEquals(boolean ascending) { try { // Just setup two identical output/inputViews and go over their data to see if compare works TestOutputView out1; TestOutputView out2; TestInputView in1; TestInputView in2; // Now use comparator and compare TypeComparator<T> comparator = getComparator(ascending); T[] data = getSortedData(); for (T d : data) { out2 = new TestOutputView(); writeSortedData(d, out2); in2 = out2.getInputView(); out1 = new TestOutputView(); writeSortedData(d, out1); in1 = out1.getInputView(); assertTrue(comparator.compareSerialized(in1, in2) == 0); } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail("Exception in test: " + e.getMessage()); } } @Test public void testEqualityWithReference() { try { TypeSerializer<T> serializer = createSerializer(); TypeComparator<T> comparator = getComparator(true); TypeComparator<T> comparator2 = getComparator(true); T[] data = getSortedData(); for (T d : data) { comparator.setReference(d); // Make a copy to compare T copy = serializer.copy(d, serializer.createInstance()); // And then test equalTo and compareToReference method of comparator assertTrue(comparator.equalToReference(d)); comparator2.setReference(copy); assertTrue(comparator.compareToReference(comparator2) == 0); } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail("Exception in test: " + e.getMessage()); } } // --------------------------------- inequality tests ---------------------------------------- @Test public void testInequality() { testGreatSmallAscDesc(true, true); testGreatSmallAscDesc(false, true); testGreatSmallAscDesc(true, false); testGreatSmallAscDesc(false, false); } protected void testGreatSmallAscDesc(boolean ascending, boolean greater) { try { //split data into low and high part T[] data = getSortedData(); TypeComparator<T> comparator = getComparator(ascending); TestOutputView out1; TestOutputView out2; TestInputView in1; TestInputView in2; //compares every element in high with every element in low for (int x = 0; x < data.length - 1; x++) { for (int y = x + 1; y < data.length; y++) { out1 = new TestOutputView(); writeSortedData(data[x], out1); in1 = out1.getInputView(); out2 = new TestOutputView(); writeSortedData(data[y], out2); in2 = out2.getInputView(); if (greater && ascending) { assertTrue(comparator.compareSerialized(in1, in2) < 0); } if (greater && !ascending) { assertTrue(comparator.compareSerialized(in1, in2) > 0); } if (!greater && ascending) { assertTrue(comparator.compareSerialized(in2, in1) > 0); } if (!greater && !ascending) { assertTrue(comparator.compareSerialized(in2, in1) < 0); } } } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail("Exception in test: " + e.getMessage()); } } @Test public void testInequalityWithReference() { testGreatSmallAscDescWithReference(true, true); testGreatSmallAscDescWithReference(true, false); testGreatSmallAscDescWithReference(false, true); testGreatSmallAscDescWithReference(false, false); } protected void testGreatSmallAscDescWithReference(boolean ascending, boolean greater) { try { T[] data = getSortedData(); TypeComparator<T> comparatorLow = getComparator(ascending); TypeComparator<T> comparatorHigh = getComparator(ascending); //compares every element in high with every element in low for (int x = 0; x < data.length - 1; x++) { for (int y = x + 1; y < data.length; y++) { comparatorLow.setReference(data[x]); comparatorHigh.setReference(data[y]); if (greater && ascending) { assertTrue(comparatorLow.compareToReference(comparatorHigh) > 0); } if (greater && !ascending) { assertTrue(comparatorLow.compareToReference(comparatorHigh) < 0); } if (!greater && ascending) { assertTrue(comparatorHigh.compareToReference(comparatorLow) < 0); } if (!greater && !ascending) { assertTrue(comparatorHigh.compareToReference(comparatorLow) > 0); } } } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail("Exception in test: " + e.getMessage()); } } // --------------------------------- Normalized key tests ------------------------------------- // Help Function for setting up a memory segment and normalize the keys of the data array in it public MemorySegment setupNormalizedKeysMemSegment(T[] data, int normKeyLen, TypeComparator<T> comparator) { MemorySegment memSeg = MemorySegmentFactory.allocateUnpooledSegment(2048); // Setup normalized Keys in the memory segment int offset = 0; for (T e : data) { comparator.putNormalizedKey(e, memSeg, offset, normKeyLen); offset += normKeyLen; } return memSeg; } // Help Function which return a normalizedKeyLength, either as done in the NormalizedKeySorter or it's half private int getNormKeyLen(boolean halfLength, T[] data, TypeComparator<T> comparator) throws Exception { // Same as in the NormalizedKeySorter int keyLen = Math.min(comparator.getNormalizeKeyLen(), DEFAULT_MAX_NORMALIZED_KEY_LEN); if (keyLen < comparator.getNormalizeKeyLen()) { assertTrue(comparator.isNormalizedKeyPrefixOnly(keyLen)); } if (halfLength) { keyLen = keyLen / 2; assertTrue(comparator.isNormalizedKeyPrefixOnly(keyLen)); } return keyLen; } @Test public void testNormalizedKeysEqualsFullLength() { // Ascending or descending does not matter in this case TypeComparator<T> comparator = getComparator(true); if (!comparator.supportsNormalizedKey()) { return; } testNormalizedKeysEquals(false); } @Test public void testNormalizedKeysEqualsHalfLength() { TypeComparator<T> comparator = getComparator(true); if (!comparator.supportsNormalizedKey()) { return; } testNormalizedKeysEquals(true); } public void testNormalizedKeysEquals(boolean halfLength) { try { TypeComparator<T> comparator = getComparator(true); T[] data = getSortedData(); int normKeyLen = getNormKeyLen(halfLength, data, comparator); MemorySegment memSeg1 = setupNormalizedKeysMemSegment(data, normKeyLen, comparator); MemorySegment memSeg2 = setupNormalizedKeysMemSegment(data, normKeyLen, comparator); for (int i = 0; i < data.length; i++) { assertTrue(memSeg1.compare(memSeg2, i * normKeyLen, i * normKeyLen, normKeyLen) == 0); } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail("Exception in test: " + e.getMessage()); } } @Test public void testNormalizedKeysGreatSmallFullLength() { // ascending/descending in comparator doesn't matter for normalized keys TypeComparator<T> comparator = getComparator(true); if (!comparator.supportsNormalizedKey()) { return; } testNormalizedKeysGreatSmall(true, comparator, false); testNormalizedKeysGreatSmall(false, comparator, false); } @Test public void testNormalizedKeysGreatSmallAscDescHalfLength() { // ascending/descending in comparator doesn't matter for normalized keys TypeComparator<T> comparator = getComparator(true); if (!comparator.supportsNormalizedKey()) { return; } testNormalizedKeysGreatSmall(true, comparator, true); testNormalizedKeysGreatSmall(false, comparator, true); } protected void testNormalizedKeysGreatSmall(boolean greater, TypeComparator<T> comparator, boolean halfLength) { try { T[] data = getSortedData(); // Get the normKeyLen on which we are testing int normKeyLen = getNormKeyLen(halfLength, data, comparator); // Write the data into different 2 memory segments MemorySegment memSegLow = setupNormalizedKeysMemSegment(data, normKeyLen, comparator); MemorySegment memSegHigh = setupNormalizedKeysMemSegment(data, normKeyLen, comparator); boolean fullyDetermines = !comparator.isNormalizedKeyPrefixOnly(normKeyLen); // Compare every element with every bigger element for (int l = 0; l < data.length - 1; l++) { for (int h = l + 1; h < data.length; h++) { int cmp; if (greater) { cmp = memSegLow.compare(memSegHigh, l * normKeyLen, h * normKeyLen, normKeyLen); if (fullyDetermines) { assertTrue(cmp < 0); } else { assertTrue(cmp <= 0); } } else { cmp = memSegHigh.compare(memSegLow, h * normKeyLen, l * normKeyLen, normKeyLen); if (fullyDetermines) { assertTrue(cmp > 0); } else { assertTrue(cmp >= 0); } } } } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail("Exception in test: " + e.getMessage()); } } @Test public void testNormalizedKeyReadWriter() { try { T[] data = getSortedData(); T reuse = getSortedData()[0]; TypeComparator<T> comp1 = getComparator(true); if(!comp1.supportsSerializationWithKeyNormalization()){ return; } TypeComparator<T> comp2 = comp1.duplicate(); comp2.setReference(reuse); TestOutputView out = new TestOutputView(); TestInputView in; for (T value : data) { comp1.setReference(value); comp1.writeWithKeyNormalization(value, out); in = out.getInputView(); comp1.readWithKeyDenormalization(reuse, in); assertTrue(comp1.compareToReference(comp2) == 0); } } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); fail("Exception in test: " + e.getMessage()); } } // -------------------------------- Key extraction tests -------------------------------------- @Test @SuppressWarnings("unchecked") public void testKeyExtraction() { TypeComparator<T> comparator = getComparator(true); T[] data = getSortedData(); for (T value : data) { TypeComparator[] comparators = comparator.getFlatComparators(); Object[] extractedKeys = new Object[comparators.length]; int insertedKeys = comparator.extractKeys(value, extractedKeys, 0); assertTrue(insertedKeys == comparators.length); for (int i = 0; i < insertedKeys; i++) { // check if some keys are null, although this is not supported if (!supportsNullKeys()) { assertNotNull(extractedKeys[i]); } // compare the extracted key with itself as a basic check // if the extracted key corresponds to the comparator assertTrue(comparators[i].compare(extractedKeys[i], extractedKeys[i]) == 0); } } } // -------------------------------------------------------------------------------------------- protected void deepEquals(String message, T should, T is) { assertEquals(message, should, is); } // -------------------------------------------------------------------------------------------- protected TypeComparator<T> getComparator(boolean ascending) { TypeComparator<T> comparator = createComparator(ascending); if (comparator == null) { throw new RuntimeException("Test case corrupt. Returns null as comparator."); } return comparator; } protected T[] getSortedData() { T[] data = getSortedTestData(); if (data == null) { throw new RuntimeException("Test case corrupt. Returns null as test data."); } if (data.length < 2) { throw new RuntimeException("Test case does not provide enough sorted test data."); } return data; } protected TypeSerializer<T> getSerializer() { TypeSerializer<T> serializer = createSerializer(); if (serializer == null) { throw new RuntimeException("Test case corrupt. Returns null as serializer."); } return serializer; } protected void writeSortedData(T value, TestOutputView out) throws IOException { TypeSerializer<T> serializer = getSerializer(); // Write data into a outputView serializer.serialize(value, out); // This are the same tests like in the serializer // Just look if the data is really there after serialization, before testing comparator on it TestInputView in = out.getInputView(); assertTrue("No data available during deserialization.", in.available() > 0); T deserialized = serializer.deserialize(serializer.createInstance(), in); deepEquals("Deserialized value is wrong.", value, deserialized); } protected boolean supportsNullKeys() { return false; } // -------------------------------------------------------------------------------------------- public static final class TestOutputView extends DataOutputStream implements DataOutputView { public TestOutputView() { super(new ByteArrayOutputStream(4096)); } public TestInputView getInputView() { ByteArrayOutputStream baos = (ByteArrayOutputStream) out; return new TestInputView(baos.toByteArray()); } @Override public void skipBytesToWrite(int numBytes) throws IOException { for (int i = 0; i < numBytes; i++) { write(0); } } @Override public void write(DataInputView source, int numBytes) throws IOException { byte[] buffer = new byte[numBytes]; source.readFully(buffer); write(buffer); } } public static final class TestInputView extends DataInputStream implements DataInputView { public TestInputView(byte[] data) { super(new ByteArrayInputStream(data)); } @Override public void skipBytesToRead(int numBytes) throws IOException { while (numBytes > 0) { int skipped = skipBytes(numBytes); numBytes -= skipped; } } } }