inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tensor& shift,
std::string name = "ScaleShift", std::string tag = kBroadcast) {
return tvm::te::compute(
- x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(w); },
+ x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(c); },
name, tag);
}
inline Tensor scale_shift_nhwc(const Tensor& x, const Tensor& scale, const Tensor& shift,
std::string name = "ScaleShift", std::string tag = kBroadcast) {
return tvm::te::compute(
- x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(w); },
+ x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(c); },
name, tag);
}