/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.util; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; import java.util.ArrayList; /** * Implements a WritableByteChannel that may contain multiple output shards. * * <p>This provides {@link #writeToShard}, which takes a shard number for * writing to a particular shard. * * <p>The channel is considered open if all downstream channels are open, and * closes all downstream channels when closed. */ public class ShardingWritableByteChannel implements WritableByteChannel { /** * Special shard number that causes a write to all shards. */ public static final int ALL_SHARDS = -2; private final ArrayList<WritableByteChannel> writers = new ArrayList<>(); /** * Returns the number of output shards. */ public int getNumShards() { return writers.size(); } /** * Adds another shard output channel. */ public void addChannel(WritableByteChannel writer) { writers.add(writer); } /** * Returns the WritableByteChannel associated with the given shard number. */ public WritableByteChannel getChannel(int shardNum) { return writers.get(shardNum); } /** * Writes the buffer to the given shard. * * <p>This does not change the current output shard. * * @return The total number of bytes written. If the shard number is * {@link #ALL_SHARDS}, then the total is the sum of each individual shard * write. */ public int writeToShard(int shardNum, ByteBuffer src) throws IOException { if (shardNum >= 0) { return writers.get(shardNum).write(src); } switch (shardNum) { case ALL_SHARDS: int size = 0; for (WritableByteChannel writer : writers) { size += writer.write(src); } return size; default: throw new IllegalArgumentException("Illegal shard number: " + shardNum); } } /** * Writes a buffer to all shards. * * <p>Same as calling {@code writeToShard(ALL_SHARDS, buf)}. */ @Override public int write(ByteBuffer src) throws IOException { return writeToShard(ALL_SHARDS, src); } @Override public boolean isOpen() { for (WritableByteChannel writer : writers) { if (!writer.isOpen()) { return false; } } return true; } @Override public void close() throws IOException { for (WritableByteChannel writer : writers) { writer.close(); } } }