/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.io.mqtt; import static com.google.common.base.Preconditions.checkArgument; import com.google.auto.value.AutoValue; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.NoSuchElementException; import java.util.UUID; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.fusesource.mqtt.client.BlockingConnection; import org.fusesource.mqtt.client.MQTT; import org.fusesource.mqtt.client.Message; import org.fusesource.mqtt.client.QoS; import org.fusesource.mqtt.client.Topic; import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * An unbounded source for MQTT broker. * * <h3>Reading from a MQTT broker</h3> * * <p>MqttIO source returns an unbounded {@link PCollection} containing MQTT message * payloads (as {@code byte[]}). * * <p>To configure a MQTT source, you have to provide a MQTT connection configuration including * {@code ClientId}, a {@code ServerURI}, a {@code Topic} pattern, and optionally {@code * username} and {@code password} to connect to the MQTT broker. The following * example illustrates various options for configuring the source: * * <pre>{@code * * pipeline.apply( * MqttIO.read() * .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create( * "tcp://host:11883", * "my_topic")) * * }</pre> * * <h3>Writing to a MQTT broker</h3> * * <p>MqttIO sink supports writing {@code byte[]} to a topic on a MQTT broker. * * <p>To configure a MQTT sink, as for the read, you have to specify a MQTT connection * configuration with {@code ServerURI}, {@code Topic}, ... * * <p>The MqttIO only fully supports QoS 1 (at least once). It's the only QoS level guaranteed * due to potential retries on bundles. * * <p>For instance: * * <pre>{@code * * pipeline * .apply(...) // provide PCollection<byte[]> * .MqttIO.write() * .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create( * "tcp://host:11883", * "my_topic")) * * }</pre> */ @Experimental public class MqttIO { private static final Logger LOG = LoggerFactory.getLogger(MqttIO.class); public static Read read() { return new AutoValue_MqttIO_Read.Builder() .setMaxReadTime(null).setMaxNumRecords(Long.MAX_VALUE).build(); } public static Write write() { return new AutoValue_MqttIO_Write.Builder() .setRetained(false) .build(); } private MqttIO() { } /** * A POJO describing a MQTT connection. */ @AutoValue public abstract static class ConnectionConfiguration implements Serializable { @Nullable abstract String getServerUri(); @Nullable abstract String getTopic(); @Nullable abstract String getClientId(); @Nullable abstract String getUsername(); @Nullable abstract String getPassword(); abstract Builder builder(); @AutoValue.Builder abstract static class Builder { abstract Builder setServerUri(String serverUri); abstract Builder setTopic(String topic); abstract Builder setClientId(String clientId); abstract Builder setUsername(String username); abstract Builder setPassword(String password); abstract ConnectionConfiguration build(); } /** * Describe a connection configuration to the MQTT broker. This method creates an unique random * MQTT client ID. * * @param serverUri The MQTT broker URI. * @param topic The MQTT getTopic pattern. * @return A connection configuration to the MQTT broker. */ public static ConnectionConfiguration create(String serverUri, String topic) { checkArgument(serverUri != null, "MqttIO.ConnectionConfiguration.create(serverUri, topic) called with null " + "serverUri"); checkArgument(topic != null, "MqttIO.ConnectionConfiguration.create(serverUri, topic) called with null " + "topic"); return new AutoValue_MqttIO_ConnectionConfiguration.Builder().setServerUri(serverUri) .setTopic(topic).build(); } /** * Describe a connection configuration to the MQTT broker. * * @param serverUri The MQTT broker URI. * @param topic The MQTT getTopic pattern. * @param clientId A client ID prefix, used to construct an unique client ID. * @return A connection configuration to the MQTT broker. */ public static ConnectionConfiguration create(String serverUri, String topic, String clientId) { checkArgument(serverUri != null, "MqttIO.ConnectionConfiguration.create(serverUri, topic) called with null " + "serverUri"); checkArgument(topic != null, "MqttIO.ConnectionConfiguration.create(serverUri, topic) called with null " + "topic"); checkArgument(clientId != null, "MqttIO.ConnectionConfiguration.create(serverUri," + "topic, clientId) called with null clientId"); return new AutoValue_MqttIO_ConnectionConfiguration.Builder().setServerUri(serverUri) .setTopic(topic).setClientId(clientId).build(); } public ConnectionConfiguration withUsername(String username) { return builder().setUsername(username).build(); } public ConnectionConfiguration withPassword(String password) { return builder().setPassword(password).build(); } private void populateDisplayData(DisplayData.Builder builder) { builder.add(DisplayData.item("serverUri", getServerUri())); builder.add(DisplayData.item("topic", getTopic())); builder.addIfNotNull(DisplayData.item("clientId", getClientId())); builder.addIfNotNull(DisplayData.item("username", getUsername())); } private MQTT createClient() throws Exception { LOG.debug("Creating MQTT client to {}", getServerUri()); MQTT client = new MQTT(); client.setHost(getServerUri()); if (getUsername() != null) { LOG.debug("MQTT client uses username {}", getUsername()); client.setUserName(getUsername()); client.setPassword(getPassword()); } if (getClientId() != null) { String clientId = getClientId() + "-" + UUID.randomUUID().toString(); LOG.debug("MQTT client id set to {}", clientId); client.setClientId(clientId); } else { String clientId = UUID.randomUUID().toString(); LOG.debug("MQTT client id set to random value {}", clientId); client.setClientId(clientId); } return client; } } /** * A {@link PTransform} to read from a MQTT broker. */ @AutoValue public abstract static class Read extends PTransform<PBegin, PCollection<byte[]>> { @Nullable abstract ConnectionConfiguration connectionConfiguration(); abstract long maxNumRecords(); @Nullable abstract Duration maxReadTime(); abstract Builder builder(); @AutoValue.Builder abstract static class Builder { abstract Builder setConnectionConfiguration(ConnectionConfiguration config); abstract Builder setMaxNumRecords(long maxNumRecords); abstract Builder setMaxReadTime(Duration maxReadTime); abstract Read build(); } /** * Define the MQTT connection configuration used to connect to the MQTT broker. */ public Read withConnectionConfiguration(ConnectionConfiguration configuration) { checkArgument(configuration != null, "MqttIO.read().withConnectionConfiguration(configuration) called with null " + "configuration or not called at all"); return builder().setConnectionConfiguration(configuration).build(); } /** * Define the max number of records received by the {@link Read}. * When this max number of records is lower than {@code Long.MAX_VALUE}, the {@link Read} * will provide a bounded {@link PCollection}. */ public Read withMaxNumRecords(long maxNumRecords) { checkArgument(maxReadTime() == null, "maxNumRecord and maxReadTime are exclusive"); return builder().setMaxNumRecords(maxNumRecords).build(); } /** * Define the max read time (duration) while the {@link Read} will receive messages. * When this max read time is not null, the {@link Read} will provide a bounded * {@link PCollection}. */ public Read withMaxReadTime(Duration maxReadTime) { checkArgument(maxNumRecords() == Long.MAX_VALUE, "maxNumRecord and maxReadTime are exclusive"); return builder().setMaxReadTime(maxReadTime).build(); } @Override public PCollection<byte[]> expand(PBegin input) { org.apache.beam.sdk.io.Read.Unbounded<byte[]> unbounded = org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this)); PTransform<PBegin, PCollection<byte[]>> transform = unbounded; if (maxNumRecords() != Long.MAX_VALUE) { transform = unbounded.withMaxNumRecords(maxNumRecords()); } else if (maxReadTime() != null) { transform = unbounded.withMaxReadTime(maxReadTime()); } return input.getPipeline().apply(transform); } @Override public void validate(PipelineOptions options) { // validation is performed in the ConnectionConfiguration create() } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); connectionConfiguration().populateDisplayData(builder); if (maxNumRecords() != Long.MAX_VALUE) { builder.add(DisplayData.item("maxNumRecords", maxNumRecords())); } builder.addIfNotNull(DisplayData.item("maxReadTime", maxReadTime())); } } /** * Checkpoint for an unbounded MQTT source. Consists of the MQTT messages waiting to be * acknowledged and oldest pending message timestamp. */ private static class MqttCheckpointMark implements UnboundedSource.CheckpointMark, Serializable { private String clientId; private Instant oldestMessageTimestamp = Instant.now(); private transient List<Message> messages = new ArrayList<>(); public MqttCheckpointMark() { } public void add(Message message, Instant timestamp) { if (timestamp.isBefore(oldestMessageTimestamp)) { oldestMessageTimestamp = timestamp; } messages.add(message); } @Override public void finalizeCheckpoint() { LOG.debug("Finalizing checkpoint acknowledging pending messages for client ID {}", clientId); for (Message message : messages) { try { message.ack(); } catch (Exception e) { LOG.warn("Can't ack message for client ID {}", clientId, e); } } oldestMessageTimestamp = Instant.now(); messages.clear(); } // set an empty list to messages when deserialize private void readObject(java.io.ObjectInputStream stream) throws java.io.IOException, ClassNotFoundException { messages = new ArrayList<>(); } } private static class UnboundedMqttSource extends UnboundedSource<byte[], MqttCheckpointMark> { private final Read spec; public UnboundedMqttSource(Read spec) { this.spec = spec; } @Override public UnboundedReader<byte[]> createReader(PipelineOptions options, MqttCheckpointMark checkpointMark) { return new UnboundedMqttReader(this, checkpointMark); } @Override public List<UnboundedMqttSource> split(int desiredNumSplits, PipelineOptions options) { // MQTT is based on a pub/sub pattern // so, if we create several subscribers on the same topic, they all will receive the same // message, resulting to duplicate messages in the PCollection. // So, for MQTT, we limit to number of split ot 1 (unique source). return Collections.singletonList(new UnboundedMqttSource(spec)); } @Override public void validate() { spec.validate(null); } @Override public void populateDisplayData(DisplayData.Builder builder) { spec.populateDisplayData(builder); } @Override public Coder<MqttCheckpointMark> getCheckpointMarkCoder() { return SerializableCoder.of(MqttCheckpointMark.class); } @Override public Coder<byte[]> getDefaultOutputCoder() { return ByteArrayCoder.of(); } } private static class UnboundedMqttReader extends UnboundedSource.UnboundedReader<byte[]> { private final UnboundedMqttSource source; private MQTT client; private BlockingConnection connection; private byte[] current; private Instant currentTimestamp; private MqttCheckpointMark checkpointMark; public UnboundedMqttReader(UnboundedMqttSource source, MqttCheckpointMark checkpointMark) { this.source = source; this.current = null; if (checkpointMark != null) { this.checkpointMark = checkpointMark; } else { this.checkpointMark = new MqttCheckpointMark(); } } @Override public boolean start() throws IOException { LOG.debug("Starting MQTT reader ..."); Read spec = source.spec; try { client = spec.connectionConfiguration().createClient(); LOG.debug("Reader client ID is {}", client.getClientId()); checkpointMark.clientId = client.getClientId().toString(); connection = client.blockingConnection(); connection.connect(); connection.subscribe(new Topic[]{ new Topic(spec.connectionConfiguration().getTopic(), QoS.AT_LEAST_ONCE)}); return advance(); } catch (Exception e) { throw new IOException(e); } } @Override public boolean advance() throws IOException { try { LOG.debug("MQTT reader (client ID {}) waiting message ...", client.getClientId()); Message message = connection.receive(); current = message.getPayload(); currentTimestamp = Instant.now(); checkpointMark.add(message, currentTimestamp); } catch (Exception e) { throw new IOException(e); } return true; } @Override public void close() throws IOException { LOG.debug("Closing MQTT reader (client ID {})", client.getClientId()); try { if (connection != null) { connection.disconnect(); } } catch (Exception e) { throw new IOException(e); } } @Override public Instant getWatermark() { return checkpointMark.oldestMessageTimestamp; } @Override public UnboundedSource.CheckpointMark getCheckpointMark() { return checkpointMark; } @Override public byte[] getCurrent() { if (current == null) { throw new NoSuchElementException(); } return current; } @Override public Instant getCurrentTimestamp() { if (current == null) { throw new NoSuchElementException(); } return currentTimestamp; } @Override public UnboundedMqttSource getCurrentSource() { return source; } } /** * A {@link PTransform} to write and send a message to a MQTT server. */ @AutoValue public abstract static class Write extends PTransform<PCollection<byte[]>, PDone> { @Nullable abstract ConnectionConfiguration connectionConfiguration(); abstract boolean retained(); abstract Builder builder(); @AutoValue.Builder abstract static class Builder { abstract Builder setConnectionConfiguration(ConnectionConfiguration configuration); abstract Builder setRetained(boolean retained); abstract Write build(); } /** * Define MQTT connection configuration used to connect to the MQTT broker. */ public Write withConnectionConfiguration(ConnectionConfiguration configuration) { checkArgument(configuration != null, "MqttIO.write().withConnectionConfiguration(configuration) called with null " + "configuration or not called at all"); return builder().setConnectionConfiguration(configuration).build(); } /** * Whether or not the publish message should be retained by the messaging engine. * Sending a message with the retained set to {@code false} will clear the * retained message from the server. The default value is {@code false}. * When a subscriber connects, it gets the latest retained message (else it doesn't get any * existing message, it will have to wait a new incoming message). * * @param retained Whether or not the messaging engine should retain the message. * @return The {@link Write} {@link PTransform} with the corresponding retained configuration. */ public Write withRetained(boolean retained) { return builder().setRetained(retained).build(); } @Override public PDone expand(PCollection<byte[]> input) { input.apply(ParDo.of(new WriteFn(this))); return PDone.in(input.getPipeline()); } @Override public void validate(PipelineOptions options) { // validate is done in connection configuration } @Override public void populateDisplayData(DisplayData.Builder builder) { connectionConfiguration().populateDisplayData(builder); builder.add(DisplayData.item("retained", retained())); } private static class WriteFn extends DoFn<byte[], Void> { private final Write spec; private transient MQTT client; private transient BlockingConnection connection; public WriteFn(Write spec) { this.spec = spec; } @Setup public void createMqttClient() throws Exception { LOG.debug("Starting MQTT writer"); client = spec.connectionConfiguration().createClient(); LOG.debug("MQTT writer client ID is {}", client.getClientId()); connection = client.blockingConnection(); connection.connect(); } @ProcessElement public void processElement(ProcessContext context) throws Exception { byte[] payload = context.element(); LOG.debug("Sending message {}", new String(payload)); connection.publish(spec.connectionConfiguration().getTopic(), payload, QoS.AT_LEAST_ONCE, false); } @Teardown public void closeMqttClient() throws Exception { if (connection != null) { LOG.debug("Disconnecting MQTT connection (client ID {})", client.getClientId()); connection.disconnect(); } } } } }