/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.rbm.network; import java.io.IOException; import java.util.ArrayList; import java.util.List; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.mahout.classifier.rbm.layer.Layer; import org.apache.mahout.classifier.rbm.layer.SoftmaxLayer; import org.apache.mahout.classifier.rbm.model.LabeledSimpleRBM; import org.apache.mahout.classifier.rbm.model.RBMModel; import org.apache.mahout.classifier.rbm.model.SimpleRBM; import org.apache.mahout.common.ClassUtils; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixWritable; import com.google.common.io.Closeables; /** * A DeepBoltzmannMachine is a (deep belief) neural network consisting of a stack of restricted boltzmann machines. */ public class DeepBoltzmannMachine implements DeepBeliefNetwork, Cloneable{ /** The restricted boltzmann machines where nr 0 is lowest. */ private List<RBMModel> rbms; /** * Instantiates a new deep boltzmann machine. * * @param lowestRBM the lowest rbm */ public DeepBoltzmannMachine(RBMModel lowestRBM) { rbms = new ArrayList<RBMModel>(); rbms.add(lowestRBM); } /** * Put a new RBM on the stack. * * @param rbm the RBM * @return true, if successful */ public boolean stackRBM(RBMModel rbm) { if(rbm.getVisibleLayer().equals(rbms.get(rbms.size()-1).getHiddenLayer())) { rbms.add(rbm); return true; } else return false; } /** * Serialize to the output. * * @param output the output * @param conf the conf * @throws IOException Signals that an I/O exception has occurred. */ public void serialize(Path output, Configuration conf) throws IOException { FileSystem fs = output.getFileSystem(conf); FSDataOutputStream out = fs.create(output, true); try { new IntWritable(rbms.size()).write(out); for (int i = 0; i < rbms.size(); i++) { if(i== 0) out.writeChars(rbms.get(i).getVisibleLayer().getClass().getName()+" "); out.writeChars(rbms.get(i).getHiddenLayer().getClass().getName()+" "); if(i<rbms.size()-1) MatrixWritable.writeMatrix(out, ((SimpleRBM)rbms.get(i)).getWeightMatrix()); else { MatrixWritable.writeMatrix(out, ((LabeledSimpleRBM)rbms.get(i)).getWeightMatrix()); MatrixWritable.writeMatrix(out, ((LabeledSimpleRBM)rbms.get(i)).getWeightLabelMatrix()); } } } finally { Closeables.closeQuietly(out); } } /** * Materialize from input path. * * @param input the input path * @param conf the hadoop config * @return the deep boltzmann machine * @throws IOException Signals that an I/O exception has occurred. */ public static DeepBoltzmannMachine materialize(Path input, Configuration conf) throws IOException { FileSystem fs = input.getFileSystem(conf); String visLayerType = ""; String hidLayerType = ""; FSDataInputStream in = fs.open(input); DeepBoltzmannMachine dbm = null; try { int rbmSize = in.readInt(); for (int i = 0; i < rbmSize; i++) { RBMModel rbm = null; hidLayerType=""; visLayerType=""; char chr; if(i==0) while((chr=in.readChar())!=' ') visLayerType += chr; while((chr=in.readChar())!=' ') hidLayerType += chr; Matrix weightMatrix = MatrixWritable.readMatrix(in); Layer vl; if(i==0) vl = ClassUtils.instantiateAs(visLayerType, Layer.class,new Class[]{int.class},new Object[]{weightMatrix.rowSize()}); else vl = dbm.rbms.get(dbm.getRbmCount()-1).getHiddenLayer(); Layer hl = ClassUtils.instantiateAs(hidLayerType, Layer.class,new Class[]{int.class},new Object[]{weightMatrix.columnSize()}); if(i<rbmSize-1){ rbm = new SimpleRBM(vl, hl); ((SimpleRBM)rbm).setWeightMatrix(weightMatrix); } else { Matrix weightLabelMatrix =MatrixWritable.readMatrix(in); rbm = new LabeledSimpleRBM(vl, hl, new SoftmaxLayer(weightLabelMatrix.rowSize())); ((LabeledSimpleRBM)rbm).setWeightMatrix(weightMatrix); ((LabeledSimpleRBM)rbm).setWeightLabelMatrix(weightLabelMatrix); } if(i==0) dbm = new DeepBoltzmannMachine(rbm); else dbm.stackRBM(rbm); } } finally { Closeables.closeQuietly(in); } return dbm; } /** * Get the i-th RBM. * * @param i the i * @return the rBM */ public RBMModel getRBM(Integer i) { if(i<=rbms.size()) return rbms.get(i); else return null; } /** * Gets the size of the rbm stack. * * @return the stacksize of rbms */ public int getRbmCount() { return rbms.size(); } /** * Gets the layer count. * * @return the layer count */ public int getLayerCount() { return rbms.size()+1; } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#exciteLayer(int) */ @Override public void exciteLayer(int l) { boolean addInput = (l<getRbmCount()); if(addInput) { RBMModel upperRbm = getRBM(l); upperRbm.exciteVisibleLayer(1, false); } if(l>0){ RBMModel lowerRbm = getRBM(l-1); lowerRbm.exciteHiddenLayer(1, addInput); } } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#getLayer(int) */ @Override public Layer getLayer(int l) { if(l<getRbmCount()) return getRBM(l).getVisibleLayer(); return getRBM(l-1).getHiddenLayer(); } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#upPass() */ @Override public void upPass() { for (int i = 0; i < getRbmCount(); i++) { RBMModel rbm = rbms.get(i); rbm.exciteHiddenLayer((i<getRbmCount()-1)?2:1, false); rbm.updateHiddenLayer(); } } /* (non-Javadoc) * @see org.apache.mahout.classifier.rbm.network.DeepBeliefNetwork#updateLayer(int) */ @Override public void updateLayer(int l) { if(l<getRbmCount()){ RBMModel rbm = getRBM(l); rbm.updateVisibleLayer(); } else getRBM(l-1).updateHiddenLayer(); } /* (non-Javadoc) * @see java.lang.Object#clone() */ public DeepBoltzmannMachine clone(){ DeepBoltzmannMachine dbm = new DeepBoltzmannMachine(rbms.get(0).clone()); for (int i = 1; i < rbms.size(); i++) { RBMModel clonedRbm = getRBM(i).clone(); clonedRbm.setVisibleLayer(dbm.getRBM(i-1).getHiddenLayer()); dbm.stackRBM(clonedRbm); } return dbm; } }