package org.apache.mahout.classifier.rbm;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM;
import org.apache.mahout.classifier.rbm.network.DBMStateIterator;
import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine;
import org.apache.mahout.classifier.rbm.test.TestRBMClassifierJob;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.Matrix;
public class TestHardWiring extends AbstractJob{
public static void main(String[] args) {
new TestHardWiring().run(new String[]{"/home/dirk/models/model_440chunks_nofine",
"/home/dirk/models/experimentalModel",
"/home/dirk/mnist/chunks9",
"/home/dirk/mnist/chunks9/chunk0"});
}
public int run(String[] args) {
setConf(new Configuration());
String inputPath = args[0];
String outputPath = args[1];
String trainingDataPath = args[2];
String testDataPath = args[3];
try {
return transformLayer(new Path(inputPath),
new Path(outputPath),
new Path(trainingDataPath),
new Path(testDataPath));
} catch (Exception e) {
e.printStackTrace();
}
return -1;
}
public int transformLayer(Path input, Path output, Path trainingData, Path testData) throws Exception {
RBMClassifier rbmCl = RBMClassifier.materialize(input, getConf());
DeepBoltzmannMachine dbm = rbmCl.getDbm();
FileSystem fs = input.getFileSystem(getConf());
Path[] batches;
if(fs.isFile(trainingData))
batches = new Path[]{trainingData};
else {
FileStatus[] stati = fs.listStatus(trainingData);
batches = new Path[stati.length];
for (int i = 0; i < stati.length; i++) {
batches[i] = stati[i].getPath();
}
}
Vector[] probs = new Vector[10];
for (int i = 0; i < probs.length; i++) {
probs[i] = new DenseVector(dbm.getLayer(dbm.getRbmCount()).getNeuronCount());
}
int[] counter = new int[10];
int count = 0;
for (int i = 0; i < batches.length; i++) {
SequenceFileIterable<IntWritable, VectorWritable> dirIterable =
new SequenceFileIterable<IntWritable, VectorWritable>( batches[i], getConf());
for (Pair<IntWritable, VectorWritable> record : dirIterable) {
dbm.getLayer(0).setActivations(record.getSecond().get());
dbm.upPass();
DBMStateIterator.iterateUntilStableLayer(dbm.getLayer(0), dbm, 3);
probs[record.getFirst().get()]=probs[record.getFirst().get()].plus(dbm.getLayer(dbm.getRbmCount()).getActivations());
counter[record.getFirst().get()]++;
count++;
if(count%1000==0)
System.out.println(count);
}
}
Vector total = null;
for (int i = 0; i < counter.length; i++) {
if(total==null)
total=probs[i].clone();
else
total=total.plus(probs[i]);
probs[i] = probs[i].divide(counter[i]);
}
total = total.divide(count);
Vector logTotal = total.clone();
Vector negativeLogTotal = total.clone();
double log2= Math.log(2);
for (int i = 0; i < total.size(); i++) {
logTotal.set(i, Math.log(total.get(i))/log2);
negativeLogTotal.set(i, Math.log(1-total.get(i))/log2);
}
LabeledSimpleRBM lrbm = (LabeledSimpleRBM)dbm.getRBM(dbm.getRbmCount()-1);
Vector biases = new DenseVector(10);
Matrix weights = lrbm.getWeightLabelMatrix().clone();
for (int i = 0; i < probs.length; i++) {
Vector negInformation = probs[i].times(-1).plus(1).
plus(total.plus(-1)).
times(negativeLogTotal);
Vector posInformation = probs[i].
minus(total).
times(logTotal);
biases.set(i, negInformation.zSum());
weights.assignRow(i, posInformation.minus(negInformation));
}
lrbm.getSoftmaxLayer().setBiases(biases);
lrbm.setWeightLabelMatrix(weights);
rbmCl.serialize(output, getConf());
TestRBMClassifierJob tester = new TestRBMClassifierJob();
tester.run(new String[]{"-m",output.toUri().getPath(),
"-labelcount","10",
"-i",testData.toUri().getPath()});
return 0;
}
}