/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.spark; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; import org.junit.After; import org.junit.Before; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Random; /** * Created by agibsonccc on 1/23/15. */ public abstract class BaseSparkTest implements Serializable { protected transient JavaSparkContext sc; protected transient INDArray labels; protected transient INDArray input; protected transient INDArray rowSums; protected transient int nRows = 200; protected transient int nIn = 4; protected transient int nOut = 3; protected transient DataSet data; protected transient JavaRDD<DataSet> sparkData; @Before public void before() { sc = getContext(); Random r = new Random(12345); labels = Nd4j.create(nRows, nOut); input = Nd4j.rand(nRows, nIn); rowSums = input.sum(1); input.diviColumnVector(rowSums); for (int i = 0; i < nRows; i++) { int x1 = r.nextInt(nOut); labels.putScalar(new int[] {i, x1}, 1.0); } sparkData = getBasicSparkDataSet(nRows, input, labels); } @After public void after() { sc.close(); sc = null; } /** * * @return */ public JavaSparkContext getContext() { if (sc != null) return sc; // set to test mode SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest"); sc = new JavaSparkContext(sparkConf); return sc; } protected JavaRDD<DataSet> getBasicSparkDataSet(int nRows, INDArray input, INDArray labels) { List<DataSet> list = new ArrayList<>(); for (int i = 0; i < nRows; i++) { INDArray inRow = input.getRow(i).dup(); INDArray outRow = labels.getRow(i).dup(); DataSet ds = new DataSet(inRow, outRow); list.add(ds); } list.iterator(); data = new DataSet().merge(list); return sc.parallelize(list); } protected SparkDl4jMultiLayer getBasicNetwork() { return new SparkDl4jMultiLayer(sc, getBasicConf(), new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0)); } protected int numExecutors() { int numProc = Runtime.getRuntime().availableProcessors(); return Math.min(4, numProc); } protected MultiLayerConfiguration getBasicConf() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).updater(Updater.NESTEROVS) .learningRate(0.1).momentum(0.9).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) .activation(Activation.TANH).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).nIn(3).nOut(nOut) .activation(Activation.SOFTMAX).build()) .backprop(true).pretrain(false).build(); return conf; } }