Fix the shift column for scale_shift_nchw and scale_shift_nhwc in C topi (#5679)
authortobe <tobeg3oogle@gmail.com>
Wed, 27 May 2020 15:59:02 +0000 (23:59 +0800)
committerGitHub <noreply@github.com>
Wed, 27 May 2020 15:59:02 +0000 (08:59 -0700)
topi/include/topi/nn/mapping.h

index d4a3a47..2bf3314 100644 (file)
@@ -48,7 +48,7 @@ using namespace tvm::te;
 inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tensor& shift,
                                std::string name = "ScaleShift", std::string tag = kBroadcast) {
   return tvm::te::compute(
-      x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(w); },
+      x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(c); },
       name, tag);
 }
 
@@ -66,7 +66,7 @@ inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tenso
 inline Tensor scale_shift_nhwc(const Tensor& x, const Tensor& scale, const Tensor& shift,
                                std::string name = "ScaleShift", std::string tag = kBroadcast) {
   return tvm::te::compute(
-      x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(w); },
+      x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(c); },
       name, tag);
 }