{relay,topi}.reinterpret support (#3599)
authorAndrew Tulloch <andrew@tullo.ch>
Tue, 23 Jul 2019 21:43:27 +0000 (14:43 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 23 Jul 2019 21:43:27 +0000 (14:43 -0700)
= Motivation

It's useful to expose the tvm::reinterpret functionality to Relay/TOPI users, as
this allows them to build (fused) operators leveraging the bitwise
reinterpretation of an operator. An example is approximate transcendental
functions, which can be implemented similar to:

```.py
    def C(x):
        return relay.expr.const(x, "float32")

    def approx_exp(x):
        x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
        x = C(127.0) + x * C(1.44269504)
        xf = relay.floor(x)
        i = relay.cast(xf, "int32")
        x = x - xf
        Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523)))
        exponent = relay.left_shift(i, relay.expr.const(23, "int32"))
        exponent = relay.reinterpret(exponent, "float32")
        return exponent * Y

    def approx_sigmoid(x):
        # <2.0e-5 absolute error over [-5, 5]
        y = approx_exp(x)
        return y / (y + C(1.0))

    def approx_tanh(x):
        # <4.0e-5 absolute error over [-5, 5]
        x = x * C(2.0)
        y = approx_exp(x)
        return (y - C(1.0)) / (y + C(1.0))
```

See unit tests for implementations of these approximate transendentals.

12 files changed:
docs/api/python/topi.rst
docs/langref/relay_op.rst
python/tvm/relay/op/_transform.py
python/tvm/relay/op/transform.py
src/codegen/codegen_c.cc
src/relay/op/tensor/transform.cc
tests/python/relay/test_op_level3.py
tests/python/unittest/test_codegen_c_host.py
topi/include/topi/elemwise.h
topi/python/topi/math.py
topi/src/topi.cc
topi/tests/python/test_topi_transform.py

index 9ac8bb1..8f59e08 100644 (file)
@@ -40,6 +40,7 @@ List of operators
    topi.sigmoid
    topi.clip
    topi.cast
+   topi.reinterpret
    topi.transpose
    topi.flip
    topi.strided_slice
@@ -133,6 +134,7 @@ topi
 .. autofunction:: topi.sigmoid
 .. autofunction:: topi.clip
 .. autofunction:: topi.cast
+.. autofunction:: topi.reinterpret
 .. autofunction:: topi.transpose
 .. autofunction:: topi.flip
 .. autofunction:: topi.strided_slice
index dad5eb8..61c9b36 100644 (file)
@@ -114,6 +114,7 @@ This level enables additional math and transform operators.
    tvm.relay.full
    tvm.relay.full_like
    tvm.relay.cast
+   tvm.relay.reinterpret
    tvm.relay.split
    tvm.relay.arange
    tvm.relay.stack
@@ -263,6 +264,7 @@ Level 3 Definitions
 .. autofunction:: tvm.relay.full
 .. autofunction:: tvm.relay.full_like
 .. autofunction:: tvm.relay.cast
+.. autofunction:: tvm.relay.reinterpret
 .. autofunction:: tvm.relay.split
 .. autofunction:: tvm.relay.arange
 .. autofunction:: tvm.relay.stack
index 0749bbd..51e7615 100644 (file)
@@ -40,6 +40,7 @@ _reg.register_schedule("reverse", schedule_injective)
 _reg.register_schedule("repeat", schedule_broadcast)
 _reg.register_schedule("tile", schedule_broadcast)
 _reg.register_schedule("cast", schedule_injective)
+_reg.register_schedule("reinterpret", schedule_injective)
 _reg.register_schedule("strided_slice", schedule_injective)
 _reg.register_schedule("slice_like", schedule_injective)
 _reg.register_schedule("split", schedule_injective)
index 5137a9c..5d8d280 100644 (file)
@@ -40,6 +40,26 @@ def cast(data, dtype):
     return _relay_make.cast(data, dtype)
 
 
+def reinterpret(data, dtype):
+    """Reinterpret input tensor to data type.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    dtype: str
+        The target data type
+
+    Returns
+    -------
+    result : relay.Expr
+        The reinterpreted result.
+    """
+    from .. import _make as _relay_make
+    return _relay_make.reinterpret(data, dtype)
+
+
 def expand_dims(data, axis, num_newaxis=1):
     """Insert `num_newaxis` axises at the position given by `axis`.
 
index 19f7a27..bbd28ba 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -569,6 +569,13 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) {  // NOLINT(*)
     os << "(";
     this->PrintExpr(op->args[0], os);
     os << " == NULL)";
+  } else if (op->is_intrinsic(Call::reinterpret)) {
+    // generate (*( TYPE *)(&(ARG)))
+    os << "(*(";
+    this->PrintType(op->type, os);
+    os << " *)(&(";
+    this->PrintExpr(op->args[0], os);
+    os << ")))";
   } else {
     if (op->call_type == Call::Intrinsic ||
         op->call_type == Call::PureIntrinsic) {
index 5942488..0b501e2 100644 (file)
@@ -97,6 +97,37 @@ RELAY_REGISTER_OP("cast")
 .set_attr<TOpPattern>("TOpPattern", kElemWise)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
 
+Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
+                                 const Type& out_type, const Target& target) {
+  const CastAttrs* param = attrs.as<CastAttrs>();
+  CHECK(param != nullptr);
+  DataType dtype = param->dtype;
+  return {topi::reinterpret(inputs[0], dtype)};
+}
+
+Expr MakeReinterpret(Expr data, DataType dtype) {
+  auto attrs = make_node<CastAttrs>();
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("reinterpret");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) {
+  runtime::detail::unpack_call<Expr, 2>(MakeReinterpret, args, rv);
+});
+
+RELAY_REGISTER_OP("reinterpret")
+    .describe(R"code(Reinterpret the data into a new data type.
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .set_attrs_type_key("relay.attrs.CastAttrs")
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(3)
+    .add_type_rel("Reinterpret", CastRel)
+    .set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
+    .set_attr<TOpPattern>("TOpPattern", kElemWise)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
 // relay.expand_dims
 TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
 
index da3de2b..01c0a12 100644 (file)
@@ -75,6 +75,7 @@ def test_cast():
     assert "dtype=" in yy.astext()
     assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
 
+
 def test_clip():
     a = relay.var("a", relay.TensorType((10, 4), "float32"))
     y = relay.clip(a, 1., 4.)
@@ -88,6 +89,69 @@ def test_clip():
     np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
 
 
+def test_reinterpret():
+    a = relay.var("a", relay.TensorType((1000, 4), "float32"))
+    y = relay.reinterpret(a, "int32")
+    yy = run_infer_type(y)
+    assert yy.checked_type == relay.TensorType((1000, 4), "int32")
+
+    data = np.random.randn(1000, 4).astype('float32') * 1000
+    intrp = create_executor()
+    op_res = intrp.evaluate(y, {a: relay.const(data)})
+    ref_res = data.view("int32")
+    np.testing.assert_equal(op_res.asnumpy(), ref_res)
+
+
+def test_approximate_transcendental():
+    def C(x):
+        return relay.expr.const(x, "float32")
+
+    def approx_exp(x):
+        # An approximation derived from Opus,
+        # https://github.com/xiph/opus/blob/c1c247/celt/mathops.h#L147-L165
+        x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
+        x = C(127.0) + x * C(1.44269504)
+        xf = relay.floor(x)
+        i = relay.cast(xf, "int32")
+        x = x - xf
+        Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523)))
+        exponent = relay.left_shift(i, relay.expr.const(23, "int32"))
+        exponent = relay.reinterpret(exponent, "float32")
+        return exponent * Y
+
+    def approximate_sigmoid(x):
+        y = approx_exp(x)
+        return y / (y + C(1.0))
+
+    def approximate_tanh(x):
+        x = x * C(2.0)
+        y = approx_exp(x)
+        return (y - C(1.0)) / (y + C(1.0))
+
+    a = relay.var("a", relay.TensorType((1000,), "float32"))
+    y = approximate_sigmoid(a)
+    yy = run_infer_type(y)
+    assert yy.checked_type == relay.TensorType((1000,), "float32")
+    data = np.linspace(-5, 5, 1000).astype("float32")
+    intrp = create_executor()
+    op_res = intrp.evaluate(y, {a: relay.const(data)})
+
+    def reference_sigmoid(x):
+        return np.exp(-np.logaddexp(0, -x))
+    np.testing.assert_allclose(op_res.asnumpy(), reference_sigmoid(data), atol=2e-5, rtol=1e-9)
+
+    y = approximate_tanh(a)
+    yy = run_infer_type(y)
+    assert yy.checked_type == relay.TensorType((1000,), "float32")
+    data = np.linspace(-5, 5, 1000).astype("float32")
+    intrp = create_executor()
+    op_res = intrp.evaluate(y, {a: relay.const(data)})
+
+    def reference_tanh(x):
+        return np.tanh(x)
+    np.testing.assert_allclose(op_res.asnumpy(), reference_tanh(data), atol=4e-5, rtol=1e-9)
+
+
 def test_squeeze():
     def verify_squeeze(shape, dtype, axis):
         x = relay.var("x", relay.TensorType(shape, dtype))
index 5161c68..70b38e1 100644 (file)
@@ -95,6 +95,31 @@ def test_add_pipeline():
     with tvm.build_config(offset_factor=4):
         check_c()
 
+
+def test_reinterpret():
+    nn = 1024
+    n = tvm.convert(nn)
+    A = tvm.placeholder((n,), name='A', dtype="int32")
+    B = tvm.compute(A.shape, lambda *i: tvm.call_pure_intrin("float32", "reinterpret", A(*i)), name='B')
+    s = tvm.create_schedule(B.op)
+
+    def check_c():
+        mhost = tvm.build(s, [A, B], "c", name="reinterpret")
+        temp = util.tempdir()
+        path_dso = temp.relpath("temp.so")
+        mhost.export_library(path_dso)
+        m = tvm.module.load(path_dso)
+        fadd = m['reinterpret']
+        ctx = tvm.cpu(0)
+        n = nn
+        a = tvm.nd.array(np.random.randint(-2 ** 30, 2 ** 30, size=n).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
+        fadd(a, b)
+        tvm.testing.assert_allclose(
+            b.asnumpy(), a.asnumpy().view('float32'))
+    check_c()
+
 if __name__ == "__main__":
     test_add()
     test_add_pipeline()
+    test_reinterpret()
index b6e6ada..000567e 100644 (file)
@@ -269,14 +269,34 @@ inline Tensor cast(const Tensor& x,
 }
 
 /*!
-* \brief Creates an operation that sum each element of a tensor
-*
-* \param xs The input tensor array
-* \param name The name of the operation
-* \param tag The tag to mark the operation
-*
-* \return A Tensor whose op member is the sum operation
-*/
+ * \brief Reinterpret each element of x to the given type.
+
+ * \param x The input tensor
+ * \param type The type to cast to
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the reinterpret operation
+ */
+inline Tensor reinterpret(const Tensor& x, Type type, std::string name = "tensor",
+                          std::string tag = kElementWise) {
+  return compute(x->shape,
+                 [&](const Array<Var>& i) {
+                   return tvm::ir::Call::make(type, "reinterpret", {x(i)},
+                                              tvm::ir::Call::PureIntrinsic);
+                 },
+                 name, tag);
+}
+
+/*!
+ * \brief Creates an operation that sum each element of a tensor
+ *
+ * \param xs The input tensor array
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the sum operation
+ */
 inline Tensor elemwise_sum(const Array<Tensor>& xs,
                            std::string name = "T_elemwise_sum",
                            std::string tag = kElementWise) {
index 406d489..87ac06c 100644 (file)
@@ -343,3 +343,21 @@ def cast(x, dtype):
         return tvm.compute(
             x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
     return tvm.make._cast(dtype, x)
+
+def reinterpret(x, dtype):
+    """Reinterpret input to specified data type.
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        Input argument.
+
+    dtype : str
+        Data type.
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return cpp.reinterpret(x, dtype)
index 44134d7..6c5a0b4 100644 (file)
@@ -193,6 +193,12 @@ TVM_REGISTER_GLOBAL("topi.cast")
   *rv = cast(args[0], args[1]);
   });
 
+
+TVM_REGISTER_GLOBAL("topi.reinterpret")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = reinterpret(args[0], args[1]);
+  });
+
 TVM_REGISTER_GLOBAL("topi.elemwise_sum")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = elemwise_sum(args[0]);
index 7f2c73e..f069303 100644 (file)
@@ -45,6 +45,29 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
         check_device(device)
 
 
+def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
+    A = tvm.placeholder(shape=in_shape, name="A", dtype=in_dtype)
+    B = topi.reinterpret(A, out_dtype)
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.generic.schedule_elemwise(B)
+        foo = tvm.build(s, [A, B], device, name="reinterpret")
+        data_npy = generator(in_shape).astype(in_dtype)
+        out_npy = data_npy.view(B.dtype)
+        data_nd = tvm.nd.array(data_npy, ctx)
+        out_nd = tvm.nd.array(np.empty(in_shape).astype(B.dtype), ctx)
+        foo(data_nd, out_nd)
+        np.testing.assert_equal(out_nd.asnumpy(), out_npy)
+
+    for device in get_all_backend():
+        check_device(device)
+
+
 def verify_transpose(in_shape, axes):
     A = tvm.placeholder(shape=in_shape, name="A")
     B = topi.transpose(A, axes)
@@ -434,6 +457,19 @@ def test_expand_dims():
     verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
 
 
+def test_reinterpret():
+    verify_reinterpret((1000,), "float32", "int32",
+                       lambda shape: np.random.randn(*shape) * 1000)
+    verify_reinterpret((1000,), "float16", "int16",
+                       lambda shape: np.random.randn(*shape) * 100)
+    verify_reinterpret((1000,), "int16", "uint16",
+                       lambda shape: np.random.randint(-1000, 1000, size=shape))
+    verify_reinterpret((1000,), "uint32", "int32",
+                       lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
+    verify_reinterpret((1000,), "uint32", "int32",
+                       lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
+
+
 def test_transpose():
     verify_transpose((3, 10, 2), (1, 0, 2))
     verify_transpose((3, 10, 5), (2, 0, 1))