/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.runners.direct; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; import java.io.File; import java.io.FileReader; import java.io.Reader; import java.nio.CharBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.UUID; import org.apache.beam.runners.direct.WriteWithShardingFactory.CalculateShardsFn; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.DefaultFilenamePolicy; import org.apache.beam.sdk.io.FileBasedSink; import org.apache.beam.sdk.io.FileBasedSink.FilenamePolicy; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.LocalResources; import org.apache.beam.sdk.io.TextIO; import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests for {@link WriteWithShardingFactory}. */ @RunWith(JUnit4.class) public class WriteWithShardingFactoryTest { private static final int INPUT_SIZE = 10000; @Rule public TemporaryFolder tmp = new TemporaryFolder(); private WriteWithShardingFactory<Object> factory = new WriteWithShardingFactory<>(); @Rule public final TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); @Test public void dynamicallyReshardedWrite() throws Exception { List<String> strs = new ArrayList<>(INPUT_SIZE); for (int i = 0; i < INPUT_SIZE; i++) { strs.add(UUID.randomUUID().toString()); } Collections.shuffle(strs); String fileName = "resharded_write"; String targetLocation = tmp.getRoot().toPath().resolve(fileName).toString(); String targetLocationGlob = targetLocation + '*'; // TextIO is implemented in terms of the WriteFiles PTransform. When sharding is not specified, // resharding should be automatically applied p.apply(Create.of(strs)).apply(TextIO.write().to(targetLocation)); p.run(); List<Metadata> matches = FileSystems.match(targetLocationGlob).metadata(); List<String> actuals = new ArrayList<>(strs.size()); List<String> files = new ArrayList<>(strs.size()); for (Metadata match : matches) { String filename = match.resourceId().toString(); files.add(filename); CharBuffer buf = CharBuffer.allocate((int) new File(filename).length()); try (Reader reader = new FileReader(filename)) { reader.read(buf); buf.flip(); } String[] readStrs = buf.toString().split("\n"); for (String read : readStrs) { if (read.length() > 0) { actuals.add(read); } } } assertThat(actuals, containsInAnyOrder(strs.toArray())); assertThat( files, hasSize( allOf( greaterThan(1), lessThan( (int) (Math.log10(INPUT_SIZE) + WriteWithShardingFactory.MAX_RANDOM_EXTRA_SHARDS))))); } @Test public void withNoShardingSpecifiedReturnsNewTransform() { ResourceId outputDirectory = LocalResources.fromString("/foo", true /* isDirectory */); FilenamePolicy policy = DefaultFilenamePolicy.constructUsingStandardParameters( StaticValueProvider.of(outputDirectory), DefaultFilenamePolicy.DEFAULT_SHARD_TEMPLATE, ""); WriteFiles<Object> original = WriteFiles.to( new FileBasedSink<Object>(StaticValueProvider.of(outputDirectory), policy) { @Override public WriteOperation<Object> createWriteOperation() { throw new IllegalArgumentException("Should not be used"); } }); @SuppressWarnings("unchecked") PCollection<Object> objs = (PCollection) p.apply(Create.empty(VoidCoder.of())); AppliedPTransform<PCollection<Object>, PDone, WriteFiles<Object>> originalApplication = AppliedPTransform.of( "write", objs.expand(), Collections.<TupleTag<?>, PValue>emptyMap(), original, p); assertThat( factory.getReplacementTransform(originalApplication).getTransform(), not(equalTo((Object) original))); } @Test public void keyBasedOnCountFnWithNoElements() throws Exception { CalculateShardsFn fn = new CalculateShardsFn(0); DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); List<Integer> outputs = fnTester.processBundle(0L); assertThat( outputs, containsInAnyOrder(1)); } @Test public void keyBasedOnCountFnWithOneElement() throws Exception { CalculateShardsFn fn = new CalculateShardsFn(0); DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); List<Integer> outputs = fnTester.processBundle(1L); assertThat( outputs, containsInAnyOrder(1)); } @Test public void keyBasedOnCountFnWithTwoElements() throws Exception { CalculateShardsFn fn = new CalculateShardsFn(0); DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); List<Integer> outputs = fnTester.processBundle(2L); assertThat(outputs, containsInAnyOrder(2)); } @Test public void keyBasedOnCountFnFewElementsThreeShards() throws Exception { CalculateShardsFn fn = new CalculateShardsFn(0); DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); List<Integer> outputs = fnTester.processBundle(5L); assertThat(outputs, containsInAnyOrder(3)); } @Test public void keyBasedOnCountFnManyElements() throws Exception { DoFn<Long, Integer> fn = new CalculateShardsFn(0); DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); List<Integer> shard = fnTester.processBundle((long) Math.pow(10, 10)); assertThat(shard, containsInAnyOrder(10)); } @Test public void keyBasedOnCountFnFewElementsExtraShards() throws Exception { long countValue = (long) WriteWithShardingFactory.MIN_SHARDS_FOR_LOG + 3; PCollection<Long> inputCount = p.apply(Create.of(countValue)); PCollectionView<Long> elementCountView = PCollectionViews.singletonView( inputCount, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); CalculateShardsFn fn = new CalculateShardsFn(3); DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, countValue); List<Integer> kvs = fnTester.processBundle(10L); assertThat(kvs, containsInAnyOrder(6)); } @Test public void keyBasedOnCountFnManyElementsExtraShards() throws Exception { CalculateShardsFn fn = new CalculateShardsFn(3); DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); double count = Math.pow(10, 10); List<Integer> shards = fnTester.processBundle((long) count); assertThat(shards, containsInAnyOrder(13)); } }