[Relay][Pass] CanonicalizeCast (#3280)
authorWuwei Lin <vincentl13x@gmail.com>
Mon, 17 Jun 2019 16:56:10 +0000 (00:56 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 17 Jun 2019 16:56:10 +0000 (09:56 -0700)
include/tvm/relay/transform.h
python/tvm/relay/transform.py
src/relay/backend/build_module.cc
src/relay/pass/canonicalize_cast.cc [new file with mode: 0644]
tests/python/relay/test_pass_canonicalize_cast.py [new file with mode: 0644]

index fb8ebbf..04b4e64 100644 (file)
@@ -534,6 +534,13 @@ TVM_DLL Pass CanonicalizeOps();
  */
 TVM_DLL Pass AlterOpLayout();
 
+/*!
+ * \brief Canonicalize cast expressions to make operator fusion more efficient.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CanonicalizeCast();
+
 }  // namespace transform
 }  // namespace relay
 }  // namespace tvm
index d7a7c26..3fae615 100644 (file)
@@ -445,6 +445,16 @@ def PartialEvaluate():
     """
     return _transform.PartialEvaluate()
 
+def CanonicalizeCast():
+    """
+    Canonicalize cast expressions to make operator fusion more efficient.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass that canonicalizes cast expression.
+    """
+    return _transform.CanonicalizeCast()
 
 def _wrap_class_module_pass(pass_cls, pass_info):
     """Wrap a python class as function pass"""
index e0014e9..3feb7e4 100644 (file)
@@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     pass_seqs.push_back(transform::CombineParallelConv2D(3));
     pass_seqs.push_back(transform::FoldConstant());
     pass_seqs.push_back(transform::FoldScaleAxis());
+    pass_seqs.push_back(transform::CanonicalizeCast());
     pass_seqs.push_back(transform::CanonicalizeOps());
 
     // Alter layout transformation is only applied to homogeneous execution yet.
diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc
new file mode 100644 (file)
index 0000000..99f4a7f
--- /dev/null
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file canonicalize_cast.cc
+ * \brief Canonicalize cast expressions to make operator fusion more efficient.
+ */
+#include <tvm/relay/pass.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
+#include "pattern_util.h"
+#include "pass_util.h"
+
+namespace tvm {
+namespace relay {
+
+// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a
+// copy of it in each branch such that after fusion the previous function have output with fewer
+// bits.
+//
+// Consider the following example:
+// \code
+// def @main(x: int8) {
+//   %1 = cast(%x, f32)
+//   %2 = exp(%1)
+//   %3 = log(%1)
+//   (%3, 4)
+// }
+// \endcode
+//
+// We would like to prevent sharing of the cast expression such that operator fusion can produce
+// more efficient result as below.
+// \code
+// def @main(x: int8) {
+//   %1 = fn (%p1: i8) {
+//     exp(cast(%p1, f32)
+//   }
+//   %3 = %1(%x)
+//   %2 = fn (%p1: i8) {
+//     log(cast(%p1, f32)
+//   }
+//   %4 = %2(%x)
+//   (%3, 4)
+// }
+// \endcode
+class CastCanonicalizer : public ExprMutator {
+ public:
+  Expr VisitExpr_(const CallNode* call) {
+    static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+
+    if (const OpNode* opnode = call->op.as<OpNode>()) {
+      auto pattern = fpattern[GetRef<Op>(opnode)];
+      if (pattern <= kBroadcast) {
+        Array<Expr> call_args = call->args;
+        bool unchanged = true;
+        for (size_t i = 0; i < call_args.size(); ++i) {
+          Expr arg = call_args[i];
+          Expr new_arg = GetNewCallArg(arg);
+          if (!arg.same_as(new_arg)) {
+            call_args.Set(i, new_arg);
+            unchanged = false;
+          }
+        }
+        if (unchanged) {
+          return GetRef<Expr>(call);
+        }
+        return CallNode::make(call->op, call_args, call->attrs, call->type_args);
+      }
+    }
+
+    Expr new_expr = ExprMutator::VisitExpr_(call);
+    return new_expr;
+  }
+
+ private:
+  std::unordered_map<const Node*, size_t> ref_counter_;
+
+  Expr GetNewCallArg(const Expr& e) {
+    // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
+
+    static auto& cast = Op::Get("cast");
+    Expr new_expr = this->VisitExpr(e);
+
+    if (const CallNode* call = e.as<CallNode>()) {
+      if (call->op.same_as(cast)) {
+        auto attrs = call->attrs.as<CastAttrs>();
+        const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
+        CHECK(from_type);
+
+        if (from_type->dtype.bits() < attrs->dtype.bits()) {
+          if (++ref_counter_[call] > 1) {
+            const CallNode* new_call = new_expr.as<CallNode>();
+            CHECK(new_call);
+            CHECK(new_call->op.same_as(cast));
+            return CallNode::make(new_call->op, new_call->args, new_call->attrs,
+                 new_call->type_args);
+          }
+        }
+      }
+    }
+    return new_expr;
+  }
+};
+
+Expr CanonicalizeCast(const Expr& e) {
+  return CastCanonicalizer().Mutate(e);
+}
+
+namespace transform {
+
+Pass CanonicalizeCast() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(CanonicalizeCast(f));
+  };
+  return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
+                            {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.CanonicalizeCast")
+.set_body_typed(CanonicalizeCast);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py
new file mode 100644 (file)
index 0000000..04478e9
--- /dev/null
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.relay as relay
+import tvm.relay.module as _module
+import tvm.relay.transform as _transform
+
+
+def test_canonicalize_cast():
+    def before(data, conv_weight, bias1, bias2):
+        x = relay.nn.conv2d(data, conv_weight,
+                          channels=16,
+                          kernel_size=(3, 3),
+                          padding=(1, 1),
+                          out_dtype="int8")
+        x1 = relay.cast(x, dtype="int32")
+        y1 = relay.add(x1, bias1)
+        y2 = relay.add(x1, bias2)
+        y = relay.add(y1, y2)
+        return relay.Function([data, conv_weight, bias1, bias2], y)
+
+    def expected(data, conv_weight, bias1, bias2):
+        x = relay.nn.conv2d(data, conv_weight,
+                          channels=16,
+                          kernel_size=(3, 3),
+                          padding=(1, 1),
+                          out_dtype="int8")
+        x1 = relay.cast(x, dtype="int32")
+        x2 = relay.cast(x, dtype="int32")
+        y1 = relay.add(x1, bias1)
+        y2 = relay.add(x2, bias2)
+        y = relay.add(y1, y2)
+        return relay.Function([data, conv_weight, bias1, bias2], y)
+
+    def check(shape):
+        data = relay.var("data", shape=shape, dtype="int8")
+        conv_weight = relay.var("weight")
+        bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32")
+        bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
+        y = before(data, conv_weight, bias1, bias2)
+        mod = _module.Module.from_expr(y)
+        seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
+                                     _transform.InferType()])
+        with _transform.PassContext(opt_level=3):
+            mod = seq(mod)
+        y = mod[mod.entry_func.name_hint]
+        y_expected = expected(data, conv_weight, bias1, bias2)
+        y_expected = relay.ir_pass.infer_type(y_expected)
+        assert relay.ir_pass.alpha_equal(y, y_expected)
+
+    check((1, 16, 7, 7))
+
+
+if __name__ == '__main__':
+    test_canonicalize_cast()