/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.rbm;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.rbm.layer.LogisticLayer;
import org.apache.mahout.classifier.rbm.layer.SoftmaxLayer;
import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM;
import org.apache.mahout.classifier.rbm.model.RBMModel;
import org.apache.mahout.classifier.rbm.model.SimpleRBM;
import org.apache.mahout.classifier.rbm.network.DBMStateIterator;
import org.apache.mahout.classifier.rbm.network.DeepBoltzmannMachine;
import org.apache.mahout.math.Vector;
/**
* The Class RBMClassifier is the an implementation of the VectorClassifier interface based on
* the paper: http://www.cs.toronto.edu/~hinton/absps/dbm.pdf.
*/
public class RBMClassifier extends AbstractVectorClassifier implements Cloneable{
/** The dbm which is the actual classifier model. */
private DeepBoltzmannMachine dbm;
/**
* Instantiates a new RBM classifier.
*
* @param numCategories the num categories
* @param layers the sizes of the layers used to initialize the DBM.
*/
public RBMClassifier(int numCategories, int[] layers) {
if(layers.length<2)
return;
RBMModel bottomRbm = null;
if(layers.length>2)
bottomRbm = new SimpleRBM(new LogisticLayer(layers[0]), new LogisticLayer(layers[1]));
else
bottomRbm = new LabeledSimpleRBM(new LogisticLayer(layers[0]), new LogisticLayer(layers[1]), new SoftmaxLayer(numCategories));
dbm = new DeepBoltzmannMachine(bottomRbm);
for(int i=1; i<layers.length-1; i++) {
if(i<layers.length-2)
dbm.stackRBM(new SimpleRBM(dbm.getLayer(i), new LogisticLayer(layers[i+1])));
else
dbm.stackRBM(new LabeledSimpleRBM(dbm.getLayer(i), new LogisticLayer(layers[i+1]), new SoftmaxLayer(numCategories)));
}
}
/**
* Gets the dbm.
*
* @return the dbm
*/
public DeepBoltzmannMachine getDbm() {
return dbm;
}
/* (non-Javadoc)
* @see org.apache.mahout.classifier.AbstractVectorClassifier#numCategories()
*/
@Override
public int numCategories() {
return ((LabeledSimpleRBM)dbm.getRBM(dbm.getRbmCount()-1)).getSoftmaxLayer().getNeuronCount();
}
/* (non-Javadoc)
* @see org.apache.mahout.classifier.AbstractVectorClassifier#classify(org.apache.mahout.math.Vector)
*/
@Override
public Vector classify(Vector instance) {
return classify(instance,5);
}
/**
* Classify: sample until the dbm has sampled (gibbs sampling) n times
* the same output or 5 times n if dbm was not stable for n times in a sequence until then
* where n is specified by the parameters.
*
* @param instance the instance
* @param stableStatesCount number of gibbs samplings = n
* @return vector of scores, same as classify(Vector instance)
*/
public Vector classify(Vector instance, int stableStatesCount) {
dbm.getLayer(0).setActivations(instance);
dbm.upPass();
SoftmaxLayer layer = ((LabeledSimpleRBM)dbm.getRBM(dbm.getRbmCount()-1)).getSoftmaxLayer();
DBMStateIterator.iterateUntilStableLayer(layer, dbm, stableStatesCount);
return layer.getExcitations().clone();//.viewPart(1, excitations.size()-1);
}
/* (non-Javadoc)
* @see org.apache.mahout.classifier.AbstractVectorClassifier#classifyScalar(org.apache.mahout.math.Vector)
*/
@Override
public double classifyScalar(Vector instance) {
return classify(instance).get(0);
}
/**
* Serialize.
*
* @param output the output path of serialization
* @param conf the hadoop configuration
* @throws IOException Signals that an I/O exception has occurred.
*/
public void serialize(Path output, Configuration conf) throws IOException {
dbm.serialize(output, conf);
}
/**
* Materialize.
*
* @param input path to the model
* @param conf the hadoop configuration
* @return the RBM classifier
* @throws IOException Signals that an I/O exception has occurred.
*/
public static RBMClassifier materialize(Path input, Configuration conf) throws IOException {
RBMClassifier cl = new RBMClassifier(0, new int[]{});
cl.dbm = DeepBoltzmannMachine.materialize(input, conf);
return cl;
}
/**
* Initialize multi layer neural network for backpropagation training.
*
* @return the deep boltzmann machine, consisting of a stack of just SimpleRBMs
*/
public DeepBoltzmannMachine initializeMultiLayerNN() {
DeepBoltzmannMachine ret= new DeepBoltzmannMachine(dbm.getRBM(0));
int rbmCount = dbm.getRbmCount();
for (int i = 1; i < rbmCount-1; i++) {
ret.stackRBM(dbm.getRBM(i));
}
LabeledSimpleRBM rbm = (LabeledSimpleRBM)dbm.getRBM(rbmCount-1);
SimpleRBM secondlastRbm = new SimpleRBM(rbm.getVisibleLayer(), rbm.getHiddenLayer(), rbm.getWeightMatrix());
ret.stackRBM(secondlastRbm);
SimpleRBM lastRbm = new SimpleRBM(rbm.getHiddenLayer(), rbm.getSoftmaxLayer(), rbm.getWeightLabelMatrix().transpose());
ret.stackRBM(lastRbm);
return ret;
}
/**
* Gets the current scores.
*
* @return the current scores
*/
public Vector getCurrentScores() {
return ((LabeledSimpleRBM)dbm.getRBM(dbm.getRbmCount()-1)).getSoftmaxLayer().getExcitations();
}
/* (non-Javadoc)
* @see java.lang.Object#clone()
*/
@Override
public RBMClassifier clone() {
RBMClassifier rbmCl = new RBMClassifier(0, new int[]{});
rbmCl.dbm = dbm.clone();
return rbmCl;
}
}