/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.convolution;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.AllocUtil;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
* Created by agibsonccc on 9/6/14.
*/
@RunWith(Parameterized.class)
public class ConvolutionTests extends BaseNd4jTest {
public ConvolutionTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testIm2ColKnownValues() {
//Input: w=3, h=3, depth=2, minibatch = 2
//kh=2, kw=2
/*
----- Input images -----
example 0:
depth 0 depth 1
[ 0 1 2 [ 9 10 11
3 4 5 12 13 14
6 7 8] 15 16 17]
example 1:
[18 19 20 [27 28 29
21 22 23 30 31 32
24 25 26] 33 34 35]
----- Expected Output -----
Shape: [miniBatch,depth,kH,kW,outH,outW]
- example 0 -
depth 0 depth 1
h0,w0 h0,w1 h0,w0 h0,w1
0 1 1 2 9 10 10 11
3 4 4 5 12 13 13 14
h1,w0 h1,w1 h1,w0 h1,w1
3 4 4 5 12 13 13 14
6 7 7 8 15 16 16 17
- example 1 -
depth 0 depth 1
h0,w0 h0,w1 h0,w0 h0,w1
18 19 19 20 27 28 28 29
21 22 22 23 30 31 31 32
h1,w0 h1,w1 h1,w0 h1,w1
21 22 22 23 30 31 31 32
24 25 25 26 33 34 34 35
*/
int miniBatch = 2;
int depth = 2;
int height = 3;
int width = 3;
int outH = 2;
int outW = 2;
int kH = 2;
int kW = 2;
int sX = 1;
int sY = 1;
int pX = 0;
int pY = 0;
//Input data: shape [miniBatch,depth,height,width]
INDArray input = Nd4j.create(new int[] {miniBatch, depth, height, width}, 'c');
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}}));
//Expected data:
INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c');
//Example 0
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{0, 1}, {3, 4}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{1, 2}, {4, 5}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{3, 4}, {6, 7}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{4, 5}, {7, 8}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{9, 10}, {12, 13}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{10, 11}, {13, 14}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{12, 13}, {15, 16}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{13, 14}, {16, 17}}));
//Example 1
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{18, 19}, {21, 22}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{19, 20}, {22, 23}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{21, 22}, {24, 25}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{22, 23}, {25, 26}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{27, 28}, {30, 31}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{28, 29}, {31, 32}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{30, 31}, {33, 34}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{31, 32}, {34, 35}}));
INDArray out = Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false);
assertEquals(expected, out);
//Now: test with a provided results array, where the results array has weird strides
INDArray out2 = Nd4j.create(new int[] {miniBatch, depth, outH, outW, kH, kW}, 'c');
INDArray out2p = out2.permute(0, 1, 4, 5, 2, 3);
Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false, out2p);
assertEquals(expected, out2p);
INDArray out3 = Nd4j.create(new int[] {miniBatch, outH, outW, depth, kH, kW}, 'c');
INDArray out3p = out3.permute(0, 3, 4, 5, 1, 2);
Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false, out3p);
assertEquals(expected, out3p);
}
@Test
public void testIm2ColKnownValuesMiniBatch3() {
//Input: w=3, h=3, depth=2, minibatch = 3
//kh=2, kw=2
/*
----- Input images -----
example 0:
depth 0 depth 1
[ 0 1 2 [ 9 10 11
3 4 5 12 13 14
6 7 8] 15 16 17]
example 1:
[18 19 20 [27 28 29
21 22 23 30 31 32
24 25 26] 33 34 35]
example 2:
[36 37 38 [45 46 47
39 40 41 48 49 50
42 43 44] 51 52 53]
----- Expected Output -----
Shape: [miniBatch,depth,kH,kW,outH,outW]
- example 0 -
depth 0 depth 1
h0,w0 h0,w1 h0,w0 h0,w1
0 1 1 2 9 10 10 11
3 4 4 5 12 13 13 14
h1,w0 h1,w1 h1,w0 h1,w1
3 4 4 5 12 13 13 14
6 7 7 8 15 16 16 17
- example 1 -
depth 0 depth 1
h0,w0 h0,w1 h0,w0 h0,w1
18 19 19 20 27 28 28 29
21 22 22 23 30 31 31 32
h1,w0 h1,w1 h1,w0 h1,w1
21 22 22 23 30 31 31 32
24 25 25 26 33 34 34 35
- example 2 -
depth 0 depth 1
h0,w0 h0,w1 h0,w0 h0,w1
36 37 37 38 45 46 46 47
39 40 40 41 48 49 49 50
h1,w0 h1,w1 h1,w0 h1,w1
39 40 40 41 48 49 49 50
42 43 43 44 51 52 52 53
*/
int miniBatch = 3;
int depth = 2;
int height = 3;
int width = 3;
int outH = 2;
int outW = 2;
int kH = 2;
int kW = 2;
int sX = 1;
int sY = 1;
int pX = 0;
int pY = 0;
//Input data: shape [miniBatch,depth,height,width]
INDArray input = Nd4j.create(new int[] {miniBatch, depth, height, width}, 'c');
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{36, 37, 38}, {39, 40, 41}, {42, 43, 44}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}}));
//Expected data:
INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c');
//Example 0
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{0, 1}, {3, 4}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{1, 2}, {4, 5}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{3, 4}, {6, 7}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{4, 5}, {7, 8}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{9, 10}, {12, 13}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{10, 11}, {13, 14}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{12, 13}, {15, 16}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{13, 14}, {16, 17}}));
//Example 1
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{18, 19}, {21, 22}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{19, 20}, {22, 23}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{21, 22}, {24, 25}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{22, 23}, {25, 26}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{27, 28}, {30, 31}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{28, 29}, {31, 32}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{30, 31}, {33, 34}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{31, 32}, {34, 35}}));
//Example 2
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{36, 37}, {39, 40}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{37, 38}, {40, 41}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{39, 40}, {42, 43}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{40, 41}, {43, 44}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{45, 46}, {48, 49}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{46, 47}, {49, 50}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{48, 49}, {51, 52}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(2), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{49, 50}, {52, 53}}));
INDArray out = Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false);
assertEquals(expected, out);
//Now: test with a provided results array, where the results array has weird strides
INDArray out2 = Nd4j.create(new int[] {miniBatch, depth, outH, outW, kH, kW}, 'c');
INDArray out2p = out2.permute(0, 1, 4, 5, 2, 3);
Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false, out2p);
assertEquals(expected, out2p);
INDArray out3 = Nd4j.create(new int[] {miniBatch, outH, outW, depth, kH, kW}, 'c');
INDArray out3p = out3.permute(0, 3, 4, 5, 1, 2);
Convolution.im2col(input, kH, kW, sY, sX, pY, pX, false, out3p);
assertEquals(expected, out3p);
}
@Test
public void testIm2ColSamePadding() {
//Input: w=3, h=3, depth=2, minibatch = 2, kH/kW = 2, stride=1
//Idea with same padding:
//outH = ceil(inH / strideH)
//outW = ceil(inW / strideW)
int miniBatch = 2;
int depth = 2;
int inH = 3;
int inW = 3;
int strideH = 1;
int strideW = 1;
int kH = 2;
int kW = 2;
int outH = (int) Math.ceil(inH / ((double) strideH));
int outW = (int) Math.ceil(inW / ((double) strideW));
assertEquals(outH, inH);
assertEquals(outW, inW);
int sumPadHeight = ((outH - 1) * strideH + kH - inH);
int padTop = sumPadHeight / 2;
int padBottom = sumPadHeight - padTop;
int sumPadWidth = ((outW - 1) * strideW + kW - inW);
int padLeft = sumPadWidth / 2;
int padRight = sumPadWidth - padLeft;
System.out.println("Output size: " + outH + ", " + outW);
System.out.println("Pad top/bottom: " + padTop + "\t" + padBottom);
System.out.println("Pad left/right: " + padLeft + "\t" + padRight);
/*
----- Input images -----
example 0:
depth 0 depth 1
[ 0 1 2 [ 9 10 11
3 4 5 12 13 14
6 7 8] 15 16 17]
example 1:
[18 19 20 [27 28 29
21 22 23 30 31 32
24 25 26] 33 34 35]
----- Expected Output -----
Shape: [miniBatch,depth,kH,kW,outH,outW]
- example 0 -
depth 0 depth 1
h0,w0 h0,w1 h0,w2 h0,w0 h0,w1 h0,w2
0 1 1 2 2 0 9 10 10 11 11 0
3 4 4 5 5 0 12 13 13 14 14 0
h1,w0 h1,w1 h1,w2 h1,w0 h1,w1 h1,w2
3 4 4 5 5 0 12 13 13 14 14 0
6 7 7 8 8 0 15 16 16 17 17 0
h2,w0 h2,w1 h2,w2 h2,w0 h2,w1 h2,w2
6 7 7 8 8 0 15 16 16 17 17 0
0 0 0 0 0 0 0 0 0 0 0 0
- example 1 -
depth 0 depth 1
h0,w0 h0,w1 h0,w2 h0,w0 h0,w1 h0,w2
18 19 19 20 20 0 27 28 28 29 29 0
21 22 22 23 23 0 30 31 31 32 32 0
h1,w0 h1,w1 h1,w2 h1,w0 h1,w1 h1,w2
21 22 22 23 23 0 30 31 31 32 32 0
24 25 25 26 26 0 33 34 34 35 35 0
h2,w0 h2,w1 h2,w2 h2,w0 h2,w1 h2,w2
24 25 25 26 26 0 33 34 34 35 35 0
0 0 0 0 0 0 0 0 0 0 0 0
*/
//Input data: shape [miniBatch,depth,height,width]
INDArray input = Nd4j.create(new int[] {miniBatch, depth, inH, inW}, 'c');
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{27, 28, 29}, {30, 31, 32}, {33, 34, 35}}));
//Expected data:
INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c');
//Example 0
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{0, 1}, {3, 4}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{1, 2}, {4, 5}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{2, 0}, {5, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{3, 4}, {6, 7}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{4, 5}, {7, 8}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{5, 0}, {8, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{6, 7}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{7, 8}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{8, 0}, {0, 0}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{9, 10}, {12, 13}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{10, 11}, {13, 14}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{11, 0}, {14, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{12, 13}, {15, 16}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{13, 14}, {16, 17}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{14, 0}, {17, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{15, 16}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{16, 17}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{17, 0}, {0, 0}}));
//Example 1
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{18, 19}, {21, 22}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{19, 20}, {22, 23}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{20, 0}, {23, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{21, 22}, {24, 25}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{22, 23}, {25, 26}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{23, 0}, {26, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{24, 25}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{25, 26}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{26, 0}, {0, 0}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{27, 28}, {30, 31}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{28, 29}, {31, 32}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{29, 0}, {32, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{30, 31}, {33, 34}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{31, 32}, {34, 35}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{32, 0}, {35, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{33, 34}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{34, 35}, {0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(1), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.point(2)},
Nd4j.create(new double[][] {{35, 0}, {0, 0}}));
//[miniBatch,depth,kH,kW,outH,outW]
INDArray outAlloc = Nd4j.create(miniBatch, depth, kH, kW, outH, outW);
INDArray out = Convolution.im2col(input, kH, kW, strideH, strideW, padTop, padLeft, true, outAlloc);
// System.out.println("Output shape: " + Arrays.toString(out.shape()));
//
// for( int mb = 0; mb<2; mb++ ){
// for( int d = 0; d<2; d++ ){
// for( int h=0; h<3; h++ ){
// for( int w=0; w<3; w++ ){
// INDArrayIndex[] indx = new INDArrayIndex[]{NDArrayIndex.point(mb),NDArrayIndex.point(d),NDArrayIndex.all(),NDArrayIndex.all(), NDArrayIndex.point(h), NDArrayIndex.point(w)};
// INDArray e = expected.get(indx);
// INDArray a = out.get(indx);
//
// System.out.println("minibatch = " + mb + ", depth = " + depth + ", outY = " + h + ", outX = " + w + "\t" + (e.equals(a) ? "ok" : "FAILED"));
// System.out.println(e);
// System.out.println(a);
// System.out.println("\n-------------------------");
// }
// }
//
// }
// }
assertEquals(expected, out);
//Now: test with a provided results array, where the results array has weird strides
INDArray out2 = Nd4j.create(new int[] {miniBatch, depth, outH, outW, kH, kW}, 'c');
INDArray out2p = out2.permute(0, 1, 4, 5, 2, 3);
Convolution.im2col(input, kH, kW, strideH, strideW, padTop, padLeft, true, out2p);
assertEquals(expected, out2p);
INDArray out3 = Nd4j.create(new int[] {miniBatch, outH, outW, depth, kH, kW}, 'c');
INDArray out3p = out3.permute(0, 3, 4, 5, 1, 2);
Convolution.im2col(input, kH, kW, strideH, strideW, padTop, padLeft, true, out3p);
assertEquals(expected, out3p);
///////////
//Finally: Check col2im with the same shapes. This doesn't check the results, more 'does it crash or not'
INDArray col2imResult = Nd4j.create(input.shape());
INDArray col2im = Convolution.col2im(out, col2imResult, strideH, strideW, padTop, padLeft, inH, inW);
System.out.println(Arrays.toString(col2im.data().asDouble()));
}
@Test
public void testIm2ColSamePaddingStride2() {
//Input: h=3, w=4, depth=2, minibatch = 1, kH/kW = 3, stride=2
//Idea with same padding:
//outH = ceil(inH / strideH)
//outW = ceil(inW / strideW)
int miniBatch = 1;
int depth = 2;
int inH = 3;
int inW = 4;
int strideH = 2;
int strideW = 2;
int kH = 3;
int kW = 3;
int outH = (int) Math.ceil(inH / ((double) strideH));
int outW = (int) Math.ceil(inW / ((double) strideW));
assertEquals(2, outH); //ceil(3/2) = 2
assertEquals(2, outW); //ceil(4/2) = 2
int sumPadHeight = ((outH - 1) * strideH + kH - inH);
int padTop = sumPadHeight / 2;
int padBottom = sumPadHeight - padTop;
assertEquals(1, padTop);
assertEquals(1, padBottom);
int sumPadWidth = ((outW - 1) * strideW + kW - inW);
int padLeft = sumPadWidth / 2;
int padRight = sumPadWidth - padLeft;
assertEquals(0, padLeft);
assertEquals(1, padRight);
System.out.println("Output size: " + outH + ", " + outW);
System.out.println("Pad top/bottom: " + padTop + "\t" + padBottom);
System.out.println("Pad left/right: " + padLeft + "\t" + padRight);
/*
----- Input images -----
example 0:
depth 0 depth 1
[ 0 1 2 3 [12 13 14 15
4 5 6 7 16 17 18 19
8 9 10 11] 20 21 22 23]
----- Expected Output -----
Shape: [miniBatch,depth,kH,kW,outH,outW]
- example 0 -
depth 0 depth 1
h0,w0 h0,w1 h0,w0 h0,w1
0 0 0 0 0 0 0 0 0 0 0 0
0 1 2 2 3 0 12 13 14 14 15 0
4 5 6 6 7 0 16 17 18 18 19 0
h1,w0
4 5 6 6 7 0 16 17 18 18 19 0
8 9 10 10 11 0 20 21 22 22 23 0
0 0 0 0 0 0 0 0 0 0 0 0
*/
//Input data: shape [miniBatch,depth,height,width]
INDArray input = Nd4j.create(new int[] {miniBatch, depth, inH, inW}, 'c');
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()}, Nd4j.create(new double[][] {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}}));
input.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()},
Nd4j.create(new double[][] {{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}));
//Expected data:
INDArray expected = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c');
//Example 0
//depth 0
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{0, 0, 0}, {0, 1, 2}, {4, 5, 6}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{0, 0, 0}, {2, 3, 0}, {6, 7, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{4, 5, 6}, {8, 9, 10}, {0, 0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{6, 7, 0}, {10, 11, 0}, {0, 0, 0}}));
//depth 1
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{0, 0, 0}, {12, 13, 14}, {16, 17, 18}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{0, 0, 0}, {14, 15, 0}, {18, 19, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{16, 17, 18}, {20, 21, 22}, {0, 0, 0}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{18, 19, 0}, {22, 23, 0}, {0, 0, 0}}));
//[miniBatch,depth,kH,kW,outH,outW]
INDArray outAlloc = Nd4j.create(miniBatch, depth, kH, kW, outH, outW);
INDArray out = Convolution.im2col(input, kH, kW, strideH, strideW, padTop, padLeft, true, outAlloc);
// System.out.println("Output shape: " + Arrays.toString(out.shape()));
//
// for( int mb = 0; mb<2; mb++ ){
// for( int d = 0; d<2; d++ ){
// for( int h=0; h<3; h++ ){
// for( int w=0; w<3; w++ ){
// INDArrayIndex[] indx = new INDArrayIndex[]{NDArrayIndex.point(mb),NDArrayIndex.point(d),NDArrayIndex.all(),NDArrayIndex.all(), NDArrayIndex.point(h), NDArrayIndex.point(w)};
// INDArray e = expected.get(indx);
// INDArray a = out.get(indx);
//
// System.out.println("minibatch = " + mb + ", depth = " + depth + ", outY = " + h + ", outX = " + w + "\t" + (e.equals(a) ? "ok" : "FAILED"));
// System.out.println(e);
// System.out.println(a);
// System.out.println("\n-------------------------");
// }
// }
//
// }
// }
assertEquals(expected, out);
//Now: test with a provided results array, where the results array has weird strides
INDArray out2 = Nd4j.create(new int[] {miniBatch, depth, outH, outW, kH, kW}, 'c');
INDArray out2p = out2.permute(0, 1, 4, 5, 2, 3);
Convolution.im2col(input, kH, kW, strideH, strideW, padTop, padLeft, true, out2p);
assertEquals(expected, out2p);
INDArray out3 = Nd4j.create(new int[] {miniBatch, outH, outW, depth, kH, kW}, 'c');
INDArray out3p = out3.permute(0, 3, 4, 5, 1, 2);
Convolution.im2col(input, kH, kW, strideH, strideW, padTop, padLeft, true, out3p);
assertEquals(expected, out3p);
///////////
//Finally: Check col2im with the same shapes. This doesn't check the results, more 'does it crash or not'
INDArray col2imResult = Nd4j.create(input.shape());
INDArray col2im = Convolution.col2im(out, col2imResult, strideH, strideW, padTop, padLeft, inH, inW);
System.out.println(Arrays.toString(col2im.data().asDouble()));
}
@Test
public void testCol2ImSamePaddingStride2() {
//Input: h=3, w=4, depth=2, minibatch = 1, kH/kW = 3, stride=2
//Idea with same padding:
//outH = ceil(inH / strideH)
//outW = ceil(inW / strideW)
int miniBatch = 1;
int depth = 2;
int inH = 3;
int inW = 4;
int strideH = 2;
int strideW = 2;
int kH = 3;
int kW = 3;
int outH = (int) Math.ceil(inH / ((double) strideH));
int outW = (int) Math.ceil(inW / ((double) strideW));
assertEquals(2, outH); //ceil(3/2) = 2
assertEquals(2, outW); //ceil(4/2) = 2
int sumPadHeight = ((outH - 1) * strideH + kH - inH);
int padTop = sumPadHeight / 2;
int padBottom = sumPadHeight - padTop;
assertEquals(1, padTop);
assertEquals(1, padBottom);
int sumPadWidth = ((outW - 1) * strideW + kW - inW);
int padLeft = sumPadWidth / 2;
int padRight = sumPadWidth - padLeft;
assertEquals(0, padLeft);
assertEquals(1, padRight);
System.out.println("Output size: " + outH + ", " + outW);
System.out.println("Pad top/bottom: " + padTop + "\t" + padBottom);
System.out.println("Pad left/right: " + padLeft + "\t" + padRight);
/*
----- Input images -----
example 0:
depth 0 depth 1
[ 0 1 2 3 [12 13 14 15
4 5 6 7 16 17 18 19
8 9 10 11] 20 21 22 23]
----- Expected Output -----
Shape: [miniBatch,depth,kH,kW,outH,outW]
- example 0 -
depth 0 depth 1
h0,w0 h0,w1 h0,w0 h0,w1
0 0 0 0 0 0 0 0 0 0 0 0
0 1 2 2 3 0 12 13 14 14 15 0
4 5 6 6 7 0 16 17 18 18 19 0
h1,w0
4 5 6 6 7 0 16 17 18 18 19 0
8 9 10 10 11 0 20 21 22 22 23 0
0 0 0 0 0 0 0 0 0 0 0 0
*/
/*
Col2im result:
example 0:
depth 0 depth 1
[ 0 1 4 3 [12 13 28 15
8 10 24 14 32 34 72 38
8 9 20 11] 20 21 44 23]
*/
//Input data: shape [miniBatch,depth,height,width]
// INDArray input = Nd4j.create(new int[]{miniBatch,depth,inH,inW},'c');
// input.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.point(0),NDArrayIndex.all(), NDArrayIndex.all()}, Nd4j.create(new double[][]{{0,1,2,3},{4,5,6,7},{8,9,10,11}}));
// input.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.point(1),NDArrayIndex.all(), NDArrayIndex.all()}, Nd4j.create(new double[][]{{12,13,14,15},{16,17,18,19},{20,21,22,23}}));
INDArray col6d = Nd4j.create(new int[] {miniBatch, depth, kH, kW, outH, outW}, 'c');
//Example 0
//depth 0
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{0, 0, 0}, {0, 1, 2}, {4, 5, 6}}));
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{0, 0, 0}, {2, 3, 0}, {6, 7, 0}}));
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{4, 5, 6}, {8, 9, 10}, {0, 0, 0}}));
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{6, 7, 0}, {10, 11, 0}, {0, 0, 0}}));
//depth 1
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{0, 0, 0}, {12, 13, 14}, {16, 17, 18}}));
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{0, 0, 0}, {14, 15, 0}, {18, 19, 0}}));
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(0)},
Nd4j.create(new double[][] {{16, 17, 18}, {20, 21, 22}, {0, 0, 0}}));
col6d.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.point(1)},
Nd4j.create(new double[][] {{18, 19, 0}, {22, 23, 0}, {0, 0, 0}}));
//Expected result:
INDArray expected = Nd4j.create(miniBatch, depth, inH, inW);
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.all()},
Nd4j.create(new double[][] {{0, 1, 4, 3}, {8, 10, 24, 14}, {8, 9, 20, 11}}));
expected.put(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.point(1), NDArrayIndex.all(),
NDArrayIndex.all()},
Nd4j.create(new double[][] {{12, 13, 28, 15}, {32, 34, 72, 38}, {20, 21, 44, 23}}));
INDArray col2imResult = Nd4j.create(miniBatch, depth, inH, inW);
INDArray col2im = Convolution.col2im(col6d, col2imResult, strideH, strideW, padTop, padLeft, inH, inW);
assertEquals(expected, col2im);
}
@Test
public void testConvOutWidthAndHeight() {
int outSize = Convolution.outSize(2, 1, 1, 2, false);
assertEquals(6, outSize);
}
@Test
public void testIm2Col() {
INDArray linspaced = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2);
INDArray ret = Convolution.im2col(linspaced, 1, 1, 1, 1, 2, 2, 0, false);
System.out.println(ret);
}
@Test
@Ignore
public void testCompareIm2ColImpl() {
int[] miniBatches = {1, 3, 5};
int[] depths = {1, 3, 5};
int[] inHeights = {5, 21};
int[] inWidths = {5, 21};
int[] strideH = {1, 2};
int[] strideW = {1, 2};
int[] sizeW = {1, 2, 3};
int[] sizeH = {1, 2, 3};
int[] padH = {0, 1, 2};
int[] padW = {0, 1, 2};
boolean[] coverall = {false, true};
DataBuffer.Type[] types = new DataBuffer.Type[] {DataBuffer.Type.FLOAT, DataBuffer.Type.DOUBLE,
DataBuffer.Type.FLOAT, DataBuffer.Type.DOUBLE};
DataBuffer.AllocationMode[] modes =
new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.HEAP, DataBuffer.AllocationMode.HEAP,
DataBuffer.AllocationMode.DIRECT, DataBuffer.AllocationMode.DIRECT};
String factoryClassName = Nd4j.factory().getClass().toString().toLowerCase();
if (factoryClassName.contains("jcublas") || factoryClassName.contains("cuda")) {
//Only test direct for CUDA; test all for CPU
types = new DataBuffer.Type[] {DataBuffer.Type.FLOAT, DataBuffer.Type.DOUBLE};
modes = new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.DIRECT,
DataBuffer.AllocationMode.DIRECT};
}
DataBuffer.Type initialType = Nd4j.dataType();
for (int i = 0; i < types.length; i++) {
DataBuffer.Type type = types[i];
DataBuffer.AllocationMode mode = modes[i];
DataTypeUtil.setDTypeForContext(type);
Nd4j.alloc = mode;
AllocUtil.setAllocationModeForContext(mode);
for (int m : miniBatches) {
for (int d : depths) {
for (int h : inHeights) {
for (int w : inWidths) {
for (int sh : strideH) {
for (int sw : strideW) {
for (int kh : sizeH) {
for (int kw : sizeW) {
for (int ph : padH) {
for (int pw : padW) {
if ((w - kw + 2 * pw) % sw != 0 || (h - kh + 2 * ph) % sh != 0)
continue; //(w-kp+2*pw)/sw + 1 is not an integer, i.e., number of outputs doesn't fit
System.out.println("Running " + m + " " + d + " " + h + " " + w);
for (boolean cAll : coverall) {
INDArray in = Nd4j.rand(new int[] {m, d, h, w});
//assertEquals(in.data().allocationMode(), mode);
//assertEquals(in.data().dataType(), type);
INDArray outOrig = OldConvolution.im2col(in, kh, kw, sh, sw, ph,
pw, -1, cAll); //Old implementation
INDArray outNew = Convolution.im2col(in, kh, kw, sh, sw, ph, pw,
cAll); //Current implementation
assertArrayEquals(outOrig.data().asFloat(),
outNew.data().asFloat(), 0.01f);
assertEquals(outOrig, outNew);
}
}
}
}
}
}
}
}
}
}
}
}
DataTypeUtil.setDTypeForContext(initialType);
}
@Test
@Ignore
public void testCompareIm2Col() throws Exception {
int[] miniBatches = {1, 3, 5};
int[] depths = {1, 3, 5};
int[] inHeights = {5, 21};
int[] inWidths = {5, 21};
int[] strideH = {1, 2};
int[] strideW = {1, 2};
int[] sizeW = {1, 2, 3};
int[] sizeH = {1, 2, 3};
int[] padH = {0, 1, 2};
int[] padW = {0, 1, 2};
DataBuffer.Type[] types = new DataBuffer.Type[] {DataBuffer.Type.FLOAT, DataBuffer.Type.DOUBLE,
DataBuffer.Type.FLOAT, DataBuffer.Type.DOUBLE};
DataBuffer.AllocationMode[] modes =
new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.HEAP, DataBuffer.AllocationMode.HEAP,
DataBuffer.AllocationMode.DIRECT, DataBuffer.AllocationMode.DIRECT};
String factoryClassName = Nd4j.factory().getClass().toString().toLowerCase();
if (factoryClassName.contains("jcublas") || factoryClassName.contains("cuda")) {
//Only test direct for CUDA; test all for CPU
types = new DataBuffer.Type[] {DataBuffer.Type.FLOAT, DataBuffer.Type.DOUBLE};
modes = new DataBuffer.AllocationMode[] {DataBuffer.AllocationMode.DIRECT,
DataBuffer.AllocationMode.DIRECT};
}
DataBuffer.Type inititalType = Nd4j.dataType();
for (int i = 0; i < types.length; i++) {
DataBuffer.Type type = types[i];
DataBuffer.AllocationMode mode = modes[i];
DataTypeUtil.setDTypeForContext(type);
Nd4j.alloc = mode;
for (int m : miniBatches) {
for (int d : depths) {
for (int h : inHeights) {
for (int w : inWidths) {
for (int sh : strideH) {
for (int sw : strideW) {
for (int kh : sizeH) {
for (int kw : sizeW) {
for (int ph : padH) {
for (int pw : padW) {
System.out.println("Before assertion");
if ((w - kw + 2 * pw) % sw != 0 || (h - kh + 2 * ph) % sh != 0)
continue; //(w-kp+2*pw)/sw + 1 is not an integer, i.e., number of outputs doesn't fit
INDArray in = Nd4j.rand(new int[] {m, d, h, w});
assertEquals(in.data().allocationMode(), mode);
assertEquals(in.data().dataType(), type);
INDArray im2col = Convolution.im2col(in, kh, kw, sh, sw, ph, pw,
false); //Cheating, to get correct shape for input
INDArray imgOutOld =
OldConvolution.col2im(im2col, sh, sw, ph, pw, h, w);
INDArray imgOutNew =
Convolution.col2im(im2col, sh, sw, ph, pw, h, w);
System.out.println("F order test");
System.out.println(imgOutOld);
System.out.println(imgOutNew);
assertEquals(imgOutOld, imgOutNew);
}
}
}
}
}
}
}
}
}
}
}
DataTypeUtil.setDTypeForContext(inititalType);
}
@Test
public void testCol2Im() {
int kh = 1;
int kw = 1;
int sy = 1;
int sx = 1;
int ph = 1;
int pw = 1;
INDArray linspaced = Nd4j.linspace(1, 64, 64).reshape(2, 2, 2, 2, 2, 2);
INDArray newTest = Convolution.col2im(linspaced, sy, sx, ph, pw, 2, 2);
INDArray assertion = OldConvolution.col2im(linspaced, sy, sx, ph, pw, 2, 2);
System.out.println("Ordering of the result, new test: " + newTest.ordering());
System.out.println("Assertion dimensions: " + Arrays.toString(assertion.shape()));
assertEquals(assertion, newTest);
}
@Test
public void testimcolim() {
int nEx = 2;
int depth = 3;
int width = 7;
int height = 7;
int[] kernel = {3, 2};
int[] stride = {2, 3};
int[] padding = {1, 2};
int prod = nEx * depth * width * height;
INDArray in = Nd4j.linspace(1, prod, prod).reshape(nEx, depth, width, height);
INDArray assertim2col = OldConvolution.im2col(in, kernel, stride, padding);
INDArray im2col = Convolution.im2col(in, kernel, stride, padding);
assertEquals(assertim2col, im2col);
INDArray assertcol2im = OldConvolution.col2im(im2col, stride, padding, height, width);
INDArray col2im = Convolution.col2im(im2col, stride, padding, height, width);
assertEquals(assertcol2im, col2im);
}
@Override
public char ordering() {
return 'f';
}
}