From 04e816241fdc8c38674277d3fd6dbb086f72fc3e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 18 Jun 2019 00:56:10 +0800 Subject: [PATCH] [Relay][Pass] CanonicalizeCast (#3280) --- include/tvm/relay/transform.h | 7 ++ python/tvm/relay/transform.py | 10 ++ src/relay/backend/build_module.cc | 1 + src/relay/pass/canonicalize_cast.cc | 144 ++++++++++++++++++++++ tests/python/relay/test_pass_canonicalize_cast.py | 70 +++++++++++ 5 files changed, 232 insertions(+) create mode 100644 src/relay/pass/canonicalize_cast.cc create mode 100644 tests/python/relay/test_pass_canonicalize_cast.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index fb8ebbf..04b4e64 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -534,6 +534,13 @@ TVM_DLL Pass CanonicalizeOps(); */ TVM_DLL Pass AlterOpLayout(); +/*! + * \brief Canonicalize cast expressions to make operator fusion more efficient. + * + * \return The pass. + */ +TVM_DLL Pass CanonicalizeCast(); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index d7a7c26..3fae615 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -445,6 +445,16 @@ def PartialEvaluate(): """ return _transform.PartialEvaluate() +def CanonicalizeCast(): + """ + Canonicalize cast expressions to make operator fusion more efficient. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that canonicalizes cast expression. + """ + return _transform.CanonicalizeCast() def _wrap_class_module_pass(pass_cls, pass_info): """Wrap a python class as function pass""" diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index e0014e9..3feb7e4 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); // Alter layout transformation is only applied to homogeneous execution yet. diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc new file mode 100644 index 0000000..99f4a7f --- /dev/null +++ b/src/relay/pass/canonicalize_cast.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file canonicalize_cast.cc + * \brief Canonicalize cast expressions to make operator fusion more efficient. + */ +#include +#include +#include +#include +#include "pattern_util.h" +#include "pass_util.h" + +namespace tvm { +namespace relay { + +// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a +// copy of it in each branch such that after fusion the previous function have output with fewer +// bits. +// +// Consider the following example: +// \code +// def @main(x: int8) { +// %1 = cast(%x, f32) +// %2 = exp(%1) +// %3 = log(%1) +// (%3, 4) +// } +// \endcode +// +// We would like to prevent sharing of the cast expression such that operator fusion can produce +// more efficient result as below. +// \code +// def @main(x: int8) { +// %1 = fn (%p1: i8) { +// exp(cast(%p1, f32) +// } +// %3 = %1(%x) +// %2 = fn (%p1: i8) { +// log(cast(%p1, f32) +// } +// %4 = %2(%x) +// (%3, 4) +// } +// \endcode +class CastCanonicalizer : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* call) { + static auto fpattern = Op::GetAttr("TOpPattern"); + + if (const OpNode* opnode = call->op.as()) { + auto pattern = fpattern[GetRef(opnode)]; + if (pattern <= kBroadcast) { + Array call_args = call->args; + bool unchanged = true; + for (size_t i = 0; i < call_args.size(); ++i) { + Expr arg = call_args[i]; + Expr new_arg = GetNewCallArg(arg); + if (!arg.same_as(new_arg)) { + call_args.Set(i, new_arg); + unchanged = false; + } + } + if (unchanged) { + return GetRef(call); + } + return CallNode::make(call->op, call_args, call->attrs, call->type_args); + } + } + + Expr new_expr = ExprMutator::VisitExpr_(call); + return new_expr; + } + + private: + std::unordered_map ref_counter_; + + Expr GetNewCallArg(const Expr& e) { + // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor + + static auto& cast = Op::Get("cast"); + Expr new_expr = this->VisitExpr(e); + + if (const CallNode* call = e.as()) { + if (call->op.same_as(cast)) { + auto attrs = call->attrs.as(); + const auto* from_type = call->args[0]->type_as(); + CHECK(from_type); + + if (from_type->dtype.bits() < attrs->dtype.bits()) { + if (++ref_counter_[call] > 1) { + const CallNode* new_call = new_expr.as(); + CHECK(new_call); + CHECK(new_call->op.same_as(cast)); + return CallNode::make(new_call->op, new_call->args, new_call->attrs, + new_call->type_args); + } + } + } + } + return new_expr; + } +}; + +Expr CanonicalizeCast(const Expr& e) { + return CastCanonicalizer().Mutate(e); +} + +namespace transform { + +Pass CanonicalizeCast() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CanonicalizeCast(f)); + }; + return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CanonicalizeCast") +.set_body_typed(CanonicalizeCast); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py new file mode 100644 index 0000000..04478e9 --- /dev/null +++ b/tests/python/relay/test_pass_canonicalize_cast.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.relay as relay +import tvm.relay.module as _module +import tvm.relay.transform as _transform + + +def test_canonicalize_cast(): + def before(data, conv_weight, bias1, bias2): + x = relay.nn.conv2d(data, conv_weight, + channels=16, + kernel_size=(3, 3), + padding=(1, 1), + out_dtype="int8") + x1 = relay.cast(x, dtype="int32") + y1 = relay.add(x1, bias1) + y2 = relay.add(x1, bias2) + y = relay.add(y1, y2) + return relay.Function([data, conv_weight, bias1, bias2], y) + + def expected(data, conv_weight, bias1, bias2): + x = relay.nn.conv2d(data, conv_weight, + channels=16, + kernel_size=(3, 3), + padding=(1, 1), + out_dtype="int8") + x1 = relay.cast(x, dtype="int32") + x2 = relay.cast(x, dtype="int32") + y1 = relay.add(x1, bias1) + y2 = relay.add(x2, bias2) + y = relay.add(y1, y2) + return relay.Function([data, conv_weight, bias1, bias2], y) + + def check(shape): + data = relay.var("data", shape=shape, dtype="int8") + conv_weight = relay.var("weight") + bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32") + bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") + y = before(data, conv_weight, bias1, bias2) + mod = _module.Module.from_expr(y) + seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(), + _transform.InferType()]) + with _transform.PassContext(opt_level=3): + mod = seq(mod) + y = mod[mod.entry_func.name_hint] + y_expected = expected(data, conv_weight, bias1, bias2) + y_expected = relay.ir_pass.infer_type(y_expected) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check((1, 16, 7, 7)) + + +if __name__ == '__main__': + test_canonicalize_cast() -- 2.7.4