/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.transforms.join; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Collections; import java.util.List; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.util.VarInt; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; /** * A UnionCoder encodes RawUnionValues. */ public class UnionCoder extends StructuredCoder<RawUnionValue> { // TODO: Think about how to integrate this with a schema object (i.e. // a tuple of tuple tags). /** * Builds a union coder with the given list of element coders. This list * corresponds to a mapping of union tag to Coder. Union tags start at 0. */ public static UnionCoder of(List<Coder<?>> elementCoders) { return new UnionCoder(elementCoders); } private int getIndexForEncoding(RawUnionValue union) { if (union == null) { throw new IllegalArgumentException("cannot encode a null tagged union"); } int index = union.getUnionTag(); if (index < 0 || index >= elementCoders.size()) { throw new IllegalArgumentException( "union value index " + index + " not in range [0.." + (elementCoders.size() - 1) + "]"); } return index; } @Override public void encode(RawUnionValue union, OutputStream outStream) throws IOException, CoderException { encode(union, outStream, Context.NESTED); } @SuppressWarnings("unchecked") @Override public void encode( RawUnionValue union, OutputStream outStream, Context context) throws IOException, CoderException { int index = getIndexForEncoding(union); // Write out the union tag. VarInt.encode(index, outStream); // Write out the actual value. Coder<Object> coder = (Coder<Object>) elementCoders.get(index); coder.encode( union.getValue(), outStream, context); } @Override public RawUnionValue decode(InputStream inStream) throws IOException, CoderException { return decode(inStream, Context.NESTED); } @Override public RawUnionValue decode(InputStream inStream, Context context) throws IOException, CoderException { int index = VarInt.decodeInt(inStream); Object value = elementCoders.get(index).decode(inStream, context); return new RawUnionValue(index, value); } @Override public List<? extends Coder<?>> getCoderArguments() { return Collections.emptyList(); } @Override public List<? extends Coder<?>> getComponents() { return elementCoders; } public List<? extends Coder<?>> getElementCoders() { return elementCoders; } /** * Since this coder uses elementCoders.get(index) and coders that are known to run in constant * time, we defer the return value to that coder. */ @Override public boolean isRegisterByteSizeObserverCheap(RawUnionValue union) { int index = getIndexForEncoding(union); @SuppressWarnings("unchecked") Coder<Object> coder = (Coder<Object>) elementCoders.get(index); return coder.isRegisterByteSizeObserverCheap(union.getValue()); } /** * Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. */ @Override public void registerByteSizeObserver( RawUnionValue union, ElementByteSizeObserver observer) throws Exception { int index = getIndexForEncoding(union); // Write out the union tag. observer.update(VarInt.getLength(index)); // Write out the actual value. @SuppressWarnings("unchecked") Coder<Object> coder = (Coder<Object>) elementCoders.get(index); coder.registerByteSizeObserver(union.getValue(), observer); } ///////////////////////////////////////////////////////////////////////////// private final List<Coder<?>> elementCoders; private UnionCoder(List<Coder<?>> elementCoders) { this.elementCoders = elementCoders; } @Override public void verifyDeterministic() throws NonDeterministicException { verifyDeterministic( this, "UnionCoder is only deterministic if all element coders are", elementCoders); } }