/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.lens.ml.algo.spark; import java.io.File; import java.io.FilenameFilter; import java.util.ArrayList; import java.util.List; import org.apache.lens.api.LensConf; import org.apache.lens.ml.algo.api.MLAlgo; import org.apache.lens.ml.algo.api.MLDriver; import org.apache.lens.ml.algo.lib.Algorithms; import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo; import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo; import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo; import org.apache.lens.ml.algo.spark.svm.SVMAlgo; import org.apache.lens.server.api.error.LensException; import org.apache.commons.lang.StringUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import lombok.extern.slf4j.Slf4j; /** * The Class SparkMLDriver. */ @Slf4j public class SparkMLDriver implements MLDriver { /** The owns spark context. */ private boolean ownsSparkContext = true; /** * The Enum SparkMasterMode. */ private enum SparkMasterMode { // Embedded mode used in tests /** The embedded. */ EMBEDDED, // Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster /** The yarn client. */ YARN_CLIENT, /** The yarn cluster. */ YARN_CLUSTER } /** The algorithms. */ private final Algorithms algorithms = new Algorithms(); /** The client mode. */ private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED; /** The is started. */ private boolean isStarted; /** The spark conf. */ private SparkConf sparkConf; /** The spark context. */ private JavaSparkContext sparkContext; /** * Use spark context. * * @param jsc the jsc */ public void useSparkContext(JavaSparkContext jsc) { ownsSparkContext = false; this.sparkContext = jsc; } /* * (non-Javadoc) * * @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String) */ @Override public boolean isAlgoSupported(String name) { return algorithms.isAlgoSupported(name); } /* * (non-Javadoc) * * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String) */ @Override public MLAlgo getAlgoInstance(String name) throws LensException { checkStarted(); if (!isAlgoSupported(name)) { return null; } MLAlgo algo = null; try { algo = algorithms.getAlgoForName(name); if (algo instanceof BaseSparkAlgo) { ((BaseSparkAlgo) algo).setSparkContext(sparkContext); } } catch (LensException exc) { log.error("Error creating algo object", exc); } return algo; } /** * Register algos. */ private void registerAlgos() { algorithms.register(NaiveBayesAlgo.class); algorithms.register(SVMAlgo.class); algorithms.register(LogisticRegressionAlgo.class); algorithms.register(DecisionTreeAlgo.class); } /* * (non-Javadoc) * * @see org.apache.lens.ml.MLDriver#init(org.apache.lens.api.LensConf) */ @Override public void init(LensConf conf) throws LensException { sparkConf = new SparkConf(); registerAlgos(); for (String key : conf.getProperties().keySet()) { if (key.startsWith("lens.ml.sparkdriver.")) { sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key)); } } String sparkAppMaster = sparkConf.get("spark.master"); if ("yarn-client".equalsIgnoreCase(sparkAppMaster)) { clientMode = SparkMasterMode.YARN_CLIENT; } else if ("yarn-cluster".equalsIgnoreCase(sparkAppMaster)) { clientMode = SparkMasterMode.YARN_CLUSTER; } else if ("local".equalsIgnoreCase(sparkAppMaster) || StringUtils.isBlank(sparkAppMaster)) { clientMode = SparkMasterMode.EMBEDDED; } else { throw new IllegalArgumentException("Invalid master mode " + sparkAppMaster); } if (clientMode == SparkMasterMode.YARN_CLIENT || clientMode == SparkMasterMode.YARN_CLUSTER) { String sparkHome = System.getenv("SPARK_HOME"); if (StringUtils.isNotBlank(sparkHome)) { sparkConf.setSparkHome(sparkHome); } // If SPARK_HOME is not set, SparkConf can read from the Lens-site.xml or System properties. if (StringUtils.isBlank(sparkConf.get("spark.home"))) { throw new IllegalArgumentException("Spark home is not set"); } // set spark.yarn.jar String yarnJars = System.getenv("SPARK_YARN_JAR"); if (StringUtils.isNotBlank(yarnJars)) { sparkConf.set("spark.yarn.jar", yarnJars); } log.info("Spark home is set to {}", sparkConf.get("spark.home")); log.info("spark.yarn.jar is set to {}", yarnJars); } sparkConf.setAppName("lens-ml"); } /* * (non-Javadoc) * * @see org.apache.lens.ml.MLDriver#start() */ @Override public void start() throws LensException { if (sparkContext == null) { sparkContext = new JavaSparkContext(sparkConf); } // Adding jars to spark context is only required when running in yarn-client mode if (clientMode != SparkMasterMode.EMBEDDED) { // TODO Figure out only necessary set of JARs to be added for HCatalog // Add hcatalog and hive jars String hiveLocation = System.getenv("HIVE_HOME"); if (StringUtils.isBlank(hiveLocation)) { throw new LensException("HIVE_HOME is not set"); } log.info("HIVE_HOME at {}", hiveLocation); File hiveLibDir = new File(hiveLocation, "lib"); FilenameFilter jarFileFilter = new FilenameFilter() { @Override public boolean accept(File file, String s) { return s.endsWith(".jar"); } }; List<String> jarFiles = new ArrayList<String>(); // Add hive jars for (File jarFile : hiveLibDir.listFiles(jarFileFilter)) { jarFiles.add(jarFile.getAbsolutePath()); log.info("Adding HIVE jar {}", jarFile.getAbsolutePath()); sparkContext.addJar(jarFile.getAbsolutePath()); } // Add hcatalog jars File hcatalogDir = new File(hiveLocation + "/hcatalog/share/hcatalog"); for (File jarFile : hcatalogDir.listFiles(jarFileFilter)) { jarFiles.add(jarFile.getAbsolutePath()); log.info("Adding HCATALOG jar {}", jarFile.getAbsolutePath()); sparkContext.addJar(jarFile.getAbsolutePath()); } // Add the current jar String[] lensSparkLibJars = JavaSparkContext.jarOfClass(SparkMLDriver.class); for (String lensSparkJar : lensSparkLibJars) { log.info("Adding Lens JAR {}", lensSparkJar); sparkContext.addJar(lensSparkJar); } } isStarted = true; log.info("Created Spark context for app: '{}', Spark master: {}", sparkContext.appName(), sparkContext.master()); } /* * (non-Javadoc) * * @see org.apache.lens.ml.MLDriver#stop() */ @Override public void stop() throws LensException { if (!isStarted) { log.warn("Spark driver was not started"); return; } isStarted = false; if (ownsSparkContext) { sparkContext.stop(); } log.info("Stopped spark context {}", this); } @Override public List<String> getAlgoNames() { return algorithms.getAlgorithmNames(); } /** * Check started. * * @throws LensException the lens exception */ public void checkStarted() throws LensException { if (!isStarted) { throw new LensException("Spark driver is not started yet"); } } public JavaSparkContext getSparkContext() { return sparkContext; } }