from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
-# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
- return None
+ """Legalize conv2d"""
+ from ... import op
+ return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# specific language governing permissions and limitations
# under the License.
"""Test legalize pass"""
+import numpy as np
import tvm
from tvm import relay
+from tvm.contrib import graph_runtime
from tvm.relay.op import register_legalize
from tvm.relay import transform, analysis
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+def test_legalize_arm_layout_functional():
+ """Test if the legalized conversion yields same result as original"""
+ def get_output(func, data_val, parameters):
+ with relay.build_config(opt_level=0):
+ graph, lib, params = relay.build(func, target='llvm', params=parameters)
+ m = graph_runtime.create(graph, lib, tvm.cpu())
+ m.set_input("data", data_val)
+ m.set_input(**params)
+ m.run()
+ out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
+ return out
+
+ def before():
+ n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
+ data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
+ kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
+ y = relay.nn.conv2d(data, kernel,
+ kernel_size=(kh, kw),
+ channels=oc,
+ padding=(1, 1),
+ dilation=(1, 1),
+ data_layout='NHWC',
+ kernel_layout='HWIO',
+ out_dtype='float32')
+ func = relay.Function([data, kernel], y)
+ return func
+
+ @register_legalize("nn.conv2d", level=101)
+ def legalize_conv2d(attrs, inputs, arg_types):
+ from topi.arm_cpu.conv2d import _conv2d_legalize
+ return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
+
+ a = before()
+ b = run_opt_pass(a, transform.Legalize())
+ assert b.astext().count('transpose') == 3
+
+ wdata = np.random.rand(3, 3, 16, 32) * 10
+ parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
+ data_val = np.random.rand(1, 224, 224, 16).astype('float32')
+ ref_out = get_output(a, data_val, parameters)
+ legalized_out = get_output(b, data_val, parameters)
+ np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
+
if __name__ == "__main__":
test_legalize()
test_legalize_none()
test_legalize_multi_input()
+ test_legalize_arm_layout_functional()
conv2d_winograd_without_weight_transform, \
conv2d_winograd_nnpack_without_weight_transform, \
depthwise_conv2d_nchw
+from ..nn import conv2d_legalize
from ..nn.util import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
# currently we only have contrib_spatial_pack and direct template
# add more schedule templates.
return None
+
+@conv2d_legalize.register("arm_cpu")
+def _conv2d_legalize(attrs, inputs, arg_types, F):
+ if F.__name__ != 'tvm.relay.op':
+ return None
+ if attrs['data_layout'] == 'NHWC':
+ data, kernel = inputs
+ if attrs['kernel_layout'] == 'HWIO':
+ # Handle HWIO layout. This is common in TF graph.
+ kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
+ elif attrs['kernel_layout'] == 'HWOI':
+ # Handle HWOI layout. This is common in TF depthwise conv2d graph.
+ kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
+ elif attrs['kernel_layout'] != 'OIHW':
+ return None
+
+ warnings.warn("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to "
+ + "fallback to NCHW. This can result in performance degradation.")
+ # Set new attrs for the tranposed conv.
+ new_attrs = {k: attrs[k] for k in attrs.keys()}
+ new_attrs['data_layout'] = 'NCHW'
+ new_attrs['kernel_layout'] = 'OIHW'
+
+ # Convert from NHWC to NCHW.
+ data = F.transpose(data, axes=(0, 3, 1, 2))
+ conv = F.nn.conv2d(data, kernel, **new_attrs)
+ # Convert back to original NHWC layout.
+ out = F.transpose(conv, axes=(0, 2, 3, 1))
+ return out
+ return None
@tvm.target.generic_func
+def conv2d_legalize(attrs, inputs, arg_dtypes, F):
+ """Legalizes Conv2D op.
+ Parameters
+ ----------
+ attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized.
+ arg_dtypes : list of types
+ List of types of input arguments
+ F: symbol
+ The context, can be either nnvm.sym or relay.op
+ Note
+ ----
+ Unlike other TOPI functions, this function operates on both graph level and operator level,
+ so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
+ """
+ # not to change by default
+ return None
+
+
+@tvm.target.generic_func
def conv2d_alter_layout(attrs, inputs, tinfos, F):
"""Change Conv2D layout.