};
template <typename Dtype>
-inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
+inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) {
+ CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
+}
+
+template <typename Dtype>
+inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
- CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}
template <typename Dtype>
-inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
+inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
int n, int c, int h, int w) {
const int stride_w = 1;
const int stride_h = w * stride_w;
const int stride_c = h * stride_h;
const int stride_n = c * stride_c;
- createTensor4dDesc<Dtype>(desc, n, c, h, w,
+ setTensor4dDesc<Dtype>(desc, n, c, h, w,
stride_n, stride_c, stride_h, stride_w);
}