[Relay] crossentropy_with_logits and its gradient (#4075)
author雾雨魔理沙 <lolisa@marisa.moe>
Fri, 25 Oct 2019 14:54:29 +0000 (07:54 -0700)
committerThierry Moreau <moreau@uw.edu>
Fri, 25 Oct 2019 14:54:29 +0000 (17:54 +0300)
* save

* lint

python/tvm/relay/op/_reduce.py
python/tvm/relay/op/_tensor_grad.py
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/nn/nn.py
src/relay/op/nn/nn.cc
tests/python/relay/test_op_grad_level10.py

index f6b699f..845ec4b 100644 (file)
@@ -37,3 +37,4 @@ _reg.register_schedule("prod", _schedule_reduce)
 _reg.register_schedule("mean", _schedule_reduce)
 _reg.register_schedule("variance", _schedule_reduce)
 _reg.register_schedule("nn.cross_entropy", _schedule_reduce)
+_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
index 1c94162..d55cad7 100644 (file)
@@ -449,3 +449,12 @@ def cross_entropy_grad(orig, grad):
     batch_size = take(shape, const(0, dtype='int32'), axis=0)
     grad = grad / batch_size.astype('float32')
     return [-grad * y / x, -grad * log(x)]
+
+
+@register_gradient("nn.cross_entropy_with_logits")
+def cross_entropy_with_logits_grad(orig, grad):
+    x, y = orig.args
+    shape = shape_of(x)
+    batch_size = take(shape, const(0, dtype='int32'), axis=0)
+    grad = grad / batch_size.astype('float32')
+    return [-grad * y, -grad * x]
index 0043ffa..5786c22 100644 (file)
@@ -770,3 +770,12 @@ reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
 def compute_cross_entropy(attrs, inputs, out_dtype, target):
     x, y = inputs
     return [-topi.sum(topi.log(x) * y) / x.shape[0]]
+
+
+reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
+
+
+@reg.register_compute("nn.cross_entropy_with_logits")
+def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
+    x, y = inputs
+    return [-topi.sum(x * y) / x.shape[0]]
index 9ddb3ec..1f289d1 100644 (file)
@@ -1807,3 +1807,22 @@ def cross_entropy(predictions, targets):
       The computed result.
     """
     return _make.cross_entropy(predictions, targets)
+
+
+def cross_entropy_with_logits(predictions, targets):
+    """CrossEntropy with logits.
+
+    Parameters
+    ----------
+    predictions : tvm.relay.Expr
+      The predictions.
+
+    targets : tvm.relay.Expr
+      The targets.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+      The computed result.
+    """
+    return _make.cross_entropy_with_logits(predictions, targets)
index dd1b4e5..416a0d7 100644 (file)
@@ -910,7 +910,7 @@ bool CrossEntropyRel(const Array<Type>& types,
   return true;
 }
 
-// Positional relay function to create batch_matmul operator used by frontend FFI.
+// Positional relay function to create cross_entropy operator used by frontend FFI.
 Expr MakeCrossEntropy(Expr predictions, Expr targets) {
   static const Op& op = Op::Get("nn.cross_entropy");
   return CallNode::make(op, {predictions, targets}, Attrs(), {});
@@ -933,5 +933,28 @@ Do log on the data - do not accept logits.
 .add_type_rel("CrossEntropy", CrossEntropyRel);
 
 
+// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
+Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
+  static const Op& op = Op::Get("nn.cross_entropy_with_logits");
+  return CallNode::make(op, {predictions, targets}, Attrs(), {});
+}
+
+
+TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits")
+.set_body_typed(MakeCrossEntropyWithLogits);
+
+
+RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
+.describe(R"code(
+Computes cross entropy given predictions and targets.
+Accept logits.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.add_argument("x", "1D Tensor", "Predictions.")
+.add_argument("y", "1D Tensor", "Targets.")
+.set_support_level(10)
+.add_type_rel("CrossEntropy", CrossEntropyRel);
+
+
 }  // namespace relay
 }  // namespace tvm
index 2592d18..7aa9e0b 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
+
 from tvm import relay
 from tvm.relay.testing import check_grad
 
 
 def test_cross_entropy_grad():
-    x = relay.var("x", shape=(1, 5))
-    y = relay.var("y", shape=(1, 5))
+    x = relay.var("x", shape=(2, 5))
+    y = relay.var("y", shape=(2, 5))
     check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
 
 
+def test_cross_entropy_with_logits_grad():
+    x = relay.var("x", shape=(2, 5))
+    y = relay.var("y", shape=(2, 5))
+    check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
+
+
 if __name__ == "__main__":
-    test_cross_entropy_grad()
+    pytest.main([__file__])