[Lang] Layout in TVM node system (#2509)
authorYizhi Liu <liuyizhi@apache.org>
Thu, 28 Feb 2019 04:26:22 +0000 (20:26 -0800)
committerLianmin Zheng <mercy_zheng@sjtu.edu.cn>
Thu, 28 Feb 2019 04:26:22 +0000 (12:26 +0800)
* move layout.h & layout.cc from relay to tvm

* change ConvertLayout in relay to bijectiveLayout->Forward/backward

* add first test case

* add LayoutAxis

* add LayoutAxis struct and compiles

* simplify BijectiveLayout rule consturct

* polish func name for Layout, move impl to .cc, remove Layout::defined(), add defined() checker

* partially add layout py support

* add layout test cases

* add doc for tvm.layout & tvm.bijective_layout

* fix lint

* fix lint

* fix layout name generation bug

* fix layout typo

* address comments and add topi.layout_transform

* layout.h->data_layout.h, test_lang_layout.py->test_lang_data_layout.py

29 files changed:
docs/api/python/topi.rst
include/tvm/data_layout.h [new file with mode: 0644]
nnvm/src/top/nn/nn.cc
python/tvm/api.py
python/tvm/tensor.py
src/api/api_lang.cc
src/lang/data_layout.cc [new file with mode: 0644]
src/relay/op/debug.cc
src/relay/op/image/resize.cc
src/relay/op/layout.cc [deleted file]
src/relay/op/layout.h [deleted file]
src/relay/op/nn/convolution.cc
src/relay/op/nn/nn.cc
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/nn/upsampling.cc
src/relay/op/tensor/transform.cc
src/relay/pass/alter_op_layout.cc
src/relay/pass/alter_op_layout.h
src/relay/pass/combine_parallel_conv2d.cc
src/relay/pass/fold_scale_axis.cc
src/relay/pass/mac_count.cc
src/relay/pass/pattern_util.h
tests/python/unittest/test_lang_data_layout.py [new file with mode: 0644]
topi/include/topi/nn.h
topi/include/topi/transform.h
topi/python/topi/transform.py
topi/src/topi.cc
topi/tests/python/test_topi_transform.py

index ec5d600dab2bb6f83418dcf94d430d060f511a2c..9680adc1231b7a539221b305bbe0aeb4bb521611 100644 (file)
@@ -68,6 +68,7 @@ List of operators
    topi.greater_equal
    topi.less_equal
    topi.arange
+   topi.layout_transform
    topi.image.resize
 
 
@@ -125,6 +126,7 @@ topi
 .. autofunction:: topi.greater
 .. autofunction:: topi.less
 .. autofunction:: topi.arange
+.. autofunction:: topi.layout_transform
 
 topi.nn
 ~~~~~~~
diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h
new file mode 100644 (file)
index 0000000..99aebc3
--- /dev/null
@@ -0,0 +1,335 @@
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tvm/data_layout.h
+ * \brief Layout expression to describe the data organization of a tensor.
+ *  And BijectiveLayout to mapping two data layouts between each other.
+ */
+#ifndef TVM_DATA_LAYOUT_H_
+#define TVM_DATA_LAYOUT_H_
+
+#include <tvm/base.h>
+#include <tvm/expr.h>
+
+#include <string>
+#include <sstream>
+#include <vector>
+#include <utility>
+#include <algorithm>
+
+#include "ir_operator.h"
+
+namespace tvm {
+
+class LayoutAxis {
+ public:
+  static const LayoutAxis& Get(const char name);
+
+  // Get the singleton LayoutAxis using itvar->var->name_hint
+  static const LayoutAxis& Get(const IterVar& itvar);
+
+  // Get the singleton LayoutAxis using name[0] (size of name must be 1).
+  static const LayoutAxis& make(const std::string& name);
+
+  inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
+  inline std::string name() const { return std::string(1, name_); }
+
+  // if current axis is primal, switch the axis to its subordinate one,
+  // else switch to the primal.
+  inline const LayoutAxis& ToDual() const {
+    if (name_ >= 'A' && name_ <= 'Z') {
+      return LayoutAxis::Get(name_ - 'A' + 'a');
+    } else {
+      return LayoutAxis::Get(name_ - 'a' + 'A');
+    }
+  }
+
+  // return the primal axis. If it is already primal, return itself.
+  const LayoutAxis& ToPrimal() const {
+    return IsPrimal() ? *this : ToDual();
+  }
+
+  // return the subordinate axis. If it is already subordinate, return itself.
+  const LayoutAxis& ToSubordinate() const {
+    return IsPrimal() ? ToDual() : *this;
+  }
+
+  inline bool operator==(const LayoutAxis& rhs) const {
+    return name_ == rhs.name_;
+  }
+
+  friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
+    os << l.name();
+    return os;
+  }
+
+ private:
+  static const LayoutAxis UPPER_CASE[];
+  static const LayoutAxis LOWER_CASE[];
+  LayoutAxis(const LayoutAxis&);
+  LayoutAxis& operator=(const LayoutAxis&);
+  explicit LayoutAxis(const char name) : name_(name) {}
+
+  const char name_;
+};
+
+class Layout;
+// Internal node container Buffer
+class LayoutNode : public Node {
+ public:
+  /*! \brief string representation of layout */
+  std::string name;
+  /*! \brief specify each axis of the layout,
+   *   in which the variable name is the name of the axis.
+   *   The IterVar's extent indicates the size of the axis,
+   *   it is a variable for a primal axis, but a constant for a subordinate axis.
+   */
+  Array<IterVar> axes;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("name", &name);
+    v->Visit("axes", &axes);
+  }
+
+  TVM_DLL static Layout make(const std::string& layout);
+
+  static constexpr const char* _type_key = "Layout";
+  TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
+};
+
+/*!
+ * \brief Layout is to describe how data is organized within an N-dimention tensor.
+ *  It is composed of upper cases, lower cases and numbers,
+ *  where upper case indicates a primal axis and
+ *  the corresponding lower case with factor size indicates the subordinate axis.
+ *  For example, NCHW16c can describe a 5-D tensor of
+ *  [batch_size, channel, height, width, channel_block].
+ *  Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
+ */
+class Layout : public NodeRef {
+ public:
+  explicit Layout(NodePtr<Node> n) : NodeRef(n) {}
+
+  /*! \brief default constructor */
+  Layout() = default;
+
+  explicit Layout(const Array<IterVar>& axes);
+
+  /*! \brief construct from a string */
+  Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
+
+  /*!
+   * \brief construct from a string.
+   * \param name input in layout convention:
+   *        upper case indicates a dimension and
+   *        the corresponding lower case with factor size
+   *        indicates the split dimension.
+   *        return undefined layout if "__undef__" is passed.
+   */
+  Layout(const std::string& name); // NOLINT(*)
+
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  const LayoutNode* operator->() const {
+    return static_cast<const LayoutNode*>(node_.get());
+  }
+
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  LayoutNode* operator->() {
+    return static_cast<LayoutNode*>(node_.get());
+  }
+
+  /*!
+   * \brief Return an undefined layout.
+   * \return a (global) undefined layout.
+   */
+  static const Layout& Undef() {
+    static Layout undef;
+    return undef;
+  }
+
+  /*!
+   * \brief Returns a sub-layout which is the portion of the object
+   *        that starts at dimension \p pos and spans \p len dimensions
+   *        (or until the end of the layout, whichever comes first).
+   * \param pos The start position.
+   * \param len The length of the sub-layout.
+   * \return A newly constructed Layout object.
+   */
+  Layout SubLayout(size_t pos, size_t len) const;
+
+  /*!
+   * \brief Split \p axis by \p size and put the sub-axis to position \p target_pos.
+   * \param axis The source axis to be split. It must be a primal-axis;
+   * \param target_pos The target position of the newly split subordinate-axis.
+   * \param factor size of the sub-dimension.
+   * \return A newly constructed Layout object.
+   */
+  Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const;
+
+
+  /*! \return number of dimensions */
+  inline size_t ndim() const {
+    if (!defined()) return 0;
+    return operator->()->axes.size();
+  }
+
+  /*! \return number of super dimensions */
+  inline size_t ndim_primal() const {
+    if (!defined()) return 0;
+    size_t ct = 0;
+    for (auto x : operator->()->axes) {
+      if (LayoutAxis::Get(x).IsPrimal()) {
+        ct++;
+      }
+    }
+    return ct;
+  }
+
+  /*!
+   * \brief return the index of the input axis.
+   *        If it is not found in the layout or the layout is undefined,
+   *        return -1.
+   * \param axis the input axis.
+   * \return the index or -1 if not found.
+   */
+  inline int32_t IndexOf(const LayoutAxis& axis) const {
+    if (!this->defined()) return -1;
+    const auto axes = operator->()->axes;
+    for (size_t i = 0; i < axes.size(); ++i) {
+      if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
+    }
+    return -1;
+  }
+
+  /*!
+   * \brief Get the factor size of the subordinate axis.
+   * \param axis the input primal-axis or subordinate-axis.
+   * \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis),
+   *         or the size of \p axis itself (if \p axis is a subordinate-axis).
+   *         Return -1 if \p axis is not in the layout the layout is undefined.
+   */
+  int32_t FactorOf(const LayoutAxis& axis) const;
+
+  /*!
+   * \brief Whether the layout contains an axis.
+   * \param axis axis to be checked.
+   * \return Whether the layout contains the axis.
+   */
+  bool Contains(const LayoutAxis& axis) const {
+    if (!defined()) return false;
+    for (const IterVar var : operator->()->axes) {
+      if (var->var.get()->name_hint == axis.name()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  const LayoutAxis& operator[](int32_t i) const {
+    CHECK(defined()) << "Try to access axis from an undefined layout.";
+    int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
+    CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
+    const IterVar axis = operator->()->axes[index];
+    return LayoutAxis::Get(axis);
+  }
+
+  /*! \return the string description of the layout */
+  inline std::string name() const {
+    if (!defined()) return "__undef__";
+    return operator->()->name;
+  }
+
+  /*!
+   * \brief Whether the two layouts are equal.
+   * \param rhs Another layout.
+   * \return whether the two layouts are equal.
+   */
+  inline bool Equals(const Layout &rhs) const {
+    return name() == rhs.name();
+  }
+
+  /*!
+   * \brief allow output string of layout to ostream
+   * \param os the output stream
+   * \param l the layout
+   * \return the ostream
+   */
+  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
+    os << l.name();
+    return os;
+  }
+
+  using ContainerType = LayoutNode;
+};
+
+class BijectiveLayout;
+// Internal node container BijectiveLayout
+class BijectiveLayoutNode : public Node {
+ public:
+  /*! \brief Describes how source axes can be mapped to the destination axes,
+   *   e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
+   */
+  Array<Expr> forward_rule;
+  /*! \brief Describes how destination axes can be mapped to the source axes */
+  Array<Expr> backward_rule;
+
+  /*! \brief The source layout */
+  Layout src_layout;
+  /*! \brief The destination layout */
+  Layout dst_layout;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("src_layout", &src_layout);
+    v->Visit("dst_layout", &dst_layout);
+    v->Visit("forward_rule", &forward_rule);
+    v->Visit("backward_rule", &backward_rule);
+  }
+
+  static constexpr const char* _type_key = "BijectiveLayout";
+  TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node);
+
+  TVM_DLL static BijectiveLayout make(const Layout& src_layout,
+                                      const Layout& dst_layout);
+};
+
+/*! \brief Bijective function mapping for data layout transformation.
+ *   Given two Layout, BijectiveLayout build and store the mapping rules,
+ *   provides API to transform N-dimention tensor from the source indices (i0, i1, …, im)
+ *   to the destination indices (j0, j1, … jm).
+ */
+class BijectiveLayout : public NodeRef {
+ public:
+  BijectiveLayout() = default;
+  explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {}
+
+  // Given the source shape, infer the destination shape.
+  TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
+  // Given the destination shape, recover the source shape.
+  TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const;
+  // Given the destination indices, infer the destination indices.
+  TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const;
+  // Given the destination indices, recover the source indices.
+  TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const;
+
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const BijectiveLayoutNode* operator->() const;
+
+  /*! \brief specify container node */
+  using ContainerType = BijectiveLayoutNode;
+};
+
+inline const BijectiveLayoutNode* BijectiveLayout::operator->() const {
+  return static_cast<const BijectiveLayoutNode*>(node_.get());
+}
+
+}  // namespace tvm
+
+#endif  // TVM_DATA_LAYOUT_H_
index e301f167ff1d9e74cc534fa4c7b2f1e8da28a678..694f0d54a0e4be1a8f8ff5b8fbfd0dbb226b0e09 100644 (file)
@@ -674,42 +674,8 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
                     const Array<Tensor>& inputs,
                     const Array<Tensor>& outputs) {
     const LayoutTransformParam& param = nnvm::get<LayoutTransformParam>(attrs.parsed);
-
-    Layout src_layout(param.src_layout);
-    Layout dst_layout(param.dst_layout);
-
-    if (src_layout == dst_layout) {
-      return Array<Tensor>{ inputs[0] };
-    } else if (!src_layout.defined() || !dst_layout.defined()) {
-      LOG(FATAL) << "cannot convert from/to undefined layout";
-    }
-
-    CHECK(src_layout.convertible(dst_layout)) << "cannot convert from " << param.src_layout
-                                                << " to " << param.dst_layout;
-
-    return Array<Tensor> {
-      topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) {
-        std::vector<Expr> dst_to_src_indices;
-        for (Layout::LayoutDim src_axis : src_layout) {
-          int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
-          int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
-          int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
-          int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
-
-          Expr src_index(dst_indices[dst_major_pos]);
-          if (dst_minor_pos >= 0) {
-            CHECK_GT(dst_factor, 0);
-            src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
-          }
-          if (Layout::is_superdim(src_axis) && src_factor > 0) {
-            src_index = src_index / src_factor;
-          } else if (Layout::is_subdim(src_axis) && src_factor > 0) {
-            src_index = src_index % src_factor;
-          }
-          dst_to_src_indices.push_back(src_index);
-        }
-        return Array<Expr>(dst_to_src_indices);
-      })
+    return Array<Tensor>{
+      topi::layout_transform(inputs[0], param.src_layout, param.dst_layout)
     };
 })
 .set_support_level(1);
index 514490ae83ea08bbc9de78d96bbc574dc7e746cb..7b81f863f6b0e068610ae9c3394a98ef25610fcf 100644 (file)
@@ -515,7 +515,7 @@ def decl_buffer(shape,
                 scope="",
                 data_alignment=-1,
                 offset_factor=0):
-    """Decleare a new symbolic buffer.
+    """Declare a new symbolic buffer.
 
     Normally buffer is created automatically during lower and build.
     This is only needed if user want to specify their own buffer layout.
@@ -587,6 +587,49 @@ def decl_buffer(shape,
         data, dtype, shape, strides, elem_offset, name, scope,
         data_alignment, offset_factor)
 
+def layout(layout_str):
+    """Create a layout node from a string.
+
+    Parameters
+    ----------
+    layout_str : str
+        A layout representation is composed of upper cases, lower cases and numbers,
+        where upper case indicates a primal axis and
+        the corresponding lower case with factor size indicates the subordinate axis.
+        For example, NCHW16c can describe a 5-D tensor of
+        [batch_size, channel, height, width, channel_block].
+        Here subordinate axis channel_block=16 is the factor size of
+        the primal axis C (channel).
+
+    Returns
+    -------
+    layout : Layout
+        The created layout
+    """
+    return _api_internal._Layout(layout_str)
+
+def bijective_layout(src_layout, dst_layout):
+    """Create a bijective layout mapping.
+
+    Parameters
+    ----------
+    src_layout : str or Layout
+        source layout.
+
+    dst_layout : str or Layout
+        destination layout.
+
+    Returns
+    -------
+    bijective_layout : BijectiveLayout
+        The created bijective layout
+    """
+    if isinstance(src_layout, str):
+        src_layout = layout(src_layout)
+    if isinstance(dst_layout, str):
+        dst_layout = layout(dst_layout)
+    return _api_internal._BijectiveLayout(src_layout, dst_layout)
+
 def _IterVar(dom, name, iter_type, thread_tag=''):
     """Internal function to create IterVar
 
index 6e7a2b357a967c0cf80563c2f89729e9312a67dc..ce8f16d6a309e0fe92635544ce9286c7124eb2df 100644 (file)
@@ -185,3 +185,142 @@ class HybridOp(Operation):
     def axis(self):
         """Represent axis of IterVar, also defined when it is a HybridOp"""
         return self.__getattr__("axis")
+
+
+@register_node
+class Layout(NodeBase):
+    """Layout is composed of upper cases, lower cases and numbers,
+    where upper case indicates a primal axis and
+    the corresponding lower case with factor size indicates the subordinate axis.
+    For example, NCHW16c can describe a 5-D tensor of
+    [batch_size, channel, height, width, channel_block].
+    Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
+
+    Do not construct directly, use :any:`layout` instead.
+    See the documentation of :any:`layout` for more details.
+
+    See Also
+    --------
+    layout : Declare a layout
+    """
+    def __str__(self):
+        return self.name
+
+    def __repr__(self):
+        return "Layout(" + self.name + ")"
+
+    def __len__(self):
+        return _api_internal._LayoutNdim(self)
+
+    def __contains__(self, axis):
+        return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name
+
+    def __getitem__(self, index):
+        if index >= len(self):
+            raise IndexError("Layout index out of range")
+        return _api_internal._LayoutGetItem(self, index)
+
+    def index_of(self, axis):
+        """Get the index of an axis
+
+        Parameters
+        ----------
+        axis : str
+            The axis name, need to be [a-z,A-Z]
+
+        Returns
+        -------
+        index : int
+            The index of the axis, -1 if not found.
+        """
+        return _api_internal._LayoutIndexOf(self, axis)
+
+    def factor_of(self, axis):
+        """Get the factor size of the subordinate axis.
+
+        Parameters
+        ----------
+        axis : str
+            The axis name, need to be [a-z,A-Z]
+
+        Returns
+        -------
+        factor : int
+            the size of the subordinate-axis of axis (if axis is a primal-axis),
+            or the size of axis itself (if axis is a subordinate-axis).
+            Return -1 if axis is not in the layout.
+        """
+        return _api_internal._LayoutFactorOf(self, axis)
+
+
+@register_node
+class BijectiveLayout(NodeBase):
+    """Bijective mapping for two layouts (src-layout and dst-layout).
+    It provides shape and index conversion between each other.
+
+    Do not construct directly, use :any:`bijective_layout` instead.
+    See the documentation of :any:`bijective_layout` for more details.
+
+    See Also
+    --------
+    bijective_layout : Declare a bijective layout converter
+    """
+    def forward_index(self, index):
+        """Given the indices of the src-layout, infer the dst index.
+
+        Parameters
+        ----------
+        index: Array of Expr
+            The indices in src-layout.
+
+        Returns
+        -------
+        dst_index: Array of Expr
+            The inferred indices in dst-layout.
+        """
+        return _api_internal._BijectiveLayoutForwardIndex(self, index)
+
+    def backward_index(self, index):
+        """Given the indices of the dst-layout, infer the src index.
+
+        Parameters
+        ----------
+        index: Array of Expr
+            The indices in dst-layout.
+
+        Returns
+        -------
+        src_index: Array of Expr
+            The inferred indices in src-layout.
+        """
+        return _api_internal._BijectiveLayoutBackwardIndex(self, index)
+
+    def forward_shape(self, shape):
+        """Given the shape of the src-layout, infer the dst shape.
+
+        Parameters
+        ----------
+        shape: Array of Expr
+            The shape in src-layout.
+
+        Returns
+        -------
+        dst_shape: Array of Expr
+            The inferred shape in dst-layout.
+        """
+        return _api_internal._BijectiveLayoutForwardShape(self, shape)
+
+    def backward_shape(self, shape):
+        """Given the shape of the dst-layout, infer the src shape.
+
+        Parameters
+        ----------
+        shape: Array of Expr
+            The shape in dst-layout.
+
+        Returns
+        -------
+        src_shape: Array of Expr
+            The inferred shape in src-layout.
+        """
+        return _api_internal._BijectiveLayoutBackwardShape(self, shape)
index e30111e938bd117981f2a3ac15549c7b38561cc7..50f81644b0b516f0db09773c49fe05d85ac2bbc0 100644 (file)
@@ -11,6 +11,7 @@
 #include <tvm/schedule.h>
 #include <tvm/api_registry.h>
 #include <tvm/build_module.h>
+#include <tvm/data_layout.h>
 
 namespace tvm {
 
@@ -224,6 +225,63 @@ TVM_REGISTER_API("_BufferVStore")
         .vstore(args[1], args[2]);
   });
 
+TVM_REGISTER_API("_Layout")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = LayoutNode::make(args[0]);
+  });
+
+TVM_REGISTER_API("_LayoutIndexOf")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+  *ret = args[0].operator Layout()
+      .IndexOf(LayoutAxis::make(args[1]));
+});
+
+TVM_REGISTER_API("_LayoutFactorOf")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+  *ret = args[0].operator Layout()
+      .FactorOf(LayoutAxis::make(args[1]));
+});
+
+TVM_REGISTER_API("_LayoutNdim")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+  *ret = static_cast<int64_t>(args[0].operator Layout().ndim());
+});
+
+TVM_REGISTER_API("_LayoutGetItem")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+  const LayoutAxis& axis = args[0].operator Layout()[args[1]];
+  *ret = axis.name();
+});
+
+TVM_REGISTER_API("_BijectiveLayout")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = BijectiveLayoutNode::make(args[0], args[1]);
+  });
+
+TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = args[0].operator BijectiveLayout()
+        .ForwardIndex(args[1]);
+  });
+
+TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = args[0].operator BijectiveLayout()
+        .BackwardIndex(args[1]);
+  });
+
+TVM_REGISTER_API("_BijectiveLayoutForwardShape")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = args[0].operator BijectiveLayout()
+        .ForwardShape(args[1]);
+  });
+
+TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = args[0].operator BijectiveLayout()
+        .BackwardShape(args[1]);
+  });
+
 TVM_REGISTER_API("_Tensor")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
     *ret = TensorNode::make(args[0],
diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc
new file mode 100644 (file)
index 0000000..900a580
--- /dev/null
@@ -0,0 +1,322 @@
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/lang/data_layout.cc
+ * \brief Data Layout expression.
+ */
+#include <tvm/data_layout.h>
+#include <tvm/ir_pass.h>
+
+namespace tvm {
+
+TVM_REGISTER_NODE_TYPE(LayoutNode);
+TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode);
+
+const LayoutAxis LayoutAxis::UPPER_CASE[] = {
+  LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'),
+  LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'),
+  LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'),
+  LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'),
+  LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'),
+  LayoutAxis('Z')
+};
+
+const LayoutAxis LayoutAxis::LOWER_CASE[] = {
+  LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'),
+  LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'),
+  LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'),
+  LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'),
+  LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'),
+  LayoutAxis('z')
+};
+
+const LayoutAxis& LayoutAxis::Get(const char name) {
+  CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z'))
+    << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z.";
+  return (name >= 'A' && name <= 'Z') ?
+         LayoutAxis::UPPER_CASE[name-'A'] :
+         LayoutAxis::LOWER_CASE[name-'a'];
+}
+
+const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
+  const std::string axis = itvar->var.get()->name_hint;
+  CHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis;
+  return LayoutAxis::Get(axis[0]);
+}
+
+const LayoutAxis& LayoutAxis::make(const std::string& name) {
+  CHECK_EQ(name.length(), 1) << "Invalid axis " << name;
+  return LayoutAxis::Get(name[0]);
+}
+
+Layout::Layout(const Array<IterVar>& axes) {
+  node_ = make_node<LayoutNode>();
+  LayoutNode *node = operator->();
+  node->axes = axes;
+  std::ostringstream repr;
+  for (const IterVar& axis : axes) {
+    if (const auto* factor = axis->dom->extent.as<IntImm>()) {
+      CHECK_GT(factor->value, 0);
+      repr << factor->value;
+    }
+    CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis "
+                                                   << axis->var.get()->name_hint;
+    char c = axis->var.get()->name_hint[0];
+    CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
+    repr << axis->var.get()->name_hint;
+  }
+  node->name = repr.str();
+}
+
+Layout::Layout(const std::string& name) { // NOLINT(*)
+  if (name.empty() || name == "__undef__") return;
+
+  node_ = make_node<LayoutNode>();
+  LayoutNode *node = operator->();
+  node->name = name;
+
+  // parse layout string
+  int32_t factor = 0;
+  for (char c : name) {
+    if (c >= 'A' && c <= 'Z') {
+      CHECK_EQ(factor, 0) << "Invalid layout " << name
+                          << ": invalid factor size " << factor
+                          << " before dimension " << c;
+      std::string shape_name("_shape");
+      shape_name.insert(0, 1, c);
+      IterVar axis = IterVarNode::make(Range(Expr(0), Var(shape_name)),
+                                       Var(std::string(1, c)), kDataPar);
+      node->axes.push_back(axis);
+    } else if (c >= 'a' && c <= 'z') {
+      CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
+                          << factor << " for dimension " << c;
+      IterVar axis = IterVarNode::make(Range(Expr(0), Expr(factor)),
+                                       Var(std::string(1, c)), kDataPar);
+      node->axes.push_back(axis);
+      factor = 0;
+    } else if (c >= '0' && c <= '9') {
+      CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
+      factor = factor * 10 + c - '0';
+    } else {
+      LOG(FATAL) << "Invalid layout " << name;
+    }
+  }
+
+  // validate layout
+  std::vector<bool> exist_axis(256, false);
+  for (const IterVar& v : node->axes) {
+    auto axis_str = v->var.get()->name_hint;
+    CHECK_EQ(axis_str.size(), 1);
+    char axis = axis_str[0];
+    CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
+    CHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
+    exist_axis[axis] = true;
+  }
+  for (const IterVar& v : node->axes) {
+    char axis = v->var.get()->name_hint[0];
+    if (axis >= 'a' && axis <= 'z') {
+      CHECK(exist_axis[axis-'a'+'A']) << "Invalid layout " << name << ": missing axis "
+                                      << axis - 'a' + 'A';
+    }
+  }
+}
+
+Layout LayoutNode::make(const std::string& layout) {
+  return Layout(layout);
+}
+
+Layout Layout::SubLayout(size_t pos, size_t len) const {
+  if (!defined() || pos > ndim()) return Layout::Undef();
+  if (pos + len > ndim()) len = ndim() - pos;
+  Array<IterVar> new_layout;
+  const auto axes = operator->()->axes;
+  for (size_t i = pos; i < pos + len; ++i) {
+    new_layout.push_back(axes[i]);
+  }
+  return Layout(new_layout);
+}
+
+Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const {
+  if (!defined()) return Layout::Undef();
+  const std::string& name = operator->()->name;
+  const auto axes = operator->()->axes;
+  CHECK(target_pos <= this->ndim()) << "Invalid split position "
+                                    << target_pos << " for layout " << name;
+  CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis;
+  CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name;
+  CHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis
+                                                << " has already been split in " << name;
+  CHECK(factor > 0) << "Invalid split size " << factor;
+  Array<IterVar> new_layout;
+  for (size_t i = 0; i <= this->ndim(); ++i) {
+    if (i == target_pos) {
+      new_layout.push_back(IterVarNode::make(Range(Expr(0), Expr(factor)),
+                                             Var(axis.ToSubordinate().name()), kDataPar));
+    }
+    if (i == this->ndim()) break;
+    new_layout.push_back(axes[i]);
+  }
+  return Layout(new_layout);
+}
+
+int32_t Layout::FactorOf(const LayoutAxis& axis) const {
+  if (!defined()) return -1;
+  const LayoutAxis& sub = axis.ToSubordinate();
+  if (!this->defined()) return -1;
+  for (const IterVar& itvar : operator->()->axes) {
+    if (sub == LayoutAxis::Get(itvar)) {
+      const auto* factor = itvar->dom->extent.as<IntImm>();
+      CHECK(factor);
+      return factor->value;
+    }
+  }
+  return -1;
+}
+
+inline bool GetStoreRule(Array<Expr>* rule,
+                         const Layout& src_layout,
+                         const Layout& dst_layout) {
+  for (size_t i = 0; i < dst_layout.ndim(); ++i) {
+    const auto& store_axis = dst_layout[i];
+    const IterVar& store_axis_impl = dst_layout->axes[i];
+    Expr store(0);
+
+    for (size_t j = 0; j < src_layout.ndim(); ++j) {
+      const auto& orig_axis = src_layout[j];
+      const IterVar& orig_axis_impl = src_layout->axes[j];
+      if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
+        if (orig_axis.IsPrimal()) {
+          Expr orig_var = orig_axis_impl->var;
+          const int32_t factor = src_layout.FactorOf(orig_axis);
+          if (factor > 0) {
+            orig_var = orig_var * Expr(factor);
+          }
+          store = store + orig_var;
+        } else {
+          store = store + orig_axis_impl->var;
+        }
+      }
+    }
+    if (is_zero(store)) {
+      // Not convertible
+      return false;
+    }
+
+    if (store_axis.IsPrimal()) {
+      const int32_t factor = dst_layout.FactorOf(store_axis);
+      if (factor > 0) {
+        store = store / Expr(factor);
+      }
+    } else {
+      store = store % store_axis_impl->dom->extent;
+    }
+
+    rule->push_back(store);
+  }
+  return true;
+}
+
+inline Array<Expr> TransformIndex(const Array<Expr>& src_index,
+                                  const Array<IterVar>& src_axis,
+                                  const Array<Expr>& transform_rule) {
+  Array<Expr> result;
+  std::unordered_map<const Variable*, Expr> bind_map;
+  for (size_t i = 0; i < src_index.size(); ++i) {
+    bind_map[src_axis[i]->var.get()] = src_index[i];
+  }
+  for (Expr rule : transform_rule) {
+    result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
+  }
+  return result;
+}
+
+Array<Expr> BijectiveLayout::ForwardIndex(const Array<Expr>& src_index) const {
+  CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
+  const BijectiveLayoutNode* self = operator->();
+  CHECK_EQ(src_index.size(), self->src_layout->axes.size())
+    << "Input mismatch with layout " << self->src_layout;
+  return TransformIndex(src_index, self->src_layout->axes, self->forward_rule);
+}
+
+
+Array<Expr> BijectiveLayout::BackwardIndex(const Array<Expr>& dst_index) const {
+  CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
+  const BijectiveLayoutNode* self = operator->();
+  CHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
+    << "Output mismatch with layout " << self->dst_layout;
+  return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
+}
+
+inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
+                                  const Array<IterVar>& src_axis,
+                                  const Array<IterVar>& target_axis,
+                                  const Array<Expr>& transform_rule) {
+  CHECK_EQ(src_shape.size(), src_axis.size());
+  // bind variables for original axes
+  // for major-axis, bind the corresponding size
+  // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
+  // e.g., (C * 16 + c) / 32
+  std::unordered_map<const Variable*, Expr> bind_map;
+  for (size_t i = 0; i < src_shape.size(); ++i) {
+    Expr orig_shape = src_shape[i];
+    IterVar orig_axis = src_axis[i];
+    if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
+      if (orig_shape.defined()) {
+        const auto* orig_shape_const = orig_shape.as<IntImm>();
+        const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
+        CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
+          << "Input shape mismatch at index " << i << ". Expected "
+          << orig_axis->dom->extent << ", get " << orig_shape;
+      }
+      bind_map[orig_axis->var.get()] = Expr(0);
+    } else {
+      bind_map[orig_axis->var.get()] = orig_shape;
+    }
+  }
+  // infer the target shape,
+  // for major-axis, use the forward/backward_rule directly,
+  // for minor-axis, simply use the extent.
+  Array<Expr> result;
+  CHECK_EQ(transform_rule.size(), target_axis.size());
+  for (size_t i = 0; i < transform_rule.size(); ++i) {
+    Expr rule = transform_rule[i];
+    IterVar axis = target_axis[i];
+    if (!LayoutAxis::Get(axis).IsPrimal()) {
+      result.push_back(axis->dom->extent);
+    } else {
+      result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
+    }
+  }
+  return result;
+}
+
+Array<Expr> BijectiveLayout::ForwardShape(const Array<Expr>& shape) const {
+  CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
+  const BijectiveLayoutNode* self = operator->();
+  return TransformShape(shape, self->src_layout->axes,
+                        self->dst_layout->axes, self->forward_rule);
+}
+
+Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const {
+  CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
+  const BijectiveLayoutNode* self = operator->();
+  return TransformShape(shape, self->dst_layout->axes,
+                        self->src_layout->axes, self->backward_rule);
+}
+
+BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
+                                          const Layout& dst_layout) {
+  auto n = make_node<BijectiveLayoutNode>();
+
+  n->src_layout = src_layout;
+  n->dst_layout = dst_layout;
+
+  if (!GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
+    // not convertible
+    return BijectiveLayout();
+  }
+  CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
+
+  return BijectiveLayout(n);
+}
+
+}  // namespace tvm
index 4c9b0a5ca83eef325756fc612d1b149fa2f86650..4a5a7a86f1eaf62ef89123f6b81358b21a9fd250 100644 (file)
@@ -4,13 +4,13 @@
  * \brief Property def of nn operators.
  */
 
+#include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/debug.h>
 #include <topi/elemwise.h>
 #include <vector>
 #include "./type_relations.h"
 #include "./op_common.h"
-#include "./layout.h"
 
 namespace tvm {
 namespace relay {
index e6efcb8ce4597964013c685f1de2796f0266a160..d92e380fa9cc07adc23bd1850a94bfdfb1a3e1d5 100644 (file)
@@ -3,11 +3,11 @@
  * \file resize.cc
  * \brief Image operators
  */
+#include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/image.h>
 #include <topi/elemwise.h>
 #include <topi/image/resize.h>
-#include "../layout.h"
 #include "../op_common.h"
 
 namespace tvm {
@@ -28,17 +28,18 @@ bool ResizeRel(const Array<Type>& types,
   const ResizeAttrs* param = attrs.as<ResizeAttrs>();
   CHECK(param != nullptr);
   const Layout in_layout(param->layout);
-  CHECK(in_layout.Convertible(kNCHW))
+  auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
+  CHECK(layout_converter.defined())
     << "Resize only support input layouts that are convertible from NCHW."
     << " But got " << in_layout;
 
-  auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
-  oshape[2] = param->size[0];
-  oshape[3] = param->size[1];
+  auto oshape = layout_converter.ForwardShape(data->shape);
+  oshape.Set(2, param->size[0]);
+  oshape.Set(3, param->size[1]);
 
   // assign output type
   reporter->Assign(types[1],
-                   TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
+                   TensorTypeNode::make(layout_converter.BackwardShape(oshape),
                                         data->dtype));
   return true;
 }
diff --git a/src/relay/op/layout.cc b/src/relay/op/layout.cc
deleted file mode 100644 (file)
index 98fea55..0000000
+++ /dev/null
@@ -1,80 +0,0 @@
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file src/relay/op/layout.cc
- * \brief Layout expression.
- */
-
-#include "layout.h"
-
-namespace tvm {
-namespace relay {
-
-TVM_REGISTER_NODE_TYPE(LayoutNode);
-
-std::vector<IndexExpr> ConvertLayout(
-    std::vector<IndexExpr> src,
-    const Layout& src_layout,
-    const Layout& dst_layout) {
-  CHECK_EQ(src_layout.ndim(), src.size());
-  if (src_layout == dst_layout) {
-    return src;
-  } else if (!src_layout.defined()) {
-    LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
-  } else if (!dst_layout.defined()) {
-    LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
-  }
-
-  CHECK(src_layout.Convertible(dst_layout))
-    << "cannot convert from "
-    << src_layout << " to " << dst_layout;
-
-  std::vector<IndexExpr> dst(dst_layout.ndim());
-  for (size_t i = 0; i < src_layout.ndim(); ++i) {
-    Layout::LayoutDim src_dim = src_layout[i];
-    if (Layout::IsSuperdim(src_dim)) {
-      int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_dim));
-      int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_dim));
-      int src_minor_pos = src_layout.Indexof(Layout::ToSubdim(src_dim));
-      int src_factor = src_layout.Subsizeof(src_dim);
-      int dst_factor = dst_layout.Subsizeof(src_dim);
-      IndexExpr src_dim_size = src[i];
-
-      if (src_minor_pos >= 0) {
-        CHECK(is_const_int(src[src_minor_pos], src_factor))
-          << "src shape " << Array<IndexExpr>(src)
-          << " does not agree with layout "
-          << src_layout;
-        src_dim_size *= src_factor;
-      }
-      dst[dst_major_pos] = src_dim_size;
-      if (dst_minor_pos >= 0) {
-        CHECK_GT(dst_factor, 0);
-        if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
-          CHECK_LE(dst_factor, const_src_dim_size[0])
-            << "Converting " << Array<IndexExpr>(src)
-            << " from " << src_layout
-            << " to " << dst_layout
-            << ": cannot split dimension size of "
-            << src_dim_size << " by " << dst_factor;
-        }
-        dst[dst_major_pos] /= dst_factor;
-        dst[dst_minor_pos] = dst_factor;
-      }
-    }
-  }
-  return dst;
-}
-
-std::vector<IndexExpr> ConvertLayout(
-    const Array<IndexExpr>& src,
-    const Layout& src_layout,
-    const Layout& dst_layout) {
-  std::vector<IndexExpr> ret(src.size());
-  for (size_t i = 0; i < src.size(); ++i) {
-    ret[i] = src[i];
-  }
-  return ConvertLayout(ret, src_layout, dst_layout);
-}
-
-}  // namespace relay
-}  // namespace tvm
diff --git a/src/relay/op/layout.h b/src/relay/op/layout.h
deleted file mode 100644 (file)
index 09cf3a9..0000000
+++ /dev/null
@@ -1,432 +0,0 @@
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file relay/op/layout.h
- * \brief Layout expression.
- *
- *  This file is adapted from its nnvm counterpart and will keep involving
- *  to the new layout system
- *
- *  The layout is composed of upper cases, lower cases and numbers,
- *  where upper case indicates a (super-)dimension and
- *  the corresponding lower case with factor size indicates the split (sub-)dimension.
- *  For example, NCHW16c can describe a 5-D tensor of
- *  [batch_size, channel, height, width, channel_block].
- *  Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
- */
-#ifndef TVM_RELAY_OP_LAYOUT_H_
-#define TVM_RELAY_OP_LAYOUT_H_
-
-#include <tvm/base.h>
-#include <tvm/expr.h>
-#include <tvm/relay/base.h>
-
-#include <string>
-#include <sstream>
-#include <vector>
-#include <utility>
-#include <algorithm>
-
-namespace tvm {
-namespace relay {
-
-class LayoutNode : public Node {
- public:
-  std::string name;
-  Array<Integer> superdim_pos;
-  Array<Integer> subdim_pos;
-  Array<Integer> subdim_size;
-  Array<Integer> layout_simplified;
-
-  void VisitAttrs(AttrVisitor* v) final {
-    v->Visit("name", &name);
-    v->Visit("superdim_pos", &superdim_pos);
-    v->Visit("subdim_pos", &subdim_pos);
-    v->Visit("subdim_size", &subdim_size);
-    v->Visit("layout_simplified", &layout_simplified);
-  }
-
-  static constexpr const char* _type_key = "Layout";
-  TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
-};
-
-class Layout : public NodeRef {
- public:
-  using LayoutDim = char;
-  static constexpr uint32_t kUniqueDim = 26;
-
-  explicit Layout(NodePtr<Node> n) : NodeRef(n) {}
-
-  /*! \brief default constructor */
-  Layout() : Layout("__undef__") {} // NOLINT(*)
-
-  /*! \brief construct from a string */
-  Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
-
-  /*!
-   * \brief construct from a string.
-   * \param layout input in layout convention:
-   *        upper case indicates a dimension and
-   *        the corresponding lower case with factor size
-   *        indicates the split dimension.
-   *        return undefined layout if "__undef__" is passed.
-   */
-  Layout(const std::string& name) { // NOLINT(*)
-    node_ = make_node<LayoutNode>();
-
-    std::vector<int32_t> superdim_pos(kUniqueDim, -1);
-    std::vector<int32_t> subdim_pos(kUniqueDim, -1);
-    std::vector<int32_t> subdim_size(kUniqueDim, -1);
-    std::vector<char> layout_simplified;
-
-    if (name != "__undef__") {  // parse layout string
-      int32_t factor = 0;
-      uint32_t curr = 0;
-      for (size_t i = 0; i < name.size(); ++i) {
-        const LayoutDim c = name.at(i);
-        if (IsSuperdim(c)) {
-          int pos = c - 'A';
-          CHECK_EQ(factor, 0) << "Invalid layout " << name
-                              << ": invalid factor size " << factor
-                              << " before dimension " << c;
-          CHECK_EQ(superdim_pos[pos], -1) << "Invalid layout " << name
-                                          << ": duplicate dimension " << c;
-          superdim_pos[pos] = curr++;
-          layout_simplified.push_back(c);
-        } else if (IsSubdim(c)) {
-          int pos = c - 'a';
-          CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
-                              << factor << " for dimension " << c;
-          CHECK_EQ(subdim_pos[pos], -1) << "Invalid layout " << name
-                                        << ": duplicate dimension " << c;
-          CHECK_EQ(subdim_size[pos], -1) << "Invalid layout " << name
-                                         << ": duplicate dimension " << c;
-          subdim_pos[pos] = curr++;
-          subdim_size[pos] = factor;
-          layout_simplified.push_back(c);
-          factor = 0;
-        } else if (c >= '0' && c <= '9') {
-          CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
-          factor = factor * 10 + c - '0';
-        } else {
-          LOG(FATAL) << "Invalid layout " << name;
-        }
-      }
-      for (LayoutDim dim : layout_simplified) {
-        CHECK(IsSuperdim(dim) || superdim_pos[dim-'a'] >= 0)
-          << "Invalid layout " << name << ": missing axis "
-          << static_cast<char>(dim - 'a' + 'A');
-      }
-    }
-
-    LayoutNode *node = operator->();
-    node->name = name;
-
-    for (uint32_t i = 0; i < kUniqueDim; ++i) {
-      node->superdim_pos.push_back(superdim_pos[i]);
-      node->subdim_pos.push_back(subdim_pos[i]);
-      node->subdim_size.push_back(subdim_size[i]);
-    }
-    for (LayoutDim dim : layout_simplified) {
-      node->layout_simplified.push_back(dim);
-    }
-  }
-
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  const LayoutNode* operator->() const {
-    return static_cast<const LayoutNode*>(node_.get());
-  }
-
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  LayoutNode* operator->() {
-    return static_cast<LayoutNode*>(node_.get());
-  }
-
-  /*!
-   * \brief Check whether a given dimension is a super-dimension.
-   * \param dim input dimension
-   * \return Whether a given dimension is a super-dimension.
-   */
-  static bool IsSuperdim(LayoutDim dim) {
-    return dim >= 'A' && dim <= 'Z';
-  }
-
-  /*!
-   * \brief Check whether a given dimension is a sub-dimension.
-   * \param dim input dimension
-   * \return Whether a given dimension is a sub-dimension.
-   */
-  static bool IsSubdim(LayoutDim dim) {
-    return dim >= 'a' && dim <= 'z';
-  }
-
-  /*!
-   * \brief Convert a given dimension to super-dimension.
-   * \param dim input dimension
-   * \return The converted description.
-   */
-  static LayoutDim ToSuperdim(LayoutDim dim) {
-    if (IsSubdim(dim)) {
-      return dim - 'a' + 'A';
-    }
-    return dim;
-  }
-
-  /*!
-   * \brief Convert a given dimension to sub-dimension.
-   * \param dim input dimension
-   * \return The converted description.
-   */
-  static LayoutDim ToSubdim(LayoutDim dim) {
-    if (IsSuperdim(dim)) {
-      return dim - 'A' + 'a';
-    }
-    return dim;
-  }
-
-  /*!
- * \brief Return an undefined layout.
- * \return a (global) undefined layout.
- */
-  static const Layout& Undef() {
-    static Layout undef;
-    return undef;
-  }
-
-  /*!
-   * \brief Two layouts are convertible only if
-   *        they have same set of super-dimensions.
-   *        e.g., NCHW, NCHW16c, NHWC are convertible between each other,
-   *        but NCHW, CHW, OIHW are not.
-   * \param dst the target layout
-   * \return Whether can be converted to dst layout.
-   */
-  bool Convertible(const Layout &dst) const {
-    const LayoutNode *n = operator->();
-    if (!this->defined() || !dst.defined()) return false;
-    for (size_t i = 0; i < kUniqueDim; ++i) {
-      if ((n->superdim_pos[i]->value >= 0 && dst->superdim_pos[i]->value < 0) ||
-          (n->superdim_pos[i]->value < 0 && dst->superdim_pos[i]->value >= 0)) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  /*!
-   * \brief Returns a sublayout which is the portion of the object
-   *        that starts at dimension \p pos and spans \p len dimensions
-   *        (or until the end of the layout, whichever comes first).
-   * \param pos The start position.
-   * \param len The length of the sub-layout.
-   * \return A newly constructed Layout object.
-   */
-  Layout Sublayout(size_t pos, size_t len) const {
-    const Array<Integer>& layout_simplified = operator->()->layout_simplified;
-    if (pos > ndim()) return Layout::Undef();
-    if (pos + len > ndim()) len = ndim() - pos;
-    std::ostringstream new_layout;
-    for (size_t i = pos; i < pos + len; ++i) {
-      if (IsSubdim(layout_simplified[i]->value)) {
-        auto block_size = this->Subsizeof(layout_simplified[i]->value);
-        CHECK_GT(block_size, 0);
-        new_layout << block_size;
-      }
-      new_layout << static_cast<char>(layout_simplified[i]->value);
-    }
-    return Layout(new_layout.str());
-  }
-
-  /*! \return A newly constructed reversed Layout object. */
-  Layout Reverse() const {
-    const Array<Integer>& layout_simplified = operator->()->layout_simplified;
-    if (!this->defined()) return Layout::Undef();
-    std::ostringstream new_layout;
-    for (int64_t i = this->ndim() - 1; i >= 0; --i) {
-      if (IsSubdim(layout_simplified[i]->value)) {
-        auto block_size = this->Subsizeof(layout_simplified[i]->value);
-        CHECK_GT(block_size, 0);
-        new_layout << block_size;
-      }
-      new_layout << layout_simplified[i]->value;
-    }
-    return Layout(new_layout.str());
-  }
-
-  /*!
-   * \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
-   * \param dim The source dimension to be split. It must be a super-dimension.
-   * \param target_pos The target position of the newly split sub-dimension.
-   * \param size size of the sub-dimension.
-   * \return A newly constructed Layout object.
-   */
-  Layout Split(LayoutDim dim, size_t target_pos, uint32_t size) const {
-    const std::string &name = operator->()->name;
-    CHECK(target_pos <= this->ndim()) << "Invalid split position "
-                                      << target_pos << " for layout " << name;
-    CHECK(IsSuperdim(dim)) << "Cannot split a sub-dimension " << dim;
-    CHECK(this->Contains(dim)) << "Axis " << dim << " does not exist in " << name;
-    CHECK(!this->Contains(ToSubdim(dim))) << "Dimension " << dim
-                                           << " has already been split in "
-                                           << name;
-    CHECK(size > 0) << "Invalid split size " << size;
-    std::ostringstream new_layout;
-    for (size_t i = 0; i <= this->ndim(); ++i) {
-      if (i == target_pos) {
-        new_layout << size << Layout::ToSubdim(dim);
-      }
-      if (i == this->ndim()) break;
-      new_layout << this->at(i);
-    }
-    Layout x(new_layout.str());
-    return x;
-  }
-
-
-  /*! \return number of dimensions */
-  size_t ndim() const {
-    return operator->()->layout_simplified.size();
-  }
-
-  /*! \return number of super dimensions */
-  size_t ndim_super() const {
-    size_t ct = 0;
-    for (auto x : operator->()->layout_simplified) {
-      if (IsSuperdim(x))
-        ct++;
-    }
-    return ct;
-  }
-
-  /*!
-   * \brief The description of the \p i-th dimension.
-   *        If it is a sub-dimension, the size will be returned as well,
-   *        e.g., 16c. Otherwise a single character is returned, e.g., C.
-   * \param i The position
-   * \return the description of the dimension.
-   */
-  std::string at(size_t i) const {
-    const Array<Integer>& layout_simplified = operator->()->layout_simplified;
-    CHECK_LT(i, this->ndim()) << "position " << i
-                              << " exceeds ndim=" << this->ndim();
-    std::ostringstream repr;
-    if (IsSubdim(layout_simplified[i]->value)) {
-      auto factor = Subsizeof(layout_simplified[i]->value);
-      CHECK_GT(factor, 0);
-      repr << factor;
-    }
-    repr << static_cast<char>(layout_simplified[i]->value);
-    return repr.str();
-  }
-
-  /*!
-   * \brief return the index of the input dimension.
-   *        If it is not found in the layout or the layout is undefined,
-   *        return -1.
-   * \param dim the input dimension.
-   * \return the index or -1 if not found.
-   */
-  int32_t Indexof(LayoutDim dim) const {
-    if (!this->defined()) return -1;
-    else if (IsSuperdim(dim)) return operator->()->superdim_pos[dim - 'A']->value;
-    else if (IsSubdim(dim)) return operator->()->subdim_pos[dim - 'a']->value;
-    return -1;
-  }
-
-  /*!
-   * \param dim the input super-dimension or sub-dimension.
-   * \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
-   *         or the size of \p dim itself (if \p dim is a sub-dimension).
-   *         Return -1 if \p dim is not in the layout or the layout is undefined.
-   */
-  int64_t Subsizeof(LayoutDim dim) const {
-    CHECK(IsSuperdim(dim) || IsSubdim(dim)) << "Invalid dim " << dim;
-    if (!this->defined() || !this->Contains(ToSubdim(dim))) {
-      return -1;
-    }
-    int idx = ToSubdim(dim) - 'a';
-    return operator->()->subdim_size[idx]->value;
-  }
-
-  /*!
-   * \brief Whether the layout contains a dimension.
-   * \param dim dimension to be checked.
-   * \return Whether the layout contains the dimension.
-   */
-  bool Contains(LayoutDim dim) const {
-    if (IsSuperdim(dim)) {
-      return operator->()->superdim_pos[dim-'A']->value >= 0;
-    } else if (IsSubdim(dim)) {
-      return operator->()->subdim_pos[dim-'a']->value >= 0;
-    }
-    return false;
-  }
-
-  LayoutDim operator[](size_t i) const {
-    return operator->()->layout_simplified[i];
-  }
-
-  /*! \return whether the layout is defined */
-  bool defined() const {
-    return operator->()->name != "__undef__";
-  }
-  /*! \return the string description of the layout */
-  const std::string& name() const {
-    return operator->()->name;
-  }
-
-  /*!
-   * \brief Whether the two layouts are equal.
-   * \param rhs Another layout.
-   * \return whether the two layouts are equal.
-   */
-  bool Equals(const Layout &rhs) const {
-    return operator->()->name == rhs->name;
-  }
-
-  /*!
- * \brief allow output string of layout to ostream
- * \param os the output stream
- * \param l the layout
- * \return the ostream
- */
-  friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
-    os << l.name();
-    return os;
-  }
-
-  using ContainerType = LayoutNode;
-};
-
-/*!
- * \brief Convert shape in src_layout to shape in dst_layout
- * \param src original shape
- * \param src_layout layout of original shape
- * \param dst_layout target layout
- * \return shape in target layout
- */
-std::vector<IndexExpr> ConvertLayout(
-    std::vector<IndexExpr> src,
-    const Layout& src_layout,
-    const Layout& dst_layout);
-
-/*!
- * \brief Convert shape in src_layout to shape in dst_layout
- * \param src original shape
- * \param src_layout layout of original shape
- * \param dst_layout target layout
- * \return shape in target layout
- */
-std::vector<IndexExpr> ConvertLayout(
-    const Array<IndexExpr>& src,
-    const Layout& src_layout,
-    const Layout& dst_layout);
-}  // namespace relay
-}  // namespace tvm
-
-#endif  // TVM_RELAY_OP_LAYOUT_H_
index e05b24d967bce9ab5a2b5ae907ad7cc89bcf3d78..963257a14961126308140684edcb58d5d8ba13d9 100644 (file)
@@ -3,12 +3,12 @@
  * \file convolution.cc
  * \brief Convolution operators
  */
+#include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <vector>
 
 #include "../../pass/alter_op_layout.h"
-#include "../layout.h"
 
 namespace tvm {
 namespace relay {
@@ -31,32 +31,36 @@ bool Conv2DRel(const Array<Type>& types,
   CHECK(param != nullptr);
   const Layout in_layout(param->data_layout);
   const Layout kernel_layout(param->kernel_layout);
-  CHECK(in_layout.Convertible(kNCHW))
+
+  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
+  CHECK(trans_in_layout.defined())
     << "Conv only support input layouts that are convertible from NCHW."
     << " But got " << in_layout;
-  CHECK(kernel_layout.Convertible(kOIHW))
+
+  const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
+  CHECK(trans_kernel_layout.defined())
     << "Conv only support kernel layouts that are convertible from OIHW."
     << " But got "<< kernel_layout;
 
   Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  CHECK(out_layout.Convertible(kNCHW))
+  const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
+  CHECK(trans_out_layout.defined())
       << "Conv only support output layouts that are convertible from NCHW."
       << " But got " << out_layout;
 
-  std::vector<IndexExpr> dshape_nchw = ConvertLayout(
-      data->shape, in_layout, kNCHW);
+  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
 
   IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
   // infer weight if the kernel_size and channels are defined
   if (param->kernel_size.defined() && param->channels.defined()) {
     CHECK_EQ(param->kernel_size.size(), 2);
     CHECK_EQ(param->dilation.size(), 2);
-    std::vector<IndexExpr> wshape(
+    Array<IndexExpr> wshape(
        {param->channels,
          dshape_nchw[1] / param->groups,
          param->kernel_size[0],
          param->kernel_size[1]});
-    wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
+    wshape = trans_kernel_layout.BackwardShape(wshape);
     channels = param->channels;
     dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
     dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
@@ -65,7 +69,7 @@ bool Conv2DRel(const Array<Type>& types,
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
-    auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
+    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
     if (param->kernel_size.defined()) {
       CHECK_EQ(param->kernel_size.size(), 2);
       // check the size
@@ -73,13 +77,13 @@ bool Conv2DRel(const Array<Type>& types,
             reporter->AssertEQ(param->kernel_size[1], wshape[3]))
           << "Conv2D: shape of weight is inconsistent with kernel_size, "
           << " kernel_size=" << param->kernel_size
-          << " wshape=" << Array<IndexExpr>(wshape);
+          << " wshape=" << wshape;
     }
     if (param->channels.defined()) {
       CHECK(reporter->AssertEQ(param->channels, wshape[0]))
           << "Conv2D: shape of weight is inconsistent with channels, "
           << " channels=" << param->channels
-          << " wshape=" << Array<IndexExpr>(wshape);
+          << " wshape=" << wshape;
     }
     CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
     channels = wshape[0];
@@ -87,15 +91,15 @@ bool Conv2DRel(const Array<Type>& types,
     dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
   }
   // dilation
-  std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
 
-  oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
-  oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
+  oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
+  oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
     out_dtype = data->dtype;
   }
-  oshape = ConvertLayout(oshape, kNCHW, out_layout);
+  oshape = trans_out_layout.BackwardShape(oshape);
   // assign output type
   reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
   return true;
@@ -193,33 +197,38 @@ bool Conv2DTransposeRel(const Array<Type>& types,
   CHECK(param != nullptr);
   const Layout in_layout(param->data_layout);
   const Layout kernel_layout(param->kernel_layout);
-  CHECK(in_layout.Convertible(kNCHW))
+
+  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
+  CHECK(trans_in_layout.defined())
     << "Conv only support input layouts that are convertible from NCHW."
     << " But got " << in_layout;
-  CHECK(kernel_layout.Convertible(kOIHW))
+
+  const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
+  CHECK(trans_kernel_layout.defined())
     << "Conv only support kernel layouts that are convertible from OIHW."
     << " But got "<< kernel_layout;
 
   Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  CHECK(out_layout.Convertible(kNCHW))
+  const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
+  CHECK(trans_out_layout.defined())
     << "Conv only support output layouts that are convertible from NCHW."
     << " But got " << out_layout;
 
   IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
 
-  auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
+  auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
 
   // infer weight if the kernel_size and channels are defined
   if (param->kernel_size.defined() && param->channels.defined()) {
     CHECK_EQ(param->kernel_size.size(), 2);
     CHECK_EQ(param->dilation.size(), 2);
 
-    std::vector<IndexExpr> wshape({dshape_nchw[1],
-                                   param->channels / param->groups,
-                                   param->kernel_size[0],
-                                   param->kernel_size[1]});
+    Array<IndexExpr> wshape({dshape_nchw[1],
+                             param->channels / param->groups,
+                             param->kernel_size[0],
+                             param->kernel_size[1]});
 
-    wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
+    wshape = trans_kernel_layout.BackwardShape(wshape);
     dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
     dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
     channels = param->channels;
@@ -229,7 +238,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
-    auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
+    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
     if (param->kernel_size.defined()) {
       CHECK_EQ(param->kernel_size.size(), 2);
       // check the size
@@ -251,17 +260,17 @@ bool Conv2DTransposeRel(const Array<Type>& types,
     dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
   }
   // dilation
-  std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
-  oshape[2] = (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
-               2 * param->padding[0] + param->output_padding[0]);
-  oshape[3] = (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
-               2 * param->padding[1] + param->output_padding[1]);
+  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+  oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
+                 2 * param->padding[0] + param->output_padding[0]));
+  oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
+                 2 * param->padding[1] + param->output_padding[1]));
 
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
     out_dtype = data->dtype;
   }
-  oshape = ConvertLayout(oshape, kNCHW, out_layout);
+  oshape = trans_out_layout.BackwardShape(oshape);
   reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
   return true;
 }
@@ -349,20 +358,24 @@ bool Conv2DWinogradRel(const Array<Type>& types,
   CHECK(param != nullptr);
   const Layout in_layout(param->data_layout);
   const Layout kernel_layout(param->kernel_layout);
-  CHECK(in_layout.Convertible(kNCHW))
+
+  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
+  CHECK(trans_in_layout.defined())
     << "Conv only support input layouts that are convertible from NCHW."
     << " But got " << in_layout;
-  CHECK(kernel_layout.Convertible(kOIHW))
+
+  const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
+  CHECK(trans_kernel_layout.defined())
     << "Conv only support kernel layouts that are convertible from OIHW."
     << " But got "<< kernel_layout;
 
   Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  CHECK(out_layout.Convertible(kNCHW))
+  const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
+  CHECK(trans_out_layout.defined())
       << "Conv only support output layouts that are convertible from NCHW."
       << " But got " << out_layout;
 
-  std::vector<IndexExpr> dshape_nchw = ConvertLayout(
-      data->shape, in_layout, kNCHW);
+  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
 
   IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
 
@@ -384,15 +397,15 @@ bool Conv2DWinogradRel(const Array<Type>& types,
   // can handle this correctly in alter_op_layout.
 
   // dilation
-  std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
 
-  oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
-  oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
+  oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
+  oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
     out_dtype = data->dtype;
   }
-  oshape = ConvertLayout(oshape, kNCHW, out_layout);
+  oshape = trans_out_layout.BackwardShape(oshape);
   // assign output type
   reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
   return true;
index 7ed43d0df01980c231ed832df375cfb455749581..9ab841cf42868eedeec973bcd4c491e45687f18b 100644 (file)
@@ -4,6 +4,7 @@
  * \brief Property def of nn operators.
  */
 
+#include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/attrs/image.h>
@@ -14,7 +15,6 @@
 #include "../type_relations.h"
 #include "../../pass/alter_op_layout.h"
 #include "../op_common.h"
-#include "../layout.h"
 
 namespace tvm {
 namespace relay {
index dc99f05f4d2d9ef7d793748bbcf2c30b33b9a81e..c24203cebdb334b0c2c525fe4d666112547e1817 100644 (file)
@@ -3,12 +3,12 @@
  * \file pad.cc
  * \brief Implementation of operator pad
  */
+#include <tvm/data_layout.h>
 #include <tvm/ir_operator.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <topi/nn.h>
 #include <vector>
-#include "../layout.h"
 #include "../op_common.h"
 
 namespace tvm {
index 8fd33e1f3cdca6e8f99e5374a55ec49bd157dec5..23704693732b843f124f60b2a08d6acccdbc8c0d 100644 (file)
@@ -3,12 +3,12 @@
  * \file pooling.cc
  * \brief Pooling operators
  */
+#include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/attrs/nn.h>
 #include <topi/nn/pooling.h>
 #include <vector>
-#include "../layout.h"
 #include "../../pass/alter_op_layout.h"
 
 namespace tvm {
@@ -32,14 +32,15 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
 
     Layout raw_layout(params->layout);
     Layout input = new_in_layouts[0];
-    if (input.Indexof('W') == raw_layout.Indexof('W') &&
-        input.Indexof('H') == raw_layout.Indexof('H') &&
-        !input.Contains('w') && !input.Contains('h')) {
+    if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
+    input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
+        !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
       params->layout = input.name();  // modify self to follow the input layout
     }
   }
 
-  return Array<Array<Layout> >{{params->layout}, {params->layout}};
+  Layout inferred_layout(params->layout);
+  return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
 }
 
 template <typename AttrType>
@@ -59,13 +60,13 @@ bool Pool2DRel(const Array<Type>& types,
   CHECK(param != nullptr);
 
   Layout layout(param->layout);
-  CHECK(layout.Contains('H') && layout.Contains('W') &&
-        !layout.Contains('h') && !layout.Contains('w'))
+  CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
+        !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
     << "Invalid layout " << layout
     << ". Pool2D layout must have H and W, which cannot be split";
 
-  const auto hidx = layout.Indexof('H');
-  const auto widx = layout.Indexof('W');
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
 
   IndexExpr pad_h, pad_w;
   if (param->padding.size() == 1) {
@@ -125,6 +126,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
                             const Array<Tensor>& inputs,
                             const Type& out_type,
                             const Target& target) {
+  static const Layout kNCHW("NCHW");
   const auto* param = attrs.as<AttrType>();
   CHECK(param != nullptr);
   auto pool_size = param->pool_size;
@@ -132,10 +134,13 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
   auto padding = param->padding;
   auto ceil_mode = param->ceil_mode;
   Layout layout(param->layout);
-  CHECK(layout.Convertible(Layout("NCHW")))
+
+  CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
       << "max_pool2d currently only supports layouts that are convertible from NCHW";
-  CHECK_EQ(layout.Indexof('h'), -1) << "max_pool2d does not support input split on height";
-  CHECK_EQ(layout.Indexof('w'), -1) << "max_pool2d does not support input split on width";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
+      << "max_pool2d does not support input split on height";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
+      << "max_pool2d does not support input split on width";
 
   CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
       << "Pool2D only support 4-D input (e.g., NCHW)"
@@ -271,13 +276,13 @@ bool GlobalPool2DRel(const Array<Type>& types,
   CHECK(param != nullptr);
 
   Layout layout(param->layout);
-  CHECK(layout.Contains('H') && layout.Contains('W') &&
-        !layout.Contains('h') && !layout.Contains('w'))
+  CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) &&
+        !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w')))
     << "Invalid layout " << layout
     << ". Pool2D layout must have H and W, which cannot be split";
 
-  const auto hidx = layout.Indexof('H');
-  const auto widx = layout.Indexof('W');
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
   Array<IndexExpr> oshape(dshape);
   oshape.Set(hidx, 1);
   oshape.Set(widx, 1);
@@ -293,14 +298,15 @@ Array<Tensor> GlobalPool2DCompute(const Attrs& attrs,
                                   const Array<Tensor>& inputs,
                                   const Type& out_type,
                                   const Target& target) {
+  static const Layout kNCHW("NCHW");
   const auto* param = attrs.as<GlobalPool2DAttrs>();
   CHECK(param != nullptr);
   Layout layout(param->layout);
-  CHECK(layout.Convertible(Layout("NCHW")))
+  CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
     << "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
-  CHECK_EQ(layout.Indexof('h'), -1)
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
     << "global_avg_pool2d does not support input split on height";
-  CHECK_EQ(layout.Indexof('w'), -1)
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
     << "global_avg_pool2d does not support input split on width";
 
   CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
index d386437ae15b8f144e393b3f3d2317a6ff57b04d..48a7a04ebb8aff7df39396f3723921a199bb99eb 100644 (file)
@@ -3,6 +3,7 @@
  * \file upsampling.cc
  * \brief upsampling operator
  */
+#include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/op_attr_types.h>
@@ -11,7 +12,6 @@
 #include <topi/nn/upsampling.h>
 #include <vector>
 #include "../op_common.h"
-#include "../layout.h"
 
 namespace tvm {
 namespace relay {
@@ -31,18 +31,20 @@ bool UpSamplingRel(const Array<Type>& types,
   const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
   CHECK(param != nullptr);
   const Layout in_layout(param->layout);
-  CHECK(in_layout.Convertible(kNCHW))
+
+  auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
+  CHECK(layout_converter.defined())
     << "UpSampling only support input layouts that are convertible from NCHW."
     << " But got " << in_layout;
 
-  auto oshape = ConvertLayout(data->shape, in_layout, kNCHW);
+  auto oshape = layout_converter.ForwardShape(data->shape);
 
-  oshape[2] = oshape[2] * param->scale;
-  oshape[3] = oshape[3] * param->scale;
+  oshape.Set(2, oshape[2] * param->scale);
+  oshape.Set(3, oshape[3] * param->scale);
 
   // assign output type
   reporter->Assign(types[1],
-                   TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout),
+                   TensorTypeNode::make(layout_converter.BackwardShape(oshape),
                                         data->dtype));
   return true;
 }
index 48c97b91dfda1c0cebfbe36ff5dc25642da85c75..df23b22512e393d6b5e61a5771be5791da284ada 100644 (file)
@@ -7,6 +7,7 @@
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/ir_operator.h>
 #include <tvm/ir.h>
+#include <tvm/data_layout.h>
 #include <topi/transform.h>
 #include <topi/elemwise.h>
 #include <topi/broadcast.h>
@@ -16,7 +17,6 @@
 #include "../op_common.h"
 #include "../../../arithmetic/compute_expr.h"
 #include "../../pass/alter_op_layout.h"
-#include "../layout.h"
 
 namespace tvm {
 namespace relay {
@@ -218,7 +218,7 @@ Array<Array<Layout>> ConcatenateLayout(
 
   Layout ret;
   if (new_in_layouts.defined()) {  // this function is called after some operators are alternated.
-    Layout::LayoutDim concate_dim = old_in_layouts[0][axis];
+    const auto& concate_dim = old_in_layouts[0][axis];
     for (size_t i = 0; i < new_in_layouts.size(); ++i) {
       if (new_in_layouts[i].ndim() > axis &&
           new_in_layouts[i][axis] == concate_dim) {
@@ -234,7 +234,7 @@ Array<Array<Layout>> ConcatenateLayout(
       }
     }
 
-    if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) {
+    if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
       return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
     }
   }
@@ -1682,46 +1682,10 @@ Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
                                      const Array<Tensor>& inputs,
                                      const Type& out_type,
                                      const Target& target) {
-  const LayoutTransformAttrs *param = attrs.as<LayoutTransformAttrs>();
+  const auto* param = attrs.as<LayoutTransformAttrs>();
   CHECK(param != nullptr);
-
-  Layout src_layout(param->src_layout);
-  Layout dst_layout(param->dst_layout);
-
-  if (src_layout.Equals(dst_layout)) {
-    return Array<Tensor>{ inputs[0] };
-  }
-
-  CHECK(src_layout.defined() && dst_layout.defined())
-    << "cannot convert from/to undefined layout";
-  CHECK(src_layout.Convertible(dst_layout))
-    << "cannot convert from " << param->src_layout << " to " << param->dst_layout;
-
-  const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout);
-  return Array<Tensor> {
-      topi::layout_transform(inputs[0], out_shape, [&](const Array<tvm::Var>& dst_indices) {
-        std::vector<tvm::Expr> dst_to_src_indices;
-        for (size_t i = 0; i < src_layout.ndim(); ++i) {
-          Layout::LayoutDim src_axis = src_layout[i];
-          int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis));
-          int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis));
-          int32_t src_factor = static_cast<int32_t>(src_layout.Subsizeof(src_axis));
-          int32_t dst_factor = static_cast<int32_t>(dst_layout.Subsizeof(src_axis));
-
-          tvm::Expr src_index(dst_indices[dst_major_pos]);
-          if (dst_minor_pos >= 0) {
-            CHECK_GT(dst_factor, 0);
-            src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
-          }
-          if (Layout::IsSuperdim(src_axis) && src_factor > 0) {
-            src_index = src_index / src_factor;
-          } else if (Layout::IsSubdim(src_axis) && src_factor > 0) {
-            src_index = src_index % src_factor;
-          }
-          dst_to_src_indices.push_back(src_index);
-        }
-        return Array<tvm::Expr>(dst_to_src_indices);
-      })
+  return Array<Tensor>{
+    topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)
   };
 }
 
@@ -1738,10 +1702,12 @@ bool LayoutTransformRel(const Array<Type>& types,
 
   CHECK(src_layout.defined() && dst_layout.defined())
     << "cannot convert from/to undefined layout";
-  CHECK(src_layout.Convertible(dst_layout))
+
+  auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout);
+  CHECK(layout_converter.defined())
     << "cannot convert from " << params->src_layout << " to " << params->dst_layout;
 
-  const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout);
+  const auto& out_shape = layout_converter.ForwardShape(data->shape);
   reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
   return true;
 }
index 6d988eb2bcdf388e7f77de4755a4b0a2cb30f440..fe624a6489c1f627967483918ba4d70061fac211 100644 (file)
@@ -26,7 +26,7 @@ Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
   if (src_layout.Equals(dst_layout)) { return raw; }
   CHECK(src_layout.defined() && dst_layout.defined())
     << "Cannot insert layout transform because there are undefined layouts";
-  CHECK(src_layout.Convertible(dst_layout))
+  CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
     << "Cannot insert layout transform because there are inconvertible layouts: "
     << src_layout << " v.s. " << dst_layout;
   static auto &transform_op = Op::Get("layout_transform");
index fcb7b379a0ec141f3928278d32f0a9ef8909466b..93d9ee52f6873d95ddc23849fea418029f6cc345 100644 (file)
@@ -9,10 +9,9 @@
 #ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
 #define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
 
+#include <tvm/data_layout.h>
 #include <tvm/relay/expr.h>
 
-#include "../op/layout.h"
-
 namespace tvm {
 namespace relay {
 
@@ -78,9 +77,9 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
 
     if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
       layouts.Set(undef_idx,
-                  layouts[defined_idx].Sublayout(
-                      old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
-                      old_in_shapes[undef_idx].size()));
+                  layouts[defined_idx].SubLayout(
+                  old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
+                  old_in_shapes[undef_idx].size()));
       return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
     } else {
       // only know the tensor with smaller dimensions,
@@ -90,21 +89,22 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
     }
   } else {
     // try to broadcast the tensors to the larger dimension
-    int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1;
+    int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
     int small_idx = 1 - large_idx;
     Layout ret = layouts[large_idx];
 
     // extract common part
     size_t i = layouts[large_idx].ndim();
     for (; i != 0; --i) {
-      auto dim = layouts[large_idx][i-1];
-      if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) {
+      const auto& axis = layouts[large_idx][i-1];
+      if (!layouts[small_idx].Contains(axis.ToPrimal())) {
         break;
       }
     }
 
-    Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i);
-    if (!layouts[small_idx].Convertible(common_part)) {  // fail
+    Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
+    if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
+      // not convertible
       return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
     }
 
index cd2d29e80048173ff7800235ad1d95f13b86ac8a..44b239919ce272586e48a32fef0d28d151df0f44 100644 (file)
@@ -91,8 +91,10 @@ class BranchGroupFinder : private ExprVisitor {
     CHECK(attrs_b);
     const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
     const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
-    const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->kernel_layout, kOIHW);
-    const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->kernel_layout, kOIHW);
+    const auto shape_a = BijectiveLayoutNode::make(
+      Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
+    const auto shape_b = BijectiveLayoutNode::make(
+      Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
 
     return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
            eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
index 270965886ab91b3c619958f7ed44f054a8faade0..044cc4e5d9c95b973df27a13d5651a89a74b1ae5 100644 (file)
@@ -6,12 +6,12 @@
  * \brief Fold axis scaling into weights of
  *  conv/dense operators.
  */
+#include <tvm/data_layout.h>
 #include <tvm/relay/pass.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/expr_functor.h>
 #include "pattern_util.h"
 #include "pass_util.h"
-#include "../op/layout.h"
 
 
 namespace tvm {
@@ -435,8 +435,8 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
   CHECK(param != nullptr);
   Layout data_layout(param->data_layout);
   Layout kernel_layout(param->kernel_layout);
-  int c_big_axis = data_layout.Indexof('C');
-  int c_small_axis = data_layout.Indexof('c');
+  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
+  int c_small_axis = data_layout.IndexOf(LayoutAxis::Get('c'));
 
   CHECK_GE(c_big_axis, 0);
   Message none = NullValue<Message>();
@@ -449,7 +449,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
   // only handle depthwise or full conv2d.
   // TODO(tvm-team) handle grouped conv by reshape + bcast
   bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
-  if (kernel_layout.Indexof('i') < 0 &&
+  if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
       c_small_axis < 0 &&
       (param->groups == 1 || is_depthwise_conv2d)) {
     data_axes = {c_big_axis};
@@ -473,15 +473,15 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
   CHECK(param != nullptr);
   Layout data_layout(param->data_layout);
   Layout kernel_layout(param->kernel_layout);
-  int c_big_axis = data_layout.Indexof('C');
+  int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
   CHECK_GE(c_big_axis, 0);
   // For now, we only support simple pattern (no folded weight/data)
   // TODO(tvm-team) support general data layout
-  CHECK_EQ(kernel_layout.Indexof('i'), -1);
+  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
   CHECK(sdata->axes.size() == 1 &&
         c_big_axis == sdata->axes[0]->value);
-  int big_oc_axis = kernel_layout.Indexof('O');
-  int big_ic_axis = kernel_layout.Indexof('I');
+  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
+  int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
 
   // Check it must be depthwise or full conv2d.
   bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
@@ -857,8 +857,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
   CHECK(param != nullptr);
   Layout kernel_layout(param->kernel_layout);
   Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  int c_big_axis = out_layout.Indexof('C');
-  int c_small_axis = out_layout.Indexof('c');
+  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
+  int c_small_axis = out_layout.IndexOf(LayoutAxis::Get('c'));
 
   CHECK_GE(c_big_axis, 0);
   // For now, we only support simple pattern (no folded weight/data)
@@ -869,8 +869,8 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
   // only handle depthwise or full conv2d.
   // TODO(tvm-team) handle grouped conv by reshape + bcast
   bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
-  if (kernel_layout.Indexof('o') < 0 &&
-      kernel_layout.Indexof('i') < 0 &&
+  if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
+  kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
       c_small_axis < 0 &&
       (param->groups == 1 || is_depthwise_conv2d)) {
     return MessageNode::make({c_big_axis}, false);
@@ -891,16 +891,16 @@ Expr Conv2DBackwardTransform(const Call& call,
   CHECK(param != nullptr);
   Layout kernel_layout(param->kernel_layout);
   Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  int c_big_axis = out_layout.Indexof('C');
+  int c_big_axis = out_layout.IndexOf(LayoutAxis::Get('C'));
   CHECK_GE(c_big_axis, 0);
   // For now, we only support simple pattern (no folded weight/data)
   // TODO(tvm-team) support general data layout
-  CHECK_EQ(kernel_layout.Indexof('o'), -1);
-  CHECK_EQ(kernel_layout.Indexof('i'), -1);
+  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
+  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
   CHECK(message->axes.size() == 1 &&
         c_big_axis == message->axes[0]->value);
 
-  int big_oc_axis = kernel_layout.Indexof('O');
+  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
   // Check it must be depthwise or full conv2d.
   bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
   CHECK(param->groups == 1 || is_depthwise_conv2d);
index 500312117c5b56335c9c7ce9796b3c5419d660ee..e801cdc37d12f930467b0bab244b2a16837c8b7c 100644 (file)
@@ -11,7 +11,7 @@
 #include <tvm/relay/op.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/expr_functor.h>
-#include "../op/layout.h"
+#include <tvm/data_layout.h>
 
 namespace tvm {
 namespace relay {
@@ -51,8 +51,8 @@ int64_t ConvMacCount(const Call& call_node) {
   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
   Array<IndexExpr> data_shape = data_type->shape;
   std::string data_layout = conv_2d_attr->data_layout;
-  int32_t C_ind = Layout(data_layout).Indexof('C');
-  int32_t c_ind = Layout(data_layout).Indexof('c');
+  int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
+  int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
   CHECK(C_ind != -1)
       << "There is no input channel dimension.";
   int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
index 08fc017f41eb347914049caac20c33c504b8c2f4..0644c26c6bccfb964a06551e7590f67c06bb544f 100644 (file)
@@ -8,13 +8,13 @@
 #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
 #define TVM_RELAY_PASS_PATTERN_UTIL_H_
 
+#include <tvm/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/attrs/nn.h>
 #include <string>
-#include "../op/layout.h"
 
 
 namespace tvm {
@@ -155,9 +155,8 @@ inline bool IsDepthwiseConv2D(const Call& call,
                               const Conv2DAttrs* param,
                               const Layout& kernel_layout) {
   static const Layout kOIHW("OIHW");
-  auto wshape = ConvertLayout(
-      call->args[1]->type_as<TensorTypeNode>()->shape,
-      kernel_layout, kOIHW);
+  const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
+  auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
   return is_const_int(wshape[0], param->groups) &&
       is_const_int(wshape[1], 1);
 }
diff --git a/tests/python/unittest/test_lang_data_layout.py b/tests/python/unittest/test_lang_data_layout.py
new file mode 100644 (file)
index 0000000..73d626e
--- /dev/null
@@ -0,0 +1,65 @@
+"""Test layout and bijective-layout node"""
+
+import tvm
+from topi.util import get_const_tuple
+
+def test_layout():
+    layout = tvm.layout("NCHW16c")
+    assert layout is not None
+    assert isinstance(layout, tvm.tensor.Layout)
+
+    assert layout.factor_of("c") == 16
+    assert layout.factor_of("C") == 16
+    assert layout.factor_of("N") == -1
+
+    assert layout.index_of("N") == 0
+    assert layout.index_of("C") == 1
+    assert layout.index_of("H") == 2
+    assert layout.index_of("W") == 3
+    assert layout.index_of("c") == 4
+    assert layout.index_of("O") == -1
+
+    assert "N" in layout
+    assert "C" in layout
+    assert "H" in layout
+    assert "W" in layout
+    assert "c" in layout
+    assert "O" not in layout
+
+    assert layout[0] == "N"
+    assert layout[1] == "C"
+    assert layout[2] == "H"
+    assert layout[3] == "W"
+    assert layout[4] == "c"
+    assert layout[-1] == "c"
+
+def test_bilayout_convertible():
+    # not convertible
+    assert tvm.bijective_layout("NCHW", "ABCD") is None
+    # convertible
+    assert tvm.bijective_layout("NCHW", "NCHW16c") is not None
+
+def test_bilayout_shape():
+    bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
+    assert isinstance(bilayout, tvm.tensor.BijectiveLayout)
+
+    dst_shape = bilayout.forward_shape((1, 32, 7, 7))
+    assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
+
+    src_shape = bilayout.backward_shape(dst_shape)
+    assert get_const_tuple(src_shape) == (1, 32, 7, 7)
+
+def test_bilayout_index():
+    bilayout = tvm.bijective_layout("NCHW", "NCHW16c")
+
+    dst_index = bilayout.forward_index([0, 18, 6, 6])
+    assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2)
+
+    src_index = bilayout.backward_index([0, 1, 6, 6, 2])
+    assert get_const_tuple(src_index) == (0, 18, 6, 6)
+
+if __name__ == "__main__":
+    test_layout()
+    test_bilayout_convertible()
+    test_bilayout_shape()
+    test_bilayout_index()
index 5f0b758c6424a2c8cb8a4baab25af007736f4516..00c3f999853d23308332c15c96280022818b1d89 100644 (file)
@@ -450,28 +450,5 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
   return tvm::compute(output_shape, l, name, tag);
 }
 
-using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
-
-/*!
- * \brief Transform the layout according to the mapping function \p to_src_indices.
- * \param src the source input.
- * \param dst_shape the output shape.
- * \param to_src_indices the mapping function from input index to output index.
- * \param name output tensor name.
- * \param tag output tensor tag.
- * \return A tensor with shape \p dst_shape.
- */
-inline Tensor layout_transform(const Tensor& src,
-                               const Array<Expr>& dst_shape,
-                               const FLayoutIndicesTransform& to_src_indices,
-                               const std::string name = "layout_transform",
-                               const std::string tag = kInjective) {
-  auto src_shape = src->shape;
-  return compute(
-  dst_shape, [&](const Array<Var>& dst_indices) {
-    return src(to_src_indices(dst_indices));
-  }, name, tag);
-}
-
 }  // namespace topi
 #endif  // TOPI_NN_H_
index e399b8c6978ca0e48a33d2403bd6a33509d7aefe..24ebe5de4a2053b4ec7703fb4a6b8a72eb8086aa 100644 (file)
@@ -16,6 +16,7 @@
 #include "topi/detail/ravel_unravel.h"
 #include "topi/detail/constant_utils.h"
 #include "tvm/tvm.h"
+#include "tvm/data_layout.h"
 
 namespace topi {
 using namespace tvm;
@@ -882,5 +883,43 @@ inline Tensor arange(const Expr start,
   }, name, tag);
 }
 
+/*!
+ * \brief Transform the layout according to \p src_layout and \p dst_layout
+ * \param src the source input.
+ * \param src_layout the source layout.
+ * \param dst_layout the destination layout.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return A tensor with shape in \p dst_layout
+ */
+inline Tensor layout_transform(const Tensor& src,
+                               const std::string& src_layout,
+                               const std::string& dst_layout,
+                               const std::string name = "layout_transform",
+                               const std::string tag = kInjective) {
+  Layout src_layout_struct = LayoutNode::make(src_layout);
+  Layout dst_layout_struct = LayoutNode::make(dst_layout);
+
+  if (src_layout_struct.Equals(dst_layout_struct)) {
+    return src;
+  }
+
+  CHECK(src_layout_struct.defined() && dst_layout_struct.defined())
+    << "cannot convert from/to undefined layout";
+
+  auto layout_converter = BijectiveLayoutNode::make(src_layout_struct, dst_layout_struct);
+  CHECK(layout_converter.defined())
+    << "cannot convert from " << src_layout << " to " << dst_layout;
+
+  Array<Expr> dst_shape = layout_converter.ForwardShape(src->shape);
+
+  return compute(
+    dst_shape, [&](const Array<Var>& dst_indices) {
+      Array<Expr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
+      Array<Expr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
+      return src(src_indices);
+  }, name, tag);
+}
+
 }  // namespace topi
 #endif  // TOPI_TRANSFORM_H_
index 2fb20162a5a7205c65ab5b6b3defb3f81e15d41d..e3ab0b364c651c6210738cef0e3ec9b8a85ab322 100644 (file)
@@ -318,3 +318,20 @@ def arange(start, stop=None, step=1, dtype="float32"):
         stop = start
         start = 0
     return cpp.arange(start, stop, step, dtype)
+
+
+def layout_transform(array, src_layout, dst_layout):
+    """Transform the layout according to src_layout and dst_layout
+
+    Parameters
+    ----------
+    array : tvm.Tensor
+        The source array.
+
+    src_layout : str
+        the source layout.
+
+    dst_layout : str
+        the destination layout.
+    """
+    return cpp.layout_transform(array, src_layout, dst_layout)
index e3fec08cb491429ee619486e9ccfe5989f38e648..aac2d1653c783df8661a4cfc7e932b3001babee9 100644 (file)
@@ -272,6 +272,11 @@ TVM_REGISTER_GLOBAL("topi.split")
   }
   });
 
+TVM_REGISTER_GLOBAL("topi.layout_transform")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = layout_transform(args[0], args[1], args[2]);
+});
+
 TVM_REGISTER_GLOBAL("topi.take")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   if (args.size() == 2) {
index dad527e3951f60a2ed04d8e9a256a6555857aa9a..31e37d4d26f21d0e2e767e9bdb9cd35a023573ae 100644 (file)
@@ -449,6 +449,34 @@ def test_arange():
     verify_arange(20, 1, -1.5)
 
 
+def test_layout_transform():
+    in_shape = (1, 32, 8, 8)
+    A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
+    B = topi.layout_transform(A, "NCHW", "NCHW16c")
+
+    input = np.random.uniform(size=in_shape).astype(A.dtype)
+    output = np.transpose(input, axes=(0, 2, 3, 1))
+    output = np.reshape(output, newshape=(1, 8, 8, 2, 16))
+    output = np.transpose(output, axes=(0, 3, 1, 2, 4))
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        tvm_input = tvm.nd.array(input, ctx)
+        tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.generic.schedule_injective(B)
+        f = tvm.build(s, [A, B], device, name="layout_transform")
+        f(tvm_input, tvm_output)
+        tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
+
+    for backend in get_all_backend():
+        check_device(backend)
+
+
 if __name__ == "__main__":
     test_strided_slice()
     test_concatenate()
@@ -462,3 +490,4 @@ if __name__ == "__main__":
     test_take()
     test_gather_nd()
     test_arange()
+    test_layout_transform()