template <typename ValueType>
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
/*!
+ * \brief Checks if an attr is present in the registry.
+ * \param attr_name The name of the attribute.
+ * \return bool True if the attr is present.
+ */
+ inline static bool HasAttr(const std::string& attr_name);
+ /*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* \param op_name Name of the operator.
* \return reference to GenericOpMap
*/
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
+ /*!
+ * \brief Checks if the key is present in the registry
+ * \param key The attribute key
+ * \return bool True if the key is present
+ */
+ TVM_DLL static const bool HasGenericAttr(const std::string& key);
};
/*! \brief Helper structure to register operators */
return OpMap<ValueType>(Op::GetGenericAttr(key));
}
+inline bool Op::HasAttr(const std::string& key) {
+ return Op::HasGenericAttr(key);
+}
+
inline OpNode* OpRegistry::get() {
return const_cast<OpNode*>(op_.operator->());
}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/qnn/transform.h
+ *
+ * This file implements a pass manager for QNN ops using Relay Pass manager.
+ */
+#ifndef TVM_RELAY_QNN_TRANSFORM_H_
+#define TVM_RELAY_QNN_TRANSFORM_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/relay/transform.h>
+
+namespace tvm {
+namespace relay {
+
+using relay::transform::Pass;
+
+namespace qnn {
+namespace transform {
+
+/*!
+ * \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
+ * converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
+ * Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
+ * One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
+ * attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
+ * only QNN ops. One can register a transformation/legalization function for an op by using the
+ * FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
+ * gives us separation of concerns, leading to a better software practice. The legalization can be
+ * configured to happen per target.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass Legalize();
+
+} // namespace transform
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_QNN_TRANSFORM_H_
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
+#include <tvm/relay/qnn/transform.h>
#include <memory>
#include "utils.h"
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
Array<Pass> pass_seqs;
+
+ // Run all dialect legalization passes.
+ pass_seqs.push_back(relay::qnn::transform::Legalize());
+
+ // Legalize pass is restricted to homogeneous execution for now.
+ if (targets.size() == 1) {
+ pass_seqs.push_back(transform::Legalize());
+ }
+
pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());
- // Legalize pass is restricted to homogeneous execution for now.
- if (targets.size() == 1) {
- pass_seqs.push_back(transform::Legalize());
- }
-
// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
return *it->second.get();
}
+// Check if a key is present in the registry.
+const bool Op::HasGenericAttr(const std::string& key) {
+ OpManager* mgr = OpManager::Global();
+ std::lock_guard<std::mutex> lock(mgr->mutex);
+ auto it = mgr->attr.find(key);
+ if (it == mgr->attr.end()) {
+ return false;
+ }
+ return true;
+}
+
void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
Expr new_e = ExprMutator::VisitExpr_(call_node);
Call new_call = Downcast<Call>(new_e);
+ // Check if the string is registered in the OpRegistry.
+ if (!Op::HasAttr(legalize_map_attr_name_)) {
+ return new_e;
+ }
+
// Collect the registered legalize function.
auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
- Op op = Downcast<Op>(call_node->op);
-
- if (fop_legalize.count(op)) {
- // Collect the new_args.
- tvm::Array<Expr> call_args = new_call->args;
-
- // Collect input and output dtypes to pass on to Legalize API.
- tvm::Array<tvm::relay::Type> types;
- for (auto arg : call_node->args) {
- types.push_back(arg->checked_type());
- }
- types.push_back(call_node->checked_type());
-
- // Transform the op by calling the registered legalize function.
- Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
-
- // Reassign new_e if the transformation succeeded.
- if (legalized_value.defined()) {
- // Check that the returned Expr from legalize is CallNode.
- const CallNode* legalized_call_node = legalized_value.as<CallNode>();
- CHECK(legalized_call_node)
- << "Can only replace the original operator with another call node";
-
- new_e = legalized_value;
+ auto call_op = call_node->op;
+ if (call_op.as<OpNode>()) {
+ Op op = Downcast<Op>(call_node->op);
+
+ if (fop_legalize.count(op)) {
+ // Collect the new_args.
+ tvm::Array<Expr> call_args = new_call->args;
+
+ // Collect input and output dtypes to pass on to Legalize API.
+ tvm::Array<tvm::relay::Type> types;
+ for (auto arg : call_node->args) {
+ types.push_back(arg->checked_type());
+ }
+ types.push_back(call_node->checked_type());
+
+ // Transform the op by calling the registered legalize function.
+ Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
+
+ // Reassign new_e if the transformation succeeded.
+ if (legalized_value.defined()) {
+ // Check that the returned Expr from legalize is CallNode.
+ const CallNode* legalized_call_node = legalized_value.as<CallNode>();
+ CHECK(legalized_call_node)
+ << "Can only replace the original operator with another call node";
+
+ new_e = legalized_value;
+ }
}
}
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
- return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
+ return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")});
}
TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/qnn/pass/legalize.cc
+ * \brief The Legalize wrapper for QNN.
+ */
+
+#include <tvm/relay/qnn/transform.h>
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+namespace transform {
+
+Pass Legalize() {
+ Array<Pass> pass_seqs;
+ pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
+ pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
+ relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
+ return seq;
+}
+
+TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize);
+
+} // namespace transform
+
+} // namespace qnn
+} // namespace relay
+} // namespace tvm
mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
- mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod
def get_funcs(data_shape,
input_zero_point=input_zero_point)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
- mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
output_zero_point=output_zero_point,out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
- mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
mod = relay.Function(relay.analysis.free_vars(mod), mod)
mod = relay.Module.from_expr(mod)
- mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod
def same_scale_test():