StreamExecutor support for float64 convolutions and backprop.
authorBrian Patton <bjp@google.com>
Tue, 6 Mar 2018 16:23:04 +0000 (08:23 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 16:29:31 +0000 (08:29 -0800)
PiperOrigin-RevId: 188025477

tensorflow/stream_executor/cuda/cuda_dnn.cc
tensorflow/stream_executor/cuda/cuda_dnn.h
tensorflow/stream_executor/dnn.h
tensorflow/stream_executor/stream.cc
tensorflow/stream_executor/stream.h

index 61cf4ba..0b3b060 100644 (file)
@@ -2281,7 +2281,6 @@ struct ConvDoFP32ComputationFP16Input {
 
 // A group of helper functions to return the internal compute type for
 // convolutions in cudnn.
-// TODO(yangzihao): Add support for float64.
 template <typename T>
 cudnnDataType_t GetConvComputeType() {
   return CUDNN_DATA_FLOAT;
@@ -2296,6 +2295,11 @@ cudnnDataType_t GetConvComputeType<Eigen::half>() {
   }
 }
 
+template <>
+cudnnDataType_t GetConvComputeType<double>() {
+  return CUDNN_DATA_DOUBLE;
+}
+
 }  // namespace
 
 template <class T>
@@ -2324,9 +2328,15 @@ bool CudnnSupport::DoConvolveImpl(
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
   // Alpha is the scaling factor for input.
-  float alpha = 1.0;
+  float falpha = 1.0;
+  double dalpha = 1.0;
+  void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
+                                                : static_cast<void*>(&falpha);
   // Beta is the scaling factor for output.
-  float beta = 0.0;
+  float fbeta = 0.0;
+  double dbeta = 0.0;
+  void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
+                                               : static_cast<void*>(&fbeta);
 
   const bool is_profiling = output_profile_result != nullptr;
   cudnnConvolutionFwdAlgo_t algo;
@@ -2464,11 +2474,11 @@ bool CudnnSupport::DoConvolveImpl(
   }
   status = wrap::cudnnConvolutionForward(
       parent_, ToHandle(dnn_handle_),
-      /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
+      /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
       /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
       /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
       /*algo=*/algo, /*workSpace=*/scratch.opaque(),
-      /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
+      /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta,
       /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
 
   if (is_profiling) {
@@ -2943,10 +2953,14 @@ bool CudnnSupport::DoConvolve(
     const FilterDescriptor& filter_descriptor,
     const DeviceMemory<double>& filter_data,
     const ConvolutionDescriptor& convolution_descriptor,
-    const BatchDescriptor& output_descriptor,
-    DeviceMemory<double>* output_data) {
-  LOG(ERROR) << "double-based DNN not yet implemented";
-  return false;
+    const BatchDescriptor& output_descriptor, DeviceMemory<double>* output_data,
+    ScratchAllocator* scratch_allocator,
+    const dnn::AlgorithmConfig& algorithm_config,
+    dnn::ProfileResult* output_profile_result) {
+  return DoConvolveImpl<double>(
+      stream, batch_descriptor, input_data, filter_descriptor, filter_data,
+      convolution_descriptor, output_descriptor, output_data, scratch_allocator,
+      algorithm_config, output_profile_result);
 }
 
 bool CudnnSupport::DoConvolve(
@@ -3151,10 +3165,17 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
 
+  cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   // Alpha is the scaling factor for input.
-  float alpha = 1.0;
+  float falpha = 1.0;
+  double dalpha = 1.0;
+  void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
+                                                : static_cast<void*>(&falpha);
   // Beta is the scaling factor for output.
-  float beta = 0.0;
+  float fbeta = 0.0;
+  double dbeta = 0.0;
+  void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
+                                               : static_cast<void*>(&fbeta);
 
   // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
   BatchDescriptor output_descriptor;
@@ -3163,7 +3184,6 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
   backward_output_data = MaybeTransformLayout(
       stream, &output_descriptor, backward_output_data, &transform_scratch);
 
-  cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type};
   ScopedTensorDescriptor in_back_nd{parent_, input_descriptor, cudnn_type};
   ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor,
@@ -3310,7 +3330,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
   status = wrap::cudnnConvolutionBackwardData_v3(
 #endif
       parent_, ToHandle(dnn_handle_),
-      /*alpha=*/&alpha,
+      /*alpha=*/alpha,
       /*filterDesc=*/filter.handle(),
       /*filterData=*/filter_data.opaque(),
       /*diffDesc=*/out_back_nd.handle(),
@@ -3319,7 +3339,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
       /*algo=*/algo,
       /*workSpace=*/scratch.opaque(),
       /*workSpaceSizeInBytes=*/scratch.size(),
-      /*beta=*/&beta,
+      /*beta=*/beta,
       /*gradDesc=*/in_back_nd.handle(),
       /*gradData=*/backward_input_data->opaque());
   if (is_profiling) {
@@ -3346,8 +3366,26 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
 
 bool CudnnSupport::DoConvolveBackwardData(
     Stream* stream, const FilterDescriptor& filter_descriptor,
+    const DeviceMemory<double>& filter_data,
+    const BatchDescriptor& output_descriptor,
+    DeviceMemory<double> backward_output_data,
+    const ConvolutionDescriptor& convolution_descriptor,
+    const BatchDescriptor& input_descriptor,
+    DeviceMemory<double>* backward_input_data,
+    ScratchAllocator* scratch_allocator,
+    const dnn::AlgorithmConfig& algorithm_config,
+    dnn::ProfileResult* output_profile_result) {
+  return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
+                                    output_descriptor, backward_output_data,
+                                    convolution_descriptor, input_descriptor,
+                                    backward_input_data, scratch_allocator,
+                                    algorithm_config, output_profile_result);
+}
+
+bool CudnnSupport::DoConvolveBackwardData(
+    Stream* stream, const FilterDescriptor& filter_descriptor,
     const DeviceMemory<float>& filter_data,
-    const BatchDescriptor& output_descriptor_in,
+    const BatchDescriptor& output_descriptor,
     DeviceMemory<float> backward_output_data,
     const ConvolutionDescriptor& convolution_descriptor,
     const BatchDescriptor& input_descriptor,
@@ -3356,7 +3394,7 @@ bool CudnnSupport::DoConvolveBackwardData(
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
-                                    output_descriptor_in, backward_output_data,
+                                    output_descriptor, backward_output_data,
                                     convolution_descriptor, input_descriptor,
                                     backward_input_data, scratch_allocator,
                                     algorithm_config, output_profile_result);
@@ -3365,7 +3403,7 @@ bool CudnnSupport::DoConvolveBackwardData(
 bool CudnnSupport::DoConvolveBackwardData(
     Stream* stream, const FilterDescriptor& filter_descriptor,
     const DeviceMemory<Eigen::half>& filter_data,
-    const BatchDescriptor& output_descriptor_in,
+    const BatchDescriptor& output_descriptor,
     DeviceMemory<Eigen::half> backward_output_data,
     const ConvolutionDescriptor& convolution_descriptor,
     const BatchDescriptor& input_descriptor,
@@ -3374,7 +3412,7 @@ bool CudnnSupport::DoConvolveBackwardData(
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
-                                    output_descriptor_in, backward_output_data,
+                                    output_descriptor, backward_output_data,
                                     convolution_descriptor, input_descriptor,
                                     backward_input_data, scratch_allocator,
                                     algorithm_config, output_profile_result);
@@ -3398,10 +3436,17 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
     LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
   }
 
+  cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   // Alpha is the scaling factor for input.
-  float alpha = 1.0;
+  float falpha = 1.0;
+  double dalpha = 1.0;
+  void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
+                                                : static_cast<void*>(&falpha);
   // Beta is the scaling factor for output.
-  float beta = 0.0;
+  float fbeta = 0.0;
+  double dbeta = 0.0;
+  void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
+                                               : static_cast<void*>(&fbeta);
 
   // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
   BatchDescriptor output_descriptor;
@@ -3410,7 +3455,6 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
   backward_output_data = MaybeTransformLayout(
       stream, &output_descriptor, backward_output_data, &transform_scratch);
 
-  cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
   ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type};
   ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type};
   ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor,
@@ -3557,7 +3601,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
 #else
   status = wrap::cudnnConvolutionBackwardFilter_v3(
 #endif
-      parent_, ToHandle(dnn_handle_), /*alpha=*/&alpha,
+      parent_, ToHandle(dnn_handle_), /*alpha=*/alpha,
       /*srcDesc=*/input_nd.handle(),
       /*srcData=*/input_data.opaque(),
       /*diffDesc=*/out_back_nd.handle(),
@@ -3566,7 +3610,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
       /*algo=*/algo,
       /*workSpace=*/scratch.opaque(),
       /*workSpaceSizeInBytes=*/scratch.size(),
-      /*beta=*/&beta,
+      /*beta=*/beta,
       /*gradDesc=*/filter.handle(),
       /*gradData=*/backward_filter_data->opaque());
 
@@ -3594,8 +3638,26 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
 
 bool CudnnSupport::DoConvolveBackwardFilter(
     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+    const DeviceMemory<double>& input_data,
+    const dnn::BatchDescriptor& output_descriptor,
+    DeviceMemory<double> backward_output_data,
+    const dnn::ConvolutionDescriptor& convolution_descriptor,
+    const dnn::FilterDescriptor& filter_descriptor,
+    DeviceMemory<double>* backward_filter_data,
+    ScratchAllocator* scratch_allocator,
+    const dnn::AlgorithmConfig& algorithm_config,
+    dnn::ProfileResult* output_profile_result) {
+  return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+                                      output_descriptor, backward_output_data,
+                                      convolution_descriptor, filter_descriptor,
+                                      backward_filter_data, scratch_allocator,
+                                      algorithm_config, output_profile_result);
+}
+
+bool CudnnSupport::DoConvolveBackwardFilter(
+    Stream* stream, const dnn::BatchDescriptor& input_descriptor,
     const DeviceMemory<float>& input_data,
-    const dnn::BatchDescriptor& output_descriptor_in,
+    const dnn::BatchDescriptor& output_descriptor,
     DeviceMemory<float> backward_output_data,
     const dnn::ConvolutionDescriptor& convolution_descriptor,
     const dnn::FilterDescriptor& filter_descriptor,
@@ -3603,17 +3665,17 @@ bool CudnnSupport::DoConvolveBackwardFilter(
     ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
-  return DoConvolveBackwardFilterImpl(
-      stream, input_descriptor, input_data, output_descriptor_in,
-      backward_output_data, convolution_descriptor, filter_descriptor,
-      backward_filter_data, scratch_allocator, algorithm_config,
-      output_profile_result);
+  return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+                                      output_descriptor, backward_output_data,
+                                      convolution_descriptor, filter_descriptor,
+                                      backward_filter_data, scratch_allocator,
+                                      algorithm_config, output_profile_result);
 }
 
 bool CudnnSupport::DoConvolveBackwardFilter(
     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
     const DeviceMemory<Eigen::half>& input_data,
-    const dnn::BatchDescriptor& output_descriptor_in,
+    const dnn::BatchDescriptor& output_descriptor,
     DeviceMemory<Eigen::half> backward_output_data,
     const dnn::ConvolutionDescriptor& convolution_descriptor,
     const dnn::FilterDescriptor& filter_descriptor,
@@ -3621,11 +3683,11 @@ bool CudnnSupport::DoConvolveBackwardFilter(
     ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
-  return DoConvolveBackwardFilterImpl(
-      stream, input_descriptor, input_data, output_descriptor_in,
-      backward_output_data, convolution_descriptor, filter_descriptor,
-      backward_filter_data, scratch_allocator, algorithm_config,
-      output_profile_result);
+  return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+                                      output_descriptor, backward_output_data,
+                                      convolution_descriptor, filter_descriptor,
+                                      backward_filter_data, scratch_allocator,
+                                      algorithm_config, output_profile_result);
 }
 
 template <class T>
index 40aa974..48d56f7 100644 (file)
@@ -259,7 +259,10 @@ class CudnnSupport : public dnn::DnnSupport {
                   const DeviceMemory<double>& filter_data,
                   const dnn::ConvolutionDescriptor& convolution_descriptor,
                   const dnn::BatchDescriptor& output_descriptor,
-                  DeviceMemory<double>* output_data) override;
+                  DeviceMemory<double>* output_data,
+                  ScratchAllocator* scratch_allocator,
+                  const dnn::AlgorithmConfig& algorithm_config,
+                  dnn::ProfileResult* output_profile_result) override;
 
   bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
                   const DeviceMemory<Eigen::half>& input_data,
@@ -373,6 +376,18 @@ class CudnnSupport : public dnn::DnnSupport {
 
   bool DoConvolveBackwardData(
       Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
+      const DeviceMemory<double>& filter_data,
+      const dnn::BatchDescriptor& output_descriptor,
+      DeviceMemory<double> backward_output_data,
+      const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const dnn::BatchDescriptor& input_descriptor,
+      DeviceMemory<double>* backward_input_data,
+      ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) override;
+
+  bool DoConvolveBackwardData(
+      Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
       const DeviceMemory<float>& filter_data,
       const dnn::BatchDescriptor& output_descriptor,
       DeviceMemory<float> backward_output_data,
@@ -397,6 +412,18 @@ class CudnnSupport : public dnn::DnnSupport {
 
   bool DoConvolveBackwardFilter(
       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+      const DeviceMemory<double>& input_data,
+      const dnn::BatchDescriptor& output_descriptor,
+      DeviceMemory<double> backward_output_data,
+      const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const dnn::FilterDescriptor& filter_descriptor,
+      DeviceMemory<double>* backward_filter_data,
+      ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) override;
+
+  bool DoConvolveBackwardFilter(
+      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
       const DeviceMemory<float>& input_data,
       const dnn::BatchDescriptor& output_descriptor,
       DeviceMemory<float> backward_output_data,
index aa88fe7..b41536e 100644 (file)
@@ -1172,7 +1172,9 @@ class DnnSupport {
       const DeviceMemory<double>& filter_data,
       const dnn::ConvolutionDescriptor& convolution_descriptor,
       const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<double>* output_data) = 0;
+      DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) = 0;
 
   // Enqueues a half-precision convolution operation onto the stream.
   // See DoConvolve above for argument details.
@@ -1275,6 +1277,18 @@ class DnnSupport {
 
   virtual bool DoConvolveBackwardData(
       Stream* stream, const FilterDescriptor& filter_descriptor,
+      const DeviceMemory<double>& filter_data,
+      const BatchDescriptor& output_descriptor,
+      DeviceMemory<double> backward_output_data,
+      const ConvolutionDescriptor& convolution_descriptor,
+      const BatchDescriptor& input_descriptor,
+      DeviceMemory<double>* backward_input_data,
+      ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      ProfileResult* output_profile_result) = 0;
+
+  virtual bool DoConvolveBackwardData(
+      Stream* stream, const FilterDescriptor& filter_descriptor,
       const DeviceMemory<Eigen::half>& filter_data,
       const BatchDescriptor& output_descriptor,
       DeviceMemory<Eigen::half> backward_output_data,
@@ -1324,6 +1338,18 @@ class DnnSupport {
 
   virtual bool DoConvolveBackwardFilter(
       Stream* stream, const BatchDescriptor& input_descriptor,
+      const DeviceMemory<double>& input_data,
+      const BatchDescriptor& output_descriptor,
+      DeviceMemory<double> backward_output_data,
+      const ConvolutionDescriptor& convolution_descriptor,
+      const FilterDescriptor& filter_descriptor,
+      DeviceMemory<double>* backward_filter_data,
+      ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      ProfileResult* output_profile_result) = 0;
+
+  virtual bool DoConvolveBackwardFilter(
+      Stream* stream, const BatchDescriptor& input_descriptor,
       const DeviceMemory<Eigen::half>& input_data,
       const BatchDescriptor& output_descriptor,
       DeviceMemory<Eigen::half> backward_output_data,
index ba5001e..4d852e6 100644 (file)
@@ -683,6 +683,37 @@ Stream &Stream::ThenFusedConvolveWithAlgorithm(
 
 Stream &Stream::ThenConvolveWithAlgorithm(
     const dnn::BatchDescriptor &input_descriptor,
+    const DeviceMemory<double> &input_data,
+    const dnn::FilterDescriptor &filter_descriptor,
+    const DeviceMemory<double> &filter_data,
+    const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
+    ScratchAllocator *scratch_allocator,
+    const dnn::AlgorithmConfig &algorithm_config,
+    dnn::ProfileResult *output_profile_result) {
+  VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+            PARAM(filter_descriptor), PARAM(filter_data),
+            PARAM(convolution_descriptor), PARAM(output_descriptor),
+            PARAM(output), PARAM(algorithm_config));
+
+  if (ok()) {
+    if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+      auto status = dnn->DoConvolve(
+          this, input_descriptor, input_data, filter_descriptor, filter_data,
+          convolution_descriptor, output_descriptor, output, scratch_allocator,
+          algorithm_config, output_profile_result);
+      if (!status && !output_profile_result) {
+        SetError();
+      }
+    } else {
+      SetErrorAndLogNoDnnSupport();
+    }
+  }
+  return *this;
+}
+
+Stream &Stream::ThenConvolveWithAlgorithm(
+    const dnn::BatchDescriptor &input_descriptor,
     const DeviceMemory<float> &input_data,
     const dnn::FilterDescriptor &filter_descriptor,
     const DeviceMemory<float> &filter_data,
@@ -892,6 +923,39 @@ Stream &Stream::ThenConvolveBackwardDataWithScratch(
 
 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
     const dnn::FilterDescriptor &filter_descriptor,
+    const DeviceMemory<double> &filter_data,
+    const dnn::BatchDescriptor &output_descriptor,
+    DeviceMemory<double> backward_output_data,
+    const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const dnn::BatchDescriptor &input_descriptor,
+    DeviceMemory<double> *backward_input_data,
+    ScratchAllocator *scratch_allocator,
+    const dnn::AlgorithmConfig &algorithm_config,
+    dnn::ProfileResult *output_profile_result) {
+  VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
+            PARAM(output_descriptor), PARAM(backward_output_data),
+            PARAM(convolution_descriptor), PARAM(input_descriptor),
+            PARAM(backward_input_data));
+
+  if (ok()) {
+    if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+      auto status = dnn->DoConvolveBackwardData(
+          this, filter_descriptor, filter_data, output_descriptor,
+          backward_output_data, convolution_descriptor, input_descriptor,
+          backward_input_data, scratch_allocator, algorithm_config,
+          output_profile_result);
+      if (!status && !output_profile_result) {
+        SetError();
+      }
+    } else {
+      SetErrorAndLogNoDnnSupport();
+    }
+  }
+  return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
+    const dnn::FilterDescriptor &filter_descriptor,
     const DeviceMemory<float> &filter_data,
     const dnn::BatchDescriptor &output_descriptor,
     DeviceMemory<float> backward_output_data,
@@ -1028,6 +1092,39 @@ Stream &Stream::ThenConvolveBackwardFilterWithScratch(
 
 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
     const dnn::BatchDescriptor &input_descriptor,
+    const DeviceMemory<double> &input_data,
+    const dnn::BatchDescriptor &output_descriptor,
+    DeviceMemory<double> backward_output_data,
+    const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const dnn::FilterDescriptor &filter_descriptor,
+    DeviceMemory<double> *backward_filter_data,
+    ScratchAllocator *scratch_allocator,
+    const dnn::AlgorithmConfig &algorithm_config,
+    dnn::ProfileResult *output_profile_result) {
+  VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+            PARAM(output_descriptor), PARAM(backward_output_data),
+            PARAM(convolution_descriptor), PARAM(filter_descriptor),
+            PARAM(backward_filter_data));
+
+  if (ok()) {
+    if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+      auto status = dnn->DoConvolveBackwardFilter(
+          this, input_descriptor, input_data, output_descriptor,
+          backward_output_data, convolution_descriptor, filter_descriptor,
+          backward_filter_data, scratch_allocator, algorithm_config,
+          output_profile_result);
+      if (!status && !output_profile_result) {
+        SetError();
+      }
+    } else {
+      SetErrorAndLogNoDnnSupport();
+    }
+  }
+  return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
+    const dnn::BatchDescriptor &input_descriptor,
     const DeviceMemory<float> &input_data,
     const dnn::BatchDescriptor &output_descriptor,
     DeviceMemory<float> backward_output_data,
index a2fb2ea..8cd0a0d 100644 (file)
@@ -360,6 +360,17 @@ class Stream {
 
   Stream &ThenConvolveWithAlgorithm(
       const dnn::BatchDescriptor &input_descriptor,
+      const DeviceMemory<double> &input_data,
+      const dnn::FilterDescriptor &filter_descriptor,
+      const DeviceMemory<double> &filter_data,
+      const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const dnn::BatchDescriptor &output_descriptor,
+      DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
+      const dnn::AlgorithmConfig &algorithm_config,
+      dnn::ProfileResult *output_profile_result);
+
+  Stream &ThenConvolveWithAlgorithm(
+      const dnn::BatchDescriptor &input_descriptor,
       const DeviceMemory<float> &input_data,
       const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory<float> &filter_data,
@@ -478,6 +489,18 @@ class Stream {
 
   Stream &ThenConvolveBackwardDataWithAlgorithm(
       const dnn::FilterDescriptor &filter_descriptor,
+      const DeviceMemory<double> &filter_data,
+      const dnn::BatchDescriptor &output_descriptor,
+      DeviceMemory<double> backward_output_data,
+      const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const dnn::BatchDescriptor &input_descriptor,
+      DeviceMemory<double> *backward_input_data,
+      ScratchAllocator *scratch_allocator,
+      const dnn::AlgorithmConfig &algorithm_config,
+      dnn::ProfileResult *output_profile_result);
+
+  Stream &ThenConvolveBackwardDataWithAlgorithm(
+      const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory<float> &filter_data,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory<float> backward_output_data,
@@ -531,6 +554,18 @@ class Stream {
 
   Stream &ThenConvolveBackwardFilterWithAlgorithm(
       const dnn::BatchDescriptor &input_descriptor,
+      const DeviceMemory<double> &input_data,
+      const dnn::BatchDescriptor &output_descriptor,
+      DeviceMemory<double> backward_output_data,
+      const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const dnn::FilterDescriptor &filter_descriptor,
+      DeviceMemory<double> *backward_filter_data,
+      ScratchAllocator *scratch_allocator,
+      const dnn::AlgorithmConfig &algorithm_config,
+      dnn::ProfileResult *output_profile_result);
+
+  Stream &ThenConvolveBackwardFilterWithAlgorithm(
+      const dnn::BatchDescriptor &input_descriptor,
       const DeviceMemory<float> &input_data,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory<float> backward_output_data,