package org.nd4j.parameterserver.client; import com.mashape.unirest.http.Unirest; import io.aeron.Aeron; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.nd4j.aeron.ipc.*; import org.nd4j.aeron.ipc.response.HostPortPublisher; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.model.MasterStatus; import org.nd4j.parameterserver.model.ServerTypeJson; import org.nd4j.parameterserver.model.SubscriberState; import org.nd4j.shade.jackson.databind.ObjectMapper; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; /** * Parameter server * client for * publishing and * retrieving ndarrays * * pushNDArray will send the given ndarray to the send url. * This is used for updating the master's current state. * * getArray() is used for retrieving the master ndarray's current * state from the parameter server. * * @author Adam Gibson */ @Data @AllArgsConstructor @Builder @Slf4j public class ParameterServerClient implements NDArrayCallback { //the url to send ndarrays to private String ndarraySendUrl; //the url to retrieve ndarrays from private String ndarrayRetrieveUrl; private AeronNDArraySubscriber subscriber; //host to listen on for the subscriber private String subscriberHost; //port to listen on for the subscriber private int subscriberPort; //the stream to listen on for the subscriber private int subscriberStream = 11; //the "current" ndarray private AtomicReference<INDArray> arr; private INDArray none = Nd4j.scalar(1.0); private AtomicBoolean running; private String masterStatusHost; private int masterStatusPort; private ObjectMapper objectMapper = new ObjectMapper(); private Aeron aeron; private boolean compressArray = true; /** * Tracks number of * arrays send to responder. * @return */ public int arraysSentToResponder() { if (objectMapper == null) objectMapper = new ObjectMapper(); try { String type = objectMapper.readValue( Unirest.get(String.format("http://%s:%d/type", masterStatusHost, masterStatusPort)).asJson() .getBody().toString(), ServerTypeJson.class).getType(); if (!type.equals("master")) throw new IllegalStateException("Wrong type " + type); Unirest.get(String.format("http://%s:%d/started", masterStatusHost, masterStatusPort)).asJson().getBody(); return objectMapper.readValue( Unirest.get(String.format("http://%s:%d/started", masterStatusHost, masterStatusPort)) .asJson().getBody().toString(), MasterStatus.class).getResponderN(); } catch (Exception e) { e.printStackTrace(); } return 0; } /** * Block the clint till ready * for next phase. * */ public void blockTillReady() { while (!isReadyForNext()) try { Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } /** * Returns true if the client is * ready for a next array or not * @return true if the client is * ready for the next array or not,false otherwise */ public boolean isReadyForNext() { if (objectMapper == null) objectMapper = new ObjectMapper(); try { int masterStream = Integer.parseInt(ndarraySendUrl.split(":")[2]); SubscriberState subscriberState = objectMapper.readValue(Unirest .get(String.format("http://%s:%d/state/%d", masterStatusHost, masterStatusPort, masterStream)) .asJson().getBody().toString(), SubscriberState.class); return subscriberState.isReady(); } catch (Exception e) { e.printStackTrace(); } return false; } /** * Sends a post request to the * status server to determine if the master node is started. * @return */ public boolean masterStarted() { if (objectMapper == null) objectMapper = new ObjectMapper(); try { String type = objectMapper.readValue( Unirest.get(String.format("http://%s:%d/type", masterStatusHost, masterStatusPort)).asJson() .getBody().toString(), ServerTypeJson.class).getType(); if (!type.equals("master")) throw new IllegalStateException("Wrong type " + type); Unirest.get(String.format("http://%s:%d/started", masterStatusHost, masterStatusPort)).asJson().getBody(); return objectMapper.readValue( Unirest.get(String.format("http://%s:%d/started", masterStatusHost, masterStatusPort)) .asJson().getBody().toString(), MasterStatus.class).started(); } catch (Exception e) { e.printStackTrace(); } return false; } /** * Push an ndarray message to the specified * ndarray send url in the form of: * host;port:stream * where stream is the stream for connecting * to a listening aeron server * @param message the array to send */ public void pushNDArrayMessage(NDArrayMessage message) { //start a subscriber that can send us ndarrays if (subscriber == null) { running = new AtomicBoolean(true); subscriber = AeronNDArraySubscriber.startSubscriber(aeron, subscriberHost, subscriberPort, this, subscriberStream, running); log.debug("Started parameter server client on " + subscriber.connectionUrl()); } String[] split = ndarraySendUrl.split(":"); int port = Integer.parseInt(split[1]); int streamToPublish = Integer.parseInt(split[2]); String channel = AeronUtil.aeronChannel(split[0], port); log.debug("Parameter server client publishing to " + ndarraySendUrl); try (AeronNDArrayPublisher publisher = AeronNDArrayPublisher.builder().streamId(streamToPublish) .compress(isCompressArray()).aeron(aeron).channel(channel).build()) { publisher.publish(message); } catch (Exception e) { throw new RuntimeException(e); } } /** * Push an ndarray to the specified * ndarray send url in the form of: * host;port:stream * where stream is the stream for connecting * to a listening aeron server * @param arr the array to send */ public void pushNDArray(INDArray arr) { pushNDArrayMessage(NDArrayMessage.wholeArrayUpdate(arr)); } /** * Get the connection url for the subscriber * in the format: * host:port:stream * @return the connection url for the subscriber * for this client */ public String connectionUrl() { return AeronConnectionInformation.of(subscriberHost, subscriberPort, subscriberStream).toString(); } /** * Get an ndarray from the * designated ndarray retrieve url. * This will "pull" the current ndarray * from the master * @return the current ndarray from the master. */ public INDArray getArray() { //start a subscriber that can send us ndarrays if (subscriber == null) { running = new AtomicBoolean(true); subscriber = AeronNDArraySubscriber.startSubscriber(aeron, subscriberHost, subscriberPort, this, subscriberStream, running); log.debug("Started parameter server client on " + subscriber.connectionUrl()); } if (arr == null) arr = new AtomicReference<>(none); log.debug("Parameter server client retrieving url from " + ndarrayRetrieveUrl); //note here that this is the "master url" String[] split = ndarrayRetrieveUrl.split(":"); //The response daemon is always the master daemon's port + 1 //A "master daemon" is one that holds both the //parameter averaging daemon AND the response daemon for being able to send //the "current state ndarray" int port = Integer.parseInt(split[1]); int streamToPublish = Integer.parseInt(split[2]); //the channel here is the master node host with the port + 1 //pointing at the response node where we can request ndarrays to be sent to //the listening daemon String channel = AeronUtil.aeronChannel(split[0], port); //publish the address of our subscriber //note here that we send the ndarray send url, because the //master also hosts try (HostPortPublisher hostPortPublisher = HostPortPublisher.builder().channel(channel).aeron(aeron) //note here that we send our subscriber's listening information .streamId(streamToPublish) .uriToSend(AeronConnectionInformation .of(subscriberHost, subscriberPort, subscriberStream) .toString()) .build()) { hostPortPublisher.send(); log.debug("Sent subscriber information " + AeronConnectionInformation .of(subscriberHost, subscriberPort, subscriberStream).toString()); //wait for array to be available while (arr.get() == none) { Thread.sleep(1000); log.info("Waiting on array to be updated."); } } catch (Exception e) { log.error("Error with publishing", e); } INDArray arr2 = arr.get(); arr.set(none); return arr2; } /** * A listener for ndarray message * * @param message the message for the callback */ @Override public void onNDArrayMessage(NDArrayMessage message) { INDArray arr = message.getArr(); //of note for ndarrays int[] dimensions = message.getDimensions(); boolean whole = dimensions.length == 1 && dimensions[0] == -1; if (!whole) onNDArrayPartial(arr, message.getIndex(), dimensions); else onNDArray(arr); } /** * Used for partial updates using tensor along * dimension * @param arr the array to count as an update * @param idx the index for the tensor along dimension * @param dimensions the dimensions to act on for the tensor along dimension */ @Override public void onNDArrayPartial(INDArray arr, long idx, int... dimensions) { INDArray get = this.arr.get(); get.tensorAlongDimension((int) idx, dimensions).assign(arr); } /** * Setup an ndarray * * @param arr */ @Override public void onNDArray(INDArray arr) { log.info("Received array"); this.arr.set(arr); } }