/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg;
import org.junit.After;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
/**
* Base Nd4j test
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public abstract class BaseNd4jTest {
private static Logger log = LoggerFactory.getLogger(BaseNd4jTest.class);
protected Nd4jBackend backend;
protected String name;
public final static String DEFAULT_BACKEND = "org.nd4j.linalg.defaultbackend";
public BaseNd4jTest() {
this("", getDefaultBackend());
}
public BaseNd4jTest(String name) {
this(name, getDefaultBackend());
}
public BaseNd4jTest(String name, Nd4jBackend backend) {
this.backend = backend;
this.name = name;
}
public BaseNd4jTest(Nd4jBackend backend) {
this(backend.getClass().getName() + UUID.randomUUID().toString(), backend);
}
private static List<Nd4jBackend> backends;
static {
ServiceLoader<Nd4jBackend> loadedBackends = ServiceLoader.load(Nd4jBackend.class);
Iterator<Nd4jBackend> backendIterator = loadedBackends.iterator();
backends = new ArrayList<>();
List<String> backendsToRun = Nd4jTestSuite.backendsToRun();
while (backendIterator.hasNext()) {
Nd4jBackend backend = backendIterator.next();
if (backend.canRun() && backendsToRun.contains(backend.getClass().getName()) || backendsToRun.isEmpty())
backends.add(backend);
}
}
@Parameterized.Parameters(name = "{index}: backend({0})={1}")
public static Collection<Object[]> configs() {
List<Object[]> ret = new ArrayList<>();
for (Nd4jBackend backend : backends)
ret.add(new Object[] {backend});
return ret;
}
/**
* Get the default backend (jblas)
* The default backend can be overridden by also passing:
* -Dorg.nd4j.linalg.defaultbackend=your.backend.classname
* @return the default backend based on the
* given command line arguments
*/
public static Nd4jBackend getDefaultBackend() {
String cpuBackend = "org.nd4j.linalg.cpu.nativecpu.CpuBackend";
//String cpuBackend = "org.nd4j.linalg.cpu.CpuBackend";
String gpuBackend = "org.nd4j.linalg.jcublas.JCublasBackend";
String clazz = System.getProperty(DEFAULT_BACKEND, cpuBackend);
try {
return (Nd4jBackend) Class.forName(clazz).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Before
public void before() throws Exception {
log.info("Running " + getClass().getName() + " on backend " + backend.getClass().getName());
Nd4j nd4j = new Nd4j();
nd4j.initWithBackend(backend);
Nd4j.factory().setOrder(ordering());
Nd4j.MAX_ELEMENTS_PER_SLICE = -1;
Nd4j.MAX_SLICES_TO_PRINT = -1;
}
@After
public void after() throws Exception {
log.info("Ending " + getClass().getName());
if (System.getProperties().getProperty("backends") != null
&& !System.getProperty("backends").contains(backend.getClass().getName()))
return;
Nd4j nd4j = new Nd4j();
nd4j.initWithBackend(backend);
Nd4j.factory().setOrder(ordering());
Nd4j.MAX_ELEMENTS_PER_SLICE = -1;
Nd4j.MAX_SLICES_TO_PRINT = -1;
}
/**
* The ordering for this test
* This test will only be invoked for
* the given test and ignored for others
*
* @return the ordering for this test
*/
public char ordering() {
return 'a';
}
public String getFailureMessage() {
return "Failed with backend " + backend.getClass().getName() + " and ordering " + ordering();
}
}