return [reshape_like(grad, orig.args[0])]
+@register_gradient("dyn.reshape")
+def dyn_reshape_grad(orig, grad):
+ """Gradient of dyn_reshape"""
+ return [reshape_like(grad, orig.args[0]), zeros_like(orig.args[1])]
+
+
+@register_gradient("shape_of")
+def shape_of_grad(orig, grad):
+ """Gradient of shape_of"""
+ return [zeros_like(orig.args[0])]
+
+
@register_gradient("cast")
def cast_grad(orig, grad):
x = orig.args[0]
import tvm.relay.op as op
from tvm.relay import Prelude
-
from . import mlp
from . import resnet
from . import resnet_3d
from .py_converter import to_python, run_as_python
from ..transform import gradient
+
def run_opt_pass(expr, opt_pass, import_prelude=False):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype)
-def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0):
+def check_grad(func,
+ inputs=None,
+ test_inputs=None,
+ eps=1e-6,
+ atol=1e-5,
+ rtol=1e-3,
+ scale=None,
+ mean=0):
"""Perform numerical gradient checking given a relay function.
Compare analytical gradients to numerical gradients derived from two-sided approximation. Note
Optional user-provided input parameters to use. If not given, will generate random normal
inputs scaled to be close to the chosen epsilon value to avoid numerical precision loss.
+ test_inputs: List[np.array]
+ The inputs to test for gradient matching. Useful in cases where some inputs are not
+ differentiable, such as symbolic inputs to dynamic ops. If not given, all inputs are
+ tested.
+
eps: float
The epsilon value to use for computing numerical gradient approximation.
# Generate random inputs on the same scale as epsilon to avoid numerical precision loss.
inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params]
+ if test_inputs is None:
+ test_inputs = inputs
+
for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
_, grads = intrp.evaluate(bwd_func)(*inputs)
grads = [grad.asnumpy().astype("float64") for grad in grads]
+ # Throw out gradients we aren't testing
+ if inputs != test_inputs:
+ tmp = []
+ # find the gradient that corresponds to every test input
+ for test_input in test_inputs:
+ for i, grad in enumerate(grads):
+ if inputs[i] is test_input:
+ tmp.append(grad)
+ break
+ grads = tmp
+
# Get numeric gradients for each dimension of each param, using two-sided approximation.
approx_grads = []
- for x in inputs:
+ for x in test_inputs:
approx_grad = np.zeros(x.shape)
for i in np.ndindex(*x.shape):
x_i = x[i]
x[i] = x_i
approx_grad[i] = np.sum((fwd_plus - fwd_minus) / (2 * eps))
approx_grads.append(approx_grad)
-
# Compare gradients by checking that relative difference is below tolerance.
for grad, approx_grad in zip(grads, approx_grads):
np.testing.assert_allclose(grad, approx_grad, atol=atol, rtol=rtol)
def count_ops(expr):
"""count number of times a given op is called in the graph"""
class OpCounter(tvm.relay.ExprVisitor):
+ """OpCounter"""
def visit_call(self, call):
if hasattr(call, 'op'):
self.node_counter[call.op.name] += 1
return super().visit_call(call)
+
def count(self, expr):
self.node_set = {}
self.node_counter = collections.Counter()
self.visit(expr)
return self.node_counter
+
return OpCounter().count(expr)
func = relay.Function([x, y], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
+ x_data = np.ones(shape).astype("float32")
ref_res = np.reshape(x_data, oshape)
+ check_grad(run_infer_type(func),
+ inputs=[x_data, np.array(newshape).astype("int64")],
+ test_inputs=[x_data], eps=1e-3)
verify_func(func, [x_data, np.array(newshape).astype("int64")], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=newshape).astype("float32")
ref_res = np.reshape(x_data, oshape)
+ check_grad(run_infer_type(func),
+ inputs=[x_data, y_data], eps=1e-3)
verify_func(func, [x_data, y_data], ref_res)
verify_reshape((2, 3, 4), (8, 3), (8, 3))
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))
func = relay.Function([x, r], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)
+ reps_data = np.array(reps).astype("float32")
verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
test_dyn_reshape()
test_dyn_shape_reshape()
test_dyn_tile()
- test_dyn_zeros_ones()
\ No newline at end of file
+ test_dyn_zeros_ones()