From e07cca13123446d61576a2f54b9c5a680ff9f364 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Fri, 11 Jan 2019 11:19:56 -0800 Subject: [PATCH] Add backend checks for batch norm (#15955) Summary: Fixes #15826 Changelog: - Add backend checks in `batch_norm_cpu` and `batch_norm_cuda` - Modify check in `checkBackend` to pass on undefined tensors. Differential Revision: D13636410 Pulled By: soumith fbshipit-source-id: 3b1cfe5ca8b7c0346569077163503065e75c2659 --- aten/src/ATen/TensorUtils.cpp | 2 +- aten/src/ATen/native/Normalization.cpp | 2 ++ aten/src/ATen/native/cuda/Normalization.cuh | 8 ++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index c1ee209..d8bbd26 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -196,7 +196,7 @@ void checkAllDefined(CheckedFrom c, ArrayRef ts) { void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) { AT_CHECK( - t.type().backend() == backend, + !t.defined() || t.type().backend() == backend, "Expected tensor to have ", toString(backend), " Backend, but got tensor with ", toString(t.type().backend()), " Backend ", "(while checking arguments for ", c, ")"); diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 85f517f..f5f4cfa 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -461,6 +461,8 @@ std::tuple batch_norm_update_stats_cpu( std::tuple batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) { + checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU); + return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] { if (!train) { return batch_norm_cpu_transform_input_template(self, weight, bias, {}, {}, running_mean, running_var, train, eps); diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index bf3e9a5..0a08129 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -395,6 +395,14 @@ std::tuple batch_norm_cuda_template(const Tensor& input_ const Tensor& running_mean_, const Tensor& running_var_, bool train, double momentum, double epsilon) { + TensorArg input_arg{ input_, "input", 1 }, + weight_arg{ weight_, "weight", 2 }, + bias_arg{ bias_, "bias", 3 }, + run_mean_arg{ running_mean_, "running_mean", 4 }, + run_var_arg{ running_var_, "running_var", 5 }; + CheckedFrom c = "batch_norm_cuda"; + checkAllSameGPU(c, {input_arg, weight_arg, bias_arg, run_mean_arg, run_var_arg}); + using accscalar_t = at::acc_type; int64_t n_input = input_.size(1); Tensor save_mean_; -- 2.7.4