[RELAY][Convert Layout] Specify additional layouts in convert layout pass (#5422)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Wed, 13 May 2020 18:08:32 +0000 (19:08 +0100)
committerGitHub <noreply@github.com>
Wed, 13 May 2020 18:08:32 +0000 (11:08 -0700)
* [RELAY] Specify additional layouts in convert layout pass

* This patch means that you can specify an additional layout, rather than using the layout chosen by default during conversion.
* This is specifically useful for external codegen when a 3rd party library needs to target a specific kernel layout for example.

Change-Id: I3ef9cf45ead574801870a38af9768f93e29aab10

* Use mapping of op name to list of desired layouts

Change-Id: Ibd691a3cb93e73a394f36112668ad52a84c7d5a2

* Fix issue with code block

Change-Id: Ibb4e38c05ad4312b7dea845be699b8d5d57e0a94

* Address comments, Improve tutorial

Change-Id: Ib824eead329d551c338234de3b2d814693afd0ec

* Fix linting

Change-Id: Ie9e1891f590b3a7496a56ff8362cdda9d4b5fa75

* Test uses NCHW default layout. Unrelated issue with NHWC.

Change-Id: I1c16f0db73db56f5e9536db3fe5eb2624c3b595c

* Fix mistake in tutorial

Change-Id: I944041245d27af262dc96f1cd8117f1f19272062

* Address multiple comments

Change-Id: If33a1e34acd8fc37d1c7797ee189a6448a392672

* Improve tutorial

Change-Id: Ib04142c94c7958ab5067947d2ff4c84354e3d0c5

* Fix Clang-format

Change-Id: Ieff39e3f0817d22579c68b3287e972a3b0fcfbc8

docs/dev/convert_layout.rst
include/tvm/relay/op_attr_types.h
include/tvm/relay/transform.h
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/qnn/op/layout_conversions.py
python/tvm/relay/transform/transform.py
src/relay/transforms/convert_layout.cc
tests/python/relay/test_pass_convert_op_layout.py

index 7345c15..ee5350c 100644 (file)
@@ -92,7 +92,7 @@ These steps happen for each operator in sequence, where ConvertLayout pass keeps
 .. code-block:: python
 
     @reg.register_convert_op_layout("nn.conv2d")
-    def convert_conv2d(attrs, inputs, tinfos, desired_layout):
+    def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
         """Convert Layout pass registration for conv2d op.
 
         Parameters
@@ -103,8 +103,9 @@ These steps happen for each operator in sequence, where ConvertLayout pass keeps
             The args of the Relay expr to be legalized
         tinfos : list of types
             List of input and output types
-        desired_layout : str
-            The desired layout
+        desired_layouts : list of layout strings
+                List of layouts defining our desired
+                layout for the data and kernel inputs respectively.
 
         Returns
         -------
@@ -113,19 +114,30 @@ These steps happen for each operator in sequence, where ConvertLayout pass keeps
         """
 
         from tvm import relay
-        data_layout = attrs['data_layout']
-        kernel_layout = attrs['kernel_layout']
         data, weight = inputs
-        assert desired_layout == 'NCHW', \
-                "Currently only transformation to NCHW layout is supported."
-        if desired_layout == 'NCHW':
-            new_attrs = dict(attrs)
-            new_attrs['data_layout'] = desired_layout
-            new_attrs['kernel_layout'] = 'OIHW'
+        new_attrs = dict(attrs)
+
+        # We expect 2 desired layouts to be specified, one for the data and one for the kernel.
+        assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs"
+
+        # Use the first entry in desired layouts which specifies the data layout.
+        # The expected ordering of layouts for this operator is defined by this function.
+        desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
+
+        assert desired_data_layout != "default", "Data layout cannot be default"
+
+        new_attrs['data_layout'] = desired_data_layout
+
+        if desired_data_layout == 'NCHW':
+            if desired_kernel_layout != 'default':
+                new_attrs['kernel_layout'] = desired_kernel_layout
+            else:
+                new_attrs['kernel_layout'] = 'OIHW'
             # Actual insertion of layout transforms is taken care internally
             # by ConvertLayout pass.
             return relay.nn.conv2d(data, weight, **new_attrs)
-        return None
+
+        raise ValueError('Layout %s is not yet supported' % desired_data_layout)
 
 
 **FInferCorrectLayout - Layout inference** - Currently, this attribute is exposed only in C++. This function takes original input layouts and the new input layouts (passed from the previous operator or from the python callback for layout alteration), and infers the final data layouts. Layout inference is called for each operator. The usage might vary for different operator categories. For layout agnostic operators, we just want to return the new data layouts in this function. For lightly-layout and heavily-layout sensitive operators, we can change the operator attributes (like axis for concatenate, pad_width for pad) so that we can adapt to the new data layout, preventing insertion of layout transforms. Let's look at a couple of examples to understand this better.
@@ -218,6 +230,8 @@ Second example is for a lightly-layout sensitive operator - batch normalization.
 
 ConvertLayout pass is extremely easy to use. The pass is not a part of default relay.build pipeline. The intended usage is to call it between the framework-to-relay parser and relay.build module call.
 
+In order to specify the layouts to convert to, we create a mapping of heavily-layout sensitive operators to a list of the desired layouts for that operator. The first example below specifies data layout, we allow the kernel layout to be automatically converted to one that is supported by TVM (for that particular data layout and operator). This is specified by the use of the "default" keyword. The second example shows how we could have also converted to a specific kernel layout of our choosing. It's worth noting that the following examples will convert to the same layouts i.e. `{'nn.conv2d': ['NCHW', 'default']} == {'nn.conv2d': ['NCHW', 'OIHW']}`
+
 .. code-block:: python
 
     # TFlite framework to Relay parser - Default layout is NHWC
@@ -225,10 +239,13 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
                                              shape_dict=shape_dict,
                                              dtype_dict=dtype_dict)
 
+    # We assume our model's heavily-layout sensitive operators only consist of nn.conv2d
+    desired_layouts = {'nn.conv2d': ['NCHW', 'default']}
+
     # Convert the layout to NCHW
     # RemoveUnunsedFunctions is used to clean up the graph.
     seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
-                                      relay.transform.ConvertLayout('NCHW')])
+                                    relay.transform.ConvertLayout(desired_layouts)])
     with relay.transform.PassContext(opt_level=3):
         mod = seq(mod)
 
@@ -236,6 +253,15 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
     with relay.build_config(opt_level=3):
          graph, lib, params = relay.build(mod, target, params=params)
 
+
+.. code-block:: python
+
+    desired_layouts = {'nn.conv2d': ['NCHW', 'OIHW']}
+    pass = relay.transform.ConvertLayout(desired_layouts)
+
+
+The ordering of the layouts is defined by the implementation of `register_convert_op_layout("OPNAME")`, you can refer to the docstring which should explicitly state the expected layout. In the examples above it's [data_layout, kernel_layout].
+
 Current implementation has support for almost all the operators commonly used in image classification models. However, if one encounters too many data layout transforms in the graph, it is highly likely that there is an operator whose layouts need special handling as described in Section 3. Some pull requests that can help in such a situation are
 
 - Layout inference for `Batch Norm <https://github.com/apache/incubator-tvm/pull/4600>`_ - Batch normalization falls into the category of lightly-sensitive operator. The PR shows how to handle the layout inference for batch norm.
index b3e70f5..acd4a03 100644 (file)
@@ -152,12 +152,14 @@ using FTVMAlterOpLayout =
  * \param inputs The input symbols of the original node.
  * \param tinfos An array of placeholders, use for getting the inferred shape
  *               and dtype of the inputs.
- * \param desired_layout The desired layout.
+ * \param desired_layouts Specify an array of desired layouts for each input.
+ *                        For example a conv2d op: Array("NHWC", "OHWI"), this
+ *                        specifies the desired layout for data then kernel.
  * \return new_expr The modified expression.
  */
 using FTVMConvertOpLayout = runtime::TypedPackedFunc<Expr(
     const Attrs& attrs, const Array<Expr>& args, const Array<te::Tensor>& tinfos,
-    const std::string& desired_layout)>;
+    const Array<String>& desired_layouts)>;
 /*!
  * \brief Legalizes an expression with another expression. This function will be
  *  invoked in Legalize pass. It is a target-dependent pass.
index 461276b..9a8ca84 100644 (file)
@@ -281,10 +281,12 @@ TVM_DLL Pass AlterOpLayout();
  * layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout
  * using the InferCorrectLayout infrastructure.
  *
- * \param desired_layout The desired layout.
+ * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input.
+ *                        For example: Map("nn.conv2d", Array("NHWC", "OHWI")),
+ *                        this specifies the desired layout for data then kernel for nn.conv2d.
  * \return The pass.
  */
-TVM_DLL Pass ConvertLayout(const std::string& desired_layout);
+TVM_DLL Pass ConvertLayout(const Map<std::string, Array<String>>& desired_layouts);
 
 /*!
  * \brief Legalizes an expr with another expression.
index ad8c654..9a9bfe0 100644 (file)
@@ -118,7 +118,7 @@ def legalize_conv2d(attrs, inputs, types):
     return topi.nn.conv2d_legalize(attrs, inputs, types)
 
 @reg.register_convert_op_layout("nn.conv2d")
-def convert_conv2d(attrs, inputs, tinfos, desired_layout):
+def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
     """Convert Layout pass registration for conv2d op.
 
     Parameters
@@ -129,8 +129,9 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
         The args of the Relay expr to be legalized
     tinfos : list of types
         List of input and output types
-    desired_layout : str
-        The desired layout
+    desired_layouts : list of layout strings
+        List of layouts defining our desired
+        layout for the data and kernel inputs respectively.
 
     Returns
     -------
@@ -141,11 +142,20 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
     from tvm import relay
     data, weight = inputs
     new_attrs = dict(attrs)
-    new_attrs['data_layout'] = desired_layout
-    if desired_layout == 'NCHW':
+    assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs"
+    desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
+    assert desired_data_layout != "default", "Data layout cannot be default"
+    new_attrs['data_layout'] = desired_data_layout
+
+    if desired_kernel_layout != "default":
+        new_attrs['kernel_layout'] = desired_kernel_layout
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    # Handle default kernel layouts
+    if desired_data_layout == 'NCHW':
         new_attrs['kernel_layout'] = 'OIHW'
         return relay.nn.conv2d(data, weight, **new_attrs)
-    elif desired_layout == 'NHWC':
+    elif desired_data_layout == 'NHWC':
         # Check for depthwise convolution.
         if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape,
                                attrs['kernel_layout'], attrs['groups']):
@@ -153,9 +163,8 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
         else:
             new_attrs['kernel_layout'] = 'HWIO'
         return relay.nn.conv2d(data, weight, **new_attrs)
-    else:
-        assert "Layout %s is not yet supported." % (desired_layout)
-    return None
+
+    raise ValueError("Layout %s is not yet supported." % desired_data_layout)
 
 
 # conv2d_transpose
@@ -193,7 +202,7 @@ def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
     return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)
 
 @reg.register_convert_op_layout("nn.conv3d")
-def convert_conv3d(attrs, inputs, tinfos, desired_layout):
+def convert_conv3d(attrs, inputs, tinfos, desired_layouts):
     """Convert Layout pass registration for conv3d op.
 
     Parameters
@@ -204,8 +213,9 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layout):
         The args of the Relay expr to be legalized
     tinfos : list of types
         List of input and output types
-    desired_layout : str
-        The desired layout
+    desired_layouts : list of layout strings
+        List of layouts defining our desired
+        layout for the data and kernel inputs respectively.
 
     Returns
     -------
@@ -216,16 +226,25 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layout):
     from tvm import relay
     data, weight = inputs
     new_attrs = dict(attrs)
-    new_attrs['data_layout'] = desired_layout
-    if desired_layout == 'NCDHW':
+    assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs"
+    desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
+    assert desired_data_layout != "default", "Data layout cannot be default"
+    new_attrs['data_layout'] = desired_data_layout
+
+    if desired_kernel_layout != "default":
+        new_attrs['kernel_layout'] = desired_kernel_layout
+        return relay.nn.conv3d(data, weight, **new_attrs)
+
+    # Handle default kernel layouts
+    if desired_data_layout == 'NCDHW':
         new_attrs['kernel_layout'] = 'OIDHW'
         return relay.nn.conv3d(data, weight, **new_attrs)
-    elif desired_layout == "NDHWC":
+    elif desired_data_layout == "NDHWC":
         new_attrs['kernel_layout'] = 'DHWIO'
         return relay.nn.conv3d(data, weight, **new_attrs)
-    else:
-        assert "Layout %s is not yet supported" % desired_layout
-    return None
+
+    raise ValueError("Layout %s is not yet supported" % desired_data_layout)
+
 
 # conv3d_winograd related operators
 reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
index f5850b8..caa4c56 100644 (file)
@@ -22,7 +22,7 @@ from tvm.relay.op import op as reg
 
 
 @reg.register_convert_op_layout("qnn.conv2d")
-def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout):
+def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts):
     """Convert Layout pass registration for QNN conv2d op.
 
     Parameters
@@ -33,8 +33,9 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout):
         The args of the Relay expr to be legalized
     tinfos : list of types
         List of input and output types
-    desired_layout : str
-        The desired layout
+    desired_layouts : list of layout strings
+        List of layouts defining our desired
+        layout for the data and kernel inputs respectively.
 
     Returns
     -------
@@ -43,11 +44,18 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout):
     """
     # pylint: disable=import-outside-toplevel
     from tvm import relay
-    assert desired_layout == 'NCHW', \
-            "Currently only transformation to NCHW layout is supported."
-    if desired_layout == 'NCHW':
-        new_attrs = dict(attrs)
-        new_attrs['data_layout'] = desired_layout
-        new_attrs['kernel_layout'] = 'OIHW'
+    assert len(desired_layouts) == 2, "A desired layout is expected for both of qnn.conv2d's inputs"
+    desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
+    assert desired_data_layout != "default", "Data layout cannot be default"
+
+    new_attrs = dict(attrs)
+    new_attrs['data_layout'] = desired_data_layout
+
+    if desired_data_layout == 'NCHW':
+        if desired_kernel_layout != "default":
+            new_attrs['kernel_layout'] = desired_kernel_layout
+        else:
+            new_attrs['kernel_layout'] = 'OIHW'
         return relay.qnn.op.conv2d(*inputs, **new_attrs)
-    return None
+
+    raise ValueError('Layout %s is not yet supported' % desired_data_layout)
index 647e999..c58c679 100644 (file)
@@ -324,7 +324,7 @@ def AlterOpLayout():
     return _ffi_api.AlterOpLayout()
 
 
-def ConvertLayout(desired_layout):
+def ConvertLayout(desired_layouts):
     """ Given a dest layout, this pass transforms the expr such that most of the ops input data
     layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms,
     one at the start and one at the end.
@@ -341,15 +341,18 @@ def ConvertLayout(desired_layout):
 
     Parameters
     ----------
-    desired_layout : str
-      The desired layout for the transformed expr.
+    desired_layouts : map of op_name to list of layouts
+        Specify a mapping of operator names to a list of layouts to convert to, in the order
+        defined by the operator. An example for nn.conv2d could be: {"nn.conv2d", ["NHWC", "OHWI]},
+        where the first item in the list specifies the data layout and the second specifies the
+        kernel layout.
 
     Returns
     -------
     pass: FunctionPass
       The pass.
     """
-    return _ffi_api.ConvertLayout(desired_layout)
+    return _ffi_api.ConvertLayout(desired_layouts)
 
 
 def Legalize(legalize_map_attr_name="FTVMLegalize"):
index f43c8f6..7d42125 100644 (file)
@@ -51,13 +51,15 @@ class ConvertTransformMemorizerNode : public TransformMemorizerNode {
  public:
   /*!
    * \brief Initializes the desired_layout.
-   * \param desired_layout The desired layout.
+   * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input.
+   *                        For example: Map("nn.conv2d", Array("NHWC", "OHWI")),
+   *                        this specifies the desired layout for data then kernel for nn.conv2d.
    */
-  explicit ConvertTransformMemorizerNode(const std::string& desired_layout)
-      : desired_layout_(desired_layout) {}
+  explicit ConvertTransformMemorizerNode(Map<std::string, Array<String>> desired_layouts)
+      : desired_layouts_(std::move(desired_layouts)) {}
 
-  /*! \brief The desired layout for the Convert Layout pass */
-  std::string desired_layout_;
+  /*! \brief A mapping of op_name to array of desired layouts for each input. */
+  Map<std::string, Array<String>> desired_layouts_;
 };
 
 /*!
@@ -91,8 +93,14 @@ class ConvertTransformMemorizer : public TransformMemorizer {
         auto ttype = expr->type_as<TensorTypeNode>();
         tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
       }
+
+      auto desired_layouts = operator->()->desired_layouts_;
+      if (desired_layouts.find(op->name) == desired_layouts.end()) {
+        LOG(FATAL) << "Desired layout(s) not specified for op: " << op->name;
+      }
+      Array<String> op_desired_layouts = desired_layouts.at(op->name);
       Expr altered_value =
-          fconvert_layout[op](ref_call->attrs, new_args, tinfos, operator->()->desired_layout_);
+          fconvert_layout[op](ref_call->attrs, new_args, tinfos, op_desired_layouts);
       if (altered_value.defined()) {
         new_e = altered_value;
         modified = true;
@@ -115,9 +123,9 @@ class ConvertTransformMemorizer : public TransformMemorizer {
  * 1. The altered op should have the same number of arguments as the previous one.
  * 2. Do not support nested tuple arguments.
  */
-Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) {
+Expr ConvertLayout(const Expr& expr, const Map<std::string, Array<String>>& desired_layouts) {
   ConvertTransformMemorizer transformMemorizer(
-      make_object<ConvertTransformMemorizerNode>(desired_layout));
+      make_object<ConvertTransformMemorizerNode>(desired_layouts));
   auto fcontext = [&](const Call& call) -> ObjectRef { return transformMemorizer; };
 
   return ForwardRewrite(expr, LayoutRewriter<ConvertTransformMemorizer>, fcontext);
@@ -127,10 +135,10 @@ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) {
 
 namespace transform {
 
-Pass ConvertLayout(const std::string& desired_layout) {
+Pass ConvertLayout(const Map<std::string, Array<String>>& desired_layouts) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
       [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
+        return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layouts));
       };
   return CreateFunctionPass(pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
 }
index c5a7b0e..f3cdbfc 100644 (file)
@@ -49,7 +49,7 @@ def test_no_convert_layout():
         return before()
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -84,7 +84,7 @@ def test_conv_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -129,7 +129,7 @@ def test_conv_bias_pool_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -177,7 +177,7 @@ def test_conv_concat_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -232,7 +232,7 @@ def test_dual_path_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -256,7 +256,7 @@ def test_bn_convert_layout():
         return relay.Function(analysis.free_vars(y), y)
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
 
     # Check that there is only 1 NHWC to NCHW transform.
     has_lt = list()
@@ -312,7 +312,7 @@ def test_resnet_convert_layout():
         return relay.Function(analysis.free_vars(y), y)
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -344,7 +344,7 @@ def test_scalar_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -392,7 +392,7 @@ def test_conv_bn_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -448,7 +448,7 @@ def test_qnn_conv_requantize_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -526,7 +526,7 @@ def test_qnn_conv_concat_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -606,7 +606,132 @@ def test_qnn_conv_add_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
+    a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']}))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
+def test_conv_convert_kernel_layout():
+    """ Check that convolution kernel layout is correctly transformed. """
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
+                            data_layout='NHWC', kernel_layout='HWIO')
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        w = relay.var("weight", shape=(3, 3, 64, 64))
+        w = relay.layout_transform(w, 'HWIO', 'OHWI')
+        y = relay.nn.conv2d(x, w,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout='NHWC',
+                            kernel_layout='OHWI')
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NHWC', 'OHWI']}))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
+def test_default_keyword():
+    """ Check that the default keyword selects correct TVM default layout. """
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var("weight", shape=(64, 3, 3, 64))
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
+                            data_layout='NCHW', kernel_layout='OHWI')
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        w = relay.var("weight", shape=(64, 3, 3, 64))
+        w = relay.layout_transform(w, 'OHWI', 'OIHW')
+        y = relay.nn.conv2d(x, w,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout='NCHW',
+                            kernel_layout='OIHW')
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
+def test_different_ops_convert_layout():
+    """ Check convert layout correctly supports converting the layout of
+    different ops in the same graph.
+    """
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var("weight1", shape=(64, 3, 3, 64))
+        weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype='int8')
+        out = relay.nn.conv2d(x, weight1,
+                              channels=64,
+                              kernel_size=(3, 3),
+                              padding=(1, 1),
+                              data_layout='NCHW',
+                              kernel_layout='OHWI')
+        out = relay.cast(out, 'int8')
+        out = relay.qnn.op.conv2d(out, weight2,
+                                  relay.const(1, 'int32'),
+                                  relay.const(1, 'int32'),
+                                  relay.const(1, 'float32'),
+                                  relay.const(1, 'float32'),
+                                  channels=64,
+                                  kernel_size=(3, 3),
+                                  padding=(1, 1),
+                                  data_layout='NCHW',
+                                  kernel_layout='OHWI')
+        out = relay.Function(analysis.free_vars(out), out)
+        return out
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var("weight1", shape=(64, 3, 3, 64))
+        weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype='int8')
+        x = relay.layout_transform(x, 'NCHW', 'NHWC')
+        weight1 = relay.layout_transform(weight1, 'OHWI', 'HWIO')
+        out = relay.nn.conv2d(x, weight1,
+                              channels=64,
+                              kernel_size=(3, 3),
+                              padding=(1, 1),
+                              data_layout='NHWC',
+                              kernel_layout='HWIO')
+        out = relay.cast(out, 'int8')
+        out = relay.layout_transform(out, 'NHWC', 'NCHW')
+        weight2 = relay.layout_transform(weight2, 'OHWI', 'OIHW')
+        out = relay.qnn.op.conv2d(out, weight2,
+                                  relay.const(1, 'int32'),
+                                  relay.const(1, 'int32'),
+                                  relay.const(1, 'float32'),
+                                  relay.const(1, 'float32'),
+                                  channels=64,
+                                  kernel_size=(3, 3),
+                                  padding=(1, 1),
+                                  data_layout='NCHW',
+                                  kernel_layout='OIHW')
+        out = relay.Function(analysis.free_vars(out), out)
+        return out
+
+    a = before()
+    desired_layouts = {'nn.conv2d': ['NHWC', 'HWIO'],
+                       'qnn.conv2d': ['NCHW', 'OIHW']}
+    a = run_opt_pass(a, transform.ConvertLayout(desired_layouts))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -625,3 +750,6 @@ if __name__ == "__main__":
     test_qnn_conv_requantize_convert_layout()
     test_qnn_conv_concat_convert_layout()
     test_qnn_conv_add_convert_layout()
+    test_conv_convert_kernel_layout()
+    test_default_keyword()
+    test_different_ops_convert_layout()