package org.deeplearning4j.keras;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import org.apache.commons.io.FileUtils;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.reflections.Reflections;
import org.reflections.scanners.ResourcesScanner;
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Set;
import java.util.regex.Pattern;
import static org.deeplearning4j.keras.StringsEndsWithPredicate.endsWith;
public class DeepLearning4jEntryPointTest {
private DeepLearning4jEntryPoint deepLearning4jEntryPoint = new DeepLearning4jEntryPoint();
public @Rule ExpectedException thrown = ExpectedException.none();
@Test
public void shouldFitTheSampleSequentialModel() throws Exception {
// Given
final File model = prepareResource("theano_mnist/model.h5");
final Path features = prepareFeatures("theano_mnist");
final Path labels = prepareLabels("theano_mnist");
EntryPointFitParameters entryPointParameters =
EntryPointFitParameters.builder().modelFilePath(model.getAbsolutePath())
.trainFeaturesDirectory(features.toAbsolutePath().toString())
.trainLabelsDirectory(labels.toAbsolutePath().toString()).batchSize(128)
.nbEpoch(2).type(KerasModelType.SEQUENTIAL).build();
// When
deepLearning4jEntryPoint.fit(entryPointParameters);
// Then
// fall through - the rule will fail the test execution if an exception is thrown
}
private Path prepareLabels(String dataSet) throws IOException {
return prepareDataSet(dataSet, "labels");
}
private Path prepareFeatures(String dataSet) throws IOException {
return prepareDataSet(dataSet, "features");
}
private Path prepareDataSet(String dataSet, String batchesSubset) throws IOException {
Set<String> batchFiles = listBatchFiles(dataSet + "/" + batchesSubset);
final Path tempDirectory = Files.createTempDirectory("dl4j-" + batchesSubset);
for (String batchFile : batchFiles) {
String batchFileName = Paths.get(batchFile).getFileName().toString();
copyClasspathResourceToFile(batchFile, targetFile(tempDirectory, batchFileName));
}
return tempDirectory;
}
private File targetFile(Path tempFeaturesDirectory, String batchFileName) {
return Paths.get(tempFeaturesDirectory.toFile().getAbsolutePath().toString(), batchFileName).toFile();
}
private Set<String> listBatchFiles(String classpathDirectory) {
return new Reflections(classpathDirectory, new ResourcesScanner()).getResources(endsWith(".h5"));
}
private File prepareResource(String resourceName) throws IOException {
File file = File.createTempFile("dl4j", "test");
file.delete();
file.deleteOnExit();
copyClasspathResourceToFile(resourceName, file);
return file;
}
private void copyClasspathResourceToFile(String resourceName, File file) throws IOException {
FileUtils.copyInputStreamToFile(this.getClass().getClassLoader().getResourceAsStream(resourceName), file);
}
}