From 7d71dd8bd7b9b6412568a6860161566071588055 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 4 Oct 2019 17:24:55 -0700 Subject: [PATCH] [Relay][Training] Add gradient for Crossentropy (#3925) * save save redo max test save address comment fix * address comment * increase rtol * address review comment --- python/tvm/relay/op/_reduce.py | 1 + python/tvm/relay/op/_tensor_grad.py | 23 +++++++++++++- python/tvm/relay/op/nn/_nn.py | 9 ++++++ python/tvm/relay/op/nn/nn.py | 19 ++++++++++++ python/tvm/relay/testing/__init__.py | 16 +++++++--- src/relay/op/nn/nn.cc | 49 ++++++++++++++++++++++++++++++ tests/python/relay/test_op_grad_level10.py | 28 +++++++++++++++++ tests/python/relay/test_op_grad_level4.py | 4 +-- tests/python/relay/test_op_level5.py | 2 +- 9 files changed, 143 insertions(+), 8 deletions(-) create mode 100644 tests/python/relay/test_op_grad_level10.py diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index b6c05b1..f6b699f 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -36,3 +36,4 @@ _reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce) _reg.register_schedule("mean", _schedule_reduce) _reg.register_schedule("variance", _schedule_reduce) +_reg.register_schedule("nn.cross_entropy", _schedule_reduce) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 5fe2dd4..fe22f45 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -25,7 +25,18 @@ from ..expr import Tuple, TupleGetItem, const from . import nn as _nn from .op import register_gradient from .reduce import sum as _sum -from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like, equal +from .tensor import ( + cos, + exp, + less, + negative, + ones_like, + power, + sin, + zeros_like, + equal, + shape_of, + log) from .transform import ( broadcast_to_like, collapse_sum_like, @@ -33,6 +44,7 @@ from .transform import ( reshape, reshape_like, strided_slice, + take, tile, transpose, where, @@ -353,3 +365,12 @@ def sum_grad(orig, grad): """Returns grad broadcasted to data dims""" data = orig.args[0] return [broadcast_to_like(grad, data)] + + +@register_gradient("nn.cross_entropy") +def cross_entropy_grad(orig, grad): + x, y = orig.args + shape = shape_of(x) + batch_size = take(shape, const(0, dtype='int32'), axis=0) + grad = grad / batch_size.astype('float32') + return [-grad * y / x, -grad * log(x)] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 7059467..8c09390 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -745,3 +745,12 @@ def schedule_bitserial_dense(attrs, outputs, target): reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + +reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) + + +@reg.register_compute("nn.cross_entropy") +def compute_cross_entropy(attrs, inputs, out_dtype, target): + x, y = inputs + return [-topi.sum(topi.log(x) * y) / x.shape[0]] diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 43ae06f..31c1006 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1758,3 +1758,22 @@ def bitserial_dense(data, """ return _make.bitserial_dense(data, weight, units, data_bits, weight_bits, pack_dtype, out_dtype, unipolar) + + +def cross_entropy(predictions, targets): + """CrossEntropy without logits. + + Parameters + ---------- + predictions : tvm.relay.Expr + The predictions. + + targets : tvm.relay.Expr + The targets. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.cross_entropy(predictions, targets) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 481f35e..84b30d3 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -56,11 +56,11 @@ def run_infer_type(expr): return run_opt_pass(expr, transform.InferType()) -def _np_randn_from_type(t, scale=1): - return (scale * np.random.randn(*(int(d) for d in t.shape))).astype(t.dtype) +def _np_randn_from_type(t, scale=1, mean=0): + return (mean + (scale * np.random.randn(*(int(d) for d in t.shape)))).astype(t.dtype) -def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3): +def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3, scale=None, mean=0): """Perform numerical gradient checking given a relay function. Compare analytical gradients to numerical gradients derived from two-sided approximation. Note @@ -86,15 +86,23 @@ def check_grad(func, inputs=None, eps=1e-6, atol=1e-5, rtol=1e-3): The relative tolerance on difference between numerical and analytical gradients. Note that this needs to be scaled appropriately relative to the chosen eps. + scale: float + The standard deviation of the inputs. + + mean: float + The mean of the inputs. """ fwd_func = run_infer_type(func) bwd_func = run_infer_type(gradient(fwd_func)) + if scale is None: + scale = 10 * eps + if inputs is None: params = fwd_func.params # Generate random inputs on the same scale as epsilon to avoid numerical precision loss. - inputs = [_np_randn_from_type(x.checked_type, scale=(10 * eps)) for x in params] + inputs = [_np_randn_from_type(x.checked_type, scale=scale, mean=mean) for x in params] for target, ctx in ctx_list(): intrp = relay.create_executor(ctx=ctx, target=target) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 9115c69..a875ffc 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -817,5 +817,54 @@ are data in batch. .add_type_rel("BatchMatmul", BatchMatmulRel); +// relay.nn.cross_entropy +bool CrossEntropyRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; + CHECK(x->shape.size() == 2 && y->shape.size() == 2) + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; + CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; + CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; + // assign output type + reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype)); + return true; +} + +// Positional relay function to create batch_matmul operator used by frontend FFI. +Expr MakeCrossEntropy(Expr predictions, Expr targets) { + static const Op& op = Op::Get("nn.cross_entropy"); + return CallNode::make(op, {predictions, targets}, Attrs(), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.cross_entropy") +.set_body_typed(MakeCrossEntropy); + + +RELAY_REGISTER_OP("nn.cross_entropy") +.describe(R"code( +Computes cross entropy given predictions and targets. +Do log on the data - do not accept logits. +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("x", "1D Tensor", "Predictions.") +.add_argument("y", "1D Tensor", "Targets.") +.set_support_level(10) +.add_type_rel("CrossEntropy", CrossEntropyRel); + + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py new file mode 100644 index 0000000..2592d18 --- /dev/null +++ b/tests/python/relay/test_op_grad_level10.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm import relay +from tvm.relay.testing import check_grad + + +def test_cross_entropy_grad(): + x = relay.var("x", shape=(1, 5)) + y = relay.var("y", shape=(1, 5)) + check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1) + + +if __name__ == "__main__": + test_cross_entropy_grad() diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index 3c799b8..f8d6c3a 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -32,14 +32,14 @@ def test_sum_grad(): def test_max_grad(): - s = (5, 10) + s = (10, 10) t = relay.TensorType(s) x = relay.var("x", t) axis = 0 z = relay.max(x, axis) fwd_func = relay.Function([x], z) - check_grad(fwd_func, eps=1e-7, rtol=1) + check_grad(fwd_func, scale=1e-3) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 8c10735..fb5dbcc 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -67,7 +67,7 @@ def test_resize(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) for method in ["bilinear", "nearest_neighbor"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout) -- 2.7.4