package org.nd4j; import lombok.AllArgsConstructor; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.broadcast.Broadcast; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import scala.Tuple2; import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; /** * Created by Alex on 04/07/2016. */ public class TestNd4jKryoSerialization { private JavaSparkContext sc; @Before public void before() { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[*]"); sparkConf.setAppName("Iris"); sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); sc = new JavaSparkContext(sparkConf); } @Test public void testSerialization() { Tuple2<INDArray, INDArray> t2 = new Tuple2<>(Nd4j.linspace(1, 10, 10), Nd4j.linspace(10, 20, 10)); Broadcast<Tuple2<INDArray, INDArray>> b = sc.broadcast(t2); List<INDArray> list = new ArrayList<>(); for (int i = 0; i < 100; i++) { list.add(Nd4j.ones(5)); } JavaRDD<INDArray> rdd = sc.parallelize(list); rdd.foreach(new AssertFn(b)); } @After public void after() { if (sc != null) sc.close(); } @AllArgsConstructor public static class AssertFn implements VoidFunction<INDArray> { private Broadcast<Tuple2<INDArray, INDArray>> b; @Override public void call(INDArray arr) throws Exception { Tuple2<INDArray, INDArray> t2 = b.getValue(); assertEquals(Nd4j.linspace(1, 10, 10), t2._1()); assertEquals(Nd4j.linspace(10, 20, 10), t2._2()); assertEquals(Nd4j.ones(5), arr); } } }