/* * Copyright 2010 Ted Dunning. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are * permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, this list * of conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY <COPYRIGHT HOLDER> ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * The views and conclusions contained in the software and documentation are those of the * authors and should not be interpreted as representing official policies, either expressed * or implied, of <copyright holder>. */ package mia.classifier.ch16.server; import com.google.common.base.Charsets; import mia.classifier.ch16.generated.Classifier; import org.apache.mahout.classifier.AbstractVectorClassifier; import org.apache.mahout.classifier.sgd.ModelSerializer; import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.server.TServer; import org.apache.thrift.server.TThreadPoolServer; import org.apache.thrift.transport.TServerSocket; import org.apache.thrift.transport.TTransportException; import org.apache.zookeeper.CreateMode; import org.apache.zookeeper.KeeperException; import org.apache.zookeeper.WatchedEvent; import org.apache.zookeeper.Watcher; import org.apache.zookeeper.ZooDefs; import org.apache.zookeeper.ZooKeeper; import org.apache.zookeeper.data.Stat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.InetAddress; import java.net.URL; import java.net.UnknownHostException; import java.util.Timer; import java.util.TimerTask; /** * Basic classification server. This server watches a Zookeeper cluster to * determine what models to load and what models to serve. * <p/> * The structure of data in ZK is as follows * <p/> * * <pre> * /model-service/ * current-servers/ Contains one file per live server. * model-to-serve Contains URL of live model. Reread on changes. * </pre> */ public class Server { public static final String ZK_BASE = "/model-service"; public static final String ZK_CURRENT_SERVERS = ZK_BASE + "/current-servers"; public static final String ZK_MODEL = ZK_BASE + "/model-to-serve"; private final TServer server; private final Logger log = LoggerFactory.getLogger(this.getClass()); private Timer timer; private ZooKeeper zk; private ServerWatcher modelWatcher = new ServerWatcher(); public Server(int port) throws TTransportException, IOException, InterruptedException, KeeperException { zk = new ZooKeeper("localhost", 2181, new Watcher() { @Override public void process(WatchedEvent watchedEvent) { // ignore } }); if (zk.exists(ZK_BASE, null) == null) { log.warn("Creating " + ZK_BASE); zk.create(ZK_BASE, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT); } if (zk.exists(ZK_CURRENT_SERVERS, null) == null) { log.warn("Creating " + ZK_CURRENT_SERVERS); zk.create(ZK_CURRENT_SERVERS, new byte[0], ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT); } zk.close(); Ops modelHandler = new Ops(); modelWatcher.setModelHandler(modelHandler); // schedule a retry every thirty seconds in case we can't reset the // watch timer = new Timer(); timer.scheduleAtFixedRate(new TimerTask() { @Override public void run() { modelWatcher.process(null); } }, 0, 3000); try { TServerSocket socket = new TServerSocket(port); Classifier.Processor processor = new Classifier.Processor( modelHandler); TProtocolFactory protocol = new TBinaryProtocol.Factory(true, true); server = new TThreadPoolServer( new TThreadPoolServer.Args(socket).processor(processor)); log.warn("Starting server on port {}", port); server.serve(); } finally { timer.cancel(); modelWatcher.close(); } } public void close() throws InterruptedException { log.warn("Exiting"); server.stop(); timer.cancel(); zk.close(); } public static void main(String[] args) throws IOException, TTransportException, InterruptedException, KeeperException { new Server(7908); } private static class ServerWatcher implements Watcher { private final Logger log = LoggerFactory.getLogger(this.getClass()); private Ops modelHandler; private String currentUrl = null; private int version; private ZooKeeper zk = null; private String hostname; private ServerWatcher() { hostname = null; try { hostname = InetAddress.getLocalHost().getHostName(); } catch (UnknownHostException e) { // continue with null hostname } if (hostname == null) { log.error("Must have hostname ... exiting"); System.exit(1); } } /** * Loads or reloads the model by looking at ZK to get the model URL, * then loads that URL to get the serialized model. * * @param watchedEvent * Ignored. */ @Override public void process(WatchedEvent watchedEvent) { if (zk == null) { try { zk = new ZooKeeper("localhost", 2181, null); } catch (IOException e) { zk = null; return; } } String url = null; try { // get new URL Stat stat = new Stat(); byte[] urlAsBytes = zk.getData(ZK_MODEL, this, stat); int latestVersion = stat.getVersion(); url = new String(urlAsBytes, Charsets.UTF_8); // check for change URL modelUrl = new URL(url); boolean needUpdate = false; if (currentUrl == null || latestVersion != version) { log.warn("Loading model from " + modelUrl); AbstractVectorClassifier model = ModelSerializer .readBinary(modelUrl.openStream(), OnlineLogisticRegression.class); modelHandler.setModel(model); currentUrl = url; version = latestVersion; log.info("done loading version " + version); needUpdate = true; } // update status file so clients find us String statusFile = ZK_CURRENT_SERVERS + "/" + hostname; // Tell ZK what model we loaded. We try to do this often because // we might have previously // updated a lingering ephemeral file belonging to a previous // incarnation. After // a short time, that ephemeral may disappear and we would need // to restore it try { zk.create(statusFile, modelUrl.toString().getBytes(Charsets.UTF_8), ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL); log.info("created server file {}", statusFile); } catch (KeeperException.NodeExistsException e) { if (needUpdate) { zk.setData(statusFile, modelUrl.toString().getBytes(Charsets.UTF_8), -1); log.info("updated server file {}", statusFile); } } catch (KeeperException e) { log.error("Couldn't write server status file"); } return; } catch (KeeperException.NoNodeException e) { // if no such data on ZK, log it and continue. log.error("Could not find model URL in ZK file: " + ZK_MODEL, e); return; } catch (KeeperException.SessionExpiredException e) { log.error("Session expired", e); zk = null; } catch (KeeperException e) { log.error("Failed to load model due to ZK exception", e); } catch (InterruptedException e) { log.error("Operation interrupted should never happen", e); } catch (IOException e) { log.error("Failed to load model from " + url, e); } // only get here on error log.warn("Clearing current URL due to error"); currentUrl = null; version = -1; } public void setModelHandler(Ops modelHandler) { this.modelHandler = modelHandler; } public void close() throws InterruptedException { zk.close(); } } }