return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
@reg.register_legalize("nn.conv2d")
-def legalize_conv2d(attrs, inputs, arg_dtypes):
- """Legalize conv2d"""
- from ... import op
- return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
+def legalize_conv2d(attrs, inputs, types):
+ """Legalize conv2d op.
+
+ Parameters
+ ----------
+ attrs : tvm.attrs.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ return topi.nn.conv2d_legalize(attrs, inputs, types)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expr new_e;
bool modified = false;
if (fop_legalize.count(op)) {
- tvm::Array<tvm::relay::Type> arg_types;
+ // Collect input and output dtypes to pass on to Legalize API.
+ tvm::Array<tvm::relay::Type> types;
for (auto& expr : ref_call->args) {
- arg_types.push_back(expr->checked_type());
+ types.push_back(expr->checked_type());
}
- Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
+ types.push_back(ref_call->checked_type());
+
+ // Transform the op by calling the registered legalize function.
+ Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);
+
+ // Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if (legalized_value.defined()) {
new_e = legalized_value;
modified = true;
Expr DequantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
- const Array<tvm::relay::Type>& arg_types) {
+ const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
CHECK(dequantize_attrs != nullptr);
- CHECK_EQ(arg_types.size(), 1);
+ CHECK_EQ(types.size(), 2);
return DequantizeLower(data, dequantize_attrs);
}
Expr QuantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
- const Array<tvm::relay::Type>& arg_types) {
+ const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr);
- CHECK_EQ(arg_types.size(), 1);
+ CHECK_EQ(types.size(), 2);
return QuantizeLower(data, quantize_attrs);
}
* 7) Cast to the out_dtype.
*/
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
- const Array<IndexExpr>& input_shape) {
+ const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale;
// Choose high precision datatype to be int64. This is for avoiding overflow
auto shifted_int64_t = Add(output_zp, scaled_int64_t);
// 7) Clip to the out_dtype min/max.
- auto q_min = GetQmin(param->out_dtype);
- auto q_max = GetQmax(param->out_dtype);
+ auto q_min = GetQmin(out_dtype);
+ auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
- return Cast(clipped_t, param->out_dtype);
+ return Cast(clipped_t, out_dtype);
}
/*
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
- const Array<tvm::relay::Type>& arg_types) {
+ const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr);
// Find input shape.
- CHECK_EQ(arg_types.size(), 1);
- auto input_dtype = arg_types[0];
- auto input_tensor_type = input_dtype.as<TensorTypeNode>();
- CHECK(input_tensor_type != nullptr) << "Type information missing."
- << " Please run infer_type pass.";
- Array<IndexExpr> input_shape = input_tensor_type->shape;
+ CHECK_EQ(types.size(), 2);
+ auto in_type = types[0];
+ auto in_tensor_type = in_type.as<TensorTypeNode>();
+ CHECK(in_tensor_type != nullptr) << "Type information missing."
+ << " Please run infer_type pass.";
+ Array<IndexExpr> input_shape = in_tensor_type->shape;
+
+ // Find the output dtype.
+ auto out_type = types[1];
+ auto out_tensor_type = out_type.as<TensorTypeNode>();
+ CHECK(out_tensor_type != nullptr) << "Type information missing."
+ << " Please run infer_type pass.";
+ auto out_dtype = out_tensor_type->dtype;
// Check rounding validity.
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST";
- return RequantizeLower(quantized_data, param, input_shape);
+ return RequantizeLower(quantized_data, param, input_shape, out_dtype);
}
/*
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
-Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
+Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.RequantizeAttrs")
return y
@register_legalize("nn.conv2d", level=100)
- def legalize_conv2d(attrs, inputs, arg_types):
+ def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
called = [False]
@register_legalize("nn.global_max_pool2d", level=101)
- def legalize_conv2d(attrs, inputs, arg_types):
+ def legalize_conv2d(attrs, inputs, types):
called[0] = True
return None
return func
@register_legalize("concatenate", level=100)
- def legalize_concatenate(attrs, inputs, arg_types):
+ def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
assert isinstance(inputs[0], tvm.relay.expr.Tuple)
- assert len(arg_types) == 1
- assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
+ assert len(types) == 2
+ assert isinstance(types[0], tvm.relay.ty.TupleType)
+ assert isinstance(types[1], tvm.relay.ty.TensorType)
return None
def expected():
return func
@register_legalize("nn.conv2d", level=101)
- def legalize_conv2d(attrs, inputs, arg_types):
+ def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize
- return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
+ return _conv2d_legalize(attrs, inputs, types)
a = before()
b = run_opt_pass(a, transform.Legalize())
"""Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs
-import warnings
+import logging
import tvm
from tvm import autotvm
+from tvm import relay
import tvm.contrib.nnpack
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
+logger = logging.getLogger('topi')
+
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
"""TOPI compute callback for conv2d
if layout != 'NCHW':
return None
if dilation != (1, 1):
- warnings.warn("Does not support weight pre-transform for dilated convolution.")
+ logger.warning("Does not support weight pre-transform for dilated convolution.")
return None
data, kernel = tinfos[0:2]
return None
@conv2d_legalize.register("arm_cpu")
-def _conv2d_legalize(attrs, inputs, arg_types, F):
- if F.__name__ != 'tvm.relay.op':
- return None
+def _conv2d_legalize(attrs, inputs, arg_types):
+ """Legalizes Conv2D op.
+
+ Parameters
+ ----------
+ attrs : tvm.attrs.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
- kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
+ kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
- kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
+ kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None
- warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
- + "fallback to NCHW. This can result in performance degradation.")
+ logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ + "fallback to NCHW. This can result in performance degradation.")
# Set new attrs for the tranposed conv.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
new_attrs['kernel_layout'] = 'OIHW'
# Convert from NHWC to NCHW.
- data = F.transpose(data, axes=(0, 3, 1, 2))
- conv = F.nn.conv2d(data, kernel, **new_attrs)
+ data = relay.transpose(data, axes=(0, 3, 1, 2))
+ conv = relay.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
- out = F.transpose(conv, axes=(0, 2, 3, 1))
+ out = relay.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
@tvm.target.generic_func
-def conv2d_legalize(attrs, inputs, arg_dtypes, F):
+def conv2d_legalize(attrs, inputs, types):
"""Legalizes Conv2D op.
+
Parameters
----------
- attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
+ attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
- The args of the Relay expr to be legalized.
- arg_dtypes : list of types
- List of types of input arguments
- F: symbol
- The context, can be either nnvm.sym or relay.op
- Note
- ----
- Unlike other TOPI functions, this function operates on both graph level and operator level,
- so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
"""
# not to change by default
return None