/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.rbm.training; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.classifier.rbm.RBMClassifier; import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM; import org.apache.mahout.classifier.rbm.model.RBMModel; import org.apache.mahout.classifier.rbm.model.SimpleRBM; import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.common.Pair; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.common.iterator.sequencefile.PathFilters; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixWritable; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * The Class RBMClassifierTrainingJob. */ public class RBMClassifierTrainingJob extends AbstractJob{ /** The Constant WEIGHT_UPDATES. */ public static final String WEIGHT_UPDATES = "weightupdates"; /** The Constant logger. */ private static final Logger logger = LoggerFactory.getLogger(RBMClassifierTrainingJob.class); /** The last update which is needed for use of the momentum. */ Matrix[] lastUpdate; /** The rbm classifier. */ RBMClassifier rbmCl=null; /** The number of iterations (epochs). */ int epochs; /** The learningrate. */ double learningrate; /** The momentum used. */ double momentum; /** monitor if true. */ boolean monitor; /** initial biases if true. */ boolean initbiases; /** train greedy if true. */ boolean greedy; /** finetune if true. */ boolean finetuning; /** The batches to train on. */ Path[] batches = null; /** The labelcount. */ int labelcount; /** The nr gibbs sampling. */ int nrGibbsSampling; /** The rbm nr to train. */ int rbmNrtoTrain; /** * The main method. * * @param args the arguments * @throws Exception the exception */ public static void main(String[] args) throws Exception { ToolRunner.run(new Configuration(), new RBMClassifierTrainingJob(), args); } /* (non-Javadoc) * @see org.apache.hadoop.util.Tool#run(java.lang.String[]) */ @Override public int run(String[] args) throws Exception { addInputOption(); addOutputOption(); addOption("epochs","e","number of training epochs through the trainingset",true); addOption("structure", "s", "comma-separated list of layer sizes", false); addOption("labelcount", "lc", "total count of labels existent in the training set", true); addOption("learningrate", "lr", "learning rate at the beginning of training", "0.005"); addOption("momentum", "m", "momentum of learning at the beginning", "0.5"); addOption("rbmnr", "nr", "rbm to train, < 0 means train all", "-1"); addOption("nrgibbs", "gn", "number of gibbs sampling used in contrastive divergence", "5"); addOption(new DefaultOptionBuilder() .withLongName(DefaultOptionCreator.MAPREDUCE_METHOD) .withRequired(false) .withDescription("Run training with map/reduce") .withShortName("mr").create()); addOption(new DefaultOptionBuilder() .withLongName("nogreedy") .withRequired(false) .withDescription("Don't run greedy pre training") .withShortName("ng").create()); addOption(new DefaultOptionBuilder() .withLongName("nofinetuning") .withRequired(false) .withDescription("Don't run fine tuning at the end") .withShortName("nf").create()); addOption(new DefaultOptionBuilder() .withLongName("nobiases") .withRequired(false) .withDescription("Don't initialize biases") .withShortName("nb").create()); addOption(new DefaultOptionBuilder() .withLongName("monitor") .withRequired(false) .withDescription("If present, errors can be monitored in cosole") .withShortName("mon").create()); addOption(DefaultOptionCreator.overwriteOption().create()); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } Path input = getInputPath(); Path output = getOutputPath(); FileSystem fs = FileSystem.get(output.toUri(),getConf()); labelcount = Integer.parseInt(getOption("labelcount")); boolean local = !hasOption("mapreduce"); monitor = hasOption("monitor"); initbiases = !hasOption("nobiases"); finetuning = !hasOption("nofinetuning"); greedy = !hasOption("nogreedy"); if(fs.isFile(input)) batches = new Path[]{input}; else { FileStatus[] stati = fs.listStatus(input); batches = new Path[stati.length]; for (int i = 0; i < stati.length; i++) { batches[i] = stati[i].getPath(); } } epochs = Integer.valueOf(getOption("epochs")); learningrate = Double.parseDouble(getOption("learningrate")); momentum = Double.parseDouble(getOption("momentum")); rbmNrtoTrain = Integer.parseInt(getOption("rbmnr")); nrGibbsSampling = Integer.parseInt(getOption("nrgibbs")); boolean initialize = hasOption(DefaultOptionCreator.OVERWRITE_OPTION)||!fs.exists(output)||fs.listStatus(output).length<=0; if (initialize) { String structure = getOption("structure"); if(structure==null||structure.isEmpty()) return -1; String[] layers = structure.split(","); if (layers.length<2) { return -1; } int[] actualLayerSizes = new int[layers.length]; for (int i = 0; i < layers.length; i++) { actualLayerSizes[i] = Integer.parseInt(layers[i]); } rbmCl = new RBMClassifier(labelcount, actualLayerSizes); logger.info("New model initialized!"); } else { rbmCl = RBMClassifier.materialize(output, getConf()); logger.info("Model found and materialized!"); } HadoopUtil.setSerializations(getConf()); lastUpdate = new Matrix[rbmCl.getDbm().getRbmCount()]; if(initbiases) { //init biases! Vector biases = null; int counter = 0; for(Path batch : batches) { for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { if(biases==null) biases = record.getSecond().get().clone(); else biases.plus(record.getSecond().get()); counter++; } } if(biases==null) { logger.info("No training data found!"); return -1; } rbmCl.getDbm().getLayer(0).setBiases(biases.divide(counter)); logger.info("Biases initialized"); } //greedy pre training with gradually decreasing learningrates if(greedy) { if(!local) rbmCl.serialize(output, getConf()); double tempLearningrate = learningrate; if(rbmNrtoTrain<0) //train all rbms for(int rbmNr=0; rbmNr<rbmCl.getDbm().getRbmCount(); rbmNr++) { tempLearningrate = learningrate; //double weights if dbm was materialized, because it was halved after greedy pretraining if(!initialize&&rbmNrtoTrain>0&&rbmNrtoTrain<rbmCl.getDbm().getRbmCount()-1) { ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNr)).setWeightMatrix( ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNr)).getWeightMatrix().times(2)); } for (int j = 0; j < epochs; j++) { logger.info("Greedy training, epoch "+(j+1)+"\nCurrent learningrate: "+tempLearningrate); for(int b=0; b<batches.length;b++) { tempLearningrate -= learningrate/(epochs*batches.length+epochs); if(local) { if(!trainGreedySeq(rbmNr, batches[b], j, tempLearningrate)) return -1; } else if(!trainGreedyMR(rbmNr, batches[b], j, tempLearningrate)) return -1; if(monitor&&(batches.length>19)&&(b+1)%(batches.length/20)==0) logger.info(rbmNr+"-RBM: "+Math.round(((double)b+1)/batches.length*100.0)+"% in epoch done!"); } logger.info(Math.round(((double)j+1)/epochs*100)+"% of training on rbm number "+rbmNr+" is done!"); if(monitor) { double error = rbmError(batches[0], rbmNr); logger.info("Average reconstruction error on batch "+batches[0].getName()+": "+error); } rbmCl.serialize(output, getConf()); } //weight normalization to avoid double counting if(rbmNr>0&&rbmNr<rbmCl.getDbm().getRbmCount()-1) { ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5)); } } else { //double weights if dbm was materialized, because it was halved after greedy pretraining if(!initialize&&rbmNrtoTrain>0&&rbmNrtoTrain<rbmCl.getDbm().getRbmCount()-1) { ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(2)); } //train just wanted rbm for (int j = 0; j < epochs; j++) { logger.info("Greedy training, epoch "+(j+1)+"\nCurrent learningrate: "+tempLearningrate); for(int b=0; b<batches.length;b++) { tempLearningrate -= learningrate/(epochs*batches.length+epochs); if(local) { if(!trainGreedySeq(rbmNrtoTrain, batches[b], j,tempLearningrate)) return -1; } else if(!trainGreedyMR(rbmNrtoTrain, batches[b], j,tempLearningrate)) return -1; if(monitor&&(batches.length>19)&&(b+1)%(batches.length/20)==0) logger.info(rbmNrtoTrain+"-RBM: "+Math.round(((double)b+1)/batches.length*100.0)+"% in epoch done!"); } logger.info(Math.round(((double)j+1)/epochs*100)+"% of training is done!"); if(monitor) { double error = rbmError(batches[0], rbmNrtoTrain); logger.info("Average reconstruction error on batch "+batches[0].getName()+": "+error); } } //weight normalization to avoid double counting if(rbmNrtoTrain>0&&rbmNrtoTrain<rbmCl.getDbm().getRbmCount()-1) { ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNrtoTrain)).setWeightMatrix( ((SimpleRBM)rbmCl.getDbm().getRBM(rbmNrtoTrain)).getWeightMatrix().times(0.5)); } } rbmCl.serialize(output, getConf()); logger.info("Pretraining done and model written to output"); } if(finetuning) { DeepBoltzmannMachine multiLayerDbm = null; double tempLearningrate = learningrate; //finetuning job for (int j = 0; j < epochs; j++) { for(int b=0; b<batches.length;b++) { multiLayerDbm = rbmCl.initializeMultiLayerNN(); logger.info("Finetuning on batch "+batches[b].getName()+"\nCurrent learningrate: "+tempLearningrate); tempLearningrate -= learningrate/(epochs*batches.length+epochs); if(local) { if(!finetuneSeq(batches[b], j, multiLayerDbm, tempLearningrate)) return -1; } else if(!fintuneMR(batches[b], j, tempLearningrate)) return -1; logger.info("Finetuning: "+Math.round(((double)b+1)/batches.length*100.0)+"% in epoch done!"); } logger.info(Math.round(((double)j+1)/epochs*100)+"% of training is done!"); if(monitor) { double error = feedForwardError(multiLayerDbm, batches[0]); logger.info("Average discriminative error on batch "+batches[0].getName()+": "+error); } } //final serialization rbmCl.serialize(output, getConf()); logger.info("RBM finetuning done and model written to output"); } if(executor!=null) executor.shutdownNow(); return 0; } /** * The Class BackpropTrainingThread is the callable thread for the local backprop task. */ class BackpropTrainingThread implements Callable<Matrix[]> { /** The dbm. */ private DeepBoltzmannMachine dbm; /** The input. */ private Vector input; /** The label. */ private Vector label; /** The trainer. */ private BackPropTrainer trainer; /** * Instantiates a new backprop training thread. * * @param dbm the dbm * @param label the label * @param input the input * @param trainer the trainer */ public BackpropTrainingThread(DeepBoltzmannMachine dbm, Vector label, Vector input, BackPropTrainer trainer) { this.dbm = dbm; this.label = label; this.input = input; this.trainer = trainer; } /* (non-Javadoc) * @see java.util.concurrent.Callable#call() */ @Override public Matrix[] call() throws Exception { Matrix[] result = trainer.calculateWeightUpdates(dbm, input, label); Matrix[] weightUpdates =new Matrix[dbm.getRbmCount()-1]; //write for each RBM i (key, number of rbm) the result and put together the last two //matrices since they refer to just one labeled rbm, which was split to two for the training for (int i = 0; i < result.length-1; i++) { if(i==result.length-2) { weightUpdates[i] = new DenseMatrix(result[i].rowSize()+result[i+1].columnSize(), result[i].columnSize()); for(int j = 0; j<weightUpdates[i].rowSize(); j++) for(int k = 0; k<weightUpdates[i].columnSize(); k++) { if(j<result[i].rowSize()) weightUpdates[i].set(j, k, result[i].get(j, k)); else weightUpdates[i].set(j, k, result[i+1].get(k, j-result[i].rowSize())); } } else weightUpdates[i]= result[i]; } return weightUpdates; } } /** The backprop training tasks. */ List<BackpropTrainingThread> backpropTrainingTasks; /** * Finetune locally. * * @param batch the batch * @param iteration the iteration * @param multiLayerDbm the multilayer dbm * @param learningrate the learningrate * @return true, if successful * @throws InterruptedException the interrupted exception * @throws ExecutionException the execution exception */ private boolean finetuneSeq(Path batch, int iteration, DeepBoltzmannMachine multiLayerDbm, double learningrate) throws InterruptedException, ExecutionException { Vector label = new DenseVector(labelcount); Map<Integer, Matrix> updates = new HashMap<Integer, Matrix>(); int batchsize = 0; //maximum number of threads that are used, I think 20 is ok int threadCount = 20; Matrix[] weightUpdates; //initialize the tasks, which are run by the executor if(backpropTrainingTasks==null) backpropTrainingTasks = new ArrayList<BackpropTrainingThread>(); //initialize the executor if not already done if(executor==null) executor = Executors.newFixedThreadPool(threadCount); for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { for (int i = 0; i < label.size(); i++) label.setQuick(i, 0); label.set(record.getFirst().get(), 1); BackPropTrainer trainer = new BackPropTrainer(learningrate); //prepare the tasks if(backpropTrainingTasks.size()<threadCount) backpropTrainingTasks.add(new BackpropTrainingThread(multiLayerDbm.clone(), label.clone(), record.getSecond().get(), trainer)); else { backpropTrainingTasks.get(batchsize%threadCount).input = record.getSecond().get(); backpropTrainingTasks.get(batchsize%threadCount).label = label.clone(); if(batchsize<threadCount){ backpropTrainingTasks.get(batchsize%threadCount).dbm = multiLayerDbm.clone(); } } //run the tasks and save results if(batchsize%threadCount==threadCount-1) { List<Future<Matrix[]>> futureUpdates = executor.invokeAll(backpropTrainingTasks); for (int i = 0; i < futureUpdates.size(); i++) { weightUpdates = futureUpdates.get(i).get(); for (int j = 0; j < weightUpdates.length; j++) { if(updates.containsKey(j)) updates.put(j, weightUpdates[j].plus(updates.get(j))); else updates.put(j, weightUpdates[j]); } } } batchsize++; } //run remaining tasks if(batchsize%20!=0) { List<Future<Matrix[]>> futureUpdates = executor.invokeAll(backpropTrainingTasks.subList(0, (batchsize-1) %20)); for (int i = 0; i < futureUpdates.size(); i++) { weightUpdates = futureUpdates.get(i).get(); for (int j = 0; j < weightUpdates.length; j++) { if(updates.containsKey(j)) updates.put(j, weightUpdates[j].plus(updates.get(j))); else updates.put(j, weightUpdates[j]); } } } updateRbmCl(batchsize, (iteration==0)?0:momentum, updates); return true; } /** * Fintune using map/reduce. * * @param batch the batch * @param iteration the iteration * @param learningrate the learningrate * @return true, if successful * @throws IOException Signals that an I/O exception has occurred. * @throws InterruptedException the interrupted exception * @throws ClassNotFoundException the class not found exception */ private boolean fintuneMR(Path batch, int iteration, double learningrate) throws IOException, InterruptedException, ClassNotFoundException { //prepare and run finetune job long batchsize; HadoopUtil.delete(getConf(), getTempPath(WEIGHT_UPDATES)); HadoopUtil.cacheFiles(getOutputPath(), getConf()); Job trainDBM = prepareJob(batch, getTempPath(WEIGHT_UPDATES), SequenceFileInputFormat.class, DBMBackPropTrainingMapper.class, IntWritable.class, MatrixWritable.class, DBMBackPropTrainingReducer.class, IntWritable.class, MatrixWritable.class, SequenceFileOutputFormat.class); trainDBM.getConfiguration().set("labelcount", String.valueOf(labelcount)); trainDBM.getConfiguration().set("learningrate", String.valueOf(learningrate)); trainDBM.setCombinerClass(DBMBackPropTrainingReducer.class); if(!trainDBM.waitForCompletion(true)) return false; batchsize = trainDBM.getCounters().findCounter(DBMBackPropTrainingMapper.BATCHES.SIZE).getValue(); changeAndSaveModel(getOutputPath(), batchsize, (iteration==0)?0:momentum); return true; } /** * The Class GreedyTrainingThread. */ class GreedyTrainingThread implements Callable<Matrix> { /** The dbm. */ private DeepBoltzmannMachine dbm; /** The input. */ private Vector input; /** The label. */ private Vector label; /** The trainer. */ private CDTrainer trainer; /** The rbm nr to train. */ int rbmNr; /** * Instantiates a new greedy training thread. * * @param dbm the dbm * @param label the label * @param input the input * @param trainer the trainer * @param rbmNr the rbm nr */ public GreedyTrainingThread(DeepBoltzmannMachine dbm, Vector label, Vector input, CDTrainer trainer, int rbmNr) { this.dbm = dbm; this.label = label; this.input = input; this.trainer = trainer; this.rbmNr = rbmNr; } /* (non-Javadoc) * @see java.util.concurrent.Callable#call() */ @Override public Matrix call() throws Exception { Matrix updates = null; dbm.getRBM(0).getVisibleLayer().setActivations(input); for(int i = 0; i<rbmNr; i++){ //double the bottom up connection for initialization dbm.getRBM(i).exciteHiddenLayer(2, false); if(i==rbmNr-1) //probabilities as activation for the data the rbm should train on dbm.getRBM(i).getHiddenLayer().setProbabilitiesAsActivation(); else dbm.getRBM(i).getHiddenLayer().updateNeurons(); } if(rbmNr==dbm.getRbmCount()-1) { ((LabeledSimpleRBM)dbm.getRBM(rbmNr)).getSoftmaxLayer().setActivations(label); updates = trainer.calculateWeightUpdates((LabeledSimpleRBM)dbm.getRBM(rbmNr), true, false); } else { updates = trainer.calculateWeightUpdates((SimpleRBM)dbm.getRBM(rbmNr), false, rbmNr==0); } return updates; } } /** The executor. */ private ExecutorService executor; /** The greedy training tasks. */ List<GreedyTrainingThread> greedyTrainingTasks; /** * Train greedy seq. * * @param rbmNr the rbm nr * @param batch the batch * @param iteration the iteration * @param learningrate the learningrate * @return true, if successful * @throws InterruptedException the interrupted exception * @throws ExecutionException the execution exception */ private boolean trainGreedySeq(int rbmNr, Path batch, int iteration, double learningrate) throws InterruptedException, ExecutionException { int batchsize = 0; DeepBoltzmannMachine dbm = rbmCl.getDbm(); Vector label = new DenseVector(labelcount); Matrix updates = null; //number of threads running the tasks int threadCount =20; if(executor==null) executor = Executors.newFixedThreadPool(threadCount); if(greedyTrainingTasks==null) greedyTrainingTasks = new ArrayList<RBMClassifierTrainingJob.GreedyTrainingThread>(); for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { CDTrainer trainer = new CDTrainer(learningrate, nrGibbsSampling); label.assign(0); label.set(record.getFirst().get(), 1); //prepare the tasks if(greedyTrainingTasks.size()<threadCount) greedyTrainingTasks.add(new GreedyTrainingThread(dbm.clone(), label.clone(), record.getSecond().get(), trainer, rbmNr)); else { greedyTrainingTasks.get(batchsize%threadCount).input = record.getSecond().get(); greedyTrainingTasks.get(batchsize%threadCount).label = label.clone(); if(batchsize<threadCount){ greedyTrainingTasks.get(batchsize%threadCount).dbm = dbm.clone(); greedyTrainingTasks.get(batchsize%threadCount).rbmNr = rbmNr; } } //run tasks if(batchsize%threadCount==threadCount-1) { List<Future<Matrix>> futureUpdates = executor.invokeAll(greedyTrainingTasks); for (int i = 0; i < futureUpdates.size(); i++) { if(updates==null) updates = futureUpdates.get(i).get(); else updates = updates.plus(futureUpdates.get(i).get()); } } batchsize++; } //run remaining tasks if(batchsize%20!=0) { List<Future<Matrix>> futureUpdates = executor.invokeAll(greedyTrainingTasks.subList(0, (batchsize-1) %20)); for (int i = 0; i < futureUpdates.size(); i++) { if(updates==null) updates = futureUpdates.get(i).get(); else updates = updates.plus(futureUpdates.get(i).get()); } } Map<Integer,Matrix> updateMap = new HashMap<Integer,Matrix>(); updateMap.put(rbmNr, updates); updateRbmCl(batchsize, (lastUpdate[rbmNr]==null)?0:momentum, updateMap); return true; } /** * Train greedy mr. * * @param rbmNr the rbm nr * @param batch the batch * @param iteration the iteration * @param learningrate the learningrate * @return true, if successful * @throws IOException Signals that an I/O exception has occurred. * @throws InterruptedException the interrupted exception * @throws ClassNotFoundException the class not found exception */ private boolean trainGreedyMR(int rbmNr, Path batch, int iteration, double learningrate) throws IOException, InterruptedException, ClassNotFoundException { //run greedy pretraining as map reduce job long batchsize; HadoopUtil.delete(getConf(), getTempPath(WEIGHT_UPDATES)); HadoopUtil.cacheFiles(getOutputPath(), getConf()); Job trainRBM = prepareJob(batch, getTempPath(WEIGHT_UPDATES), SequenceFileInputFormat.class, RBMGreedyPreTrainingMapper.class, IntWritable.class, MatrixWritable.class, RBMGreedyPreTrainingReducer.class, IntWritable.class, MatrixWritable.class, SequenceFileOutputFormat.class); trainRBM.getConfiguration().set("rbmNr", String.valueOf(rbmNr)); trainRBM.getConfiguration().set("labelcount", String.valueOf(labelcount)); trainRBM.getConfiguration().set("learningrate", String.valueOf(learningrate)); trainRBM.getConfiguration().set("nrGibbsSampling", String.valueOf(nrGibbsSampling)); trainRBM.setCombinerClass(RBMGreedyPreTrainingReducer.class); if(!trainRBM.waitForCompletion(true)) return false; batchsize = trainRBM.getCounters().findCounter(RBMGreedyPreTrainingMapper.BATCH.SIZE).getValue(); changeAndSaveModel(getOutputPath(), batchsize, (lastUpdate[rbmNr]==null)?0:momentum); return true; } /** * calculate classifiers error after 1 iteration of sampling. * * @param batch the batch * @return the error */ @SuppressWarnings("unused") private double classifierError(Path batch) { double error = 0; int counter = 0; Vector scores; for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { scores = rbmCl.classify(record.getSecond().get(),1); error += 1-scores.get(record.getFirst().get()); counter++; } error /= counter; return error; } /** * Calculates error of fann. * * @param feedForwardNet the feed forward net * @param batch the batch * @return the error */ private double feedForwardError(DeepBoltzmannMachine feedForwardNet, Path batch) { double error = 0; int counter = 0; RBMModel currentRBM =null; for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { feedForwardNet.getRBM(0).getVisibleLayer().setActivations(record.getSecond().get()); for(int i = 0; i< feedForwardNet.getRbmCount(); i++) { currentRBM = feedForwardNet.getRBM(i); currentRBM.exciteHiddenLayer(1, false); currentRBM.getHiddenLayer().setProbabilitiesAsActivation(); } error+= 1-currentRBM.getHiddenLayer().getActivations().get(record.getFirst().get()); counter++; } error /= counter; return error; } /** * Rbms reconstruction error. * * @param batch the batch * @param rbmNr the rbm nr * @return the error */ private double rbmError(Path batch, int rbmNr) { DeepBoltzmannMachine dbm = rbmCl.getDbm(); Vector label = new DenseVector(((LabeledSimpleRBM)dbm.getRBM(dbm.getRbmCount()-1)).getSoftmaxLayer().getNeuronCount()); double error = 0; int counter = 0; for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(batch, getConf())) { dbm.getRBM(0).getVisibleLayer().setActivations(record.getSecond().get()); for(int i = 0; i<rbmNr; i++){ //double the bottom up connection for initialization dbm.getRBM(i).exciteHiddenLayer(2, false); if(i==rbmNr-1) dbm.getRBM(i).getHiddenLayer().setProbabilitiesAsActivation(); else dbm.getRBM(i).getHiddenLayer().updateNeurons(); } if(dbm.getRBM(rbmNr) instanceof LabeledSimpleRBM) { label.assign(0); label.set(record.getFirst().get(), 1); ((LabeledSimpleRBM)dbm.getRBM(rbmNr)).getSoftmaxLayer().setActivations(label); } error += dbm.getRBM(rbmNr).getReconstructionError(); counter++; } error/=counter; return error; } /** * Change and save model. * * @param output the output * @param batchsize the batchsize * @param momentum the momentum * @throws IOException Signals that an I/O exception has occurred. */ private void changeAndSaveModel(Path output, long batchsize, double momentum) throws IOException { Map<Integer,Matrix> updates = new HashMap<Integer,Matrix>(); for (Pair<IntWritable, MatrixWritable> record : new SequenceFileDirIterable<IntWritable, MatrixWritable>( getTempPath(WEIGHT_UPDATES), PathType.LIST, PathFilters.partFilter(), getConf())) { if(!updates.containsKey(record.getFirst().get())) updates.put(record.getFirst().get(), record.getSecond().get()); else updates.put(record.getFirst().get(), record.getSecond().get().plus(updates.get(record.getFirst().get()))); } updateRbmCl(batchsize, momentum, updates); //serialization for mappers to have actual version of the dbm rbmCl.serialize(output, getConf()); } /** * Update rbm classifier with given updates. * * @param batchsize the batchsize * @param momentum the momentum * @param updates the updates */ private void updateRbmCl(long batchsize, double momentum, Map<Integer, Matrix> updates) { for(Integer rbmNr : updates.keySet()) { if(momentum>0) updates.put(rbmNr, (updates.get(rbmNr).divide(batchsize).times(1-momentum)). plus(lastUpdate[rbmNr].times(momentum)) ); else updates.put(rbmNr,updates.get(rbmNr).divide(batchsize)); if(rbmNr<rbmCl.getDbm().getRbmCount()-1) { SimpleRBM simpleRBM = (SimpleRBM)rbmCl.getDbm().getRBM(rbmNr); simpleRBM.setWeightMatrix( simpleRBM.getWeightMatrix().plus(updates.get(rbmNr))); }else { LabeledSimpleRBM lrbm = (LabeledSimpleRBM)rbmCl.getDbm().getRBM(rbmNr); int rowSize = lrbm.getWeightMatrix().rowSize(); Matrix weightUpdates = updates.get(rbmNr).viewPart(0, rowSize, 0, updates.get(rbmNr).columnSize()); Matrix weightLabelUpdates = updates.get(rbmNr).viewPart(rowSize, updates.get(rbmNr).rowSize()-rowSize, 0, updates.get(rbmNr).columnSize()); lrbm.setWeightMatrix(lrbm.getWeightMatrix().plus(weightUpdates)); lrbm.setWeightLabelMatrix(lrbm.getWeightLabelMatrix().plus(weightLabelUpdates)); } lastUpdate[rbmNr] = updates.get(rbmNr); } } }