From ec7790e355f00bf5b0764d223087dc0f63176122 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 30 Aug 2019 21:30:18 -0700 Subject: [PATCH] [QNN] Concat - Refactoring to C++ (#3819) --- include/tvm/relay/qnn/attrs.h | 30 ++++++- python/tvm/relay/qnn/op/qnn.py | 55 ++++-------- src/relay/qnn/op/concatenate.cc | 134 +++++++++++++++++++++++++++++ src/relay/qnn/util.h | 19 ++++ tests/python/relay/test_qnn_concatenate.py | 4 - 5 files changed, 197 insertions(+), 45 deletions(-) create mode 100644 src/relay/qnn/op/concatenate.cc diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 1ebdeaa..b8d775c 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -88,7 +88,7 @@ struct DequantizeAttrs : public tvm::AttrsNode { int32_t input_zero_point; double input_scale; - TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { + TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") { TVM_ATTR_FIELD(input_zero_point) .describe("The zero_point for the input tensor of this op."); @@ -97,6 +97,34 @@ struct DequantizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in QNN concatenate operator */ +struct QnnConcatenateAttrs : public tvm::AttrsNode { + Array input_scales; + Array input_zero_points; + double output_scale; + int32_t output_zero_point; + int axis; + + TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") { + TVM_ATTR_FIELD(input_scales) + .describe("The list of scales of input quantized tensors."); + + TVM_ATTR_FIELD(input_zero_points) + .describe("The list of zero points of input quantized tensors."); + + TVM_ATTR_FIELD(output_zero_point) + .describe("The zero_point for the output tensor."); + + TVM_ATTR_FIELD(output_scale) + .describe("The scale for the output tensor."); + + TVM_ATTR_FIELD(axis) + .describe("The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`.") + .set_default(0); + } +}; // struct QnnConcatenateAttrs + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index b153cd5..7eb0408 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,8 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm import relay +from tvm.expr import FloatImm, IntImm +from tvm.relay.expr import Tuple from . import _make def requantize(data, @@ -134,6 +135,8 @@ def dequantize(data, return _make.dequantize(data, input_scale, input_zero_point) + + def concatenate(data, input_scales, input_zero_points, @@ -169,42 +172,14 @@ def concatenate(data, """ data = list(data) - requantized_exprs = list(data) - - # Find the dtype of the input expr. This is required for the requantize op. Since, this is - # concatenate op, the dtype of the input is same as dtype of the output. - mod = relay.Module.from_expr(data[0]) - mod = relay.transform.InferType()(mod) - entry = mod["main"] - data0 = entry if isinstance(data[0], relay.Function) else entry.body - in_dtype = data0.checked_type.dtype - - # First check if all the input qnn params match. If yes, we can call concatenate first, followed - # by a requantize. - if all(scale == input_scales[0] for scale in input_scales)\ - and all(zero_point == input_zero_points[0] for zero_point in input_zero_points): - out = relay.concatenate(tuple(data), axis) - input_scale = input_scales[0] - input_zero_point = input_zero_points[0] - if input_scale != output_scale or input_zero_point != output_zero_point: - out = requantize(data=out, - input_scale=input_scales[0], - input_zero_point=input_zero_points[0], - output_scale=output_scale, - output_zero_point=output_zero_point, - out_dtype=in_dtype) - return out - - # If the output qnn params do not match the input qnn params, we can call requantize on the - # input expr first, followed by a concatenate on the requantized input exprs. - for idx, quantized_expr in enumerate(data): - input_scale = input_scales[idx] - input_zero_point = input_zero_points[idx] - if input_scale != output_scale or input_zero_point != output_zero_point: - requantized_exprs[idx] = requantize(data=quantized_expr, - input_scale=input_scale, - input_zero_point=input_zero_point, - output_scale=output_scale, - output_zero_point=output_zero_point, - out_dtype=in_dtype) - return relay.concatenate(tuple(requantized_exprs), axis) + if not data: + raise ValueError("relay.concatenate requires data to be non-empty.") + if not isinstance(axis, int): + raise ValueError("For now, we only support integer axis") + + return _make.concatenate(Tuple(data), + [FloatImm("float64", x) for x in input_scales], + [IntImm("int32", x) for x in input_zero_points], + output_scale, + output_zero_point, + axis) diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc new file mode 100644 index 0000000..e87eaa1 --- /dev/null +++ b/src/relay/qnn/op/concatenate.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/qnn/op/concatenate.cc + * \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis. + */ + +#include +#include +#include +#include +#include "../../op/tensor/transform.h" +#include "../../pass/pattern_util.h" +#include "../util.h" + +namespace tvm { +namespace relay { +namespace qnn { + +TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs); + +Expr MakeQnnConcatenate(Expr data, Array input_scales, + Array input_zero_points, double output_scale, + int32_t output_zero_point, int axis) { + auto attrs = make_node(); + attrs->input_scales = std::move(input_scales); + attrs->input_zero_points = std::move(input_zero_points); + attrs->output_scale = output_scale; + attrs->output_zero_point = output_zero_point; + attrs->axis = axis; + static const Op& op = Op::Get("qnn.concatenate"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +/* + * \brief Canonicalizes the QNN concatenate op. + * \param attrs The QNN concatenate attrs. + * \param new_args The new mutated args to the call node. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for concatenate op. + */ +Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + // Get the attrs. + CHECK_EQ(new_args.size(), 1); + auto& data = new_args[0]; + const auto* concatenate_attrs = attrs.as(); + CHECK(concatenate_attrs != nullptr); + auto input_scales = concatenate_attrs->input_scales; + auto input_zero_points = concatenate_attrs->input_zero_points; + auto output_scale = concatenate_attrs->output_scale; + auto output_zero_point = concatenate_attrs->output_zero_point; + + // Get the input dtype and shape. + CHECK_GE(arg_types.size(), 1); + auto tuple_type = arg_types[0].as(); + CHECK(tuple_type != nullptr); + + // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in + // the start, we can insert requantize at the end if and only if all the input tensors have same + // qnn params. This can be done in future. + + // If the output qnn params do not match the input qnn params, we can call requantize on the input + // expr first, followed by a concatenate on the requantized input exprs. + + auto tuple_data = data.as(); + CHECK(tuple_data != nullptr); + + int idx = 0; + Array requantized_exprs; + for (auto quantized_expr : tuple_data->fields) { + // Get the input scale for the idx quantized input tensor. + auto input_scale_expr = input_scales[idx].as(); + CHECK(input_scale_expr != nullptr); + auto input_scale = input_scale_expr->value; + + // Get the zero point for the idx quantized input tensor. + auto input_zero_point_expr = input_zero_points[idx].as(); + CHECK(input_zero_point_expr != nullptr); + auto input_zero_point = input_zero_point_expr->value; + + // Check if output and input qnn params are same. If not, requantize. + if (input_scale != output_scale || input_zero_point != output_zero_point) { + // Get the input shape and dtype. + auto tensor_type = tuple_type->fields[idx].as(); + auto input_dtype = tensor_type->dtype; + auto input_shape = tensor_type->shape; + + // Requantize the input. + auto requantized_expr = Requantize(quantized_expr, input_shape, input_scale, input_zero_point, + output_scale, output_zero_point, input_dtype); + requantized_exprs.push_back(requantized_expr); + } else { + requantized_exprs.push_back(quantized_expr); + } + idx++; + } + return MakeConcatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis); +} + +RELAY_REGISTER_OP("qnn.concatenate") +.describe(R"code(Concatenate the quantized input tensors along the given axis. +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.QnnConcatenateAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The tensor to concatenate.") +.set_support_level(11) +.add_type_rel("QnnConcatenate", ConcatenateRel) +.set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize); + +TVM_REGISTER_API("relay.qnn.op._make.concatenate") +.set_body_typed(MakeQnnConcatenate); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 1ada7ec..208e04f 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -28,6 +28,8 @@ #include #include #include +#include +#include namespace tvm { namespace relay { @@ -67,6 +69,23 @@ static inline const int32_t GetQmax(const DataType& dtype) { } } +Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, + const Array& input_shape, const DataType& out_dtype); + +static inline Expr Requantize(const Expr& data, const Array& input_shape, + double input_scale, int32_t input_zero_point, double output_scale, + int32_t output_zero_point, const DataType& out_dtype, + const std::string& rounding = "TONEAREST") { + auto attrs = make_node(); + attrs->input_scale = std::move(input_scale); + attrs->input_zero_point = std::move(input_zero_point); + attrs->output_scale = std::move(output_scale); + attrs->output_zero_point = std::move(output_zero_point); + attrs->rounding = std::move(rounding); + attrs->out_dtype = std::move(out_dtype); + return RequantizeLower(data, attrs.operator->(), input_shape, out_dtype); +} + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_qnn_concatenate.py b/tests/python/relay/test_qnn_concatenate.py index 0a5f909..b24e1a0 100644 --- a/tests/python/relay/test_qnn_concatenate.py +++ b/tests/python/relay/test_qnn_concatenate.py @@ -39,7 +39,6 @@ def test_same_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 0 mod = relay.Module.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -68,7 +67,6 @@ def test_different_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 2 mod = relay.Module.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -97,7 +95,6 @@ def test_few_same_io_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 1 mod = relay.Module.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] @@ -126,7 +123,6 @@ def test_same_i_qnn_params(): axis=axis) func = relay.Function([x, y], z) - assert func.astext().count('requantize') == 1 mod = relay.Module.from_expr(func) mod = relay.qnn.transform.CanonicalizeOps()(mod) func = mod["main"] -- 2.7.4