[Relay] add test for second order ad (#2754)
author雾雨魔理沙 <lolisa@marisa.moe>
Sat, 30 Mar 2019 00:02:32 +0000 (17:02 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Sat, 30 Mar 2019 00:02:32 +0000 (17:02 -0700)
* do second order

* add comment

* better name

* use tvm assert all close

* refire ci

python/tvm/relay/op/_tensor_grad.py
tests/python/relay/test_pass_gradient.py

index 173e97a..0e79629 100644 (file)
@@ -3,7 +3,7 @@
 from __future__ import absolute_import
 from ..expr import const
 from .op import register_gradient
-from .transform import collapse_sum_like, where
+from .transform import collapse_sum_like, broadcast_to_like, where
 from .tensor import exp, negative, power, less
 from .tensor import zeros_like, ones_like
 
@@ -77,3 +77,20 @@ def divide_grad(orig, grad):
     x, y = orig.args
     return [collapse_sum_like(grad / y, x),
             collapse_sum_like(- (grad * orig / y), y)]
+
+
+@register_gradient("zeros_like")
+def zeros_like_grad(orig, grad):
+    """Returns [0]"""
+    return [orig]
+
+@register_gradient("ones_like")
+def ones_like_grad(orig, grad):
+    """Returns [0]"""
+    return [zeros_like(orig.args[0])]
+
+@register_gradient("collapse_sum_like")
+def collapse_sum_like_grad(orig, grad):
+    """Returns [broadcast_to_like(grad, x), 0]"""
+    x, y = orig.args
+    return [broadcast_to_like(grad, x), zeros_like(y)]
index 400941f..690c82e 100644 (file)
@@ -20,8 +20,8 @@ def test_id():
     ex = create_executor()
     x = rand(dtype, *shape)
     forward, (grad,) = ex.evaluate(back_func)(x)
-    np.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
-    np.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
+    tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
+    tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
 
 
 def test_add():
@@ -35,8 +35,8 @@ def test_add():
     ex = create_executor()
     x = rand(dtype, *shape)
     forward, (grad,) = ex.evaluate(back_func)(x)
-    np.testing.assert_allclose(forward.asnumpy(), 2 * x.asnumpy())
-    np.testing.assert_allclose(grad.asnumpy(), 2 * np.ones_like(x.asnumpy()))
+    tvm.testing.assert_allclose(forward.asnumpy(), 2 * x.asnumpy())
+    tvm.testing.assert_allclose(grad.asnumpy(), 2 * np.ones_like(x.asnumpy()))
 
 
 def test_temp_add():
@@ -51,8 +51,8 @@ def test_temp_add():
     ex = create_executor()
     x = rand(dtype, *shape)
     forward, (grad,) = ex.evaluate(back_func)(x)
-    np.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
-    np.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
+    tvm.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
+    tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
 
 
 def test_sub():
@@ -66,8 +66,8 @@ def test_sub():
     ex = create_executor()
     x = rand(dtype, *shape)
     forward, (grad,) = ex.evaluate(back_func)(x)
-    np.testing.assert_allclose(forward.asnumpy(), np.zeros_like(x.asnumpy()))
-    np.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy()))
+    tvm.testing.assert_allclose(forward.asnumpy(), np.zeros_like(x.asnumpy()))
+    tvm.testing.assert_allclose(grad.asnumpy(), np.zeros_like(x.asnumpy()))
 
 
 def test_broadcast_add():
@@ -90,11 +90,11 @@ def test_broadcast_add():
                                                                      relay.TupleType([t1, t2])]))
     ex = create_executor()
     forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
-    np.testing.assert_allclose(forward.asnumpy(), expected_forward)
-    np.testing.assert_allclose(grad_x.asnumpy(),
-                               np.ones_like(expected_forward).sum(axis=2, keepdims=True))
-    np.testing.assert_allclose(grad_y.asnumpy(),
-                               np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
+    tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
+    tvm.testing.assert_allclose(grad_x.asnumpy(),
+                                np.ones_like(expected_forward).sum(axis=2, keepdims=True))
+    tvm.testing.assert_allclose(grad_y.asnumpy(),
+                                np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
 
 
 def test_broadcast_subtract():
@@ -117,11 +117,11 @@ def test_broadcast_subtract():
                                                                      relay.TupleType([t1, t2])]))
     ex = create_executor()
     forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd)
-    np.testing.assert_allclose(forward.asnumpy(), expected_forward)
-    np.testing.assert_allclose(grad_x.asnumpy(),
-                               np.ones_like(expected_forward).sum(axis=2, keepdims=True))
-    np.testing.assert_allclose(grad_y.asnumpy(),
-                               -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
+    tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
+    tvm.testing.assert_allclose(grad_x.asnumpy(),
+                                np.ones_like(expected_forward).sum(axis=2, keepdims=True))
+    tvm.testing.assert_allclose(grad_y.asnumpy(),
+                                -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0))
 
 
 def test_tuple():
@@ -147,10 +147,10 @@ def test_tuple():
     expected_forward = x_np + y_np - z_np
     ex = create_executor()
     forward, (grad_x, grad_y, grad_z) = ex.evaluate(back_func)(x_nd, y_nd, z_nd)
-    np.testing.assert_allclose(forward.asnumpy(), expected_forward)
-    np.testing.assert_allclose(grad_x.asnumpy(), np.ones_like(grad_x.asnumpy()))
-    np.testing.assert_allclose(grad_y.asnumpy(), np.ones_like(grad_y.asnumpy()))
-    np.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))
+    tvm.testing.assert_allclose(forward.asnumpy(), expected_forward)
+    tvm.testing.assert_allclose(grad_x.asnumpy(), np.ones_like(grad_x.asnumpy()))
+    tvm.testing.assert_allclose(grad_y.asnumpy(), np.ones_like(grad_y.asnumpy()))
+    tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy()))
 
 
 def test_pow():
@@ -168,8 +168,9 @@ def test_pow():
     i_nd = rand(dtype, *shape)
     ex = create_executor(mod=mod)
     forward, (grad_i,) = ex.evaluate(back_func)(i_nd)
-    np.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
-    np.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
+    tvm.testing.assert_allclose(forward.asnumpy(), 8 * i_nd.asnumpy())
+    tvm.testing.assert_allclose(grad_i.asnumpy(), 8 * np.ones_like(grad_i.asnumpy()))
+
 
 def test_ref():
     shape = (10, 10)
@@ -187,8 +188,28 @@ def test_ref():
     x_nd = rand(dtype, *shape)
     ex = create_executor()
     forward, (grad_x,) = ex.evaluate(back_func)(x_nd)
-    np.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
-    np.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
+    tvm.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
+    tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
+
+
+def test_square_second_order():
+    shape = (10, 10)
+    dtype = 'float32'
+    t = relay.TensorType(shape, dtype)
+    x = relay.var("x", t)
+    func = relay.Function([x], x * x)
+    back_func = relay.ir_pass.infer_type(gradient(func))
+    y = relay.var("y", t)
+    back_func_adjusted = relay.Function([y], relay.TupleGetItem(relay.TupleGetItem(back_func(y), 1), 0))
+    back_func_adjusted = relay.ir_pass.infer_type(back_func_adjusted)
+    back_back_func = relay.ir_pass.infer_type(gradient(back_func_adjusted))
+    assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
+    x_nd = rand(dtype, *shape)
+    ex = create_executor()
+    forward, (grad_x,) = ex.evaluate(back_back_func)(x_nd)
+    tvm.testing.assert_allclose(forward.asnumpy(), 2 * x_nd.asnumpy())
+    tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
+
 
 if __name__ == "__main__":
     test_id()
@@ -200,3 +221,4 @@ if __name__ == "__main__":
     test_tuple()
     test_pow()
     test_ref()
+    test_square_second_order()