[QNN][Relay] Calling Dialect passes from inside Relay Build API. (#3971)
authorAnimesh Jain <anijain@umich.edu>
Wed, 2 Oct 2019 22:39:54 +0000 (15:39 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Wed, 2 Oct 2019 22:39:54 +0000 (15:39 -0700)
include/tvm/relay/op.h
include/tvm/relay/qnn/transform.h [new file with mode: 0644]
src/relay/backend/build_module.cc
src/relay/ir/op.cc
src/relay/pass/legalize.cc
src/relay/qnn/pass/legalize.cc [new file with mode: 0644]
tests/python/relay/test_op_qnn_conv2d.py
tests/python/relay/test_op_qnn_dequantize.py
tests/python/relay/test_op_qnn_quantize.py
tests/python/relay/test_op_qnn_requantize.py

index e4c9649..0a6d372 100644 (file)
@@ -154,6 +154,12 @@ class Op : public relay::Expr {
   template <typename ValueType>
   inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
   /*!
+   * \brief Checks if an attr is present in the registry.
+   * \param attr_name The name of the attribute.
+   * \return bool True if the attr is present.
+   */
+  inline static bool HasAttr(const std::string& attr_name);
+  /*!
    * \brief Get an Op for a given operator name.
    *  Will raise an error if the op has not been registered.
    * \param op_name Name of the operator.
@@ -171,6 +177,12 @@ class Op : public relay::Expr {
    * \return reference to GenericOpMap
    */
   TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
+  /*!
+   * \brief Checks if the key is present in the registry
+   * \param key The attribute key
+   * \return bool True if the key is present
+   */
+  TVM_DLL static const bool HasGenericAttr(const std::string& key);
 };
 
 /*! \brief Helper structure to register operators */
@@ -393,6 +405,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
   return OpMap<ValueType>(Op::GetGenericAttr(key));
 }
 
+inline bool Op::HasAttr(const std::string& key) {
+  return Op::HasGenericAttr(key);
+}
+
 inline OpNode* OpRegistry::get() {
   return const_cast<OpNode*>(op_.operator->());
 }
diff --git a/include/tvm/relay/qnn/transform.h b/include/tvm/relay/qnn/transform.h
new file mode 100644 (file)
index 0000000..10cd19a
--- /dev/null
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/qnn/transform.h
+ *
+ * This file implements a pass manager for QNN ops using Relay Pass manager.
+ */
+#ifndef TVM_RELAY_QNN_TRANSFORM_H_
+#define TVM_RELAY_QNN_TRANSFORM_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+using relay::transform::Pass;
+
+namespace qnn {
+namespace transform {
+
+/*!
+ * \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
+ * converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
+ * Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
+ * One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
+ * attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
+ * only QNN ops. One can register a transformation/legalization function for an op by using the
+ * FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
+ * gives us separation of concerns, leading to a better software practice. The legalization can be
+ * configured to happen per target.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass Legalize();
+
+}  // namespace transform
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_QNN_TRANSFORM_H_
index ef3ab72..20e760f 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/runtime/vm.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/transform.h>
+#include <tvm/relay/qnn/transform.h>
 #include <memory>
 
 #include "utils.h"
@@ -286,6 +287,15 @@ class RelayBuildModule : public runtime::ModuleNode {
       const TargetsMap& targets,
       const std::unordered_map<std::string, runtime::NDArray>& params) {
     Array<Pass> pass_seqs;
+
+    // Run all dialect legalization passes.
+    pass_seqs.push_back(relay::qnn::transform::Legalize());
+
+    // Legalize pass is restricted to homogeneous execution for now.
+    if (targets.size() == 1) {
+      pass_seqs.push_back(transform::Legalize());
+    }
+
     pass_seqs.push_back(transform::SimplifyInference());
     PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
       Expr expr = args[0];
@@ -309,11 +319,6 @@ class RelayBuildModule : public runtime::ModuleNode {
     pass_seqs.push_back(transform::CanonicalizeCast());
     pass_seqs.push_back(transform::CanonicalizeOps());
 
-    // Legalize pass is restricted to homogeneous execution for now.
-    if (targets.size() == 1) {
-      pass_seqs.push_back(transform::Legalize());
-    }
-
     // Alter layout transformation is only applied to homogeneous execution yet.
     if (targets.size() == 1) {
       pass_seqs.push_back(transform::AlterOpLayout());
index 76b56ae..d098863 100644 (file)
@@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
   return *it->second.get();
 }
 
+// Check if a key is present in the registry.
+const bool Op::HasGenericAttr(const std::string& key) {
+  OpManager* mgr = OpManager::Global();
+  std::lock_guard<std::mutex> lock(mgr->mutex);
+  auto it = mgr->attr.find(key);
+  if (it == mgr->attr.end()) {
+    return false;
+  }
+  return true;
+}
+
 void OpRegistry::UpdateAttr(const std::string& key,
                             TVMRetValue value,
                             int plevel) {
index 07b1d81..f57d910 100644 (file)
@@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
     Expr new_e = ExprMutator::VisitExpr_(call_node);
     Call new_call = Downcast<Call>(new_e);
 
+    // Check if the string is registered in the OpRegistry.
+    if (!Op::HasAttr(legalize_map_attr_name_)) {
+      return new_e;
+    }
+
     // Collect the registered legalize function.
     auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
-    Op op = Downcast<Op>(call_node->op);
-
-    if (fop_legalize.count(op)) {
-      // Collect the new_args.
-      tvm::Array<Expr> call_args = new_call->args;
-
-      // Collect input and output dtypes to pass on to Legalize API.
-      tvm::Array<tvm::relay::Type> types;
-      for (auto arg : call_node->args) {
-        types.push_back(arg->checked_type());
-      }
-      types.push_back(call_node->checked_type());
-
-      // Transform the op by calling the registered legalize function.
-      Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
-
-      // Reassign new_e if the transformation succeeded.
-      if (legalized_value.defined()) {
-        // Check that the returned Expr from legalize is CallNode.
-        const CallNode* legalized_call_node = legalized_value.as<CallNode>();
-        CHECK(legalized_call_node)
-            << "Can only replace the original operator with another call node";
-
-        new_e = legalized_value;
+    auto call_op = call_node->op;
+    if (call_op.as<OpNode>()) {
+      Op op = Downcast<Op>(call_node->op);
+
+      if (fop_legalize.count(op)) {
+        // Collect the new_args.
+        tvm::Array<Expr> call_args = new_call->args;
+
+        // Collect input and output dtypes to pass on to Legalize API.
+        tvm::Array<tvm::relay::Type> types;
+        for (auto arg : call_node->args) {
+          types.push_back(arg->checked_type());
+        }
+        types.push_back(call_node->checked_type());
+
+        // Transform the op by calling the registered legalize function.
+        Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
+
+        // Reassign new_e if the transformation succeeded.
+        if (legalized_value.defined()) {
+          // Check that the returned Expr from legalize is CallNode.
+          const CallNode* legalized_call_node = legalized_value.as<CallNode>();
+          CHECK(legalized_call_node)
+              << "Can only replace the original operator with another call node";
+
+          new_e = legalized_value;
+        }
       }
     }
 
@@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
       [=](Function f, Module m, PassContext pc) {
         return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
       };
-  return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
+  return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")});
 }
 
 TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
diff --git a/src/relay/qnn/pass/legalize.cc b/src/relay/qnn/pass/legalize.cc
new file mode 100644 (file)
index 0000000..07864ad
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/qnn/pass/legalize.cc
+ * \brief The Legalize wrapper for QNN.
+ */
+
+#include <tvm/relay/qnn/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+namespace transform {
+
+Pass Legalize() {
+  Array<Pass> pass_seqs;
+  pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
+  pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
+  relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
+  return seq;
+}
+
+TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize);
+
+}  // namespace transform
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm
index dd4ad8d..c8e479d 100644 (file)
@@ -77,7 +77,6 @@ def get_qnn_func(data,
 
     mod = relay.Function(relay.analysis.free_vars(func), func)
     mod = relay.Module.from_expr(mod)
-    mod = relay.qnn.transform.CanonicalizeOps()(mod)
     return mod
 
 def get_funcs(data_shape,
index 76b61ae..5125865 100644 (file)
@@ -31,7 +31,6 @@ def test_dequantize_op():
                                                    input_zero_point=input_zero_point)
         mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
         mod = relay.Module.from_expr(mod)
-        mod = relay.qnn.transform.CanonicalizeOps()(mod)
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build(mod, "llvm", params=None)
             rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
index 15319a7..9805db5 100644 (file)
@@ -31,7 +31,6 @@ def test_quantize_op():
                                                  output_zero_point=output_zero_point,out_dtype=out_dtype)
         mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
         mod = relay.Module.from_expr(mod)
-        mod = relay.qnn.transform.CanonicalizeOps()(mod)
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build(mod, "llvm", params=None)
             rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
index 131500c..18e2f30 100644 (file)
@@ -49,7 +49,6 @@ def test_requantize():
 
         mod = relay.Function(relay.analysis.free_vars(mod), mod)
         mod = relay.Module.from_expr(mod)
-        mod = relay.qnn.transform.CanonicalizeOps()(mod)
         return mod
 
     def same_scale_test():