[Relay][QNN] Moving Conv, Dense, Concatenate InferTypes to header for sharing. (...
authorAnimesh Jain <anijain@umich.edu>
Fri, 30 Aug 2019 16:12:03 +0000 (09:12 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 30 Aug 2019 16:12:03 +0000 (09:12 -0700)
src/relay/op/nn/convolution.cc
src/relay/op/nn/convolution.h [new file with mode: 0644]
src/relay/op/nn/nn.cc
src/relay/op/nn/nn.h [new file with mode: 0644]
src/relay/op/tensor/transform.cc
src/relay/op/tensor/transform.h [new file with mode: 0644]

index 5eb54a1..2f59fb9 100644 (file)
@@ -29,6 +29,7 @@
 #include <vector>
 
 #include "../../pass/alter_op_layout.h"
+#include "convolution.h"
 
 namespace tvm {
 namespace relay {
@@ -36,111 +37,6 @@ namespace relay {
 // relay.nn.conv2d
 TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
 
-bool Conv2DRel(const Array<Type>& types,
-               int num_inputs,
-               const Attrs& attrs,
-               const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
-  const auto* data = types[0].as<TensorTypeNode>();
-  const auto* weight = types[1].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-  static const Layout kNCHW("NCHW");
-  static const Layout kOIHW("OIHW");
-
-  const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
-  CHECK(param != nullptr);
-  const Layout in_layout(param->data_layout);
-  const Layout kernel_layout(param->kernel_layout);
-
-  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
-  CHECK(trans_in_layout.defined())
-    << "Conv only support input layouts that are convertible from NCHW."
-    << " But got " << in_layout;
-
-  const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
-  CHECK(trans_kernel_layout.defined())
-    << "Conv only support kernel layouts that are convertible from OIHW."
-    << " But got "<< kernel_layout;
-
-  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
-  const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
-  CHECK(trans_out_layout.defined())
-      << "Conv only support output layouts that are convertible from NCHW."
-      << " But got " << out_layout;
-
-  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
-
-  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
-  // infer weight if the kernel_size and channels are defined
-  if (param->kernel_size.defined() && param->channels.defined()) {
-    CHECK_EQ(param->kernel_size.size(), 2);
-    CHECK_EQ(param->dilation.size(), 2);
-    Array<IndexExpr> wshape;
-
-    if (tvm::ir::Equal(param->channels, param->groups)) {
-      // infer weight's shape for depthwise convolution
-      wshape = {
-         {dshape_nchw[1],
-          param->groups / dshape_nchw[1],
-          param->kernel_size[0],
-          param->kernel_size[1]}};
-    } else {
-      wshape = {
-         {param->channels,
-          dshape_nchw[1] / param->groups,
-          param->kernel_size[0],
-          param->kernel_size[1]}};
-    }
-
-    wshape = trans_kernel_layout.BackwardShape(wshape);
-    channels = param->channels;
-    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
-    DataType weight_dtype = data->dtype;
-    if (weight != nullptr) {
-      weight_dtype = weight->dtype;
-    }
-    // assign result to reporter
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
-  } else {
-    // use weight to infer the conv shape.
-    if (weight == nullptr) return false;
-    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
-    if (param->kernel_size.defined()) {
-      CHECK_EQ(param->kernel_size.size(), 2);
-      // check the size
-      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
-            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
-          << "Conv2D: shape of weight is inconsistent with kernel_size, "
-          << " kernel_size=" << param->kernel_size
-          << " wshape=" << wshape;
-    }
-    if (param->channels.defined()) {
-      CHECK(reporter->AssertEQ(param->channels, wshape[0]))
-          << "Conv2D: shape of weight is inconsistent with channels, "
-          << " channels=" << param->channels
-          << " wshape=" << wshape;
-    }
-    CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
-    channels = wshape[0];
-    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
-    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
-  }
-  // dilation
-  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
-
-  oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
-  oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
-  DataType out_dtype = param->out_dtype;
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-  oshape = trans_out_layout.BackwardShape(oshape);
-  // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
-  return true;
-}
-
 template<typename T>
 Array<Array<Layout> > Conv2DInferCorrectLayout(
     const Attrs& attrs,
@@ -208,7 +104,7 @@ with the layer input to produce a tensor of outputs.
 .add_argument("data", "Tensor", "The input tensor.")
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(2)
-.add_type_rel("Conv2D", Conv2DRel)
+.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
 
 
@@ -770,7 +666,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
 .add_argument("data", "Tensor", "The input tensor.")
 .add_argument("weight", "Tensor", "The weight tensor.")
 .set_support_level(10)
-.add_type_rel("Conv2D", Conv2DRel)
+.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
         Conv2DInferCorrectLayout<Conv2DAttrs>);
 
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
new file mode 100644 (file)
index 0000000..fb58447
--- /dev/null
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/op/nn/convolution.h
+ * \brief Properties def of convlution operator for sharing.
+ */
+#ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_
+#define TVM_RELAY_OP_NN_CONVOLUTION_H_
+
+#include <string>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+template <typename AttrType>
+bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  static const Layout kNCHW("NCHW");
+  static const Layout kOIHW("OIHW");
+
+  const AttrType* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
+  CHECK(trans_in_layout.defined())
+      << "Conv only support input layouts that are convertible from NCHW."
+      << " But got " << in_layout;
+
+  const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
+  CHECK(trans_kernel_layout.defined())
+      << "Conv only support kernel layouts that are convertible from OIHW."
+      << " But got " << kernel_layout;
+
+  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+  const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
+  CHECK(trans_out_layout.defined())
+      << "Conv only support output layouts that are convertible from NCHW."
+      << " But got " << out_layout;
+
+  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+  // infer weight if the kernel_size and channels are defined
+  if (param->kernel_size.defined() && param->channels.defined()) {
+    CHECK_EQ(param->kernel_size.size(), 2);
+    CHECK_EQ(param->dilation.size(), 2);
+    Array<IndexExpr> wshape;
+
+    if (tvm::ir::Equal(param->channels, param->groups)) {
+      // infer weight's shape for depthwise convolution
+      wshape = {{dshape_nchw[1], param->groups / dshape_nchw[1], param->kernel_size[0],
+                 param->kernel_size[1]}};
+    } else {
+      wshape = {{param->channels, dshape_nchw[1] / param->groups, param->kernel_size[0],
+                 param->kernel_size[1]}};
+    }
+
+    wshape = trans_kernel_layout.BackwardShape(wshape);
+    channels = param->channels;
+    dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+    DataType weight_dtype = data->dtype;
+    if (weight != nullptr) {
+      weight_dtype = weight->dtype;
+    }
+    // assign result to reporter
+    reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+  } else {
+    // use weight to infer the conv shape.
+    if (weight == nullptr) return false;
+    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
+    if (param->kernel_size.defined()) {
+      CHECK_EQ(param->kernel_size.size(), 2);
+      // check the size
+      CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
+            reporter->AssertEQ(param->kernel_size[1], wshape[3]))
+          << "Conv2D: shape of weight is inconsistent with kernel_size, "
+          << " kernel_size=" << param->kernel_size << " wshape=" << wshape;
+    }
+    if (param->channels.defined()) {
+      CHECK(reporter->AssertEQ(param->channels, wshape[0]))
+          << "Conv2D: shape of weight is inconsistent with channels, "
+          << " channels=" << param->channels << " wshape=" << wshape;
+    }
+    CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
+    channels = wshape[0];
+    dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
+    dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
+  }
+  // dilation
+  Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
+
+  oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
+  oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = trans_out_layout.BackwardShape(oshape);
+  // assign output type
+  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  return true;
+}
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_OP_NN_CONVOLUTION_H_
index 2c03bba..42a0f01 100644 (file)
@@ -35,6 +35,7 @@
 #include "../type_relations.h"
 #include "../../pass/alter_op_layout.h"
 #include "../op_common.h"
+#include "nn.h"
 
 namespace tvm {
 namespace relay {
@@ -102,45 +103,6 @@ RELAY_REGISTER_OP("nn.bias_add")
 // relay.nn.dense
 TVM_REGISTER_NODE_TYPE(DenseAttrs);
 
-
-bool DenseRel(const Array<Type>& types,
-              int num_inputs,
-              const Attrs& attrs,
-              const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
-  const auto* data = types[0].as<TensorTypeNode>();
-  const auto* weight = types[1].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const DenseAttrs* param = attrs.as<DenseAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK(static_cast<int>(data->shape.size()) != 0);
-
-  Array<tvm::Expr> oshape = data->shape;
-  if (param->units.defined()) {
-    Array<tvm::Expr> dshape = data->shape;
-    // validate the weight shape is proper if defined
-    // Assign weight type
-    Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
-    oshape.Set((oshape.size() - 1), param->units);
-  } else {
-    if (weight == nullptr) return false;
-    Array<tvm::Expr> wshape = weight->shape;
-    oshape.Set((oshape.size() - 1), wshape[0]);
-  }
-
-  DataType out_dtype = param->out_dtype;
-  if (out_dtype.bits() == 0) {
-    out_dtype = data->dtype;
-  }
-  // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
-  return true;
-}
-
-
 // Positional relay function to create dense operator used by frontend FFI.
 Expr MakeDense(Expr data,
                Expr weight,
@@ -171,7 +133,7 @@ RELAY_REGISTER_OP("nn.dense")
 .add_argument("data", "nD Tensor", "Input data.")
 .add_argument("weight", "2D Tensor", "Weight matrix.")
 .set_support_level(1)
-.add_type_rel("Dense", DenseRel);
+.add_type_rel("Dense", DenseRel<DenseAttrs>);
 
 // relay.leaky_relu
 TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
new file mode 100644 (file)
index 0000000..2c65d25
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/op/nn/nn.h
+ * \brief Properties def of nn operators for sharing.
+ */
+#ifndef TVM_RELAY_OP_NN_NN_H_
+#define TVM_RELAY_OP_NN_NN_H_
+
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+template <typename AttrType>
+bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+              const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* weight = types[1].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const AttrType* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+
+  CHECK(static_cast<int>(data->shape.size()) != 0);
+
+  Array<tvm::Expr> oshape = data->shape;
+  if (param->units.defined()) {
+    Array<tvm::Expr> dshape = data->shape;
+    // validate the weight shape is proper if defined
+    // Assign weight type
+    Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
+    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
+    oshape.Set((oshape.size() - 1), param->units);
+  } else {
+    if (weight == nullptr) return false;
+    Array<tvm::Expr> wshape = weight->shape;
+    oshape.Set((oshape.size() - 1), wshape[0]);
+  }
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  // assign output type
+  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  return true;
+}
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_OP_NN_NN_H_
index b39c282..c3975c3 100644 (file)
@@ -37,6 +37,7 @@
 #include "../op_common.h"
 #include "../../../arithmetic/compute_expr.h"
 #include "../../pass/alter_op_layout.h"
+#include "transform.h"
 
 namespace tvm {
 namespace relay {
@@ -210,86 +211,6 @@ RELAY_REGISTER_OP("expand_dims")
 // relay.concatenate
 TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
 
-bool ConcatenateRel(const Array<Type>& types,
-                    int num_inputs,
-                    const Attrs& attrs,
-                    const TypeReporter& reporter) {
-  // types: [data, result]
-  CHECK_EQ(types.size(), 2);
-  /* If we receive a tuple we can continue, if we receive
-   * anything but an incomplete type we should signal an
-   * error.
-  */
-  const auto* tensor_tuple = types[0].as<TupleTypeNode>();
-  if (tensor_tuple == nullptr) {
-    throw relay::Error(
-        RELAY_ERROR(
-          "concatenate requires a tuple of tensors as the first argument, found "
-        << PrettyPrint(types[0])));
-  } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
-    return false;
-  }
-
-  const auto* param = attrs.as<ConcatenateAttrs>();
-  if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
-    return false;
-  }
-  const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
-  // Sanity check: ndim and dtype.
-  const int ndim = static_cast<int>(first->shape.size());
-  const DataType dtype = first->dtype;
-
-  for (const Type& ele : tensor_tuple->fields) {
-    if (ele.as<IncompleteTypeNode>()) {
-      return false;
-    }
-
-    const auto& e = Downcast<TensorType>(ele);
-
-    int e_ndim = static_cast<int>(e->shape.size());
-    const DataType& e_dtype = e->dtype;
-    if (e_ndim != ndim) {
-      throw relay::Error("relay.concatenate requires all tensors have the same ndim");
-    }
-    if (e_dtype != dtype) {
-      throw relay::Error("relay.concatenate requires all tensors have the same dtype");
-    }
-  }
-  // Sanity check: axis
-  int axis = param->axis;
-  if (!(-ndim <= axis && axis < ndim)) {
-    throw relay::Error(RELAY_ERROR(
-      "concatenate only accepts `axis` in [-ndim, ndim)" <<
-      ", but got axis = " << axis <<
-      ", and ndim = " << ndim));
-  }
-  axis = axis < 0 ? ndim + axis : axis;
-  // Calculate shape
-  std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
-  IndexExpr &concat_dim = oshape[axis];
-  bool has_any = false;
-  if (concat_dim.as<Any>()) {
-    has_any = true;
-  } else {
-    for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
-      const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
-      if (e->shape[axis].as<Any>()) {
-        has_any = true;
-        break;
-      }
-      concat_dim += e->shape[axis];
-    }
-  }
-
-  if (has_any) {
-    concat_dim = Any::make();
-  }
-
-  auto rtype = TensorTypeNode::make(oshape, dtype);
-  reporter->Assign(types[1], rtype);
-  return true;
-}
-
 Array<Tensor> ConcatenateCompute(const Attrs& attrs,
                           const Array<Tensor>& inputs,
                           const Type& out_type,
@@ -358,7 +279,7 @@ RELAY_REGISTER_OP("concatenate")
 .set_num_inputs(1)
 .add_argument("data", "Tensor", "The input list of tensors.")
 .set_support_level(1)
-.add_type_rel("Concatenate", ConcatenateRel)
+.add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
 .set_attr<FTVMCompute>("FTVMCompute", ConcatenateCompute)
 .set_attr<TOpPattern>("TOpPattern", kInjective);
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
new file mode 100644 (file)
index 0000000..3a4d50b
--- /dev/null
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/op/tensor/transform.h
+ * \brief Transform op attributes that can be shared among Relay and its dialects.
+ */
+#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
+#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
+
+#include <vector>
+#include <algorithm>
+#include <limits>
+#include <string>
+#include <unordered_set>
+#include <utility>
+
+namespace tvm {
+namespace relay {
+
+template <typename AttrType>
+bool ConcatenateRel(const Array<Type>& types,
+                    int num_inputs,
+                    const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  // types: [data, result]
+  CHECK_EQ(types.size(), 2);
+  /* If we receive a tuple we can continue, if we receive
+   * anything but an incomplete type we should signal an
+   * error.
+  */
+  const auto* tensor_tuple = types[0].as<TupleTypeNode>();
+  if (tensor_tuple == nullptr) {
+    throw relay::Error(
+        RELAY_ERROR(
+          "concatenate requires a tuple of tensors as the first argument, found "
+        << PrettyPrint(types[0])));
+  } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
+    return false;
+  }
+
+  const auto* param = attrs.as<AttrType>();
+  if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
+    return false;
+  }
+  const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
+  // Sanity check: ndim and dtype.
+  const int ndim = static_cast<int>(first->shape.size());
+  const DataType dtype = first->dtype;
+
+  for (const Type& ele : tensor_tuple->fields) {
+    if (ele.as<IncompleteTypeNode>()) {
+      return false;
+    }
+
+    const auto& e = Downcast<TensorType>(ele);
+
+    int e_ndim = static_cast<int>(e->shape.size());
+    const DataType& e_dtype = e->dtype;
+    if (e_ndim != ndim) {
+      throw relay::Error("relay.concatenate requires all tensors have the same ndim");
+    }
+    if (e_dtype != dtype) {
+      throw relay::Error("relay.concatenate requires all tensors have the same dtype");
+    }
+  }
+  // Sanity check: axis
+  int axis = param->axis;
+  if (!(-ndim <= axis && axis < ndim)) {
+    throw relay::Error(RELAY_ERROR(
+      "concatenate only accepts `axis` in [-ndim, ndim)" <<
+      ", but got axis = " << axis <<
+      ", and ndim = " << ndim));
+  }
+  axis = axis < 0 ? ndim + axis : axis;
+  // Calculate shape
+  std::vector<IndexExpr> oshape(first->shape.begin(), first->shape.end());
+  IndexExpr &concat_dim = oshape[axis];
+  bool has_any = false;
+  if (concat_dim.as<Any>()) {
+    has_any = true;
+  } else {
+    for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
+      const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
+      if (e->shape[axis].as<Any>()) {
+        has_any = true;
+        break;
+      }
+      concat_dim += e->shape[axis];
+    }
+  }
+
+  if (has_any) {
+    concat_dim = Any::make();
+  }
+
+  auto rtype = TensorTypeNode::make(oshape, dtype);
+  reporter->Assign(types[1], rtype);
+  return true;
+}
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_OP_TENSOR_TRANSFORM_H_