package org.deeplearning4j.streaming.embedded; /** * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the * License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ import kafka.admin.AdminUtils; import kafka.admin.RackAwareMode; import kafka.metrics.KafkaMetricsReporter; import kafka.server.KafkaConfig; import kafka.server.KafkaServer; import kafka.utils.ZkUtils; import scala.Option; import scala.collection.mutable.Buffer; import java.io.File; import java.io.FileNotFoundException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Properties; public class EmbeddedKafkaCluster { private final List<Integer> ports; private final String zkConnection; private final Properties baseProperties; private final String brokerList; private final List<KafkaServer> brokers; private final List<File> logDirs; public EmbeddedKafkaCluster(String zkConnection) { this(zkConnection, new Properties()); } public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties) { this(zkConnection, baseProperties, Collections.singletonList(-1)); } public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties, List<Integer> ports) { this.zkConnection = zkConnection; this.ports = resolvePorts(ports); this.baseProperties = baseProperties; this.brokers = new ArrayList<KafkaServer>(); this.logDirs = new ArrayList<File>(); this.brokerList = constructBrokerList(this.ports); } public ZkUtils getZkUtils() { for (KafkaServer server : brokers) { return server.zkUtils(); } return null; } public void createTopic(String topic, int partitionCount) { AdminUtils.createTopic(getZkUtils(), topic, partitionCount, 1, new Properties(), RackAwareMode.Enforced$.MODULE$); } public void createTopics(String... topics) { for (String topic : topics) { AdminUtils.createTopic(getZkUtils(), topic, 2, 1, new Properties(), RackAwareMode.Enforced$.MODULE$); } } private List<Integer> resolvePorts(List<Integer> ports) { List<Integer> resolvedPorts = new ArrayList<Integer>(ports.size()); for (Integer port : ports) { resolvedPorts.add(resolvePort(port)); } return resolvedPorts; } private int resolvePort(int port) { if (port == -1) { return TestUtils.getAvailablePort(); } return port; } private String constructBrokerList(List<Integer> ports) { StringBuilder sb = new StringBuilder(); for (Integer port : ports) { if (sb.length() > 0) { sb.append(","); } sb.append("localhost:").append(port); } return sb.toString(); } public void startup() { for (int i = 0; i < ports.size(); i++) { Integer port = ports.get(i); File logDir = TestUtils.constructTempDir("kafka-local"); Properties properties = new Properties(); properties.putAll(baseProperties); properties.setProperty("zookeeper.connect", zkConnection); properties.setProperty("broker.id", String.valueOf(i + 1)); properties.setProperty("host.name", "localhost"); properties.setProperty("port", Integer.toString(port)); properties.setProperty("log.dir", logDir.getAbsolutePath()); properties.setProperty("num.partitions", String.valueOf(1)); properties.setProperty("auto.create.topics.enable", String.valueOf(Boolean.TRUE)); System.out.println("EmbeddedKafkaCluster: local directory: " + logDir.getAbsolutePath()); properties.setProperty("log.flush.interval.messages", String.valueOf(1)); KafkaServer broker = startBroker(properties); brokers.add(broker); logDirs.add(logDir); } } private KafkaServer startBroker(Properties props) { List<KafkaMetricsReporter> kmrList = new ArrayList<>(); Buffer<KafkaMetricsReporter> metricsList = scala.collection.JavaConversions.asScalaBuffer(kmrList); KafkaServer server = new KafkaServer(new KafkaConfig(props), new SystemTime(), Option.<String>empty(), metricsList); server.startup(); return server; } public Properties getProps() { Properties props = new Properties(); props.putAll(baseProperties); props.put("metadata.broker.list", brokerList); props.put("zookeeper.connect", zkConnection); return props; } public String getBrokerList() { return brokerList; } public List<Integer> getPorts() { return ports; } public String getZkConnection() { return zkConnection; } public void shutdown() { for (KafkaServer broker : brokers) { try { broker.shutdown(); } catch (Exception e) { e.printStackTrace(); } } for (File logDir : logDirs) { try { TestUtils.deleteFile(logDir); } catch (FileNotFoundException e) { e.printStackTrace(); } } } @Override public String toString() { final StringBuilder sb = new StringBuilder("EmbeddedKafkaCluster{"); sb.append("brokerList='").append(brokerList).append('\''); sb.append('}'); return sb.toString(); } }