package org.deeplearning4j.streaming.routes;
import com.google.common.io.Files;
import org.apache.camel.*;
import org.apache.camel.builder.RouteBuilder;
import org.apache.camel.component.kafka.KafkaConstants;
import org.apache.camel.model.ProcessorDefinition;
import org.apache.camel.test.junit4.CamelTestSupport;
import org.apache.commons.io.FileUtils;
import org.apache.commons.net.util.Base64;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.streaming.embedded.EmbeddedKafkaCluster;
import org.deeplearning4j.streaming.embedded.EmbeddedZookeeper;
import org.deeplearning4j.streaming.embedded.TestUtils;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.util.UUID;
/**
* Created by agibsonccc on 6/12/16.
*/
public class Dl4jServingRouteTest extends CamelTestSupport {
private static EmbeddedZookeeper zookeeper;
private static EmbeddedKafkaCluster kafkaCluster;
private static int zkPort;
public final static String LOCALHOST = "localhost";
private File dir = Files.createTempDir();
private DataSet next;
private static String topicName = "predict";
@BeforeClass
public static void init() throws Exception {
zkPort = TestUtils.getAvailablePort();
zookeeper = new EmbeddedZookeeper(zkPort);
zookeeper.startup();
kafkaCluster = new EmbeddedKafkaCluster(LOCALHOST + ":" + zkPort);
kafkaCluster.startup();
kafkaCluster.createTopics(topicName);
}
@AfterClass
public static void after2() {
kafkaCluster.shutdown();
zookeeper.shutdown();
}
@Override
protected RouteBuilder createRouteBuilder() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
next = iter.next();
next.normalizeZeroMeanZeroUnitVariance();
return new RouteBuilder() {
@Override
public void configure() throws Exception {
final String kafkaUri = String.format("kafka:%s?topic=%s&groupId=dl4j-serving",
kafkaCluster.getBrokerList(), topicName);
from("direct:start").process(new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
final INDArray arr = next.getFeatureMatrix();
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
Nd4j.write(arr, dos);
byte[] bytes = bos.toByteArray();
String base64 = Base64.encodeBase64String(bytes);
exchange.getIn().setBody(base64, String.class);
exchange.getIn().setHeader(KafkaConstants.KEY, UUID.randomUUID().toString());
exchange.getIn().setHeader(KafkaConstants.PARTITION_KEY, "1");
}
}).to(kafkaUri);
}
};
}
@Override
public boolean isUseDebugger() {
// must enable debugger
return true;
}
@Override
protected void debugBefore(Exchange exchange, Processor processor, ProcessorDefinition<?> definition, String id,
String shortName) {
// this method is invoked before we are about to enter the given processor
// from your Java editor you can just add a breakpoint in the code line below
log.info("Before " + definition + " with body " + exchange.getIn().getBody());
}
@After
public void after() throws Exception {
FileUtils.deleteDirectory(dir);
}
@Test
public void testServingRoute() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).iterations(5).seed(123).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER)
.activation(Activation.TANH).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER)
.activation(Activation.TANH).build())
.layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).nIn(2).nOut(3).build())
.backprop(true).pretrain(false).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.setListeners(new ScoreIterationListener(1));
network.fit(next);
String outputPath = "networktest.zip";
dir.mkdirs();
File tmp = new File(dir, "tmp.txt");
tmp.createNewFile();
tmp.deleteOnExit();
ModelSerializer.writeModel(network, outputPath, false);
final boolean computationGraph = false;
final String uri = String.format("file://%s?fileName=tmp.txt", dir.getAbsolutePath());
context.addRoutes(DL4jServeRouteBuilder.builder().computationGraph(computationGraph)
.zooKeeperPort(zookeeper.getPort()).kafkaBroker(kafkaCluster.getBrokerList())
.consumingTopic(topicName).modelUri(outputPath).outputUri(uri).finalProcessor(new Processor() {
@Override
public void process(Exchange exchange) throws Exception {
exchange.getIn().setBody(exchange.getIn().getBody().toString());
}
}).build());
context.startAllRoutes();
Endpoint endpoint = context.getRoutes().get(1).getConsumer().getEndpoint();
ConsumerTemplate consumerTemplate = context.createConsumerTemplate();
ProducerTemplate producerTemplate = context.createProducerTemplate();
producerTemplate.sendBody("direct:start", "hello");
consumerTemplate.receiveBody(endpoint, 3000, String.class);
String contents = FileUtils.readFileToString(new File(dir, "tmp.txt"));
}
}