*/
TVM_DLL Pass AlterOpLayout();
+/*!
+ * \brief Canonicalize cast expressions to make operator fusion more efficient.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CanonicalizeCast();
+
} // namespace transform
} // namespace relay
} // namespace tvm
"""
return _transform.PartialEvaluate()
+def CanonicalizeCast():
+ """
+ Canonicalize cast expressions to make operator fusion more efficient.
+
+ Returns
+ -------
+ ret : tvm.relay.Pass
+ The registered pass that canonicalizes cast expression.
+ """
+ return _transform.CanonicalizeCast()
def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
+ pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
// Alter layout transformation is only applied to homogeneous execution yet.
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file canonicalize_cast.cc
+ * \brief Canonicalize cast expressions to make operator fusion more efficient.
+ */
+#include <tvm/relay/pass.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
+#include "pattern_util.h"
+#include "pass_util.h"
+
+namespace tvm {
+namespace relay {
+
+// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a
+// copy of it in each branch such that after fusion the previous function have output with fewer
+// bits.
+//
+// Consider the following example:
+// \code
+// def @main(x: int8) {
+// %1 = cast(%x, f32)
+// %2 = exp(%1)
+// %3 = log(%1)
+// (%3, 4)
+// }
+// \endcode
+//
+// We would like to prevent sharing of the cast expression such that operator fusion can produce
+// more efficient result as below.
+// \code
+// def @main(x: int8) {
+// %1 = fn (%p1: i8) {
+// exp(cast(%p1, f32)
+// }
+// %3 = %1(%x)
+// %2 = fn (%p1: i8) {
+// log(cast(%p1, f32)
+// }
+// %4 = %2(%x)
+// (%3, 4)
+// }
+// \endcode
+class CastCanonicalizer : public ExprMutator {
+ public:
+ Expr VisitExpr_(const CallNode* call) {
+ static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+
+ if (const OpNode* opnode = call->op.as<OpNode>()) {
+ auto pattern = fpattern[GetRef<Op>(opnode)];
+ if (pattern <= kBroadcast) {
+ Array<Expr> call_args = call->args;
+ bool unchanged = true;
+ for (size_t i = 0; i < call_args.size(); ++i) {
+ Expr arg = call_args[i];
+ Expr new_arg = GetNewCallArg(arg);
+ if (!arg.same_as(new_arg)) {
+ call_args.Set(i, new_arg);
+ unchanged = false;
+ }
+ }
+ if (unchanged) {
+ return GetRef<Expr>(call);
+ }
+ return CallNode::make(call->op, call_args, call->attrs, call->type_args);
+ }
+ }
+
+ Expr new_expr = ExprMutator::VisitExpr_(call);
+ return new_expr;
+ }
+
+ private:
+ std::unordered_map<const Node*, size_t> ref_counter_;
+
+ Expr GetNewCallArg(const Expr& e) {
+ // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
+
+ static auto& cast = Op::Get("cast");
+ Expr new_expr = this->VisitExpr(e);
+
+ if (const CallNode* call = e.as<CallNode>()) {
+ if (call->op.same_as(cast)) {
+ auto attrs = call->attrs.as<CastAttrs>();
+ const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
+ CHECK(from_type);
+
+ if (from_type->dtype.bits() < attrs->dtype.bits()) {
+ if (++ref_counter_[call] > 1) {
+ const CallNode* new_call = new_expr.as<CallNode>();
+ CHECK(new_call);
+ CHECK(new_call->op.same_as(cast));
+ return CallNode::make(new_call->op, new_call->args, new_call->attrs,
+ new_call->type_args);
+ }
+ }
+ }
+ }
+ return new_expr;
+ }
+};
+
+Expr CanonicalizeCast(const Expr& e) {
+ return CastCanonicalizer().Mutate(e);
+}
+
+namespace transform {
+
+Pass CanonicalizeCast() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(CanonicalizeCast(f));
+ };
+ return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
+ {ir::StringImm::make("InferType")});
+}
+
+TVM_REGISTER_API("relay._transform.CanonicalizeCast")
+.set_body_typed(CanonicalizeCast);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.relay as relay
+import tvm.relay.module as _module
+import tvm.relay.transform as _transform
+
+
+def test_canonicalize_cast():
+ def before(data, conv_weight, bias1, bias2):
+ x = relay.nn.conv2d(data, conv_weight,
+ channels=16,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ out_dtype="int8")
+ x1 = relay.cast(x, dtype="int32")
+ y1 = relay.add(x1, bias1)
+ y2 = relay.add(x1, bias2)
+ y = relay.add(y1, y2)
+ return relay.Function([data, conv_weight, bias1, bias2], y)
+
+ def expected(data, conv_weight, bias1, bias2):
+ x = relay.nn.conv2d(data, conv_weight,
+ channels=16,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ out_dtype="int8")
+ x1 = relay.cast(x, dtype="int32")
+ x2 = relay.cast(x, dtype="int32")
+ y1 = relay.add(x1, bias1)
+ y2 = relay.add(x2, bias2)
+ y = relay.add(y1, y2)
+ return relay.Function([data, conv_weight, bias1, bias2], y)
+
+ def check(shape):
+ data = relay.var("data", shape=shape, dtype="int8")
+ conv_weight = relay.var("weight")
+ bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32")
+ bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
+ y = before(data, conv_weight, bias1, bias2)
+ mod = _module.Module.from_expr(y)
+ seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
+ _transform.InferType()])
+ with _transform.PassContext(opt_level=3):
+ mod = seq(mod)
+ y = mod[mod.entry_func.name_hint]
+ y_expected = expected(data, conv_weight, bias1, bias2)
+ y_expected = relay.ir_pass.infer_type(y_expected)
+ assert relay.ir_pass.alpha_equal(y, y_expected)
+
+ check((1, 16, 7, 7))
+
+
+if __name__ == '__main__':
+ test_canonicalize_cast()