/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import static com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.apache.beam.sdk.TestUtils;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.junit.runners.Parameterized;
import org.junit.runners.Suite;
/**
* Tests for Sample transform.
*/
@RunWith(Suite.class)
@Suite.SuiteClasses({
SampleTest.PickAnyTest.class,
SampleTest.MiscTest.class
})
public class SampleTest {
private static final Integer[] EMPTY = new Integer[] { };
private static final Integer[] DATA = new Integer[] {1, 2, 3, 4, 5};
private static final Integer[] REPEATED_DATA = new Integer[] {1, 1, 2, 2, 3, 3, 4, 4, 5, 5};
/**
* Test variations for Sample transform.
*/
@RunWith(Parameterized.class)
public static class PickAnyTest {
@Rule
public final transient TestPipeline p = TestPipeline.create();
@Parameterized.Parameters(name = "limit_{1}")
public static Iterable<Object[]> data() throws IOException {
return ImmutableList.<Object[]>builder()
.add(
new Object[] {
TestUtils.NO_LINES,
0
},
new Object[] {
TestUtils.NO_LINES,
1
},
new Object[] {
TestUtils.LINES,
1
},
new Object[] {
TestUtils.LINES,
TestUtils.LINES.size() / 2
},
new Object[] {
TestUtils.LINES,
TestUtils.LINES.size() * 2
},
new Object[] {
TestUtils.LINES,
TestUtils.LINES.size() - 1
},
new Object[] {
TestUtils.LINES,
TestUtils.LINES.size()
},
new Object[] {
TestUtils.LINES,
TestUtils.LINES.size() + 1
}
)
.build();
}
@SuppressWarnings("DefaultAnnotationParam")
@Parameterized.Parameter(0)
public List<String> lines;
@Parameterized.Parameter(1)
public int limit;
private static class VerifyAnySample implements SerializableFunction<Iterable<String>, Void> {
private final List<String> lines;
private final int limit;
private VerifyAnySample(List<String> lines, int limit) {
this.lines = lines;
this.limit = limit;
}
@Override
public Void apply(Iterable<String> actualIter) {
final int expectedSize = Math.min(limit, lines.size());
// Make sure actual is the right length, and is a
// subset of expected.
List<String> actual = new ArrayList<>();
for (String s : actualIter) {
actual.add(s);
}
assertEquals(expectedSize, actual.size());
Set<String> actualAsSet = new TreeSet<>(actual);
Set<String> linesAsSet = new TreeSet<>(lines);
assertEquals(actual.size(), actualAsSet.size());
assertEquals(lines.size(), linesAsSet.size());
assertTrue(linesAsSet.containsAll(actualAsSet));
return null;
}
}
void runPickAnyTest(final List<String> lines, int limit) {
checkArgument(new HashSet<String>(lines).size() == lines.size(),
"Duplicates are unsupported.");
PCollection<String> input = p.apply(Create.of(lines)
.withCoder(StringUtf8Coder.of()));
PCollection<String> output =
input.apply(Sample.<String>any(limit));
PAssert.that(output)
.satisfies(new VerifyAnySample(lines, limit));
p.run();
}
@Test
@Category(ValidatesRunner.class)
public void testPickAny() {
runPickAnyTest(lines, limit);
}
}
/**
* Further tests for Sample transform.
*/
@RunWith(JUnit4.class)
public static class MiscTest {
@Rule
public final transient TestPipeline pipeline = TestPipeline.create();
/**
* Verifies that the result of a Sample operation contains the expected number of elements,
* and that those elements are a subset of the items in expected.
*/
@SuppressWarnings("rawtypes")
public static class VerifyCorrectSample<T extends Comparable>
implements SerializableFunction<Iterable<T>, Void> {
private T[] expectedValues;
private int expectedSize;
/**
* expectedSize is the number of elements that the Sample should contain. expected is the set
* of elements that the sample may contain.
*/
@SafeVarargs
VerifyCorrectSample(int expectedSize, T... expected) {
this.expectedValues = expected;
this.expectedSize = expectedSize;
}
/**
* expectedSize is the number of elements that the Sample should contain. expected is the set
* of elements that the sample may contain.
*/
VerifyCorrectSample(int expectedSize, Collection<T> expected) {
this.expectedValues = (T[]) expected.toArray();
this.expectedSize = expectedSize;
}
@Override
@SuppressWarnings("unchecked")
public Void apply(Iterable<T> in) {
List<T> actual = new ArrayList<>();
for (T elem : in) {
actual.add(elem);
}
assertEquals(expectedSize, actual.size());
Collections.sort(actual); // We assume that @expected is already sorted.
int i = 0; // Index into @expected
for (T s : actual) {
boolean matchFound = false;
for (; i < expectedValues.length; i++) {
if (s.equals(expectedValues[i])) {
matchFound = true;
break;
}
}
assertTrue("Invalid sample: " + Joiner.on(',').join(actual), matchFound);
i++; // Don't match the same element again.
}
return null;
}
}
@Test
@Category(ValidatesRunner.class)
public void testSample() {
PCollection<Integer> input =
pipeline.apply(
Create.of(ImmutableList.copyOf(DATA)).withCoder(BigEndianIntegerCoder.of()));
PCollection<Iterable<Integer>> output = input.apply(Sample.<Integer>fixedSizeGlobally(3));
PAssert.thatSingletonIterable(output)
.satisfies(new VerifyCorrectSample<>(3, DATA));
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testSampleEmpty() {
PCollection<Integer> input = pipeline.apply(Create.empty(BigEndianIntegerCoder.of()));
PCollection<Iterable<Integer>> output = input.apply(
Sample.<Integer>fixedSizeGlobally(3));
PAssert.thatSingletonIterable(output)
.satisfies(new VerifyCorrectSample<>(0, EMPTY));
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testSampleZero() {
PCollection<Integer> input = pipeline.apply(Create.of(ImmutableList.copyOf(DATA))
.withCoder(BigEndianIntegerCoder.of()));
PCollection<Iterable<Integer>> output = input.apply(
Sample.<Integer>fixedSizeGlobally(0));
PAssert.thatSingletonIterable(output)
.satisfies(new VerifyCorrectSample<>(0, DATA));
pipeline.run();
}
@Test
@Category(ValidatesRunner.class)
public void testSampleInsufficientElements() {
PCollection<Integer> input =
pipeline.apply(
Create.of(ImmutableList.copyOf(DATA)).withCoder(BigEndianIntegerCoder.of()));
PCollection<Iterable<Integer>> output = input.apply(
Sample.<Integer>fixedSizeGlobally(10));
PAssert.thatSingletonIterable(output)
.satisfies(new VerifyCorrectSample<>(5, DATA));
pipeline.run();
}
@Test(expected = IllegalArgumentException.class)
public void testSampleNegative() {
pipeline.enableAbandonedNodeEnforcement(false);
PCollection<Integer> input =
pipeline.apply(
Create.of(ImmutableList.copyOf(DATA)).withCoder(BigEndianIntegerCoder.of()));
input.apply(Sample.<Integer>fixedSizeGlobally(-1));
}
@Test
@Category(ValidatesRunner.class)
public void testSampleMultiplicity() {
PCollection<Integer> input =
pipeline.apply(
Create.of(ImmutableList.copyOf(REPEATED_DATA)).withCoder(BigEndianIntegerCoder.of()));
// At least one value must be selected with multiplicity.
PCollection<Iterable<Integer>> output = input.apply(
Sample.<Integer>fixedSizeGlobally(6));
PAssert.thatSingletonIterable(output)
.satisfies(new VerifyCorrectSample<>(6, REPEATED_DATA));
pipeline.run();
}
@Test
public void testSampleGetName() {
assertEquals("Sample.Any", Sample.<String>any(1).getName());
}
@Test
public void testDisplayData() {
PTransform<?, ?> sampleAny = Sample.any(1234);
DisplayData sampleAnyDisplayData = DisplayData.from(sampleAny);
assertThat(sampleAnyDisplayData, hasDisplayItem("sampleSize", 1234));
PTransform<?, ?> samplePerKey = Sample.fixedSizePerKey(2345);
DisplayData perKeyDisplayData = DisplayData.from(samplePerKey);
assertThat(perKeyDisplayData, hasDisplayItem("sampleSize", 2345));
}
}
}