}
template <typename Dtype>
-inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
+inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {
+ CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
+}
+
+template <typename Dtype>
+inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
- CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}