/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.spark;
import com.google.common.collect.Lists;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.zeppelin.annotation.ZeppelinApi;
import org.apache.zeppelin.display.AngularObjectWatcher;
import org.apache.zeppelin.display.Input;
import org.apache.zeppelin.display.ui.OptionInput;
import org.apache.zeppelin.interpreter.*;
import scala.Tuple2;
import scala.Unit;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import static scala.collection.JavaConversions.asJavaCollection;
import static scala.collection.JavaConversions.asJavaIterable;
import static scala.collection.JavaConversions.collectionAsScalaIterable;
/**
* ZeppelinContext for Spark
*/
public class SparkZeppelinContext extends BaseZeppelinContext {
private SparkContext sc;
public SQLContext sqlContext;
private List<Class> supportedClasses;
private Map<String, String> interpreterClassMap;
public SparkZeppelinContext(
SparkContext sc, SQLContext sql,
InterpreterHookRegistry hooks,
int maxResult) {
super(hooks, maxResult);
this.sc = sc;
this.sqlContext = sql;
interpreterClassMap = new HashMap<String, String>();
interpreterClassMap.put("spark", "org.apache.zeppelin.spark.SparkInterpreter");
interpreterClassMap.put("sql", "org.apache.zeppelin.spark.SparkSqlInterpreter");
interpreterClassMap.put("dep", "org.apache.zeppelin.spark.DepInterpreter");
interpreterClassMap.put("pyspark", "org.apache.zeppelin.spark.PySparkInterpreter");
this.supportedClasses = new ArrayList<>();
try {
supportedClasses.add(this.getClass().forName("org.apache.spark.sql.Dataset"));
} catch (ClassNotFoundException e) {
}
try {
supportedClasses.add(this.getClass().forName("org.apache.spark.sql.DataFrame"));
} catch (ClassNotFoundException e) {
}
try {
supportedClasses.add(this.getClass().forName("org.apache.spark.sql.SchemaRDD"));
} catch (ClassNotFoundException e) {
}
if (supportedClasses.isEmpty()) {
throw new InterpreterException("Can not load Dataset/DataFrame/SchemaRDD class");
}
}
@Override
public List<Class> getSupportedClasses() {
return supportedClasses;
}
@Override
public Map<String, String> getInterpreterClassMap() {
return interpreterClassMap;
}
@Override
public String showData(Object df) {
Object[] rows = null;
Method take;
String jobGroup = Utils.buildJobGroupId(interpreterContext);
sc.setJobGroup(jobGroup, "Zeppelin", false);
try {
// convert it to DataFrame if it is Dataset, as we will iterate all the records
// and assume it is type Row.
if (df.getClass().getCanonicalName().equals("org.apache.spark.sql.Dataset")) {
Method convertToDFMethod = df.getClass().getMethod("toDF");
df = convertToDFMethod.invoke(df);
}
take = df.getClass().getMethod("take", int.class);
rows = (Object[]) take.invoke(df, maxResult + 1);
} catch (NoSuchMethodException | SecurityException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException | ClassCastException e) {
sc.clearJobGroup();
throw new InterpreterException(e);
}
List<Attribute> columns = null;
// get field names
try {
// Use reflection because of classname returned by queryExecution changes from
// Spark <1.5.2 org.apache.spark.sql.SQLContext$QueryExecution
// Spark 1.6.0> org.apache.spark.sql.hive.HiveContext$QueryExecution
Object qe = df.getClass().getMethod("queryExecution").invoke(df);
Object a = qe.getClass().getMethod("analyzed").invoke(qe);
scala.collection.Seq seq = (scala.collection.Seq) a.getClass().getMethod("output").invoke(a);
columns = (List<Attribute>) scala.collection.JavaConverters.seqAsJavaListConverter(seq)
.asJava();
} catch (NoSuchMethodException | SecurityException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException e) {
throw new InterpreterException(e);
}
StringBuilder msg = new StringBuilder();
msg.append("%table ");
for (Attribute col : columns) {
msg.append(col.name() + "\t");
}
String trim = msg.toString().trim();
msg = new StringBuilder(trim);
msg.append("\n");
// ArrayType, BinaryType, BooleanType, ByteType, DecimalType, DoubleType, DynamicType,
// FloatType, FractionalType, IntegerType, IntegralType, LongType, MapType, NativeType,
// NullType, NumericType, ShortType, StringType, StructType
try {
for (int r = 0; r < maxResult && r < rows.length; r++) {
Object row = rows[r];
Method isNullAt = row.getClass().getMethod("isNullAt", int.class);
Method apply = row.getClass().getMethod("apply", int.class);
for (int i = 0; i < columns.size(); i++) {
if (!(Boolean) isNullAt.invoke(row, i)) {
msg.append(apply.invoke(row, i).toString());
} else {
msg.append("null");
}
if (i != columns.size() - 1) {
msg.append("\t");
}
}
msg.append("\n");
}
} catch (NoSuchMethodException | SecurityException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException e) {
throw new InterpreterException(e);
}
if (rows.length > maxResult) {
msg.append("\n");
msg.append(ResultMessages.getExceedsLimitRowsMessage(maxResult,
SparkSqlInterpreter.MAX_RESULTS));
}
sc.clearJobGroup();
return msg.toString();
}
@ZeppelinApi
public Object select(String name, scala.collection.Iterable<Tuple2<Object, String>> options) {
return select(name, "", options);
}
@ZeppelinApi
public Object select(String name, Object defaultValue,
scala.collection.Iterable<Tuple2<Object, String>> options) {
return select(name, defaultValue, tuplesToParamOptions(options));
}
@ZeppelinApi
public scala.collection.Seq<Object> checkbox(
String name,
scala.collection.Iterable<Tuple2<Object, String>> options) {
List<Object> allChecked = new LinkedList<>();
for (Tuple2<Object, String> option : asJavaIterable(options)) {
allChecked.add(option._1());
}
return checkbox(name, collectionAsScalaIterable(allChecked), options);
}
@ZeppelinApi
public scala.collection.Seq<Object> checkbox(
String name,
scala.collection.Iterable<Object> defaultChecked,
scala.collection.Iterable<Tuple2<Object, String>> options) {
return scala.collection.JavaConversions.asScalaBuffer(
gui.checkbox(name, asJavaCollection(defaultChecked),
tuplesToParamOptions(options))).toSeq();
}
private OptionInput.ParamOption[] tuplesToParamOptions(
scala.collection.Iterable<Tuple2<Object, String>> options) {
int n = options.size();
OptionInput.ParamOption[] paramOptions = new OptionInput.ParamOption[n];
Iterator<Tuple2<Object, String>> it = asJavaIterable(options).iterator();
int i = 0;
while (it.hasNext()) {
Tuple2<Object, String> valueAndDisplayValue = it.next();
paramOptions[i++] = new OptionInput.ParamOption(valueAndDisplayValue._1(),
valueAndDisplayValue._2());
}
return paramOptions;
}
@ZeppelinApi
public void angularWatch(String name,
final scala.Function2<Object, Object, Unit> func) {
angularWatch(name, interpreterContext.getNoteId(), func);
}
@Deprecated
public void angularWatchGlobal(String name,
final scala.Function2<Object, Object, Unit> func) {
angularWatch(name, null, func);
}
@ZeppelinApi
public void angularWatch(
String name,
final scala.Function3<Object, Object, InterpreterContext, Unit> func) {
angularWatch(name, interpreterContext.getNoteId(), func);
}
@Deprecated
public void angularWatchGlobal(
String name,
final scala.Function3<Object, Object, InterpreterContext, Unit> func) {
angularWatch(name, null, func);
}
private void angularWatch(String name, String noteId,
final scala.Function2<Object, Object, Unit> func) {
AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) {
@Override
public void watch(Object oldObject, Object newObject,
InterpreterContext context) {
func.apply(newObject, newObject);
}
};
angularWatch(name, noteId, w);
}
private void angularWatch(
String name,
String noteId,
final scala.Function3<Object, Object, InterpreterContext, Unit> func) {
AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) {
@Override
public void watch(Object oldObject, Object newObject,
InterpreterContext context) {
func.apply(oldObject, newObject, context);
}
};
angularWatch(name, noteId, w);
}
}