from . import nn as _nn
from .op import register_gradient
from .reduce import sum as _sum
-from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
+from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal
from .transform import (
broadcast_to_like,
collapse_sum_like,
return [backward_data, backward_weight]
+@register_gradient("max")
+def max_grad(orig, grad):
+ """Returns the gradient of max"""
+ # Only support axis=0, since broadcasting orig to x behaves incorrectly
+ x, axis = orig.args[0], orig.attrs.axis
+ assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
+ orig = broadcast_to_like(orig, x)
+ grad = broadcast_to_like(grad, x)
+ indicators = cast_like(equal(orig, x), grad)
+ return [indicators * grad]
+
+
@register_gradient("nn.softmax")
def softmax_grad(orig, grad):
"""Gradient of softmax"""
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pytest
from tvm import relay
from tvm.relay.testing import check_grad
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
+def test_max_grad():
+ s = (5, 10)
+ t = relay.TensorType(s)
+ x = relay.var("x", t)
+ axis = 0
+ z = relay.max(x, axis)
+
+ fwd_func = relay.Function([x], z)
+ check_grad(fwd_func, eps=1e-7, rtol=1)
+
if __name__ == "__main__":
- test_sum_grad()
+ pytest.main()