/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.rbm.training;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM;
import org.apache.mahout.classifier.rbm.model.SimpleRBM;
import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
/**
* The Class RBMGreedyPreTrainingMapper.
*/
public class RBMGreedyPreTrainingMapper extends Mapper<IntWritable, VectorWritable, IntWritable, MatrixWritable> {
/**
* The Enum BATCH.
*/
static enum BATCH {
/** The SIZE of the batch. */
SIZE }
/** The dbm. */
DeepBoltzmannMachine dbm;
/** The learning rate. */
double learningRate;
/** The label. */
private Vector label;
/** The nr. */
private int nr;
/** The nr gibbs sampling. */
private int nrGibbsSampling;
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void setup(Context context) throws java.io.IOException ,InterruptedException {
Configuration conf = context.getConfiguration();
Path p = HadoopUtil.cachedFile(conf);
dbm = DeepBoltzmannMachine.materialize(p, conf);
learningRate = Double.parseDouble(conf.get("learningrate"));
nr = Integer.parseInt(conf.get("rbmNr"));
nrGibbsSampling = Integer.parseInt(conf.get("nrGibbsSampling"));
Integer count = Integer.parseInt(conf.get("labelcount"));
label = new RandomAccessSparseVector(count);
};
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#map(KEYIN, VALUEIN, org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void map(IntWritable key, VectorWritable value, Context context) throws java.io.IOException ,InterruptedException {
CDTrainer trainer = new CDTrainer(learningRate, nrGibbsSampling);
label.set(key.get(), 1);
dbm.getRBM(0).getVisibleLayer().setActivations(value.get());
for(int i = 0; i<nr; i++){
//double the bottom up connection for initialization
dbm.getRBM(i).exciteHiddenLayer(2, false);
if(i==nr-1)
//probabilities as activation for the data the rbm should train on
dbm.getRBM(i).getHiddenLayer().setProbabilitiesAsActivation();
else
dbm.getRBM(i).getHiddenLayer().updateNeurons();
}
context.getCounter(BATCH.SIZE).increment(1);
if(nr==dbm.getRbmCount()-1) {
((LabeledSimpleRBM)dbm.getRBM(nr)).getSoftmaxLayer().setActivations(label);
Matrix updates = trainer.calculateWeightUpdates((LabeledSimpleRBM)dbm.getRBM(nr), true, false);
context.write(new IntWritable(nr), new MatrixWritable(updates));
}
else {
Matrix updates = trainer.calculateWeightUpdates((SimpleRBM)dbm.getRBM(nr), false, nr==0);
context.write(new IntWritable(nr), new MatrixWritable(updates));
}
};
}