/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zeppelin.rest;
import static org.junit.Assert.assertEquals;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.zeppelin.interpreter.InterpreterSetting;
import org.apache.zeppelin.notebook.Note;
import org.apache.zeppelin.notebook.Paragraph;
import org.apache.zeppelin.scheduler.Job.Status;
import org.apache.zeppelin.server.ZeppelinServer;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import com.google.gson.Gson;
/**
* Test against spark cluster.
* Spark cluster is started by CI server using testing/startSparkCluster.sh
*/
public class ZeppelinSparkClusterTest extends AbstractTestRestApi {
Gson gson = new Gson();
@BeforeClass
public static void init() throws Exception {
AbstractTestRestApi.startUp();
}
@AfterClass
public static void destroy() throws Exception {
AbstractTestRestApi.shutDown();
}
private void waitForFinish(Paragraph p) {
while (p.getStatus() != Status.FINISHED
&& p.getStatus() != Status.ERROR
&& p.getStatus() != Status.ABORT) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
@Test
public void basicRDDTransformationAndActionTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote();
// run markdown paragraph, again
Paragraph p = note.addParagraph();
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark print(sc.parallelize(1 to 10).reduce(_ + _))");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("55", p.getResult().message());
ZeppelinServer.notebook.removeNote(note.id());
}
@Test
public void pySparkTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote();
int sparkVersion = getSparkVersionNumber(note);
if (isPyspark() && sparkVersion >= 12) { // pyspark supported from 1.2.1
// run markdown paragraph, again
Paragraph p = note.addParagraph();
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark print(sc.parallelize(range(1, 11)).reduce(lambda a, b: a + b))");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("55\n", p.getResult().message());
}
ZeppelinServer.notebook.removeNote(note.id());
}
@Test
public void pySparkAutoConvertOptionTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote();
int sparkVersion = getSparkVersionNumber(note);
if (isPyspark() && sparkVersion >= 14) { // auto_convert enabled from spark 1.4
// run markdown paragraph, again
Paragraph p = note.addParagraph();
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark\nfrom pyspark.sql.functions import *\n"
+ "print(sqlContext.range(0, 10).withColumn('uniform', rand(seed=10) * 3.14).count())");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("10\n", p.getResult().message());
}
ZeppelinServer.notebook.removeNote(note.id());
}
@Test
public void zRunTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote();
Paragraph p0 = note.addParagraph();
Map config0 = p0.getConfig();
config0.put("enabled", true);
p0.setConfig(config0);
p0.setText("%spark z.run(1)");
Paragraph p1 = note.addParagraph();
Map config1 = p1.getConfig();
config1.put("enabled", true);
p1.setConfig(config1);
p1.setText("%spark val a=10");
Paragraph p2 = note.addParagraph();
Map config2 = p2.getConfig();
config2.put("enabled", true);
p2.setConfig(config2);
p2.setText("%spark print(a)");
note.run(p0.getId());
waitForFinish(p0);
assertEquals(Status.FINISHED, p0.getStatus());
note.run(p2.getId());
waitForFinish(p2);
assertEquals(Status.FINISHED, p2.getStatus());
assertEquals("10", p2.getResult().message());
ZeppelinServer.notebook.removeNote(note.id());
}
@Test
public void pySparkDepLoaderTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote();
if (isPyspark() && getSparkVersionNumber(note) >= 14) {
// restart spark interpreter
List<InterpreterSetting> settings =
ZeppelinServer.notebook.getBindedInterpreterSettings(note.id());
for (InterpreterSetting setting : settings) {
if (setting.getGroup().equals("spark")) {
ZeppelinServer.notebook.getInterpreterFactory().restart(setting.id());
break;
}
}
// load dep
Paragraph p0 = note.addParagraph();
Map config = p0.getConfig();
config.put("enabled", true);
p0.setConfig(config);
p0.setText("%dep z.load(\"com.databricks:spark-csv_2.11:1.2.0\")");
note.run(p0.getId());
waitForFinish(p0);
assertEquals(Status.FINISHED, p0.getStatus());
// write test csv file
File tmpFile = File.createTempFile("test", "csv");
FileUtils.write(tmpFile, "a,b\n1,2");
// load data using libraries from dep loader
Paragraph p1 = note.addParagraph();
p1.setConfig(config);
p1.setText("%pyspark\n" +
"from pyspark.sql import SQLContext\n" +
"print(sqlContext.read.format('com.databricks.spark.csv')" +
".load('"+ tmpFile.getAbsolutePath() +"').count())");
note.run(p1.getId());
waitForFinish(p1);
assertEquals(Status.FINISHED, p1.getStatus());
assertEquals("2\n", p1.getResult().message());
}
}
/**
* Get spark version number as a numerical value.
* eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ...
*/
private int getSparkVersionNumber(Note note) {
Paragraph p = note.addParagraph();
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark print(sc.version)");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
String sparkVersion = p.getResult().message();
System.out.println("Spark version detected " + sparkVersion);
String[] split = sparkVersion.split("\\.");
int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]);
return version;
}
}