[Relay][Training] Add gradient for cast (#3894)
author雾雨魔理沙 <lolisa@marisa.moe>
Sun, 8 Sep 2019 03:11:47 +0000 (20:11 -0700)
committerWuwei Lin <wuwei@apache.org>
Sun, 8 Sep 2019 03:11:47 +0000 (23:11 -0400)
save

fix

fix grad

python/tvm/relay/op/_tensor_grad.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/relay/test_op_grad_level3.py

index 08624e1..0cd2efb 100644 (file)
@@ -29,6 +29,7 @@ from .tensor import cos, exp, less, negative, ones_like, power, sin, zeros_like
 from .transform import (
     broadcast_to_like,
     collapse_sum_like,
+    cast_like,
     reshape,
     reshape_like,
     strided_slice,
@@ -296,6 +297,12 @@ def reshape_grad(orig, grad):
     return [reshape_like(grad, orig.args[0])]
 
 
+@register_gradient("cast")
+def cast_grad(orig, grad):
+    x = orig.args[0]
+    return [cast_like(grad, x)]
+
+
 @register_gradient("nn.batch_flatten")
 def batch_flatten_grad(orig, grad):
     """Returns grad reshaped to data dims"""
index 7f29e85..5dddfc6 100644 (file)
@@ -43,6 +43,7 @@ _reg.register_schedule("reverse", schedule_injective)
 _reg.register_schedule("repeat", schedule_broadcast)
 _reg.register_schedule("tile", schedule_broadcast)
 _reg.register_schedule("cast", schedule_injective)
+_reg.register_schedule("cast_like", schedule_injective)
 _reg.register_schedule("reinterpret", schedule_injective)
 _reg.register_schedule("strided_slice", schedule_injective)
 _reg.register_schedule("slice_like", schedule_injective)
index 38ce653..7f921d0 100644 (file)
@@ -40,6 +40,23 @@ def cast(data, dtype):
     return _relay_make.cast(data, dtype)
 
 
+def cast_like(data, dtype_like):
+    """Cast input tensor to data type of another tensor.
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+    dtype_like: relay.Expr
+        The tensor to cast to.
+    Returns
+    -------
+    result : relay.Expr
+        The casted result.
+    """
+    from .. import _make as _relay_make
+    return _relay_make.cast_like(data, dtype_like)
+
+
 def reinterpret(data, dtype):
     """Reinterpret input tensor to data type.
 
index ed09516..459c27a 100644 (file)
@@ -98,6 +98,63 @@ RELAY_REGISTER_OP("cast")
 .set_attr<TOpPattern>("TOpPattern", kElemWise)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
 
+
+// relay.cast_like
+bool CastLikeRel(const Array<Type>& types,
+                 int num_inputs,
+                 const Attrs& attrs,
+                 const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "cast: expect input type to be TensorType but get "
+        << types[0];
+    return false;
+  }
+  const auto* dtype_like = types[1].as<TensorTypeNode>();
+  if (dtype_like == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "cast: expect input type to be TensorType but get "
+        << types[1];
+    return false;
+  }
+  reporter->Assign(types[2], TensorTypeNode::make(data->shape, dtype_like->dtype));
+  return true;
+}
+
+
+Array<Tensor> CastLikeCompute(const Attrs& attrs,
+                              const Array<Tensor>& inputs,
+                              const Type& out_type,
+                              const Target& target) {
+  return { topi::cast(inputs[0], inputs[1]->dtype) };
+}
+
+
+Expr MakeCastLike(Expr data,
+                  Expr dtype_like) {
+  static const Op& op = Op::Get("cast_like");
+  return CallNode::make(op, {data, dtype_like}, Attrs(), {});
+}
+
+
+TVM_REGISTER_API("relay._make.cast_like")
+.set_body_typed(MakeCastLike);
+
+RELAY_REGISTER_OP("cast_like")
+.describe(R"code(Cast the data into the type of another tensor.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("dtype_like", "Tensor", "The tensor to cast to.")
+.set_support_level(3)
+.add_type_rel("CastLike", CastLikeRel)
+.set_attr<FTVMCompute>("FTVMCompute", CastLikeCompute)
+.set_attr<TOpPattern>("TOpPattern", kElemWise)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
+
 Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
                                  const Type& out_type, const Target& target) {
   const CastAttrs* param = attrs.as<CastAttrs>();
index cc57361..430c3dd 100644 (file)
@@ -58,5 +58,10 @@ def test_negative_grad():
     check_grad(fwd_func)
 
 
+def test_cast_grad():
+    data = relay.var("data", relay.TensorType((10, 4), "float32"))
+    fwd_func = relay.Function([data], relay.cast(data, "float64"))
+    check_grad(fwd_func)
+
 if __name__ == "__main__":
     pytest.main()