/* * Copyright (c) 2011-2016, Peter Abeles. All Rights Reserved. * * This file is part of BoofCV (http://boofcv.org). * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package boofcv.deepboof; import boofcv.alg.misc.GImageMiscOps; import boofcv.struct.image.GrayF32; import boofcv.struct.image.Planar; import deepboof.Function; import deepboof.graph.FunctionSequence; import deepboof.graph.Node; import deepboof.impl.forward.standard.FunctionLinear_F32; import deepboof.misc.TensorFactory_F32; import deepboof.tensors.Tensor_F32; import org.junit.Test; import java.util.ArrayList; import java.util.List; import java.util.Random; import static deepboof.misc.TensorOps.WI; import static org.junit.Assert.assertTrue; /** * @author Peter Abeles */ public abstract class CheckBaseImageClassifier { protected Random rand = new Random(234); protected int numCategories = 8; /** * Basic test which sees if it blows up. Does not validate quality of results since a fake network * is provided. Regression test is required to validate correctness. * * The real network is not used because it requires downloading external data and can be slow. */ @Test public void checkForBlowUp() { Planar<GrayF32> input = createImage(); GImageMiscOps.fillUniform(input,rand,0,255); BaseImageClassifier classifier = createClassifier(); createDummyNetwork(classifier, input.width, input.height); classifier.classify(input); int best = classifier.getBestResult(); assertTrue(best>=0 && best < numCategories); } public abstract Planar<GrayF32> createImage(); public abstract BaseImageClassifier createClassifier(); private void createDummyNetwork(BaseImageClassifier alg, int width , int height ) { for (int i = 0; i < numCategories; i++) { alg.getCategories().add("Category "+i); } FunctionLinear_F32 function = new FunctionLinear_F32(numCategories); function.initialize(3,height,width); List<Tensor_F32> parameters = new ArrayList<>(); parameters.add(TensorFactory_F32.random(rand,false,function.getParameterShapes().get(0))); parameters.add(TensorFactory_F32.random(rand,false,function.getParameterShapes().get(1))); function.setParameters(parameters); Node<Tensor_F32, Function<Tensor_F32>> node = new Node<>(); node.function = function; List<Node<Tensor_F32, Function<Tensor_F32>>> sequence = new ArrayList<>(); sequence.add(node); alg.network = new FunctionSequence<>(sequence, Tensor_F32.class); alg.tensorOutput = new Tensor_F32(WI(1,alg.network.getOutputShape())); } }