/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.zeppelin.spark; import static scala.collection.JavaConversions.asJavaCollection; import static scala.collection.JavaConversions.asJavaIterable; import static scala.collection.JavaConversions.collectionAsScalaIterable; import java.io.PrintStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import org.apache.spark.SparkContext; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SQLContext.QueryExecution; import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.hive.HiveContext; import org.apache.zeppelin.display.AngularObject; import org.apache.zeppelin.display.AngularObjectRegistry; import org.apache.zeppelin.display.AngularObjectWatcher; import org.apache.zeppelin.display.GUI; import org.apache.zeppelin.display.Input.ParamOption; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterContextRunner; import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.spark.dep.DependencyResolver; import scala.Tuple2; import scala.Unit; import scala.collection.Iterable; /** * Spark context for zeppelin. */ public class ZeppelinContext extends HashMap<String, Object> { private DependencyResolver dep; private PrintStream out; private InterpreterContext interpreterContext; private int maxResult; public ZeppelinContext(SparkContext sc, SQLContext sql, InterpreterContext interpreterContext, DependencyResolver dep, PrintStream printStream, int maxResult) { this.sc = sc; this.sqlContext = sql; this.interpreterContext = interpreterContext; this.dep = dep; this.out = printStream; this.maxResult = maxResult; } public SparkContext sc; public SQLContext sqlContext; public HiveContext hiveContext; private GUI gui; /** * Load dependency for interpreter and runtime (driver). * And distribute them to spark cluster (sc.add()) * * @param artifact "group:artifact:version" or file path like "/somepath/your.jar" * @return * @throws Exception */ public Iterable<String> load(String artifact) throws Exception { return collectionAsScalaIterable(dep.load(artifact, true)); } /** * Load dependency and it's transitive dependencies for interpreter and runtime (driver). * And distribute them to spark cluster (sc.add()) * * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. * @return * @throws Exception */ public Iterable<String> load(String artifact, scala.collection.Iterable<String> excludes) throws Exception { return collectionAsScalaIterable( dep.load(artifact, asJavaCollection(excludes), true)); } /** * Load dependency and it's transitive dependencies for interpreter and runtime (driver). * And distribute them to spark cluster (sc.add()) * * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. * @return * @throws Exception */ public Iterable<String> load(String artifact, Collection<String> excludes) throws Exception { return collectionAsScalaIterable(dep.load(artifact, excludes, true)); } /** * Load dependency for interpreter and runtime, and then add to sparkContext. * But not adding them to spark cluster * * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" * @return * @throws Exception */ public Iterable<String> loadLocal(String artifact) throws Exception { return collectionAsScalaIterable(dep.load(artifact, false)); } /** * Load dependency and it's transitive dependencies and then add to sparkContext. * But not adding them to spark cluster * * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. * @return * @throws Exception */ public Iterable<String> loadLocal(String artifact, scala.collection.Iterable<String> excludes) throws Exception { return collectionAsScalaIterable(dep.load(artifact, asJavaCollection(excludes), false)); } /** * Load dependency and it's transitive dependencies and then add to sparkContext. * But not adding them to spark cluster * * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. * @return * @throws Exception */ public Iterable<String> loadLocal(String artifact, Collection<String> excludes) throws Exception { return collectionAsScalaIterable(dep.load(artifact, excludes, false)); } /** * Add maven repository * * @param id id of repository ex) oss, local, snapshot * @param url url of repository. supported protocol : file, http, https */ public void addRepo(String id, String url) { addRepo(id, url, false); } /** * Add maven repository * * @param id id of repository * @param url url of repository. supported protocol : file, http, https * @param snapshot true if it is snapshot repository */ public void addRepo(String id, String url, boolean snapshot) { dep.addRepo(id, url, snapshot); } /** * Remove maven repository by id * @param id id of repository */ public void removeRepo(String id){ dep.delRepo(id); } /** * Load dependency only interpreter. * * @param name * @return */ public Object input(String name) { return input(name, ""); } public Object input(String name, Object defaultValue) { return gui.input(name, defaultValue); } public Object select(String name, scala.collection.Iterable<Tuple2<Object, String>> options) { return select(name, "", options); } public Object select(String name, Object defaultValue, scala.collection.Iterable<Tuple2<Object, String>> options) { int n = options.size(); ParamOption[] paramOptions = new ParamOption[n]; Iterator<Tuple2<Object, String>> it = asJavaIterable(options).iterator(); int i = 0; while (it.hasNext()) { Tuple2<Object, String> valueAndDisplayValue = it.next(); paramOptions[i++] = new ParamOption(valueAndDisplayValue._1(), valueAndDisplayValue._2()); } return gui.select(name, "", paramOptions); } public void setGui(GUI o) { this.gui = o; } private void restartInterpreter() { } public InterpreterContext getInterpreterContext() { return interpreterContext; } public void setInterpreterContext(InterpreterContext interpreterContext) { this.interpreterContext = interpreterContext; } public void setMaxResult(int maxResult) { this.maxResult = maxResult; } /** * show DataFrame or SchemaRDD * @param o DataFrame or SchemaRDD object */ public void show(Object o) { show(o, maxResult); } /** * show DataFrame or SchemaRDD * @param o DataFrame or SchemaRDD object * @param maxResult maximum number of rows to display */ public void show(Object o, int maxResult) { Class cls = null; try { cls = this.getClass().forName("org.apache.spark.sql.DataFrame"); } catch (ClassNotFoundException e) { } if (cls == null) { try { cls = this.getClass().forName("org.apache.spark.sql.SchemaRDD"); } catch (ClassNotFoundException e) { } } if (cls == null) { throw new InterpreterException("Can not road DataFrame/SchemaRDD class"); } if (cls.isInstance(o)) { out.print(showDF(sc, interpreterContext, o, maxResult)); } else { out.print(o.toString()); } } public static String showDF(ZeppelinContext z, Object df) { return showDF(z.sc, z.interpreterContext, df, z.maxResult); } public static String showDF(SparkContext sc, InterpreterContext interpreterContext, Object df, int maxResult) { Object[] rows = null; Method take; String jobGroup = "zeppelin-" + interpreterContext.getParagraphId(); sc.setJobGroup(jobGroup, "Zeppelin", false); try { take = df.getClass().getMethod("take", int.class); rows = (Object[]) take.invoke(df, maxResult + 1); } catch (NoSuchMethodException | SecurityException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | ClassCastException e) { sc.clearJobGroup(); throw new InterpreterException(e); } String msg = null; // get field names Method queryExecution; QueryExecution qe; try { queryExecution = df.getClass().getMethod("queryExecution"); qe = (QueryExecution) queryExecution.invoke(df); } catch (NoSuchMethodException | SecurityException | IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { throw new InterpreterException(e); } List<Attribute> columns = scala.collection.JavaConverters.asJavaListConverter( qe.analyzed().output()).asJava(); for (Attribute col : columns) { if (msg == null) { msg = col.name(); } else { msg += "\t" + col.name(); } } msg += "\n"; // ArrayType, BinaryType, BooleanType, ByteType, DecimalType, DoubleType, DynamicType, // FloatType, FractionalType, IntegerType, IntegralType, LongType, MapType, NativeType, // NullType, NumericType, ShortType, StringType, StructType try { for (int r = 0; r < maxResult && r < rows.length; r++) { Object row = rows[r]; Method isNullAt = row.getClass().getMethod("isNullAt", int.class); Method apply = row.getClass().getMethod("apply", int.class); for (int i = 0; i < columns.size(); i++) { if (!(Boolean) isNullAt.invoke(row, i)) { msg += apply.invoke(row, i).toString(); } else { msg += "null"; } if (i != columns.size() - 1) { msg += "\t"; } } msg += "\n"; } } catch (NoSuchMethodException | SecurityException | IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { throw new InterpreterException(e); } if (rows.length > maxResult) { msg += "\n<font color=red>Results are limited by " + maxResult + ".</font>"; } sc.clearJobGroup(); return "%table " + msg; } /** * Run paragraph by id * @param id */ public void run(String id) { run(id, interpreterContext); } /** * Run paragraph by id * @param id * @param context */ public void run(String id, InterpreterContext context) { if (id.equals(context.getParagraphId())) { throw new InterpreterException("Can not run current Paragraph"); } for (InterpreterContextRunner r : context.getRunners()) { if (id.equals(r.getParagraphId())) { r.run(); return; } } throw new InterpreterException("Paragraph " + id + " not found"); } /** * Run paragraph at idx * @param idx */ public void run(int idx) { run(idx, interpreterContext); } /** * Run paragraph at index * @param idx index starting from 0 * @param context interpreter context */ public void run(int idx, InterpreterContext context) { if (idx >= context.getRunners().size()) { throw new InterpreterException("Index out of bound"); } InterpreterContextRunner runner = context.getRunners().get(idx); if (runner.getParagraphId().equals(context.getParagraphId())) { throw new InterpreterException("Can not run current Paragraph"); } runner.run(); } public void run(List<Object> paragraphIdOrIdx) { run(paragraphIdOrIdx, interpreterContext); } /** * Run paragraphs * @param paragraphIdOrIdxs list of paragraph id or idx */ public void run(List<Object> paragraphIdOrIdx, InterpreterContext context) { for (Object idOrIdx : paragraphIdOrIdx) { if (idOrIdx instanceof String) { String id = (String) idOrIdx; run(id, context); } else if (idOrIdx instanceof Integer) { Integer idx = (Integer) idOrIdx; run(idx, context); } else { throw new InterpreterException("Paragraph " + idOrIdx + " not found"); } } } public void runAll() { runAll(interpreterContext); } /** * Run all paragraphs. except this. */ public void runAll(InterpreterContext context) { for (InterpreterContextRunner r : context.getRunners()) { if (r.getParagraphId().equals(context.getParagraphId())) { // skip itself continue; } r.run(); } } public List<String> listParagraphs() { List<String> paragraphs = new LinkedList<String>(); for (InterpreterContextRunner r : interpreterContext.getRunners()) { paragraphs.add(r.getParagraphId()); } return paragraphs; } private AngularObject getAngularObject(String name, InterpreterContext interpreterContext) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); String noteId = interpreterContext.getNoteId(); // try get local object AngularObject ao = registry.get(name, interpreterContext.getNoteId()); if (ao == null) { // then global object ao = registry.get(name, null); } return ao; } /** * Get angular object. Look up local registry first and then global registry * @param name variable name * @return value */ public Object angular(String name) { AngularObject ao = getAngularObject(name, interpreterContext); if (ao == null) { return null; } else { return ao.get(); } } /** * Get angular object. Look up global registry * @param name variable name * @return value */ public Object angularGlobal(String name) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); AngularObject ao = registry.get(name, null); if (ao == null) { return null; } else { return ao.get(); } } /** * Create angular variable in local registry and bind with front end Angular display system. * If variable exists, it'll be overwritten. * @param name name of the variable * @param o value */ public void angularBind(String name, Object o) { angularBind(name, o, interpreterContext.getNoteId()); } /** * Create angular variable in global registry and bind with front end Angular display system. * If variable exists, it'll be overwritten. * @param name name of the variable * @param o value */ public void angularBindGlobal(String name, Object o) { angularBind(name, o, (String) null); } /** * Create angular variable in local registry and bind with front end Angular display system. * If variable exists, value will be overwritten and watcher will be added. * @param name name of variable * @param o value * @param watcher watcher of the variable */ public void angularBind(String name, Object o, AngularObjectWatcher watcher) { angularBind(name, o, interpreterContext.getNoteId(), watcher); } /** * Create angular variable in global registry and bind with front end Angular display system. * If variable exists, value will be overwritten and watcher will be added. * @param name name of variable * @param o value * @param watcher watcher of the variable */ public void angularBindGlobal(String name, Object o, AngularObjectWatcher watcher) { angularBind(name, o, null, watcher); } /** * Add watcher into angular variable (local registry) * @param name name of the variable * @param watcher watcher */ public void angularWatch(String name, AngularObjectWatcher watcher) { angularWatch(name, interpreterContext.getNoteId(), watcher); } /** * Add watcher into angular variable (global registry) * @param name name of the variable * @param watcher watcher */ public void angularWatchGlobal(String name, AngularObjectWatcher watcher) { angularWatch(name, null, watcher); } public void angularWatch(String name, final scala.Function2<Object, Object, Unit> func) { angularWatch(name, interpreterContext.getNoteId(), func); } public void angularWatchGlobal(String name, final scala.Function2<Object, Object, Unit> func) { angularWatch(name, null, func); } public void angularWatch( String name, final scala.Function3<Object, Object, InterpreterContext, Unit> func) { angularWatch(name, interpreterContext.getNoteId(), func); } public void angularWatchGlobal( String name, final scala.Function3<Object, Object, InterpreterContext, Unit> func) { angularWatch(name, null, func); } /** * Remove watcher from angular variable (local) * @param name * @param watcher */ public void angularUnwatch(String name, AngularObjectWatcher watcher) { angularUnwatch(name, interpreterContext.getNoteId(), watcher); } /** * Remove watcher from angular variable (global) * @param name * @param watcher */ public void angularUnwatchGlobal(String name, AngularObjectWatcher watcher) { angularUnwatch(name, null, watcher); } /** * Remove all watchers for the angular variable (local) * @param name */ public void angularUnwatch(String name) { angularUnwatch(name, interpreterContext.getNoteId()); } /** * Remove all watchers for the angular variable (global) * @param name */ public void angularUnwatchGlobal(String name) { angularUnwatch(name, (String) null); } /** * Remove angular variable and all the watchers. * @param name */ public void angularUnbind(String name) { String noteId = interpreterContext.getNoteId(); angularUnbind(name, noteId); } /** * Remove angular variable and all the watchers. * @param name */ public void angularUnbindGlobal(String name) { angularUnbind(name, null); } /** * Create angular variable in local registry and bind with front end Angular display system. * If variable exists, it'll be overwritten. * @param name name of the variable * @param o value */ private void angularBind(String name, Object o, String noteId) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); if (registry.get(name, noteId) == null) { registry.add(name, o, noteId); } else { registry.get(name, noteId).set(o); } } /** * Create angular variable in local registry and bind with front end Angular display system. * If variable exists, value will be overwritten and watcher will be added. * @param name name of variable * @param o value * @param watcher watcher of the variable */ private void angularBind(String name, Object o, String noteId, AngularObjectWatcher watcher) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); if (registry.get(name, noteId) == null) { registry.add(name, o, noteId); } else { registry.get(name, noteId).set(o); } angularWatch(name, watcher); } /** * Add watcher into angular binding variable * @param name name of the variable * @param watcher watcher */ private void angularWatch(String name, String noteId, AngularObjectWatcher watcher) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); if (registry.get(name, noteId) != null) { registry.get(name, noteId).addWatcher(watcher); } } private void angularWatch(String name, String noteId, final scala.Function2<Object, Object, Unit> func) { AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) { @Override public void watch(Object oldObject, Object newObject, InterpreterContext context) { func.apply(newObject, newObject); } }; angularWatch(name, noteId, w); } private void angularWatch( String name, String noteId, final scala.Function3<Object, Object, InterpreterContext, Unit> func) { AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) { @Override public void watch(Object oldObject, Object newObject, InterpreterContext context) { func.apply(oldObject, newObject, context); } }; angularWatch(name, noteId, w); } /** * Remove watcher * @param name * @param watcher */ private void angularUnwatch(String name, String noteId, AngularObjectWatcher watcher) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); if (registry.get(name, noteId) != null) { registry.get(name, noteId).removeWatcher(watcher); } } /** * Remove all watchers for the angular variable * @param name */ private void angularUnwatch(String name, String noteId) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); if (registry.get(name, noteId) != null) { registry.get(name, noteId).clearAllWatchers(); } } /** * Remove angular variable and all the watchers. * @param name */ private void angularUnbind(String name, String noteId) { AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); registry.remove(name, noteId); } }