/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.matrix.data; import java.io.Serializable; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.util.ConvolutionUtils; /** * This class is container that stores parameters required for executing following operations: * conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, maxpooling_backward */ public class ConvolutionParameters implements Serializable { private static final long serialVersionUID = -212362627205772829L; public int N; public int C; public int H; public int W; public int K; public int R; public int S; public int stride_h; public int stride_w; public int pad_h; public int pad_w; public int P; public int Q; public int numThreads; public boolean enableNative = false; public MatrixBlock input1; public MatrixBlock input2; public MatrixBlock output; public MatrixBlock bias; public int [] start_indexes_h, end_indexes_h, start_indexes_w, end_indexes_w; private int convertToInt(long val) throws DMLRuntimeException { if( val > Integer.MAX_VALUE ) { throw new DMLRuntimeException("The value for ConvolutionParameters is too large:" + val); } return (int) val; } public boolean compare(ConvolutionParameters that) { if(this.N == that.N && this.C == that.C && this.H == that.H && this.W == that.W && this.K == that.K && this.R == that.R && this.S == that.S && this.stride_h == that.stride_h && this.stride_w == that.stride_w && this.pad_h == that.pad_h && this.pad_w == that.pad_w && this.numThreads == that.numThreads) { return true; } return false; } public String toString() { return "(" + N + " " + C + " " + H + " " + W + " " + K + " " + R + " " + S + ")"; } public ConvolutionParameters(long N, long C, long H, long W, long K, long R, long S, long stride_h, long stride_w, long pad_h, long pad_w, int numThreads) throws DMLRuntimeException { this.N = convertToInt(N); this.C = convertToInt(C); this.H = convertToInt(H); this.W = convertToInt(W); this.K = convertToInt(K); this.R = convertToInt(R); this.S = convertToInt(S); this.stride_h = convertToInt(stride_h); this.stride_w = convertToInt(stride_w); this.pad_h = convertToInt(pad_h); this.pad_w = convertToInt(pad_w); if(H >= 0 && pad_h >= 0 && R >= 0 && stride_h >= 0) P = (int) ((H + 2 * pad_h - R) / stride_h + 1); else P = -1; // P = convertToInt(ConvolutionUtils.getP(H, R, stride_h, pad_h)); if(W >= 0 && pad_w >= 0 && S >= 0 && stride_w >= 0) Q = (int) ((W + 2 * pad_w - S) / stride_w + 1); else Q = -1; // Q = convertToInt(ConvolutionUtils.getQ(W, S, stride_w, pad_w)); this.numThreads = numThreads; } public ConvolutionParameters(int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int numThreads) { this.N = N; this.C = C; this.H = H; this.W = W; this.K = K; this.R = R; this.S = S; this.stride_h = stride_h; this.stride_w = stride_w; this.pad_h = pad_h; this.pad_w = pad_w; P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h); Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w); this.numThreads = numThreads; } public boolean isOutputThreadSafe() { return output.isThreadSafe(); } }