/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hive.spark.client;
import com.google.common.collect.Lists;
import org.apache.hive.spark.client.JobHandle.Listener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.mockito.Mockito;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.Serializable;
import java.net.URI;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.jar.JarOutputStream;
import java.util.zip.ZipEntry;
import com.google.common.base.Objects;
import com.google.common.base.Strings;
import com.google.common.io.ByteStreams;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hive.spark.counter.SparkCounters;
import org.apache.spark.SparkException;
import org.apache.spark.SparkFiles;
import org.apache.spark.api.java.JavaFutureAction;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
public class TestSparkClient {
// Timeouts are bad... mmmkay.
private static final long TIMEOUT = 20;
private static final HiveConf HIVECONF = new HiveConf();
private Map<String, String> createConf(boolean local) {
Map<String, String> conf = new HashMap<String, String>();
if (local) {
conf.put(SparkClientFactory.CONF_KEY_IN_PROCESS, "true");
conf.put("spark.master", "local");
conf.put("spark.app.name", "SparkClientSuite Local App");
} else {
String classpath = System.getProperty("java.class.path");
conf.put("spark.master", "local");
conf.put("spark.app.name", "SparkClientSuite Remote App");
conf.put("spark.driver.extraClassPath", classpath);
conf.put("spark.executor.extraClassPath", classpath);
}
if (!Strings.isNullOrEmpty(System.getProperty("spark.home"))) {
conf.put("spark.home", System.getProperty("spark.home"));
}
return conf;
}
@Test
public void testJobSubmission() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
JobHandle.Listener<String> listener = newListener();
List<JobHandle.Listener<String>> listeners = Lists.newArrayList(listener);;
JobHandle<String> handle = client.submit(new SimpleJob(), listeners);
assertEquals("hello", handle.get(TIMEOUT, TimeUnit.SECONDS));
// Try an invalid state transition on the handle. This ensures that the actual state
// change we're interested in actually happened, since internally the handle serializes
// state changes.
assertFalse(((JobHandleImpl<String>)handle).changeState(JobHandle.State.SENT));
verify(listener).onJobStarted(handle);
verify(listener).onJobSucceeded(same(handle), eq(handle.get()));
}
});
}
@Test
public void testSimpleSparkJob() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
JobHandle<Long> handle = client.submit(new SparkJob());
assertEquals(Long.valueOf(5L), handle.get(TIMEOUT, TimeUnit.SECONDS));
}
});
}
@Test
public void testErrorJob() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
JobHandle.Listener<String> listener = newListener();
List<JobHandle.Listener<String>> listeners = Lists.newArrayList(listener);
JobHandle<String> handle = client.submit(new ErrorJob(), listeners);
try {
handle.get(TIMEOUT, TimeUnit.SECONDS);
fail("Should have thrown an exception.");
} catch (ExecutionException ee) {
assertTrue(ee.getCause() instanceof SparkException);
assertTrue(ee.getCause().getMessage().contains("IllegalStateException: Hello"));
}
// Try an invalid state transition on the handle. This ensures that the actual state
// change we're interested in actually happened, since internally the handle serializes
// state changes.
assertFalse(((JobHandleImpl<String>)handle).changeState(JobHandle.State.SENT));
verify(listener).onJobQueued(handle);
verify(listener).onJobStarted(handle);
verify(listener).onJobFailed(same(handle), any(Throwable.class));
}
});
}
@Test
public void testSyncRpc() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
Future<String> result = client.run(new SyncRpc());
assertEquals("Hello", result.get(TIMEOUT, TimeUnit.SECONDS));
}
});
}
@Test
public void testRemoteClient() throws Exception {
runTest(false, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
JobHandle<Long> handle = client.submit(new SparkJob());
assertEquals(Long.valueOf(5L), handle.get(TIMEOUT, TimeUnit.SECONDS));
}
});
}
@Test
public void testMetricsCollection() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
JobHandle.Listener<Integer> listener = newListener();
List<JobHandle.Listener<Integer>> listeners = Lists.newArrayList(listener);
JobHandle<Integer> future = client.submit(new AsyncSparkJob(), listeners);
future.get(TIMEOUT, TimeUnit.SECONDS);
MetricsCollection metrics = future.getMetrics();
assertEquals(1, metrics.getJobIds().size());
assertTrue(metrics.getAllMetrics().executorRunTime >= 0L);
verify(listener).onSparkJobStarted(same(future),
eq(metrics.getJobIds().iterator().next()));
JobHandle.Listener<Integer> listener2 = newListener();
List<JobHandle.Listener<Integer>> listeners2 = Lists.newArrayList(listener2);
JobHandle<Integer> future2 = client.submit(new AsyncSparkJob(), listeners2);
future2.get(TIMEOUT, TimeUnit.SECONDS);
MetricsCollection metrics2 = future2.getMetrics();
assertEquals(1, metrics2.getJobIds().size());
assertFalse(Objects.equal(metrics.getJobIds(), metrics2.getJobIds()));
assertTrue(metrics2.getAllMetrics().executorRunTime >= 0L);
verify(listener2).onSparkJobStarted(same(future2),
eq(metrics2.getJobIds().iterator().next()));
}
});
}
@Test
public void testAddJarsAndFiles() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
File jar = null;
File file = null;
try {
// Test that adding a jar to the remote context makes it show up in the classpath.
jar = File.createTempFile("test", ".jar");
JarOutputStream jarFile = new JarOutputStream(new FileOutputStream(jar));
jarFile.putNextEntry(new ZipEntry("test.resource"));
jarFile.write("test resource".getBytes("UTF-8"));
jarFile.closeEntry();
jarFile.close();
client.addJar(new URI("file:" + jar.getAbsolutePath()))
.get(TIMEOUT, TimeUnit.SECONDS);
// Need to run a Spark job to make sure the jar is added to the class loader. Monitoring
// SparkContext#addJar() doesn't mean much, we can only be sure jars have been distributed
// when we run a task after the jar has been added.
String result = client.submit(new JarJob()).get(TIMEOUT, TimeUnit.SECONDS);
assertEquals("test resource", result);
// Test that adding a file to the remote context makes it available to executors.
file = File.createTempFile("test", ".file");
FileOutputStream fileStream = new FileOutputStream(file);
fileStream.write("test file".getBytes("UTF-8"));
fileStream.close();
client.addJar(new URI("file:" + file.getAbsolutePath()))
.get(TIMEOUT, TimeUnit.SECONDS);
// The same applies to files added with "addFile". They're only guaranteed to be available
// to tasks started after the addFile() call completes.
result = client.submit(new FileJob(file.getName()))
.get(TIMEOUT, TimeUnit.SECONDS);
assertEquals("test file", result);
} finally {
if (jar != null) {
jar.delete();
}
if (file != null) {
file.delete();
}
}
}
});
}
@Test
public void testCounters() throws Exception {
runTest(true, new TestFunction() {
@Override
public void call(SparkClient client) throws Exception {
JobHandle<?> job = client.submit(new CounterIncrementJob());
job.get(TIMEOUT, TimeUnit.SECONDS);
SparkCounters counters = job.getSparkCounters();
assertNotNull(counters);
long expected = 1 + 2 + 3 + 4 + 5;
assertEquals(expected, counters.getCounter("group1", "counter1").getValue());
assertEquals(expected, counters.getCounter("group2", "counter2").getValue());
}
});
}
private static final Logger LOG = LoggerFactory.getLogger(TestSparkClient.class);
private <T extends Serializable> JobHandle.Listener<T> newListener() {
@SuppressWarnings("unchecked")
JobHandle.Listener<T> listener = mock(JobHandle.Listener.class);
answerWhen(listener, "cancelled").onJobCancelled(Mockito.<JobHandle<T>>any());
answerWhen(listener, "queued").onJobQueued(Mockito.<JobHandle<T>>any());
answerWhen(listener, "started").onJobStarted(Mockito.<JobHandle<T>>any());
answerWhen(listener, "succeeded").onJobSucceeded(
Mockito.<JobHandle<T>>any(), Mockito.<T>any());
answerWhen(listener, "job started").onSparkJobStarted(
Mockito.<JobHandle<T>>any(), Mockito.anyInt());
Mockito.doAnswer(new Answer<Void>() {
public Void answer(InvocationOnMock invocation) throws Throwable {
@SuppressWarnings("rawtypes")
JobHandleImpl arg = ((JobHandleImpl)invocation.getArguments()[0]);
LOG.info("Job failed " + arg.getClientJobId(),
(Throwable)invocation.getArguments()[1]);
return null;
};
}).when(listener).onJobFailed(Mockito.<JobHandle<T>>any(), Mockito.<Throwable>any());
return listener;
}
protected <T extends Serializable> Listener<T> answerWhen(
Listener<T> listener, final String logStr) {
return Mockito.doAnswer(new Answer<Void>() {
public Void answer(InvocationOnMock invocation) throws Throwable {
@SuppressWarnings("rawtypes")
JobHandleImpl arg = ((JobHandleImpl)invocation.getArguments()[0]);
LOG.info("Job " + logStr + " " + arg.getClientJobId());
return null;
};
}).when(listener);
}
private void runTest(boolean local, TestFunction test) throws Exception {
Map<String, String> conf = createConf(local);
SparkClientFactory.initialize(conf);
SparkClient client = null;
try {
test.config(conf);
client = SparkClientFactory.createClient(conf, HIVECONF);
test.call(client);
} finally {
if (client != null) {
client.stop();
}
SparkClientFactory.stop();
}
}
private static class SimpleJob implements Job<String> {
@Override
public String call(JobContext jc) {
return "hello";
}
}
private static class ErrorJob implements Job<String> {
@Override
public String call(JobContext jc) {
throw new IllegalStateException("Hello");
}
}
private static class SparkJob implements Job<Long> {
@Override
public Long call(JobContext jc) {
JavaRDD<Integer> rdd = jc.sc().parallelize(Arrays.asList(1, 2, 3, 4, 5));
return rdd.count();
}
}
private static class AsyncSparkJob implements Job<Integer> {
@Override
public Integer call(JobContext jc) throws Exception {
JavaRDD<Integer> rdd = jc.sc().parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaFutureAction<?> future = jc.monitor(rdd.foreachAsync(new VoidFunction<Integer>() {
@Override
public void call(Integer l) throws Exception {
}
}), null, null);
future.get(TIMEOUT, TimeUnit.SECONDS);
return 1;
}
}
private static class JarJob implements Job<String>, Function<Integer, String> {
@Override
public String call(JobContext jc) {
return jc.sc().parallelize(Arrays.asList(1)).map(this).collect().get(0);
}
@Override
public String call(Integer i) throws Exception {
ClassLoader ccl = Thread.currentThread().getContextClassLoader();
InputStream in = ccl.getResourceAsStream("test.resource");
byte[] bytes = ByteStreams.toByteArray(in);
in.close();
return new String(bytes, 0, bytes.length, "UTF-8");
}
}
private static class FileJob implements Job<String>, Function<Integer, String> {
private final String fileName;
FileJob(String fileName) {
this.fileName = fileName;
}
@Override
public String call(JobContext jc) {
return jc.sc().parallelize(Arrays.asList(1)).map(this).collect().get(0);
}
@Override
public String call(Integer i) throws Exception {
InputStream in = new FileInputStream(SparkFiles.get(fileName));
byte[] bytes = ByteStreams.toByteArray(in);
in.close();
return new String(bytes, 0, bytes.length, "UTF-8");
}
}
private static class CounterIncrementJob implements Job<String>, VoidFunction<Integer> {
private SparkCounters counters;
@Override
public String call(JobContext jc) {
counters = new SparkCounters(jc.sc());
counters.createCounter("group1", "counter1");
counters.createCounter("group2", "counter2");
jc.monitor(jc.sc().parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).foreachAsync(this),
counters, null);
return null;
}
@Override
public void call(Integer l) throws Exception {
counters.getCounter("group1", "counter1").increment(l.longValue());
counters.getCounter("group2", "counter2").increment(l.longValue());
}
}
private static class SyncRpc implements Job<String> {
@Override
public String call(JobContext jc) {
return "Hello";
}
}
private abstract static class TestFunction {
abstract void call(SparkClient client) throws Exception;
void config(Map<String, String> conf) { }
}
}