[Relay][Dyn] Add dynamic reshape grad (#6080)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Fri, 17 Jul 2020 22:39:57 +0000 (15:39 -0700)
committerGitHub <noreply@github.com>
Fri, 17 Jul 2020 22:39:57 +0000 (15:39 -0700)
* add dynamic rehape grad

* fix lint

* fix unit tests, warning

python/tvm/relay/op/_tensor_grad.py
python/tvm/relay/testing/__init__.py
tests/python/relay/dyn/test_dynamic_op_level3.py

index 849d0a3..3e87f60 100644 (file)
@@ -514,6 +514,18 @@ def reshape_grad(orig, grad):
     return [reshape_like(grad, orig.args[0])]
 
 
+@register_gradient("dyn.reshape")
+def dyn_reshape_grad(orig, grad):
+    """Gradient of dyn_reshape"""
+    return [reshape_like(grad, orig.args[0]), zeros_like(orig.args[1])]
+
+
+@register_gradient("shape_of")
+def shape_of_grad(orig, grad):
+    """Gradient of shape_of"""
+    return [zeros_like(orig.args[0])]
+
+
 @register_gradient("cast")
 def cast_grad(orig, grad):
     x = orig.args[0]
index a53e9d7..0204e5b 100644 (file)
@@ -26,7 +26,6 @@ import tvm.relay as relay
 import tvm.relay.op as op
 from tvm.relay import Prelude
 
-
 from . import mlp
 from . import resnet
 from . import resnet_3d
@@ -47,6 +46,7 @@ from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
 from .py_converter import to_python, run_as_python
 from ..transform import gradient
 
+
 def run_opt_pass(expr, opt_pass, import_prelude=False):
     assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
@@ -65,7 +65,14 @@ def _np_randn_from_type(t, scale=1, mean=0):
     return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)
 
 
-def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0):
+def check_grad(func,
+               inputs=None,
+               test_inputs=None,
+               eps=1e-6,
+               atol=1e-5,
+               rtol=1e-3,
+               scale=None,
+               mean=0):
     """Perform numerical gradient checking given a relay function.
 
     Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
@@ -80,6 +87,11 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
         Optional user-provided input parameters to use. If not given, will generate random normal
         inputs scaled to be close to the chosen epsilon value to avoid numerical precision loss.
 
+    test_inputs: List[np.array]
+        The inputs to test for gradient matching. Useful in cases where some inputs are not
+        differentiable, such as symbolic inputs to dynamic ops. If not given, all inputs are
+        tested.
+
     eps: float
         The epsilon value to use for computing numerical gradient approximation.
 
@@ -109,6 +121,9 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
         # Generate random inputs on the same scale as epsilon to avoid numerical precision loss.
         inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params]
 
+    if test_inputs is None:
+        test_inputs = inputs
+
     for target, ctx in ctx_list():
         intrp = relay.create_executor(ctx=ctx, target=target)
 
@@ -116,9 +131,20 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
         _, grads = intrp.evaluate(bwd_func)(*inputs)
         grads = [grad.asnumpy().astype("float64") for grad in grads]
 
+        # Throw out gradients we aren't testing
+        if inputs != test_inputs:
+            tmp = []
+            # find the gradient that corresponds to every test input
+            for test_input in test_inputs:
+                for i, grad in enumerate(grads):
+                    if inputs[i] is test_input:
+                        tmp.append(grad)
+                        break
+            grads = tmp
+
         # Get numeric gradients for each dimension of each param, using two-sided approximation.
         approx_grads = []
-        for x in inputs:
+        for x in test_inputs:
             approx_grad = np.zeros(x.shape)
             for i in np.ndindex(*x.shape):
                 x_i = x[i]
@@ -129,7 +155,6 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, me
                 x[i] = x_i
                 approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps))
             approx_grads.append(approx_grad)
-
         # Compare gradients by checking that relative difference is below tolerance.
         for grad, approx_grad in zip(grads, approx_grads):
             np.testing.assert_allclose(grad, approx_grad, atol=atol, rtol=rtol)
@@ -142,13 +167,16 @@ def rand(dtype, *shape):
 def count_ops(expr):
     """count number of times a given op is called in the graph"""
     class OpCounter(tvm.relay.ExprVisitor):
+        """OpCounter"""
         def visit_call(self, call):
             if hasattr(call, 'op'):
                 self.node_counter[call.op.name] += 1
             return super().visit_call(call)
+
         def count(self, expr):
             self.node_set = {}
             self.node_counter = collections.Counter()
             self.visit(expr)
             return self.node_counter
+
     return OpCounter().count(expr)
index e63f9b8..ff98c48 100644 (file)
@@ -44,7 +44,11 @@ def test_dyn_reshape():
 
         func = relay.Function([x, y], z)
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
+        x_data = np.ones(shape).astype("float32")
         ref_res = np.reshape(x_data, oshape)
+        check_grad(run_infer_type(func),
+                   inputs=[x_data, np.array(newshape).astype("int64")],
+                   test_inputs=[x_data], eps=1e-3)
         verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res)
     verify_reshape((2, 3, 4), (8, 3), (8, 3))
     verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
@@ -66,6 +70,8 @@ def test_dyn_shape_reshape():
         x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
         y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
         ref_res = np.reshape(x_data, oshape)
+        check_grad(run_infer_type(func),
+                   inputs=[x_data, y_data], eps=1e-3)
         verify_func(func, [x_data, y_data], ref_res)
     verify_reshape((2, 3, 4), (8, 3), (8, 3))
     verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
@@ -79,6 +85,7 @@ def test_dyn_tile():
         func = relay.Function([x, r], z)
         x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
         ref_res = np.tile(x_data, reps=reps)
+        reps_data = np.array(reps).astype("float32")
         verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
     verify_tile((2, 3, 4), (3, 2, 1))
     verify_tile((2, 3, 4), (1, 2))
@@ -111,4 +118,4 @@ if __name__ == "__main__":
     test_dyn_reshape()
     test_dyn_shape_reshape()
     test_dyn_tile()
-    test_dyn_zeros_ones()
\ No newline at end of file
+    test_dyn_zeros_ones()