package org.apache.mahout.classifier.rbm.training; import java.io.IOException; import java.util.Random; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.SequenceFile.Writer; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; import org.apache.mahout.classifier.naivebayes.training.WeightsMapper; import org.apache.mahout.classifier.rbm.layer.Layer; import org.apache.mahout.classifier.rbm.layer.LogisticLayer; import org.apache.mahout.classifier.rbm.layer.SoftmaxLayer; import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM; import org.apache.mahout.classifier.rbm.model.SimpleRBM; import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine; import org.apache.mahout.classifier.rbm.training.RBMClassifierTrainingJob; import org.apache.mahout.classifier.rbm.training.RBMGreedyPreTrainingMapper; import org.apache.mahout.classifier.rbm.training.RBMGreedyPreTrainingReducer; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.Pair; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixWritable; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.easymock.ConstructorArgs; import org.easymock.EasyMock; import org.junit.Test; import com.google.common.io.Closeables; public class rbmClassifierTest extends MahoutTestCase { private Path output; private Path input; private Configuration conf; private FileSystem fileSystem; private DeepBoltzmannMachine dbm; private Pair<Integer,Vector> testPair; @Override public void setUp() throws Exception { input = getTestTempFilePath("input"); output = getTestTempDirPath("output"); conf = new Configuration(); fileSystem = input.getFileSystem(conf); fileSystem.delete(output, true); fileSystem.mkdirs(input.getParent()); fileSystem.create(input, true); fileSystem.mkdirs(output); Writer writer = SequenceFile.createWriter(fileSystem, conf, input, IntWritable.class, VectorWritable.class); //TODO randomutils gave always the same value?? Random rand = new Random(); for (int i = 0; i < 4; i++) { VectorWritable vectorWritable = new VectorWritable( new DenseVector(new double[]{rand.nextDouble(), rand.nextDouble(), rand.nextDouble(), rand.nextDouble(), rand.nextDouble()})); if(i==0) testPair = new Pair<Integer, Vector>(0, vectorWritable.get()); writer.append(new IntWritable(i%2), vectorWritable); } Closeables.closeQuietly(writer); String[] layers = {"5","10","10"}; Layer[] layers2 = new Layer[layers.length]; layers2[0]= new LogisticLayer(Integer.parseInt(layers[0])); for (int i = 1; i < layers.length; i++) { layers2[i]= new LogisticLayer(Integer.parseInt(layers[i])); if(i==1) dbm = new DeepBoltzmannMachine(new SimpleRBM(layers2[0], layers2[1])); else if(i==layers.length-1) dbm.stackRBM(new LabeledSimpleRBM(layers2[i-1],layers2[i],new SoftmaxLayer(2))); else dbm.stackRBM(new SimpleRBM(layers2[i-1],layers2[i])); } dbm.serialize(output, conf); super.setUp(); } /*@Test public void testMapper() throws IOException, InterruptedException, NoSuchFieldException, IllegalAccessException { Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class); Vector instance1 = new DenseVector(new double[] { 1, 0, 0.5, 0.5, 0 }); RBMClassifierTrainingMapper mapper = new RBMClassifierTrainingMapper(); setField(mapper, "dbm", dbm); mapper.map(new VectorWritable(instance1), new VectorWritable(instance1), ctx); }*/ @Test public void testClassification() throws Exception { /*DeepBoltzmannMachine dbm = null; String[] layers = {"10","15","20"}; Layer layer2 = new LogisticLayer(Integer.parseInt(layers[0])); for (int i = 0; i < layers.length-1; i++) { Layer layer = layer2; layer2 = new LogisticLayer(Integer.parseInt(layers[i+1])); if(i==0) dbm = new DeepBoltzmannMachine(new SimpleRBM(layer, layer2)); else if(i==layers.length-2) dbm.stackRBM(new LabeledSimpleRBM(layer,layer2,new SoftmaxLayer(10))); else dbm.stackRBM(new SimpleRBM(layer, layer2)); } dbm.serialize(output, conf);*/ RBMClassifierTrainingJob job = new RBMClassifierTrainingJob(); job.setConf(conf); String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), input.toUri().getPath(), optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toUri().getPath(), //optKey("structure"), "5,5,5", optKey("labelcount"), "2" , optKey("maxIter"), "2"}; double errorBefore = dbm.getRBM(0).getReconstructionError(testPair.getSecond()); assertEquals(0, job.run(args)); DeepBoltzmannMachine dbm2 = DeepBoltzmannMachine.materialize(output, conf); double errorAfter = dbm2.getRBM(0).getReconstructionError(testPair.getSecond()); assertTrue(errorAfter<=errorBefore); } @Override public void tearDown() throws Exception { fileSystem.delete(input.getParent(), true); super.tearDown(); } }