/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.rbm.training; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapreduce.Mapper; import org.apache.mahout.classifier.rbm.RBMClassifier; import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixWritable; import org.apache.mahout.math.VectorWritable; /** * The Class DBMBackPropTrainingMapper for backpropagation training. */ public class DBMBackPropTrainingMapper extends Mapper<IntWritable, VectorWritable, IntWritable, MatrixWritable>{ /** * The Enum BATCHES. */ static enum BATCHES { /** The SIZE. */ SIZE } /** The dbm. */ DeepBoltzmannMachine dbm; /** The learningrate. */ double learningrate; /** The label. */ private DenseVector label; /* (non-Javadoc) * @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context) */ protected void setup(Context context) throws java.io.IOException ,InterruptedException { Configuration conf = context.getConfiguration(); Path p = HadoopUtil.cachedFile(conf); dbm = RBMClassifier.materialize(p, conf).initializeMultiLayerNN(); learningrate = Double.parseDouble(conf.get("learningrate")); Integer count = Integer.parseInt(conf.get("labelcount")); label = new DenseVector(count); }; /* (non-Javadoc) * @see org.apache.hadoop.mapreduce.Mapper#map(KEYIN, VALUEIN, org.apache.hadoop.mapreduce.Mapper.Context) */ protected void map(IntWritable key, VectorWritable value, Context context) throws java.io.IOException ,InterruptedException { for (int i = 0; i < label.size(); i++) label.setQuick(i, 0); label.set(key.get(), 1); BackPropTrainer trainer = new BackPropTrainer(learningrate); Matrix[] result = trainer.calculateWeightUpdates(dbm, value.get(), label); context.getCounter(BATCHES.SIZE).increment(1); //write for each RBM i (key, number of rbm) the result and put together the last two //matrices since they refer to just one labeled rbm, which was split to two for the training for (int i = 0; i < result.length-1; i++) { if(i==result.length-2) { Matrix updates = new DenseMatrix(result[i].rowSize()+result[i+1].columnSize(), result[i].columnSize()); for(int j = 0; j<updates.rowSize(); j++) for(int k = 0; k<updates.columnSize(); k++) { if(j<result[i].rowSize()) updates.set(j, k, result[i].get(j, k)); else updates.set(j, k, result[i+1].get(k, j-result[i].rowSize())); } context.write(new IntWritable(i), new MatrixWritable(updates)); } else context.write(new IntWritable(i), new MatrixWritable(result[i])); } }; }