package org.nd4j.parameterserver.distributed; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.math3.util.Pair; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.enums.NodeRole; import org.nd4j.parameterserver.distributed.logic.*; import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; import org.nd4j.parameterserver.distributed.logic.sequence.BasicSequenceProvider; import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage; import org.nd4j.parameterserver.distributed.messages.*; import org.nd4j.parameterserver.distributed.messages.requests.*; import org.nd4j.parameterserver.distributed.training.TrainingDriver; import org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer; import org.nd4j.parameterserver.distributed.transport.MulticastTransport; import org.nd4j.parameterserver.distributed.transport.RoutedTransport; import org.nd4j.parameterserver.distributed.transport.Transport; import java.net.InterfaceAddress; import java.net.NetworkInterface; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicBoolean; /** * This is "special case" distributed P2P parameter server implementation, suitable for Spark Word2Vec/ParagraphVectors/DeepWalk implementations. * Aeron is used as backbone for messaging system here. * * Highlights: * a) It does ONLY one-way messaging. Clients are NOT getting anything back from ParamServer. * b) It works sharded. Amount of shards is defined in runtime. * c) Data replication strategy is defined in runtime. * d) It's supposed to be used as singleton in Spark environment ONLY. * * @author raver119@gmail.com */ @Slf4j public class VoidParameterServer { private static final VoidParameterServer INSTANCE = new VoidParameterServer(); @Getter protected volatile NodeRole nodeRole; protected volatile VoidConfiguration voidConfiguration; protected AtomicBoolean initLocker = new AtomicBoolean(false); protected AtomicBoolean initFinished = new AtomicBoolean(false); protected AtomicBoolean shutdownLocker = new AtomicBoolean(false); protected AtomicBoolean shutdownFinished = new AtomicBoolean(false); protected transient Transport transport; protected transient AtomicBoolean manualMode = new AtomicBoolean(false); protected transient AtomicBoolean runner = new AtomicBoolean(false); protected transient Thread[] processingThreads; protected transient Runnable[] processingRunnables; // FIXME: we want trainer to be configurable here protected transient TrainingDriver<? extends TrainingMessage> trainer; protected short shardIndex; protected Clipboard clipboard = new Clipboard(); protected Storage storage = new WordVectorStorage(); protected Map<String, Frame<TrainingMessage>> frames = new ConcurrentHashMap<>(); protected static final int numThreads = Runtime.getRuntime().availableProcessors() * 2; protected ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2); ////////////////////// SeqVec part protected static double MAX_EXP = 6; ////////////////////// end of SeqVec part protected VoidParameterServer() { nodeRole = NodeRole.NONE; } protected VoidParameterServer(@NonNull NodeRole nodeRole) { this.nodeRole = nodeRole; } protected VoidParameterServer(boolean manualMode) { this(); this.manualMode.set(manualMode); } public static VoidParameterServer getInstance() { return INSTANCE; } public void setTrainingDriver(@NonNull TrainingDriver<? extends TrainingMessage> trainer) { this.trainer = trainer; } /** * This method returns shardIndex value. * If current node is Shard - it reutrns it's value * If current node is client - it returns Shard index of paired Shard * @return */ public short getShardIndex() { return shardIndex; } protected void setIpPortForShard(String ip, int port) { transport.setIpAndPort(ip, port); } protected void setShardIndex(short idx) { shardIndex = idx; } protected Transport getTransport() { return transport; } protected INDArray getSyn0() { return storage.getArray(WordVectorStorage.SYN_0); } protected INDArray getSyn1() { return storage.getArray(WordVectorStorage.SYN_1); } protected INDArray getSyn1Neg() { return storage.getArray(WordVectorStorage.SYN_1_NEGATIVE); } protected INDArray getExpTable() { return storage.getArray(WordVectorStorage.EXP_TABLE); } protected INDArray getNegTable() { return storage.getArray(WordVectorStorage.NEGATIVE_TABLE); } protected void init(@NonNull VoidConfiguration voidConfiguration) { init(voidConfiguration, new RoutedTransport(), new SkipGramTrainer()); } /** * This method starts ParameterServer instance * * PLEASE NOTE: This method is blocking for first caller only */ public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, TrainingDriver<? extends TrainingMessage> trainer) { /** * Basic plan here: * start publishers/listeners/subscribers * determine role of the current instance: * Shard * Backup * Client * shutdown unwanted aeron helpers (according to role) * wait for incoming task queries (according to role * */ if (initFinished.get()) return; synchronized (this) { if (initLocker.compareAndSet(false, true)) { this.trainer = trainer; this.voidConfiguration = voidConfiguration; this.transport = transport; // first we need to check, if our current IP matches designated shards or backup if (nodeRole == NodeRole.NONE && (voidConfiguration.getForcedRole() == null || voidConfiguration.getForcedRole() == NodeRole.NONE)) { Pair<NodeRole, String> pair = null; if (voidConfiguration.getShardAddresses().size() == 1 && voidConfiguration.getShardAddresses().get(0).contains("127.0.0.1")) { pair = Pair.create(NodeRole.SHARD, voidConfiguration.getShardAddresses().get(0)); } else { pair = getRole(voidConfiguration, getLocalAddresses()); } nodeRole = pair.getFirst(); String ipAndPort = pair.getSecond(); String ip = "127.0.0.1"; int port = 0; // if we're Shard and have port enforced if (ipAndPort.contains(":")) { String[] split = ipAndPort.split(":"); ip = split[0]; port = Integer.valueOf(split[1]); } else { ip = ipAndPort; port = voidConfiguration.getUnicastPort(); } // if we're Shard here, we should define shardIndex if (nodeRole == NodeRole.SHARD && voidConfiguration.getShardAddresses().size() > 1) { short cnt = 0; for (String shard : voidConfiguration.getShardAddresses()) { String lIp = null; if (shard.contains(":")) { String[] split = ipAndPort.split(":"); lIp = split[0]; } else lIp = shard; if (lIp.equals(ip)) { shardIndex = cnt; } cnt++; } } this.transport.init(voidConfiguration, clipboard, nodeRole, ip, port, shardIndex); } else { if (nodeRole == NodeRole.NONE) nodeRole = voidConfiguration.getForcedRole(); this.transport.init(voidConfiguration, clipboard, nodeRole, "127.0.0.1", voidConfiguration.getUnicastPort(), shardIndex); } // TODO: we need real ip only if this is a shard *FOR NOW*, but later we'll need it for client as well // we launch message processing if we're not in debug mode if (!manualMode.get()) { processingThreads = new Thread[numThreads]; processingRunnables = new Runnable[numThreads]; for (int x = 0; x < numThreads; x++) { processingThreads[x] = new Thread(() -> { runner.set(true); while (runner.get()) { try { //VoidMessage message = transport.takeMessage(); // if (nodeRole == NodeRole.SHARD) // log.info("Processing message: {}", message.getClass().getSimpleName()); handleMessage(transport.takeMessage()); } catch (ND4JIllegalStateException e) { throw new RuntimeException(e); } catch (Exception e) { throw new RuntimeException(e); } } }); //executor.submit(processingRunnables[x); processingThreads[x].setDaemon(true); processingThreads[x].setName("VoidParameterServer messages handling thread"); processingThreads[x].start(); } } // TODO: uncomment this line on later stages //if (!(NodeRole.SHARD == nodeRole && voidConfiguration.getShardAddresses().size() == 1)) { log.info("Launching transport..."); transport.launch(Transport.ThreadingModel.DEDICATED_THREADS); //} trainer.init(this.voidConfiguration, this.transport, storage, clipboard); initFinished.set(true); } } } /** * This method is available for debug purposes only * * @param mode */ protected VoidParameterServer toggleManualMode(boolean mode) { manualMode.set(mode); return this; } /** * This method checks for designated role, according to local IP addresses and configuration passed into method * * @param voidConfiguration * @param localIPs * @return */ protected Pair<NodeRole, String> getRole(@NonNull VoidConfiguration voidConfiguration, @NonNull Collection<String> localIPs) { NodeRole result = NodeRole.CLIENT; for (String ip : voidConfiguration.getShardAddresses()) { String cleansed = ip.replaceAll(":.*", ""); if (localIPs.contains(cleansed)) return Pair.create(NodeRole.SHARD, ip); } if (voidConfiguration.getBackupAddresses() != null) for (String ip : voidConfiguration.getBackupAddresses()) { String cleansed = ip.replaceAll(":.*", ""); if (localIPs.contains(cleansed)) return Pair.create(NodeRole.BACKUP, ip); } String sparkIp = System.getenv("SPARK_PUBLIC_DNS"); log.info("Got [{}] as sparkIp", sparkIp); // local IP from pair is used for shard only, so we don't care return Pair.create(result, sparkIp + ":" + voidConfiguration.getUnicastPort()); } /** * This method initiates shutdown sequence for this instance. * * PLEASE NOTE: This method is blocking for first caller only */ public void shutdown() { /** * Probably we don't need this method in practice */ if (initLocker.get() && shutdownLocker.compareAndSet(false, true)) { // do shutdown log.info("Shutting down transport..."); // we just sending out ShutdownRequestMessage //transport.sendMessage(new ShutdownRequestMessage()); transport.shutdown(); executor.shutdown(); } } /** * This method returns set of local IP addresses available in system. * * PLEASE NOTE: loopback, disabled interfaces, IPv6 addresses are ignored here. * * @return */ public static Set<String> getLocalAddresses() { try { List<NetworkInterface> interfaces = Collections.list(NetworkInterface.getNetworkInterfaces()); Set<String> result = new HashSet<>(); for (NetworkInterface networkInterface : interfaces) { if (networkInterface.isLoopback() || !networkInterface.isUp()) continue; for (InterfaceAddress address : networkInterface.getInterfaceAddresses()) { String addr = address.getAddress().getHostAddress(); if (addr == null || addr.isEmpty() || addr.contains(":")) continue; result.add(addr); } } return result; } catch (Exception e) { throw new RuntimeException(e); } } // TODO: remove @NonNull check here protected void handleMessage(@NonNull VoidMessage message) { if (message == null) { // log.info("sI_{} got null message", getShardIndex()); return; } if (message.getTargetId() >= 0 && message.getTargetId() != shardIndex) { log.warn("sI_{}: Skipping message: [{}]; TargetIdx: [{}]", shardIndex, message.getClass().getSimpleName(), message.getTargetId()); return; } // log.info("sI_{}: Processing message: [{}]", shardIndex, message.getClass().getSimpleName()); message.attachContext(voidConfiguration, trainer, clipboard, transport, storage, nodeRole, shardIndex); message.processMessage(); } /** * This method handles Shards initialization * * PLEASE NOTE: This method is blocking */ // TODO: right now we support only columnar splits over tables public void initializeSeqVec(int vectorLength, int numWords, long seed, int columnsPerShard, boolean useHs, boolean useNegSampling) { InitializationRequestMessage dim = new InitializationRequestMessage(vectorLength, numWords, seed, useHs, useNegSampling, columnsPerShard); transport.sendMessage(dim); } /** * This method dispatches TrainingMessage to ParameterServer network * * PLEASE NOTE: This method is synchronized and *periodically* becomes blocking by design * @param message */ public synchronized void execDistributed(@NonNull TrainingMessage message) { /** * Basically we should batch messages coming from different TrainingFunctions on spark executor side here. * So we pack them into batches, and send over the wire to selected Shard */ Frame currentFrame; if ((currentFrame = frames.get(message.getClass().getSimpleName())) == null) { currentFrame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue()); frames.put(message.getClass().getSimpleName(), currentFrame); } currentFrame.stackMessage(message); // TODO: make this threshold variable if (currentFrame.size() >= 128) { transport.sendMessage(currentFrame); currentFrame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue()); frames.put(message.getClass().getSimpleName(), currentFrame); } //transport.sendMessage(message); } public void execDistributed(@NonNull Frame<? extends TrainingMessage> messages) { transport.sendMessage(messages); } public INDArray getVector(int rowIdx) { return getVector(WordVectorStorage.SYN_0, rowIdx); } /** * This method returns INDArray matching requested storageId value * * PLEASE NOTE: This method IS blocking * * @param rowIdx * @return */ public INDArray getVector(@NonNull Integer key, int rowIdx) { /** * we create VoidMessage, send it, and block until it gets responded */ VectorRequestMessage message = new VectorRequestMessage(key, rowIdx); MeaningfulMessage response = transport.sendMessageAndGetResponse(message); return response.getPayload(); } }