/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.io; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.hash.HashFunction; import com.google.common.hash.Hashing; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; import java.util.NoSuchElementException; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.io.Read.Bounded; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.util.MimeTypes; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; /** * {@link PTransform}s for reading and writing TensorFlow TFRecord files. */ public class TFRecordIO { /** The default coder, which returns each record of the input file as a byte array. */ public static final Coder<byte[]> DEFAULT_BYTE_ARRAY_CODER = ByteArrayCoder.of(); /** * A {@link PTransform} that reads from a TFRecord file (or multiple TFRecord * files matching a pattern) and returns a {@link PCollection} containing * the decoding of each of the records of the TFRecord file(s) as a byte array. */ public static Read read() { return new AutoValue_TFRecordIO_Read.Builder() .setValidate(true) .setCompressionType(CompressionType.AUTO) .build(); } /** * A {@link PTransform} that writes a {@link PCollection} to TFRecord file (or * multiple TFRecord files matching a sharding pattern), with each * element of the input collection encoded into its own record. */ public static Write write() { return new AutoValue_TFRecordIO_Write.Builder() .setShardTemplate(null) .setFilenameSuffix(null) .setNumShards(0) .setCompressionType(CompressionType.NONE) .build(); } /** Implementation of {@link #read}. */ @AutoValue public abstract static class Read extends PTransform<PBegin, PCollection<byte[]>> { @Nullable abstract ValueProvider<String> getFilepattern(); abstract boolean getValidate(); abstract CompressionType getCompressionType(); abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { abstract Builder setFilepattern(ValueProvider<String> filepattern); abstract Builder setValidate(boolean validate); abstract Builder setCompressionType(CompressionType compressionType); abstract Read build(); } /** * Returns a transform for reading TFRecord files that reads from the file(s) * with the given filename or filename pattern. This can be a local path (if running locally), * or a Google Cloud Storage filename or filename pattern of the form * {@code "gs://<bucket>/<filepath>"} (if running locally or using remote * execution). Standard <a href="http://docs.oracle.com/javase/tutorial/essential/io/find.html" * >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported. */ public Read from(String filepattern) { return from(StaticValueProvider.of(filepattern)); } /** * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}. */ public Read from(ValueProvider<String> filepattern) { return toBuilder().setFilepattern(filepattern).build(); } /** * Returns a transform for reading TFRecord files that has GCS path validation on * pipeline creation disabled. * * <p>This can be useful in the case where the GCS input does not * exist at the pipeline creation time, but is expected to be * available at execution time. */ public Read withoutValidation() { return toBuilder().setValidate(false).build(); } /** * Returns a transform for reading TFRecord files that decompresses all input files * using the specified compression type. * * <p>If no compression type is specified, the default is * {@link TFRecordIO.CompressionType#AUTO}. * In this mode, the compression type of the file is determined by its extension * (e.g., {@code *.gz} is gzipped, {@code *.zlib} is zlib compressed, and all other * extensions are uncompressed). */ public Read withCompressionType(TFRecordIO.CompressionType compressionType) { return toBuilder().setCompressionType(compressionType).build(); } @Override public PCollection<byte[]> expand(PBegin input) { if (getFilepattern() == null) { throw new IllegalStateException( "Need to set the filepattern of a TFRecordIO.Read transform"); } if (getValidate()) { checkState(getFilepattern().isAccessible(), "Cannot validate with a RVP."); try { MatchResult matches = FileSystems.match(getFilepattern().get()); checkState( !matches.metadata().isEmpty(), "Unable to find any files matching %s", getFilepattern().get()); } catch (IOException e) { throw new IllegalStateException( String.format("Failed to validate %s", getFilepattern().get()), e); } } final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource()); PCollection<byte[]> pcol = input.getPipeline().apply("Read", read); // Honor the default output coder that would have been used by this PTransform. pcol.setCoder(getDefaultOutputCoder()); return pcol; } // Helper to create a source specific to the requested compression type. protected FileBasedSource<byte[]> getSource() { switch (getCompressionType()) { case NONE: return new TFRecordSource(getFilepattern()); case AUTO: return CompressedSource.from(new TFRecordSource(getFilepattern())); case GZIP: return CompressedSource.from(new TFRecordSource(getFilepattern())) .withDecompression(CompressedSource.CompressionMode.GZIP); case ZLIB: return CompressedSource.from(new TFRecordSource(getFilepattern())) .withDecompression(CompressedSource.CompressionMode.DEFLATE); default: throw new IllegalArgumentException("Unknown compression type: " + getCompressionType()); } } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); String filepatternDisplay = getFilepattern().isAccessible() ? getFilepattern().get() : getFilepattern().toString(); builder .add(DisplayData.item("compressionType", getCompressionType().toString()) .withLabel("Compression Type")) .addIfNotDefault(DisplayData.item("validation", getValidate()) .withLabel("Validation Enabled"), true) .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay) .withLabel("File Pattern")); } @Override protected Coder<byte[]> getDefaultOutputCoder() { return ByteArrayCoder.of(); } } ///////////////////////////////////////////////////////////////////////////// /** Implementation of {@link #write}. */ @AutoValue public abstract static class Write extends PTransform<PCollection<byte[]>, PDone> { /** The directory to which files will be written. */ @Nullable abstract ValueProvider<ResourceId> getOutputPrefix(); /** The suffix of each file written, combined with prefix and shardTemplate. */ @Nullable abstract String getFilenameSuffix(); /** Requested number of shards. 0 for automatic. */ abstract int getNumShards(); /** The shard template of each file written, combined with prefix and suffix. */ @Nullable abstract String getShardTemplate(); /** Option to indicate the output sink's compression type. Default is NONE. */ abstract CompressionType getCompressionType(); abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { abstract Builder setOutputPrefix(ValueProvider<ResourceId> outputPrefix); abstract Builder setShardTemplate(String shardTemplate); abstract Builder setFilenameSuffix(String filenameSuffix); abstract Builder setNumShards(int numShards); abstract Builder setCompressionType(CompressionType compressionType); abstract Write build(); } /** * Writes TFRecord file(s) with the given output prefix. The {@code prefix} will be used as a * to generate a {@link ResourceId} using any supported {@link FileSystem}. * * <p>In addition to their prefix, created files will have a shard identifier (see * {@link #withNumShards(int)}), and end in a common suffix, if given by * {@link #withSuffix(String)}. * * <p>For more information on filenames, see {@link DefaultFilenamePolicy}. */ public Write to(String outputPrefix) { return to(FileBasedSink.convertToFileResourceIfPossible(outputPrefix)); } /** * Writes TFRecord file(s) with a prefix given by the specified resource. * * <p>In addition to their prefix, created files will have a shard identifier (see * {@link #withNumShards(int)}), and end in a common suffix, if given by * {@link #withSuffix(String)}. * * <p>For more information on filenames, see {@link DefaultFilenamePolicy}. */ @Experimental(Kind.FILESYSTEM) public Write to(ResourceId outputResource) { return toResource(StaticValueProvider.of(outputResource)); } /** * Like {@link #to(ResourceId)}. */ @Experimental(Kind.FILESYSTEM) public Write toResource(ValueProvider<ResourceId> outputResource) { return toBuilder().setOutputPrefix(outputResource).build(); } /** * Writes to the file(s) with the given filename suffix. * * @see ShardNameTemplate */ public Write withSuffix(String suffix) { return toBuilder().setFilenameSuffix(suffix).build(); } /** * Writes to the provided number of shards. * * <p>Constraining the number of shards is likely to reduce * the performance of a pipeline. Setting this value is not recommended * unless you require a specific number of output files. * * @param numShards the number of shards to use, or 0 to let the system * decide. * @see ShardNameTemplate */ public Write withNumShards(int numShards) { checkArgument(numShards >= 0, "Number of shards %s must be >= 0", numShards); return toBuilder().setNumShards(numShards).build(); } /** * Uses the given shard name template. * * @see ShardNameTemplate */ public Write withShardNameTemplate(String shardTemplate) { return toBuilder().setShardTemplate(shardTemplate).build(); } /** * Forces a single file as output. * * <p>Constraining the number of shards is likely to reduce * the performance of a pipeline. Using this setting is not recommended * unless you truly require a single output file. * * <p>This is a shortcut for * {@code .withNumShards(1).withShardNameTemplate("")} */ public Write withoutSharding() { return withNumShards(1).withShardNameTemplate(""); } /** * Writes to output files using the specified compression type. * * <p>If no compression type is specified, the default is * {@link TFRecordIO.CompressionType#NONE}. * See {@link TFRecordIO.Read#withCompressionType} for more details. */ public Write withCompressionType(CompressionType compressionType) { return toBuilder().setCompressionType(compressionType).build(); } @Override public PDone expand(PCollection<byte[]> input) { checkState(getOutputPrefix() != null, "need to set the output prefix of a TFRecordIO.Write transform"); WriteFiles<byte[]> write = WriteFiles.to( new TFRecordSink( getOutputPrefix(), getShardTemplate(), getFilenameSuffix(), getCompressionType())); if (getNumShards() > 0) { write = write.withNumShards(getNumShards()); } return input.apply("Write", write); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); String outputPrefixString = null; if (getOutputPrefix().isAccessible()) { ResourceId dir = getOutputPrefix().get(); outputPrefixString = dir.toString(); } else { outputPrefixString = getOutputPrefix().toString(); } builder .add(DisplayData.item("filePrefix", outputPrefixString) .withLabel("Output File Prefix")) .addIfNotNull(DisplayData.item("fileSuffix", getFilenameSuffix()) .withLabel("Output File Suffix")) .addIfNotNull(DisplayData.item("shardNameTemplate", getShardTemplate()) .withLabel("Output Shard Name Template")) .addIfNotDefault(DisplayData.item("numShards", getNumShards()) .withLabel("Maximum Output Shards"), 0) .add(DisplayData.item("compressionType", getCompressionType().toString()) .withLabel("Compression Type")); } @Override protected Coder<Void> getDefaultOutputCoder() { return VoidCoder.of(); } } /** * Possible TFRecord file compression types. */ public enum CompressionType { /** * Automatically determine the compression type based on filename extension. */ AUTO(""), /** * Uncompressed. */ NONE(""), /** * GZipped. */ GZIP(".gz"), /** * ZLIB compressed. */ ZLIB(".zlib"); private String filenameSuffix; CompressionType(String suffix) { this.filenameSuffix = suffix; } /** * Determine if a given filename matches a compression type based on its extension. * @param filename the filename to match * @return true iff the filename ends with the compression type's known extension. */ public boolean matches(String filename) { return filename.toLowerCase().endsWith(filenameSuffix.toLowerCase()); } } ////////////////////////////////////////////////////////////////////////////// /** Disable construction of utility class. */ private TFRecordIO() {} /** * A {@link FileBasedSource} which can decode records in TFRecord files. */ @VisibleForTesting static class TFRecordSource extends FileBasedSource<byte[]> { @VisibleForTesting TFRecordSource(String fileSpec) { super(StaticValueProvider.of(fileSpec), 1L); } @VisibleForTesting TFRecordSource(ValueProvider<String> fileSpec) { super(fileSpec, Long.MAX_VALUE); } private TFRecordSource(Metadata metadata, long start, long end) { super(metadata, Long.MAX_VALUE, start, end); } @Override protected FileBasedSource<byte[]> createForSubrangeOfFile( Metadata metadata, long start, long end) { checkArgument(start == 0, "TFRecordSource is not splittable"); return new TFRecordSource(metadata, start, end); } @Override protected FileBasedReader<byte[]> createSingleFileReader(PipelineOptions options) { return new TFRecordReader(this); } @Override public Coder<byte[]> getDefaultOutputCoder() { return DEFAULT_BYTE_ARRAY_CODER; } @Override protected boolean isSplittable() throws Exception { // TFRecord files are not splittable return false; } /** * A {@link org.apache.beam.sdk.io.FileBasedSource.FileBasedReader FileBasedReader} * which can decode records in TFRecord files. * * <p>See {@link TFRecordIO.TFRecordSource} for further details. */ @VisibleForTesting static class TFRecordReader extends FileBasedReader<byte[]> { private long startOfRecord; private volatile long startOfNextRecord; private volatile boolean elementIsPresent; private byte[] currentValue; private ReadableByteChannel inChannel; private TFRecordCodec codec; private TFRecordReader(TFRecordSource source) { super(source); } @Override public boolean allowsDynamicSplitting() { /* TFRecords cannot be dynamically split. */ return false; } @Override protected long getCurrentOffset() throws NoSuchElementException { if (!elementIsPresent) { throw new NoSuchElementException(); } return startOfRecord; } @Override public byte[] getCurrent() throws NoSuchElementException { if (!elementIsPresent) { throw new NoSuchElementException(); } return currentValue; } @Override protected void startReading(ReadableByteChannel channel) throws IOException { this.inChannel = channel; this.codec = new TFRecordCodec(); } @Override protected boolean readNextRecord() throws IOException { startOfRecord = startOfNextRecord; currentValue = codec.read(inChannel); if (currentValue != null) { elementIsPresent = true; startOfNextRecord = startOfRecord + codec.recordLength(currentValue); return true; } else { elementIsPresent = false; return false; } } } } /** * A {@link FileBasedSink} for TFRecord files. Produces TFRecord files. */ @VisibleForTesting static class TFRecordSink extends FileBasedSink<byte[]> { @VisibleForTesting TFRecordSink(ValueProvider<ResourceId> outputPrefix, @Nullable String shardTemplate, @Nullable String suffix, TFRecordIO.CompressionType compressionType) { super( outputPrefix, DefaultFilenamePolicy.constructUsingStandardParameters( outputPrefix, shardTemplate, suffix), writableByteChannelFactory(compressionType)); } private static class ExtractDirectory implements SerializableFunction<ResourceId, ResourceId> { @Override public ResourceId apply(ResourceId input) { return input.getCurrentDirectory(); } } @Override public WriteOperation<byte[]> createWriteOperation() { return new TFRecordWriteOperation(this); } private static WritableByteChannelFactory writableByteChannelFactory( TFRecordIO.CompressionType compressionType) { switch (compressionType) { case AUTO: throw new IllegalArgumentException("Unsupported compression type AUTO"); case NONE: return CompressionType.UNCOMPRESSED; case GZIP: return CompressionType.GZIP; case ZLIB: return CompressionType.DEFLATE; } return CompressionType.UNCOMPRESSED; } /** * A {@link WriteOperation * WriteOperation} for TFRecord files. */ private static class TFRecordWriteOperation extends WriteOperation<byte[]> { private TFRecordWriteOperation(TFRecordSink sink) { super(sink); } @Override public Writer<byte[]> createWriter() throws Exception { return new TFRecordWriter(this); } } /** * A {@link Writer Writer} * for TFRecord files. */ private static class TFRecordWriter extends Writer<byte[]> { private WritableByteChannel outChannel; private TFRecordCodec codec; private TFRecordWriter(WriteOperation<byte[]> writeOperation) { super(writeOperation, MimeTypes.BINARY); } @Override protected void prepareWrite(WritableByteChannel channel) throws Exception { this.outChannel = channel; this.codec = new TFRecordCodec(); } @Override public void write(byte[] value) throws Exception { codec.write(outChannel, value); } } } ////////////////////////////////////////////////////////////////////////////// /** * Codec for TFRecords file format. * See https://www.tensorflow.org/api_guides/python/python_io#TFRecords_Format_Details */ private static class TFRecordCodec { private static final int HEADER_LEN = (Long.SIZE + Integer.SIZE) / Byte.SIZE; private static final int FOOTER_LEN = Integer.SIZE / Byte.SIZE; private static HashFunction crc32c = Hashing.crc32c(); private ByteBuffer header = ByteBuffer.allocate(HEADER_LEN).order(ByteOrder.LITTLE_ENDIAN); private ByteBuffer footer = ByteBuffer.allocate(FOOTER_LEN).order(ByteOrder.LITTLE_ENDIAN); private int mask(int crc) { return ((crc >>> 15) | (crc << 17)) + 0xa282ead8; } private int hashLong(long x) { return mask(crc32c.hashLong(x).asInt()); } private int hashBytes(byte[] x) { return mask(crc32c.hashBytes(x).asInt()); } public int recordLength(byte[] data) { return HEADER_LEN + data.length + FOOTER_LEN; } public byte[] read(ReadableByteChannel inChannel) throws IOException { header.clear(); int headerBytes = inChannel.read(header); if (headerBytes <= 0) { return null; } checkState( headerBytes == HEADER_LEN, "Not a valid TFRecord. Fewer than 12 bytes."); header.rewind(); long length = header.getLong(); int maskedCrc32OfLength = header.getInt(); checkState( hashLong(length) == maskedCrc32OfLength, "Mismatch of length mask"); ByteBuffer data = ByteBuffer.allocate((int) length); checkState(inChannel.read(data) == length, "Invalid data"); footer.clear(); inChannel.read(footer); footer.rewind(); int maskedCrc32OfData = footer.getInt(); checkState( hashBytes(data.array()) == maskedCrc32OfData, "Mismatch of data mask"); return data.array(); } public void write(WritableByteChannel outChannel, byte[] data) throws IOException { int maskedCrc32OfLength = hashLong(data.length); int maskedCrc32OfData = hashBytes(data); header.clear(); header.putLong(data.length).putInt(maskedCrc32OfLength); header.rewind(); outChannel.write(header); outChannel.write(ByteBuffer.wrap(data)); footer.clear(); footer.putInt(maskedCrc32OfData); footer.rewind(); outChannel.write(footer); } } }