_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
+_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)
batch_size = take(shape, const(0, dtype='int32'), axis=0)
grad = grad / batch_size.astype('float32')
return [-grad * y / x, -grad * log(x)]
+
+
+@register_gradient("nn.cross_entropy_with_logits")
+def cross_entropy_with_logits_grad(orig, grad):
+ x, y = orig.args
+ shape = shape_of(x)
+ batch_size = take(shape, const(0, dtype='int32'), axis=0)
+ grad = grad / batch_size.astype('float32')
+ return [-grad * y, -grad * x]
def compute_cross_entropy(attrs, inputs, out_dtype, target):
x, y = inputs
return [-topi.sum(topi.log(x) * y) / x.shape[0]]
+
+
+reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
+
+
+@reg.register_compute("nn.cross_entropy_with_logits")
+def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
+ x, y = inputs
+ return [-topi.sum(x * y) / x.shape[0]]
The computed result.
"""
return _make.cross_entropy(predictions, targets)
+
+
+def cross_entropy_with_logits(predictions, targets):
+ """CrossEntropy with logits.
+
+ Parameters
+ ----------
+ predictions : tvm.relay.Expr
+ The predictions.
+
+ targets : tvm.relay.Expr
+ The targets.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ return _make.cross_entropy_with_logits(predictions, targets)
return true;
}
-// Positional relay function to create batch_matmul operator used by frontend FFI.
+// Positional relay function to create cross_entropy operator used by frontend FFI.
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
.add_type_rel("CrossEntropy", CrossEntropyRel);
+// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
+Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
+ static const Op& op = Op::Get("nn.cross_entropy_with_logits");
+ return CallNode::make(op, {predictions, targets}, Attrs(), {});
+}
+
+
+TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits")
+.set_body_typed(MakeCrossEntropyWithLogits);
+
+
+RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
+.describe(R"code(
+Computes cross entropy given predictions and targets.
+Accept logits.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.add_argument("x", "1D Tensor", "Predictions.")
+.add_argument("y", "1D Tensor", "Targets.")
+.set_support_level(10)
+.add_type_rel("CrossEntropy", CrossEntropyRel);
+
+
} // namespace relay
} // namespace tvm
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pytest
+
from tvm import relay
from tvm.relay.testing import check_grad
def test_cross_entropy_grad():
- x = relay.var("x", shape=(1, 5))
- y = relay.var("y", shape=(1, 5))
+ x = relay.var("x", shape=(2, 5))
+ y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
+def test_cross_entropy_with_logits_grad():
+ x = relay.var("x", shape=(2, 5))
+ y = relay.var("y", shape=(2, 5))
+ check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
+
+
if __name__ == "__main__":
- test_cross_entropy_grad()
+ pytest.main([__file__])