From 2ec7caa073307b546aeea7787238ad4960898354 Mon Sep 17 00:00:00 2001 From: handar423 <47707767+handar423@users.noreply.github.com> Date: Sat, 6 Jun 2020 11:40:17 +0800 Subject: [PATCH] fix small bug about dense_grad (#5695) --- python/tvm/relay/op/_tensor_grad.py | 7 ++++--- tests/python/relay/test_op_grad_level2.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 8ba1020..61488f1 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -472,9 +472,10 @@ def bias_add_grad(orig, grad): def dense_grad(orig, grad): """Returns [grad' @ weight, data @ grad']""" data, weight = orig.args - return [collapse_sum_like(transpose(grad) * weight, data), - collapse_sum_like(data * transpose(grad), weight)] - + return [collapse_sum_like(_nn.dense(grad, transpose(weight), + units=weight.checked_type.shape[1]), data), + collapse_sum_like(_nn.dense(transpose(grad), transpose(data), + units=data.checked_type.shape[1]), weight)] @register_gradient("reshape") def reshape_grad(orig, grad): diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 2b5a1c2..d898451 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -162,6 +162,7 @@ def verify_dense_grad(d_shape, w_shape): def test_dense_grad(): verify_dense_grad((1, 8), (16, 8)) verify_dense_grad((1, 4), (3, 4)) + verify_dense_grad((5, 4), (3, 4)) def verify_batch_flatten_grad(d_shape): -- 2.7.4