package org.deeplearning4j.streaming.routes;
import lombok.AllArgsConstructor;
import lombok.Builder;
import org.apache.camel.Exchange;
import org.apache.camel.Processor;
import org.apache.camel.builder.RouteBuilder;
import org.apache.commons.net.util.Base64;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
/**
* Serve results from a kafka queue.
* The input to the route can either be a pre serialized ndarray
* or a normal ndarray itself.
*
* @author Adam Gibson
*/
@AllArgsConstructor
@Builder
public class DL4jServeRouteBuilder extends RouteBuilder {
protected String modelUri;
protected String kafkaBroker;
protected String consumingTopic;
protected boolean computationGraph;
protected String outputUri;
protected Processor finalProcessor;
protected String groupId = "dl4j-serving";
protected String zooKeeperHost = "localhost";
protected int zooKeeperPort = 2181;
//default no-op
protected Processor beforeProcessor;
/**
* <b>Called on initialization to build the routes using the fluent builder syntax.</b>
* <p/>
* This is a central method for RouteBuilder implementations to implement
* the routes using the Java fluent builder syntax.
*
* @throws Exception can be thrown during configuration
*/
@Override
public void configure() throws Exception {
if (groupId == null)
groupId = "dl4j-serving";
if (zooKeeperHost == null)
zooKeeperHost = "localhost";
String kafkaUri = String.format("kafka:%s?topic=%s&groupId=%s", kafkaBroker, consumingTopic, groupId);
if (beforeProcessor == null) {
beforeProcessor = new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
}
};
}
from(kafkaUri).process(beforeProcessor).process(new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
INDArray predict;
if (exchange.getIn().getBody() instanceof byte[]) {
byte[] o = (byte[]) exchange.getIn().getBody();
byte[] arr = Base64.decodeBase64(new String(o));
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bis);
predict = Nd4j.read(dis);
} else
predict = (INDArray) exchange.getIn().getBody();
if (computationGraph) {
ComputationGraph graph = ModelSerializer.restoreComputationGraph(modelUri);
INDArray[] output = graph.output(predict);
exchange.getOut().setBody(output);
exchange.getIn().setBody(output);
} else {
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(modelUri);
INDArray output = network.output(predict);
exchange.getOut().setBody(output);
exchange.getIn().setBody(output);
}
}
}).process(finalProcessor).to(outputUri);
}
}