/*
* JavaCV version of OpenCV caffe_googlenet.cpp
* https://github.com/ludv1x/opencv_contrib/blob/master/modules/dnn/samples/caffe_googlenet.cpp
*
* Paolo Bolettieri <paolo.bolettieri@gmail.com>
*/
import static org.bytedeco.javacpp.opencv_core.minMaxLoc;
import static org.bytedeco.javacpp.opencv_dnn.createCaffeImporter;
import static org.bytedeco.javacpp.opencv_imgcodecs.imread;
import static org.bytedeco.javacpp.opencv_imgproc.resize;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.Point;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacpp.opencv_dnn.Blob;
import org.bytedeco.javacpp.opencv_dnn.Importer;
import org.bytedeco.javacpp.opencv_dnn.Net;
public class CaffeGooglenet {
/* Find best class for the blob (i. e. class with maximal probability) */
public static void getMaxClass(Blob probBlob, Point classId, double[] classProb) {
Mat probMat =probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix
minMaxLoc(probMat, null, classProb, null, classId, null);
}
public static List<String> readClassNames() {
String filename = "synset_words.txt";
List<String> classNames = null;
try (BufferedReader br = new BufferedReader(new FileReader(new File(filename)))) {
classNames = new ArrayList<String>();
String name = null;
while ((name = br.readLine()) != null) {
classNames.add(name.substring(name.indexOf(' ')+1));
}
} catch (IOException ex) {
System.err.println("File with classes labels not found " + filename);
System.exit(-1);
}
return classNames;
}
public static void main(String[] args) throws Exception {
String modelTxt = "bvlc_googlenet.prototxt";
String modelBin = "bvlc_googlenet.caffemodel";
String imageFile = (args.length > 0) ? args[0] : "space_shuttle.jpg";
//! [Create the importer of Caffe model]
Importer importer = null;
try { //Try to import Caffe GoogleNet model
importer = createCaffeImporter(modelTxt, modelBin);
} catch (Exception e) { //Importer can throw errors, we will catch them
e.printStackTrace();
}
//! [Create the importer of Caffe model]
if (importer == null) {
System.err.println("Can't load network by using the following files: ");
System.err.println("prototxt: " + modelTxt);
System.err.println("caffemodel: " + modelBin);
System.err.println("bvlc_googlenet.caffemodel can be downloaded here:");
System.err.println("http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel");
System.exit(-1);
}
//! [Initialize network]
Net net = new Net();
importer.populateNet(net);
importer.close(); //We don't need importer anymore
//! [Initialize network]
//! [Prepare blob]
Mat img = imread(imageFile);
if (img.empty()) {
System.err.println("Can't read image from the file: " + imageFile);
System.exit(-1);
}
resize(img, img, new Size(224, 224)); //GoogLeNet accepts only 224x224 RGB-images
Blob inputBlob = new Blob(img); //Convert Mat to dnn::Blob image batch
//! [Prepare blob]
//! [Set input blob]
net.setBlob(".data", inputBlob); //set the network input
//! [Set input blob]
//! [Make forward pass]
net.forward(); //compute output
//! [Make forward pass]
//! [Gather output]
Blob prob = net.getBlob("prob"); //gather output of "prob" layer
Point classId = new Point();
double[] classProb = new double[1];
getMaxClass(prob, classId, classProb);//find the best class
//! [Gather output]
//! [Print results]
List<String> classNames = readClassNames();
System.out.println("Best class: #" + classId.x() + " '" + classNames.get(classId.x()) + "'");
System.out.println("Best class: #" + classId.x());
System.out.println("Probability: " + classProb[0] * 100 + "%");
//! [Print results]
} //main
}