IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
- std::string weight_layout;
+ std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
- TVM_ATTR_FIELD(weight_layout).set_default("OIHW")
+ TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
}
};
+
+/*! \brief Attributes used in winograd weight transformation operators */
+struct Conv2DWinogradWeightTransformAttrs :
+ public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> {
+ int tile_size;
+
+ TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs,
+ "relay.attrs.Conv2DWinogradWeightTransformAttrs") {
+ TVM_ATTR_FIELD(tile_size)
+ .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
+ }
+};
+
+/*! \brief Attributes used in convolution operators with winograd algorithm */
+struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
+ int tile_size;
+ Array<IndexExpr> strides;
+ Array<IndexExpr> padding;
+ Array<IndexExpr> dilation;
+ int groups;
+ IndexExpr channels;
+ Array<IndexExpr> kernel_size;
+ std::string data_layout;
+ std::string kernel_layout;
+ std::string out_layout;
+ DataType out_dtype;
+
+ TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") {
+ TVM_ATTR_FIELD(tile_size)
+ .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
+ TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
+ .describe("Specifies the strides of the convolution.");
+ TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
+ .describe("If padding is non-zero, then the input is implicitly zero-padded"
+ "on both sides for padding number of points");
+ TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
+ .describe("Specifies the dilation rate to use for dilated convolution.");
+ TVM_ATTR_FIELD(groups).set_default(1)
+ .describe("Controls the connections between inputs and outputs."
+ "At groups=1, all inputs are convolved to all outputs."
+ "At groups=2, the operation becomes equivalent to having two convolution"
+ "layers side by side, each seeing half the input channels, and producing"
+ "half the output channels, and both subsequently concatenated.");
+ TVM_ATTR_FIELD(channels)
+ .describe("The number of output channels in the convolution."
+ " If it is not set, inferred by shape of the weight.")
+ .set_default(NullValue<IndexExpr>());
+ TVM_ATTR_FIELD(kernel_size)
+ .describe("Specifies the dimensions of the convolution window.")
+ .set_default(NullValue<Array<IndexExpr> >());
+ TVM_ATTR_FIELD(data_layout).set_default("NCHW")
+ .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Convolution is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
+ .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+ "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout).set_default("")
+ .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
+
+ // use 0 bits to indicate none.
+ TVM_ATTR_FIELD(out_dtype)
+ .set_default(NullValue<DataType>())
+ .describe("Output data type, set to explicit type under mixed precision setting");
+ }
+};
+
+
/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;
TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
- TVM_ATTR_FIELD(axis).set_default(-1)
- .describe("The axis to sum over when computing softmax.");
+ TVM_ATTR_FIELD(axis).set_default(-1)
+ .describe("The axis to sum over when computing softmax.");
}
};
Array<IndexExpr> dilation;
int groups;
std::string data_layout;
- std::string weight_layout;
+ std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
- TVM_ATTR_FIELD(weight_layout).set_default("OIHW")
+ TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument
"""Convert an NNVM graph to Relay."""
import json
+import numpy
+
from tvm import relay, nd
from tvm.relay import op, expr, var
from tvm.relay.frontend.common import StrAttrsDict
from tvm.relay.frontend.nnvm_common import _rename
-import numpy
from .symbol import Symbol
from .compiler import graph_attr
from .graph import create as graph_create
dilation = attrs.get_int_tuple('dilation', (1, 1))
groups = attrs.get_int('groups', 1)
data_layout = attrs.get_str('layout', 'NCHW')
- weight_layout = attrs.get_str('kernel_layout', 'OIHW')
+ kernel_layout = attrs.get_str('kernel_layout', 'OIHW')
out_layout = ''
out_dtype = attrs.get_str('out_dtype', '')
dilation=dilation,
groups=groups,
data_layout=data_layout,
- weight_layout=weight_layout,
+ kernel_layout=kernel_layout,
out_layout=out_layout,
out_dtype=out_dtype)
dilation = attrs.get_int_tuple('dilation', (1, 1))
groups = attrs.get_int('groups', 1)
data_layout = attrs.get_str('layout', 'NCHW')
- weight_layout = attrs.get_str('kernel_layout', 'OIHW')
+ kernel_layout = attrs.get_str('kernel_layout', 'OIHW')
out_dtype = attrs.get_str('out_dtype', '')
out_conv2d = op.nn.conv2d_transpose(
dilation=dilation,
groups=groups,
data_layout=data_layout,
- weight_layout=weight_layout,
+ kernel_layout=kernel_layout,
out_dtype=out_dtype)
if use_bias:
else:
raise ValueError("Wrong bool format for key %s" % key)
- def get_string(self, key):
+ def get_str(self, key):
"""Get string from attr dict
Parameters
@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
- return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos)
+ """Replace conv2d op with other layouts or algorithms"""
+ import nnvm.symbol as sym
+
+ # map relay op names to nnvm op names
+ sym.contrib_conv2d_winograd_without_weight_transform = \
+ sym.contrib.conv2d_winograd_without_weight_transform
+ sym.contrib_conv2d_winograd_weight_transform = \
+ sym.contrib.conv2d_winograd_weight_transform
+ sym.nn = sym
+
+ # map relay argument names to nnvm argument names
+ raw_reshape = sym.reshape
+ def _reshape(*args, **kwargs):
+ if "newshape" in kwargs:
+ kwargs['shape'] = kwargs.pop('newshape')
+ return raw_reshape(*args, **kwargs)
+ sym.reshape = _reshape
+
+ return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, sym)
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
dilation = attrs.get_int_tuple("dilation")
out_channel = attrs.get_int("channels")
groups = attrs.get_int("groups")
- layout = attrs.get_string("layout")
- out_layout = attrs.get_string("out_layout")
- out_dtype = attrs.get_string("out_dtype")
+ layout = attrs.get_str("layout")
+ out_layout = attrs.get_str("out_layout")
+ out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
if layout == "NCHW":
_, in_channel, _, _ = get_const_tuple(inputs[0].shape)
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
- layout = attrs.get_string("layout")
- out_dtype = attrs.get_string("out_dtype")
+ layout = attrs.get_str("layout")
+ out_dtype = attrs.get_str("out_dtype")
tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
- out_dtype = attrs.get_string("out_dtype")
+ out_dtype = attrs.get_str("out_dtype")
layout = attrs["layout"]
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
for field in fields:
yield field.name
+ def get_int_tuple(self, key):
+ """Get a python int tuple of a key
+
+ Parameters
+ ----------
+ key: str
+
+ Returns
+ -------
+ value: Tuple of int
+ """
+ return tuple(x.value for x in self.__getattr__(key))
+
+ def get_int(self, key):
+ """Get a python int value of a key
+
+ Parameters
+ ----------
+ key: str
+
+ Returns
+ -------
+ value: int
+ """
+ return self.__getattr__(key)
+
+ def get_str(self, key):
+ """Get a python int value of a key
+
+ Parameters
+ ----------
+ key: str
+
+ Returns
+ -------
+ value: int
+ """
+ return self.__getattr__(key)
+
def __getitem__(self, item):
return self.__getattr__(item)
return expr.bind(func, bind_dict)
-def optimize(func, params=None):
+def optimize(func, target, params=None):
"""Perform target invariant optimizations.
Parameters
func : tvm.relay.Function
The input to optimization.
+ target: :any:`tvm.target.Target`
+ The optimization target. Some optimization passes are target specific.
+
params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
during inference time. used for constant folding.
func = ir_pass.infer_type(func)
func = ir_pass.canonicalize_ops(func)
func = ir_pass.infer_type(func)
- func = ir_pass.alter_op_layout(func)
+ with target:
+ func = ir_pass.alter_op_layout(func)
+
+ if cfg.pass_enabled("FoldConstant"):
+ func = ir_pass.fold_constant(func)
return func
cfg = BuildConfig.current
with tophub_context:
- func = optimize(func, params)
+ func = optimize(func, target, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
channel_axis = _get_channel_axis(data_layout, "conv2d")
if "kernel_layout" in attrs.attrs:
- weight_layout = attrs.get_str("kernel_layout")
+ kernel_layout = attrs.get_str("kernel_layout")
else:
- weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
+ kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1))
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
- new_attrs["weight_layout"] = weight_layout
+ new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", False)
res = _op.nn.conv2d(inputs[0], inputs[1], **new_attrs)
if use_bias:
channel_axis = _get_channel_axis(data_layout, "conv2d_transpose")
if "kernel_layout" in attrs.attrs:
- weight_layout = attrs.get_str("kernel_layout")
+ kernel_layout = attrs.get_str("kernel_layout")
else:
- weight_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
+ kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1))
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
- new_attrs["weight_layout"] = weight_layout
+ new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", False)
res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs)
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
- weight_layout = attrs.weight_layout
+ kernel_layout = attrs.kernel_layout
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if (out_dtype == "same" or out_dtype == "")
else out_dtype)
inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
elif layout == "NCHW" and \
- weight_layout == "OIHW" and \
+ kernel_layout == "OIHW" and \
get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
elif layout == "NHWC" and \
- weight_layout == "HWOI" and\
+ kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1:
out = topi.nn.depthwise_conv2d_nhwc(
"""Schedule definition of conv2d"""
groups = attrs.groups
layout = attrs.data_layout
- kernel_layout = attrs.weight_layout
+ kernel_layout = attrs.kernel_layout
with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
@reg.register_alter_op_layout("nn.conv2d")
def alter_op_layout_conv2d(attrs, inputs, tinfos):
"""Alternate the layout of conv2d"""
- return None
+ from ... import op
+ return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
-# Upsampling
+# upsampling
reg.register_schedule("nn.upsampling", reg.schedule_injective)
def schedule_upsampling(_, outs, target):
"""Schedule definition of upsampling"""
return topi.generic.schedule_injective(outs)
# pad
reg.register_schedule("nn.pad", schedule_broadcast)
+
+# winograd related operators
+@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
+def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target):
+ """Compute definition of conv2d_winograd_without_weight_transform"""
+ # pylint: disable=assignment-from-no-return
+ padding = attrs.get_int_tuple("padding")
+ strides = attrs.get_int_tuple("strides")
+ dilation = attrs.get_int_tuple("dilation")
+ groups = attrs.get_int("groups")
+ data_layout = attrs.get_str("data_layout")
+ out_dtype = attrs.get_str("out_dtype")
+ tile_size = attrs.get_int("tile_size")
+ out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
+ assert dilation == (1, 1), "Do not support dilate now"
+ assert groups == 1, "Do not supoort arbitrary group number"
+
+ out = topi.nn.conv2d_winograd_without_weight_transform(
+ inputs[0], inputs[1], strides, padding, dilation, data_layout,
+ out_dtype, tile_size)
+
+ return [out]
+
+@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
+def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
+ """Schedule definition of conv2d_winograd_without_weight_transform"""
+ with target:
+ return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
+
+reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
+ OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
+@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
+def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
+ """Compute definition of contrib_conv2d_winograd_weight_transform"""
+ out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size'))
+ return [out]
+
+@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
+def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
+ """Schedule definition of contrib_conv2d_winograd_weight_transform"""
+ with target:
+ return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
+
+reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
+ OpPattern.OUT_ELEMWISE_FUSABLE)
channels=None,
kernel_size=None,
data_layout="NCHW",
- weight_layout="OIHW",
+ kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""2D convolution.
In the default case, where the data_layout is `NCHW`
- and weight_layout is `OIHW`, conv2d takes in
+ and kernel_layout is `OIHW`, conv2d takes in
a data Tensor with shape `(batch_size, in_channels, height, width)`,
and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1])`
to produce an output Tensor with the following rule:
data_layout : str, optional
Layout of the input.
- weight_layout : str, optional
+ kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
"""
return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
- weight_layout, out_layout, out_dtype)
+ kernel_layout, out_layout, out_dtype)
def conv2d_transpose(data,
channels=None,
kernel_size=None,
data_layout="NCHW",
- weight_layout="OIHW",
+ kernel_layout="OIHW",
output_padding=(0, 0),
out_dtype=""):
"""Two dimensional trnasposed convolution operator.
data_layout : str, optional
Layout of the input.
- weight_layout : str, optional
+ kernel_layout : str, optional
Layout of the weight.
output_padding : Tuple[int], optional
"""
return _make.conv2d_transpose(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
- weight_layout, output_padding, out_dtype)
+ kernel_layout, output_padding, out_dtype)
def softmax(data, axis=-1):
center,
scale)
return TupleWrapper(result, 3)
+
+
+def contrib_conv2d_winograd_without_weight_transform(data,
+ weight,
+ tile_size,
+ strides=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ groups=1,
+ channels=None,
+ kernel_size=None,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="",
+ out_dtype=""):
+ r"""2D convolution with winograd algorithm.
+
+ The basic parameters are the same as the ones in vanilla conv2d.
+ It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_weight_transform
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input data to the operator.
+
+ weight : tvm.relay.Expr
+ The weight expressions.
+
+ tile_size : int
+ The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
+
+ strides : tuple of int, optional
+ The strides of convoltution.
+
+ padding : tuple of int, optional
+ The padding of convolution on both sides of inputs before convolution.
+
+ dilation : tuple of int, optional
+ Specifies the dilation rate to be used for dilated convolution.
+
+ groups : int, optional
+ Number of groups for grouped convolution.
+
+ channels : int, optional
+ Number of output channels of this convolution.
+
+ kernel_size : tuple of int, optional
+ The spatial of the convolution kernel.
+
+ data_layout : str, optional
+ Layout of the input.
+
+ kernel_layout : str, optional
+ Layout of the weight.
+
+ out_layout : str, optional
+ Layout of the output, by default, out_layout is the same as data_layout
+
+ out_dtype : str, optional
+ Specifies the output data type for mixed precision conv2d.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ return _make.contrib_conv2d_winograd_without_weight_transform(
+ data, weight, tile_size, strides, padding, dilation,
+ groups, channels, kernel_size, data_layout,
+ kernel_layout, out_layout, out_dtype)
+
+
+def contrib_conv2d_winograd_weight_transform(weight,
+ tile_size):
+ r"""Weight Transformation part for 2D convolution with winograd algorithm.
+
+ We separate this as a single op to enable pre-compute for inference.
+ Use this together with nn.contrib_conv2d_winograd_without_weight_transform
+
+ Parameters
+ ----------
+ weight : tvm.relay.Expr
+ The weight expressions.
+
+ tile_size : int
+ The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
@register_relay_attr_node
class Conv2DAttrs(Attrs):
- """Attribute of a Convolution Operator"""
+ """Attribute of nn.conv2d"""
+ pass
+
+@register_relay_attr_node
+class Conv2DWinogradAttrs(Attrs):
+ """Attribute of nn.contrib_conv2d_winograd_without_weight_transform"""
+ pass
+
+@register_relay_attr_node
+class Conv2DWinogradWeightTransformAttrs(Attrs):
+ """Attribute of nn.contrib_conv2d_winograd_weight_transform"""
pass
@register_relay_attr_node
class GlobalPool2DAttrs(Attrs):
- """Attribute of a Global 2D Pooling Operator"""
+ """Attribute of nn.global_pool"""
pass
from . import squeezenet
from . import vgg
from . import densenet
+
from .config import ctx_list
+from .init import create_workload
Layout() : Layout("__undef__") {} // NOLINT(*)
/*! \brief construct from a string */
- Layout(const char* str) : Layout(std::string(str)) {} // NOLINT(*)
+ Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
/*!
* \brief construct from a string.
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
- Layout(const std::string& layout) { // NOLINT(*)
- if (layout.length() != 0) {
- Parse(layout);
- } else {
- Parse("__undef__");
+ Layout(const std::string& name) { // NOLINT(*)
+ node_ = make_node<LayoutNode>();
+
+ std::vector<uint32_t> superdim_pos(kUniqueDim, -1);
+ std::vector<uint32_t> subdim_pos(kUniqueDim, -1);
+ std::vector<uint32_t> subdim_size(kUniqueDim, -1);
+ std::vector<char> layout_simplified;
+
+ if (name != "__undef__") { // parse layout string
+ int32_t factor = 0;
+ uint32_t curr = 0;
+ for (size_t i = 0; i < name.size(); ++i) {
+ const LayoutDim c = name.at(i);
+ if (IsSuperdim(c)) {
+ int pos = c - 'A';
+ CHECK_EQ(factor, 0) << "Invalid layout " << name
+ << ": invalid factor size " << factor
+ << " before dimension " << c;
+ CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << name
+ << ": duplicate dimension " << c;
+ superdim_pos[pos] = curr++;
+ layout_simplified.push_back(c);
+ } else if (IsSubdim(c)) {
+ int pos = c - 'a';
+ CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
+ << factor << " for dimension " << c;
+ CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << name
+ << ": duplicate dimension " << c;
+ CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << name
+ << ": duplicate dimension " << c;
+ subdim_pos[pos] = curr++;
+ subdim_size[pos] = factor;
+ layout_simplified.push_back(c);
+ factor = 0;
+ } else if (c >= '0' && c <= '9') {
+ CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
+ factor = factor * 10 + c - '0';
+ } else {
+ LOG(FATAL) << "Invalid layout " << name;
+ }
+ }
+ for (LayoutDim dim : layout_simplified) {
+ CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
+ << "Invalid layout " << name << ": missing axis "
+ << static_cast<char>(dim - 'a' + 'A');
+ }
+ }
+
+ LayoutNode *node = operator->();
+ node->name = name;
+
+ for (uint32_t i = 0; i < kUniqueDim; ++i) {
+ node->superdim_pos.push_back(superdim_pos[i]);
+ node->subdim_pos.push_back(subdim_pos[i]);
+ node->subdim_size.push_back(subdim_size[i]);
+ }
+ for (LayoutDim dim : layout_simplified) {
+ node->layout_simplified.push_back(dim);
}
}
const Array<Integer>& layout_simplified = operator->()->layout_simplified;
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
- if (len == 0) return Layout::Undef();
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (IsSubdim(layout_simplified[i]->value)) {
}
using ContainerType = LayoutNode;
-
- private:
- void Parse(const std::string &layout) {
- node_ = make_node<LayoutNode>();
-
- std::vector<uint32_t> superdim_pos(kUniqueDim, -1);
- std::vector<uint32_t> subdim_pos(kUniqueDim, -1);
- std::vector<uint32_t> subdim_size(kUniqueDim, -1);
- std::vector<char> layout_simplified;
-
- if (layout != "__undef__") { // parse layout string
- int32_t factor = 0;
- uint32_t curr = 0;
- for (size_t i = 0; i < layout.size(); ++i) {
- const LayoutDim c = layout.at(i);
- if (IsSuperdim(c)) {
- int pos = c - 'A';
- CHECK_EQ(factor, 0) << "Invalid layout " << layout
- << ": invalid factor size " << factor
- << " before dimension " << c;
- CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << layout
- << ": duplicate dimension " << c;
- superdim_pos[pos] = curr++;
- layout_simplified.push_back(c);
- } else if (IsSubdim(c)) {
- int pos = c - 'a';
- CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
- << factor << " for dimension " << c;
- CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << layout
- << ": duplicate dimension " << c;
- CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << layout
- << ": duplicate dimension " << c;
- subdim_pos[pos] = curr++;
- subdim_size[pos] = factor;
- layout_simplified.push_back(c);
- factor = 0;
- } else if (c >= '0' && c <= '9') {
- CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
- factor = factor * 10 + c - '0';
- } else {
- LOG(FATAL) << "Invalid layout " << layout;
- }
- }
- CHECK(!layout_simplified.empty()) << "Invalid layout " << layout;
- for (LayoutDim dim : layout_simplified) {
- CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
- << "Invalid layout " << layout << ": missing axis "
- << static_cast<char>(dim - 'a' + 'A');
- }
- }
-
- LayoutNode *node = operator->();
- node->name = layout;
-
- for (uint32_t i = 0; i < kUniqueDim; ++i) {
- node->superdim_pos.push_back(superdim_pos[i]);
- node->subdim_pos.push_back(subdim_pos[i]);
- node->subdim_size.push_back(subdim_size[i]);
- }
- for (LayoutDim dim : layout_simplified) {
- node->layout_simplified.push_back(dim);
- }
- }
};
/*!
const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
- const Layout kernel_layout(param->weight_layout);
+ const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
- Layout out_layout(param->out_layout);
- if (!out_layout.defined()) out_layout = in_layout;
+ Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const T* params = attrs.as<T>();
- Layout out_layout(params->out_layout);
// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
- return Array<Array<Layout> >{{params->data_layout, params->weight_layout},
- {out_layout.defined() ? out_layout : params->data_layout}};
+ return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
+ {params->out_layout == "" ?
+ params->data_layout : params->out_layout}};
}
// Positional relay function to create conv2d operator
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
- std::string weight_layout,
+ std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
- attrs->channels = channels;
- attrs->kernel_size = kernel_size;
+ attrs->channels = std::move(channels);
+ attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
- attrs->weight_layout = std::move(weight_layout);
+ attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d");
const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
- const Layout kernel_layout(param->weight_layout);
+ const Layout kernel_layout(param->kernel_layout);
CHECK(in_layout.Convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
- Layout out_layout(param->out_layout);
- if (!out_layout.defined()) out_layout = in_layout;
+ Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
CHECK(out_layout.Convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
- std::string weight_layout,
+ std::string kernel_layout,
Array<IndexExpr> output_padding,
DataType out_dtype) {
auto attrs = make_node<Conv2DTransposeAttrs>();
- attrs->channels = channels;
- attrs->kernel_size = kernel_size;
+ attrs->channels = std::move(channels);
+ attrs->kernel_size = std::move(kernel_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->output_padding = std::move(output_padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
- attrs->weight_layout = std::move(weight_layout);
+ attrs->kernel_layout = std::move(kernel_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d_transpose");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
+
+// relay.nn.contrib_conv2d_winograd_without_weight_transform
+TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);
+
+bool Conv2DWinogradRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 3);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) return false;
+ static const Layout kNCHW("NCHW");
+ static const Layout kOIHW("OIHW");
+
+ const Conv2DWinogradAttrs* param = attrs.as<Conv2DWinogradAttrs>();
+ CHECK(param != nullptr);
+ const Layout in_layout(param->data_layout);
+ const Layout kernel_layout(param->kernel_layout);
+ CHECK(in_layout.Convertible(kNCHW))
+ << "Conv only support input layouts that are convertible from NCHW."
+ << " But got " << in_layout;
+ CHECK(kernel_layout.Convertible(kOIHW))
+ << "Conv only support kernel layouts that are convertible from OIHW."
+ << " But got "<< kernel_layout;
+
+ Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+ CHECK(out_layout.Convertible(kNCHW))
+ << "Conv only support output layouts that are convertible from NCHW."
+ << " But got " << out_layout;
+
+ std::vector<IndexExpr> dshape_nchw = ConvertLayout(
+ data->shape, in_layout, kNCHW);
+
+ IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+ CHECK(param->kernel_size.defined() && param->channels.defined())
+ << "The kernel size and channels of a Conv must be set or infered by previous pass";
+
+ CHECK_EQ(param->kernel_size.size(), 2);
+ CHECK_EQ(param->dilation.size(), 2);
+
+ channels = param->channels;
+ dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+ dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+
+ // NOTE: Do not check weight shape here!
+ // Different backend requires different layout to compute
+ // the batch gemm stage in winograd efficiently, but we want to
+ // make this op work for all backends.
+ // So we accept all weight shapes, and assume the TOPI developers
+ // can handle this correctly in alter_op_layout.
+
+ // dilation
+ std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+
+ oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
+ oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
+ DataType out_dtype = param->out_dtype;
+ if (out_dtype.bits() == 0) {
+ out_dtype = data->dtype;
+ }
+ oshape = ConvertLayout(oshape, kNCHW, out_layout);
+ // assign output type
+ reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ return true;
+}
+
+
+// Positional relay function to create conv2d winograd operator
+// used by frontend FFI.
+Expr MakeConv2DWinograd(Expr data,
+ Expr weight,
+ int tile_size,
+ Array<IndexExpr> strides,
+ Array<IndexExpr> padding,
+ Array<IndexExpr> dilation,
+ int groups,
+ IndexExpr channels,
+ Array<IndexExpr> kernel_size,
+ std::string data_layout,
+ std::string kernel_layout,
+ std::string out_layout,
+ DataType out_dtype) {
+ auto attrs = make_node<Conv2DWinogradAttrs>();
+ attrs->tile_size = tile_size;
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
+ attrs->groups = groups;
+ attrs->channels = channels;
+ attrs->kernel_size = std::move(kernel_size);
+ attrs->data_layout = std::move(data_layout);
+ attrs->kernel_layout = std::move(kernel_layout);
+ attrs->out_layout = std::move(out_layout);
+ attrs->out_dtype = std::move(out_dtype);
+ static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform");
+ return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 13>(MakeConv2DWinograd, args, rv);
+ });
+
+
+RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
+.describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout.
+ This operator assumes the weight tensor is already pre-transformed by
+ nn.contrib_conv2d_winograd_weight_transform.
+
+- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
+- **weight**: Any shape
+ We do not check the shape for this input tensor. Since different backend
+ has different layout strategy.
+
+- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.Conv2DWinograd")
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("weight", "Tensor", "The weight tensor.")
+.set_support_level(5)
+.add_type_rel("Conv2DWinograd", Conv2DWinogradRel)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);
+
+// relay.nn.contrib_conv2d_winograd_weight_transform
+TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
+
+bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) return false;
+
+ const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
+ CHECK(param != nullptr);
+
+ CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
+
+ // each pad width element should be a pair of positive integers
+ std::vector<IndexExpr> oshape {
+ param->tile_size + data->shape[2] - 1,
+ param->tile_size + data->shape[3] - 1,
+ data->shape[0],
+ data->shape[1],
+ };
+
+ reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
+ data->dtype));
+ return true;
+}
+
+Expr MakeConv2DWinogradWeightTransform(Expr weight,
+ int tile_size) {
+ auto attrs = make_node<Conv2DWinogradWeightTransformAttrs>();
+ attrs->tile_size = tile_size;
+ static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
+ return CallNode::make(op, {weight}, Attrs(attrs), {});
+}
+
+
+TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+ runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
+ });
+
+
+RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
+.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+
+Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
+weight transformation in advance.
+
+- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.Conv2DWinogradWeightTransformAttrs")
+.set_num_inputs(1)
+.add_argument("weight", "Tensor", "The weight tensor.")
+.set_support_level(5)
+.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
+
} // namespace relay
} // namespace tvm
}
if (!modified) {
new_e = CallNode::make(ref_call->op, new_args,
- ref_call->attrs, ref_call->type_args);
+ ref_call->attrs);
}
const CallNode *new_call = new_e.as<CallNode>();
// NOTE: discard the "const" qualifier
TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);
- // fill incomplete state and expand tuple
- for (auto new_arg : new_args) {
- auto push_back_one_arg = [&](Expr arg) {
- // We always expect LayoutAlternatedExpr.
- // This is used to convert the normal Expr to LayoutAlternatedExpr.
- if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) {
- inputs.push_back(GetRef<LayoutAlternatedExpr>(inp));
- normal_new_args.push_back(inp->value);
- } else {
- auto inode = make_node<LayoutAlternatedExprNode>();
- inode->value = arg;
- inode->memorizer = memorizer;
- inputs.push_back(LayoutAlternatedExpr(inode));
- normal_new_args.push_back(arg);
- }
- };
+ // fill incomplete state and flatten tuple
+ auto push_back_one_arg = [&inputs, memorizer](Expr arg) {
+ // We always expect LayoutAlternatedExpr.
+ // This is used to convert the normal Expr to LayoutAlternatedExpr.
+ if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) {
+ inputs.push_back(GetRef<LayoutAlternatedExpr>(inp));
+ return inp->value;
+ } else {
+ auto inode = make_node<LayoutAlternatedExprNode>();
+ inode->value = arg;
+ inode->memorizer = memorizer;
+ inputs.push_back(LayoutAlternatedExpr(inode));
+ return arg;
+ }
+ };
+ for (auto new_arg : new_args) {
+ // NOTE: do not support nested tuple
if (new_arg->is_type<TupleNode>()) {
Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
+ std::vector<Expr> fields;
for (auto x : tuple_new_arg->fields) {
- push_back_one_arg(x);
+ Expr tmp = push_back_one_arg(x);
+ fields.push_back(tmp);
}
+ normal_new_args.push_back(TupleNode::make(fields));
} else {
- push_back_one_arg(new_arg);
+ Expr tmp = push_back_one_arg(new_arg);
+ normal_new_args.push_back(tmp);
}
}
}
for (auto arg : ref_call->args) {
- if (arg->is_type<TupleNode>()) { // expand tuple
+ if (arg->is_type<TupleNode>()) { // flatten tuple
Tuple tuple_arg = Downcast<Tuple>(arg);
for (auto x : tuple_arg->fields) {
input_shapes.push_back(x->type_as<TensorTypeNode>()->shape);
// if (new_in != new_in2): insert transform (new_in -> new_in2)
Array<Expr> transformed_args;
- for (size_t i = 0; i < inputs.size(); ++i) {
- transformed_args.push_back(memorizer.Transform(new_call->args[i], new_in[i], new_in2[i]));
+ size_t pt = 0;
+ for (auto arg : new_call->args) {
+ if (arg->is_type<TupleNode>()) { // unflatten tuple
+ Tuple tuple_arg = Downcast<Tuple>(arg);
+ std::vector<Expr> transformed_tuple_arg;
+ for (auto arg_item : tuple_arg->fields) {
+ transformed_tuple_arg.push_back(
+ memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
+ pt++;
+ }
+ transformed_args.push_back(TupleNode::make(transformed_tuple_arg));
+ } else {
+ transformed_args.push_back(
+ memorizer.Transform(arg, new_in[pt], new_in2[pt]));
+ pt++;
+ }
}
+ CHECK_EQ(pt, inputs.size());
// state[node] = (old_out, new_out)
- CHECK(ref_call->checked_type_.defined())
- << "Call infer_type pass before alter_op_layout pass";
-
+ // (handle tuple output)
if (ref_call->checked_type()->is_type<TupleTypeNode>()) {
Expr tuple_output = CallNode::make(new_call->op, transformed_args,
- new_call->attrs, new_call->type_args);
+ new_call->attrs);
Array<Expr> fields;
for (size_t i = 0; i < new_out.size(); ++i) {
auto rnode = make_node<LayoutAlternatedExprNode>();
auto rnode = make_node<LayoutAlternatedExprNode>();
CHECK_EQ(new_out.size(), 1);
rnode->value = CallNode::make(new_call->op, transformed_args,
- new_call->attrs, new_call->type_args);
+ new_call->attrs);
rnode->old_layout = old_out[0];
rnode->new_layout = new_out[0];
rnode->memorizer = memorizer;
}
}
+// Limiations:
+// 1. the altered op should have the same number of arguments as the previous one
+// 2. do not support nested tuple arguments
TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
- const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW);
- const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW);
+ const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW);
+ const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
eq(attrs_a->data_layout, attrs_b->data_layout) &&
- eq(attrs_a->weight_layout, attrs_b->weight_layout) &&
+ eq(attrs_a->kernel_layout, attrs_b->kernel_layout) &&
eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
eq(shape_a[3], shape_b[3]);
auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels;
}
- auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->weight_layout.find('O');
+ auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
new_attrs->groups = attrs->groups;
new_attrs->kernel_size = attrs->kernel_size;
new_attrs->data_layout = attrs->data_layout;
- new_attrs->weight_layout = attrs->weight_layout;
+ new_attrs->kernel_layout = attrs->kernel_layout;
new_attrs->out_layout = attrs->out_layout;
new_attrs->out_dtype = attrs->out_dtype;
new_attrs->channels = new_channels;
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
- Layout weight_layout(param->weight_layout);
+ Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C');
int c_small_axis = data_layout.Indexof('c');
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
- bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
- if (weight_layout.Indexof('i') < 0 &&
+ bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
+ if (kernel_layout.Indexof('i') < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
data_axes = {c_big_axis};
const auto* param = ref_call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
Layout data_layout(param->data_layout);
- Layout weight_layout(param->weight_layout);
+ Layout kernel_layout(param->kernel_layout);
int c_big_axis = data_layout.Indexof('C');
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
- CHECK_EQ(weight_layout.Indexof('i'), -1);
+ CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK(sdata->axes.size() == 1 &&
c_big_axis == sdata->axes[0]->value);
- int big_oc_axis = weight_layout.Indexof('O');
- int big_ic_axis = weight_layout.Indexof('I');
+ int big_oc_axis = kernel_layout.Indexof('O');
+ int big_ic_axis = kernel_layout.Indexof('I');
// Check it must be depthwise or full conv2d.
- bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout);
+ bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr weight = new_args[1];
// match the ic_axis
if (is_depthwise_conv2d) {
Expr scale = ExpandBiasToMatchAxis(
- sdata->scale, weight_layout.ndim(), {big_oc_axis});
+ sdata->scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, scale);
} else {
Expr scale = ExpandBiasToMatchAxis(
- sdata->scale, weight_layout.ndim(), {big_ic_axis});
+ sdata->scale, kernel_layout.ndim(), {big_ic_axis});
weight = Multiply(weight, scale);
}
// return transformed conv2d
AxesSet Conv2DBackwardPrep(const Call& call, const Array<AxesSet>& in_axes) {
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
- Layout out_layout(param->out_layout);
- if (!out_layout.defined()) {
- out_layout = Layout(param->data_layout);
- }
- Layout weight_layout(param->weight_layout);
+ Layout kernel_layout(param->kernel_layout);
+ Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C');
int c_small_axis = out_layout.Indexof('c');
//
// only handle depthwise or full conv2d.
// TODO(tvm-team) handle grouped conv by reshape + bcast
- bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
- if (weight_layout.Indexof('o') < 0 &&
- weight_layout.Indexof('i') < 0 &&
+ bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
+ if (kernel_layout.Indexof('o') < 0 &&
+ kernel_layout.Indexof('i') < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return {c_big_axis};
}
const auto* param = call->attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
- Layout out_layout(param->out_layout);
- if (!out_layout.defined()) {
- out_layout = Layout(param->data_layout);
- }
- Layout weight_layout(param->weight_layout);
+ Layout kernel_layout(param->kernel_layout);
+ Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
int c_big_axis = out_layout.Indexof('C');
CHECK_GE(c_big_axis, 0);
// For now, we only support simple pattern (no folded weight/data)
// TODO(tvm-team) support general data layout
- CHECK_EQ(weight_layout.Indexof('o'), -1);
- CHECK_EQ(weight_layout.Indexof('i'), -1);
+ CHECK_EQ(kernel_layout.Indexof('o'), -1);
+ CHECK_EQ(kernel_layout.Indexof('i'), -1);
CHECK(axes.size() == 1 &&
c_big_axis == axes[0]->value);
- int big_oc_axis = weight_layout.Indexof('O');
+ int big_oc_axis = kernel_layout.Indexof('O');
// Check it must be depthwise or full conv2d.
- bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, weight_layout);
+ bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
CHECK(param->groups == 1 || is_depthwise_conv2d);
Expr data = transformer->Transform(
call->args[1], NullValue<AxesSet>(), NullValue<Expr>());
// scale on input for deptwise.
Expr wscale = ExpandBiasToMatchAxis(
- scale, weight_layout.ndim(), {big_oc_axis});
+ scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, wscale);
return CallNode::make(
call->op, {data, weight}, call->attrs, call->type_args);
*/
inline bool IsDepthwiseConv2D(const Call& call,
const Conv2DAttrs* param,
- const Layout& weight_layout) {
+ const Layout& kernel_layout) {
static const Layout kOIHW("OIHW");
auto wshape = ConvertLayout(
call->args[1]->type_as<TensorTypeNode>()->shape,
- weight_layout, kOIHW);
+ kernel_layout, kOIHW);
return is_const_int(wshape[0], param->groups) &&
is_const_int(wshape[1], 1);
}
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
auto param = call->attrs.as<Conv2DAttrs>();
auto tweight = call->args[1]->type_as<TensorTypeNode>();
- auto index = param->weight_layout.find('O');
+ auto index = param->kernel_layout.find('O');
CHECK_NE(index, std::string::npos);
auto channels = as_const_int(tweight->shape[index]);
return *channels;
padding=(1, 1),
channels=16,
data_layout="NCHW4n4c",
- weight_layout="OIHW4o4i",
+ kernel_layout="OIHW4o4i",
out_dtype="int32")
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
- new_attrs['weight_layout'] = 'OIHW16i'
+ new_attrs['kernel_layout'] = 'OIHW16i'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
- weight_layout="OIHW16i",
+ kernel_layout="OIHW16i",
data_layout="NCHW16c")
b = relay.expand_dims(bias, axis=1, num_newaxis=2)
b = relay.layout_transform(b, "CHW", "CHW16c")
y = relay.Function(free_vars(y), y)
return y
- @register_alter_op_layout("nn.conv2d", level=102)
+ @register_alter_op_layout("nn.conv2d", level=105)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
assert(alpha_equal(a, b))
+def test_alter_layout_scalar():
+ """Test alternating the layout of a conv2d.
+ The layout of broadcast operators and the weight should be changed accordingly.
+ """
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var("weight")
+ y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
+ y = relay.add(y, relay.const(1, "float32"))
+ y = relay.Function(free_vars(y), y)
+ return y
+
+ @register_alter_op_layout("nn.conv2d", level=106)
+ def alter_conv2d(attrs, inputs, tinfos):
+ data, weight = inputs
+ new_attrs = dict(attrs)
+ new_attrs['data_layout'] = 'NCHW16c'
+ return relay.nn.conv2d(data, weight, **new_attrs)
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ w = relay.var("weight")
+
+ y = relay.layout_transform(x, "NCHW", "NCHW16c")
+ y = relay.nn.conv2d(y, w,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW16c")
+ y = relay.add(y, relay.const(1.0, "float32"))
+
+ y = relay.layout_transform(y, "NCHW16c", "NCHW")
+ y = relay.Function(free_vars(y), y)
+ return y
+
+ a = before()
+ a = infer_type(a)
+ a = canonicalize_ops(a)
+ a = infer_type(a)
+ a = alter_op_layout(a)
+ a = infer_type(a)
+
+ b = expected()
+ b = infer_type(b)
+
+ assert(alpha_equal(a, b))
+
+def test_alter_layout_concatenate():
+ """ """
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var('weight1')
+ weight2 = relay.var('weight2')
+ y = relay.nn.conv2d(x, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ y1 = relay.nn.conv2d(y, weight2,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ ret = relay.concatenate([y, y1], axis=1)
+ y = relay.Function(free_vars(ret), ret)
+ return y
+
+ @register_alter_op_layout("nn.conv2d", level=107)
+ def alter_conv2d(attrs, inputs, tinfos):
+ data, weight = inputs
+ new_attrs = dict(attrs)
+ new_attrs['data_layout'] = 'NCHW16c'
+ return relay.nn.conv2d(data, weight, **new_attrs)
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var('weight1')
+ weight2 = relay.var('weight2')
+ y = relay.layout_transform(x, "NCHW", "NCHW16c")
+ y = relay.nn.conv2d(y, weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW16c")
+ y1 = relay.nn.conv2d(y, weight2,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout='NCHW16c')
+ ret = relay.concatenate([y, y1], axis=1)
+ ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
+ y = relay.Function(free_vars(ret), ret)
+ return y
+
+ a = before()
+ a = infer_type(a)
+ a = alter_op_layout(a)
+ a = infer_type(a)
+
+ b = expected()
+ b = infer_type(b)
+
+ assert(alpha_equal(a, b))
if __name__ == "__main__":
test_alter_op()
test_alter_layout_dual_path()
test_alter_layout_resnet()
test_alter_layout_broadcast_op()
+ test_alter_layout_scalar()
+ test_alter_layout_concatenate()
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
- weight_layout="HWIO",
+ kernel_layout="HWIO",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x, conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
- weight_layout="HWIO",
+ kernel_layout="HWIO",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
- weight_layout="HWIO",
+ kernel_layout="HWIO",
groups=channels,
padding=(1, 1))
y2 = relay.nn.conv2d(x,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
- weight_layout="HWIO",
+ kernel_layout="HWIO",
groups=channels,
padding=(1, 1))
z = relay.add(y1, y2)
##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["arm_cpu"])
-def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
- """Alter op layout for pre-computing kernel transformation"""
- import nnvm.symbol as sym
+def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
+ """Alter op layout for pre-computing kernel transformation
+
+ Parameters
+ ----------
+ attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
+ Attributes of current convolution
+ inputs : nnvm.symbol or tvm.relay.Expr
+ Grouped input symbols
+ tinfos : list
+ Input shape and dtype
+ F: symbol
+ The context, can be either nnvm.sym or relay.op
+
+ Note
+ ----
+ Unlike other TOPI functions, this function operates on both graph level and operator level,
+ so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
+ """
copy_inputs = [s for s in inputs]
new_attrs = {k: attrs[k] for k in attrs.keys()}
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
groups = attrs.get_int('groups')
- layout = attrs["layout"]
+ data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
+ layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"]
- out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype
+ if out_dtype == "" or out_dtype == "same":
+ out_dtype = tinfos[0].dtype
if layout != 'NCHW' or groups != 1:
return None
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)
- return sym.conv2d(*copy_inputs, **new_attrs)
+ return F.nn.conv2d(*copy_inputs, **new_attrs)
else: # pre-compute weight transformation in winograd
if "-device=arm_cpu" in target.options:
tile_size = 4
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
VC = cfg['tile_bna'].val
- weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
- weight = sym.reshape(weight,
- shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
- weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3])
+ weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
+ weight = F.reshape(weight,
+ newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
+ weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
- new_attrs['layout'], out_dtype, tile_size],
+ new_attrs[data_layout_key], out_dtype, tile_size],
conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg)
- return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
+ return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
##### REGISTER ALTER OP LAYOUT #####
@nn.conv2d_alter_layout.register(["cuda", "gpu"])
-def _alter_conv2d_layout(attrs, inputs, tinfos):
- """Alter op layout for pre-computing kernel transformation"""
+def _alter_conv2d_layout(attrs, inputs, tinfos, F):
+ """Alter op layout for pre-computing kernel transformation
+
+ Parameters
+ ----------
+ attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
+ Attributes of current convolution
+ inputs : nnvm.symbol or tvm.relay.Expr
+ Grouped input symbols
+ tinfos : list
+ Input shape and dtype
+ F: symbol
+ The context, can be either nnvm.sym or relay.op
+
+ Note
+ ----
+ Unlike other TOPI functions, this function operates on both graph level and operator level,
+ so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
+ """
if 'cudnn' in tvm.target.current_target().libs or 'miopen' in tvm.target.current_target().libs:
return None
- import nnvm.symbol as sym
copy_inputs = [s for s in inputs]
-
new_attrs = {k: attrs[k] for k in attrs.keys()}
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int('groups')
- layout = attrs["layout"]
+ data_layout_key = "data_layout" if "data_layout" in new_attrs else "layout"
+ layout = attrs[data_layout_key]
out_dtype = attrs["out_dtype"]
- out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype
+ if out_dtype == "" or out_dtype == "same":
+ out_dtype = tinfos[0].dtype
data, kernel = tinfos[0:2]
N, CI, H, W = get_const_tuple(data.shape)
if cfg.template_key == 'int8':
assert 'cuda' in target.keys
new_layout = 'NCHW4c'
- new_attrs['layout'] = new_layout
+ new_attrs[data_layout_key] = new_layout
new_attrs['out_layout'] = new_layout
new_attrs['kernel_layout'] = 'OIHW4o4i'
ic_block_factor = oc_block_factor = 4
conv2d
)
dispatch_ctx.update(target, new_workload, cfg)
- return sym.conv2d(*copy_inputs, **new_attrs)
+ return F.nn.conv2d(*copy_inputs, **new_attrs)
if attrs.get_int_tuple("dilation") != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
# pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1])
- weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1],
- tile_size=tile_size)
- weight = sym.transpose(weight, axes=[0, 1, 3, 2])
+ weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
+ tile_size=tile_size)
+ weight = F.transpose(weight, axes=[0, 1, 3, 2])
copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size
conv2d_winograd_without_weight_transform
)
dispatch_ctx.update(target, new_workload, cfg)
- return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
+ return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
elif groups != CI:
workload = autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype],
if cfg.template_key == 'int8':
assert 'cuda' in target.keys
new_layout = 'NCHW4c'
- new_attrs['layout'] = new_layout
+ new_attrs[data_layout_key] = new_layout
new_attrs['out_layout'] = new_layout
new_attrs['kernel_layout'] = 'OIHW4o4i'
ic_block_factor = oc_block_factor = 4
group_conv2d_nchw
)
dispatch_ctx.update(target, new_workload, cfg)
- return sym.conv2d(*copy_inputs, **new_attrs)
+ return F.nn.conv2d(*copy_inputs, **new_attrs)
# do nothing for depthwise convolution
return None
from __future__ import absolute_import as _abs
+import warnings
import tvm
from .. import generic
return xi, thread_z, thread_y, thread_x
@conv2d_alter_layout.register(["intel_graphics"])
-def _alter_conv2d_layout(attrs, inputs, tinfos):
+def _alter_conv2d_layout(attrs, inputs, tinfos, F):
import nnvm.symbol as sym
+ if F != sym:
+ warnings.warn("Only support alter layout for intel graphics in NNVM now. "
+ "This pass is ignored in relay.")
+ return None
+
copy_inputs = [s for s in inputs]
data = tinfos[0]
##### REGISTER ALTER OP LAYOUT #####
@conv2d_alter_layout.register(["mali"])
-def _alter_conv2d_layout(attrs, inputs, tinfos):
+def _alter_conv2d_layout(attrs, inputs, tinfos, F):
try:
- return _alter_conv2d_layout_arm(attrs, inputs, tinfos)
+ return _alter_conv2d_layout_arm(attrs, inputs, tinfos, F)
except KeyError: # to filter out fallback opencl templates
return None
@tvm.target.generic_func
-def conv2d_alter_layout(attrs, inputs, tinfos):
+def conv2d_alter_layout(attrs, inputs, tinfos, F):
"""Change Conv2D layout.
Parameters
----------
- attrs : nnvm.top.AttrDict
+ attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
Attributes of current convolution
- inputs : nnvm.symbol
+ inputs : nnvm.symbol or tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
+ F: symbol
+ The context, can be either nnvm.sym or relay.op
+
+ Note
+ ----
+ Unlike other TOPI functions, this function operates on both graph level and operator level,
+ so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
"""
# not to change by default
return None
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D schedule on x86"""
+import warnings
+
import tvm
from tvm import autotvm
from tvm.autotvm.task.topi_integration import deserialize_args
@conv2d_alter_layout.register("cpu")
-def _alter_conv2d_layout(attrs, inputs, tinfo):
+def _alter_conv2d_layout(attrs, inputs, tinfo, F):
import nnvm.symbol as sym
+ if F != sym:
+ warnings.warn("Only support alter layout for x86 in NNVM now. "
+ "This pass is ignored in relay.")
+ return None
+
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
data, kernel = tinfo[0], tinfo[1]