/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.rbm;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Logger;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.VectorWritable;
import com.google.common.io.Closeables;
/**
* The Class MnistPreparer.
*/
public class MnistPreparer extends AbstractJob{
/**
* The main method.
*
* @param args the arguments
* @throws Exception the exception
*/
public static void main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new MnistPreparer(), args);
}
/**
* To only process 44.000 images like the paper [hinton,2006] proposed, choose size 44000
* (http://www.cs.toronto.edu/~hinton/absps/ncfast.pdf)
*
* @param args the args
* @return the int
* @throws Exception the exception
*/
@Override
public int run(String[] args) throws Exception {
addOutputOption();
//chunknumber 600 gives nullpointer exception???
addOption("chunknumber","cnr","number of chunks to be created",true);
addOption("labelpath","l","path to the label file",true);
addOption("imagepath","i","path to image file",true);
addOption("size","s","number of pairs to be processed",true);
Map<String, String> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
return -1;
}
Path output = getOutputPath();
FileSystem fileSystem = output.getFileSystem(getConf());
HadoopUtil.delete(getConf(), getOutputPath());
fileSystem.mkdirs(output);
DataInputStream dataReader = new DataInputStream(
new FileInputStream(new File(getOption("imagepath"))));
DataInputStream labelReader = new DataInputStream(
new FileInputStream(new File(getOption("labelpath"))));
labelReader.skipBytes(8);
dataReader.skipBytes(16);
int label;
IntWritable labelVector = new IntWritable();
VectorWritable imageVector = new VectorWritable(new DenseVector(28*28));
double[] pixels=new double[28*28];
Integer chunks = Integer.parseInt(getOption("chunknumber"));
Integer size = Integer.parseInt(getOption("size"));
SequenceFile.Writer[] writer = new SequenceFile.Writer[chunks];
int writernr=0;
Integer closedwriters = 0;
int cntr = 0;
//counter for the ten labels, each batch should have size/chunks /10(labels) examples of each label
Integer[][] batches = new Integer[chunks][10];
for (int i = 0; i < batches.length; i++) {
for(int j=0; j<10; j++)
batches[i][j]=size/(10*chunks);
}
try {
while(cntr<size) {
writernr =-1;
label = labelReader.readUnsignedByte();
labelVector.set(label);
for (int i = 0; i < pixels.length; i++) {
pixels[i]=Double.valueOf(String.valueOf(dataReader.readUnsignedByte()))/255.0;
}
for(int i = closedwriters; i<chunks; i++) {
if(batches[i][label]>0) {
writernr = i;
//open writers only when they are needed
if(writer[writernr]==null)
writer[writernr] = new Writer(fileSystem, getConf(), new Path(output,"chunk"+i), IntWritable.class, VectorWritable.class);
break;
} else
//close writers, that are opened, yet finished
for(int j=0;j<10;j++) {
if(batches[i][j]!=0)
break;
if(j==9){
writer[i].close();
closedwriters++;
}
}
}
if(closedwriters>=chunks)
break;
if(writernr==-1)
continue;
cntr++;
if(cntr%1000==0)
Logger.getLogger(this.getClass()).info(cntr+" processed pairs");
imageVector.get().assign(pixels);
writer[writernr].append(labelVector, imageVector);
batches[writernr][label]--;
}
}
catch(EOFException ex){
if(writernr>-1)
//close last writer
Closeables.closeQuietly(writer[writernr]);
}
if(writernr>-1)
Closeables.closeQuietly(writer[writernr]);
return 0;
}
}