separate setTensor4dDesc from createTensor4dDesc
authorJonathan L Long <jonlong@cs.berkeley.edu>
Thu, 11 Sep 2014 04:51:58 +0000 (21:51 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 18 Sep 2014 19:41:46 +0000 (12:41 -0700)
This will make it possible to add reshaping to cuDNN layers.

include/caffe/util/cudnn.hpp

index e7ddea7..aef3e21 100644 (file)
@@ -56,22 +56,26 @@ template<> class dataType<double> {
 };
 
 template <typename Dtype>
-inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
+inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) {
+  CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
+}
+
+template <typename Dtype>
+inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
     int n, int c, int h, int w,
     int stride_n, int stride_c, int stride_h, int stride_w) {
-  CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
   CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
       n, c, h, w, stride_n, stride_c, stride_h, stride_w));
 }
 
 template <typename Dtype>
-inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
+inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
     int n, int c, int h, int w) {
   const int stride_w = 1;
   const int stride_h = w * stride_w;
   const int stride_c = h * stride_h;
   const int stride_n = c * stride_c;
-  createTensor4dDesc<Dtype>(desc, n, c, h, w,
+  setTensor4dDesc<Dtype>(desc, n, c, h, w,
       stride_n, stride_c, stride_h, stride_w);
 }