Add ShapePattern and DataTypePattern (#5760)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Thu, 11 Jun 2020 14:25:09 +0000 (07:25 -0700)
committerGitHub <noreply@github.com>
Thu, 11 Jun 2020 14:25:09 +0000 (23:25 +0900)
include/tvm/relay/dataflow_pattern.h
include/tvm/relay/dataflow_pattern_functor.h
python/tvm/relay/dataflow_pattern/__init__.py
src/relay/ir/dataflow_matcher.cc
src/relay/ir/dataflow_pattern.cc
src/relay/ir/dataflow_pattern_functor.cc
src/relay/ir/indexed_graph.cc
tests/python/relay/test_dataflow_pattern.py

index 80a5d6f..11ac7e3 100644 (file)
@@ -309,6 +309,64 @@ class TypePattern : public DFPattern {
   TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
 };
 
+class ShapePattern;
+/*!
+ * \brief Pattern for Shapes.
+ */
+class ShapePatternNode : public DFPatternNode {
+ public:
+  /*! \brief The pattern. */
+  DFPattern pattern;
+  /*! \brief The type to match */
+  Array<PrimExpr> shape;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("shape", &shape);
+  }
+
+  static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class ShapePattern : public DFPattern {
+ public:
+  TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type);
+  TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode);
+};
+
+class DataTypePattern;
+/*!
+ * \brief Pattern for Types.
+ */
+class DataTypePatternNode : public DFPatternNode {
+ public:
+  /*! \brief The pattern. */
+  DFPattern pattern;
+  /*! \brief The type to match */
+  DataType dtype;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pattern", &pattern);
+    v->Visit("dtype", &dtype);
+  }
+
+  static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class DataTypePattern : public DFPattern {
+ public:
+  TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
+  TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode);
+};
+
 class AttrPattern;
 /*!
  * \brief Pattern for Attributes.
index a1140ae..98c81c9 100644 (file)
@@ -84,8 +84,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
   virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
                             Args... args) DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
@@ -106,13 +108,15 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
-    RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
     return vtable;
   }
@@ -130,13 +134,15 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
   void VisitDFPattern_(const AltPatternNode* op) override;
   void VisitDFPattern_(const AttrPatternNode* op) override;
   void VisitDFPattern_(const CallPatternNode* op) override;
+  void VisitDFPattern_(const ConstantPatternNode* op) override;
+  void VisitDFPattern_(const DataTypePatternNode* op) override;
   void VisitDFPattern_(const DominatorPatternNode* op) override;
   void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const ShapePatternNode* op) override;
   void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
   void VisitDFPattern_(const TuplePatternNode* op) override;
   void VisitDFPattern_(const TypePatternNode* op) override;
   void VisitDFPattern_(const VarPatternNode* op) override;
-  void VisitDFPattern_(const ConstantPatternNode* op) override;
   void VisitDFPattern_(const WildcardPatternNode* op) override;
 
  protected:
index e6a1a5e..915842c 100644 (file)
@@ -97,6 +97,38 @@ class DFPattern(Node):
         """
         return has_type(ttype, self)
 
+    def has_dtype(self, dtype: str):
+        """
+        Add a type constraint to this pattern
+
+        Parameters
+        ----------
+        dtype: str
+            The dtype to match
+
+        Returns
+        -------
+        result: tvm.relay.dataflow_pattern.DFPattern
+            The resulting DataTypePattern
+        """
+        return has_dtype(dtype, self)
+
+    def has_shape(self, shape: List[tvm.ir.PrimExpr]):
+        """
+        Add a type constraint to this pattern
+
+        Parameters
+        ----------
+        shape: List[tvm.ir.PrimExpr]
+            The shape to match
+
+        Returns
+        -------
+        result: tvm.relay.dataflow_pattern.DFPattern
+            The resulting ShapePattern
+        """
+        return has_shape(shape, self)
+
     def match(self, expr: Expr) -> bool:
         """
         Match this pattern to an expression
@@ -293,18 +325,18 @@ def wildcard() -> "DFPattern":
     return WildcardPattern()
 
 
-def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
+def has_type(ttype: tvm.ir.type.Type, pattern: "DFPattern" = None) -> "DFPattern":
     """
     Syntatic sugar for creating a TypePattern
 
     Parameters
     ----------
-    pattern: tvm.relay.dataflow_pattern.DFPattern
-        The pattern that needs type annotation
-
     ttype: tvm.ir.type.Type
         The type to match
 
+    pattern: tvm.relay.dataflow_pattern.DFPattern
+        The pattern that needs type annotation
+
     Returns
     -------
     result: tvm.relay.dataflow_pattern.DFPattern
@@ -315,6 +347,50 @@ def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
     return TypePattern(pattern, ttype)
 
 
+def has_dtype(dtype: str, pattern: "DFPattern" = None) -> "DFPattern":
+    """
+    Syntatic sugar for creating a DataTypePattern
+
+    Parameters
+    ----------
+    dtype: str
+        The dtype to match
+
+    pattern: tvm.relay.dataflow_pattern.DFPattern
+        The pattern that needs type annotation
+
+    Returns
+    -------
+    result: tvm.relay.dataflow_pattern.DFPattern
+        The resulting DataTypePattern
+    """
+    if pattern is None:
+        pattern = wildcard()
+    return DataTypePattern(pattern, dtype)
+
+
+def has_shape(shape: List[tvm.ir.PrimExpr], pattern: "DFPattern" = None) -> "DFPattern":
+    """
+    Syntatic sugar for creating a ShapePattern
+
+    Parameters
+    ----------
+    shape: List[tvm.ir.PrimExpr]
+        The shape to match
+
+    pattern: tvm.relay.dataflow_pattern.DFPattern
+        The pattern that needs type annotation
+
+    Returns
+    -------
+    result: tvm.relay.dataflow_pattern.DFPattern
+        The resulting ShapePattern
+    """
+    if pattern is None:
+        pattern = wildcard()
+    return ShapePattern(pattern, shape)
+
+
 def has_attr(attrs, pattern=None) -> "DFPattern":
     """
     Syntatic sugar for creating an AttrPattern
@@ -514,7 +590,7 @@ class WildcardPattern(DFPattern):
 
 @register_df_node
 class TypePattern(DFPattern):
-    """Get index-th item from a TuplePattern.
+    """A pattern that matches another pattern with a certain type annotation.
 
     Parameters
     ----------
@@ -530,6 +606,40 @@ class TypePattern(DFPattern):
 
 
 @register_df_node
+class DataTypePattern(DFPattern):
+    """A pattern that matches another pattern with certain data type
+
+    Parameters
+    ----------
+    pattern: tvm.relay.dataflow_pattern.DFPattern
+        The input pattern that needs type annotation.
+
+    dtype: str
+        The dtype to match.
+    """
+
+    def __init__(self, pattern: "DFPattern", dtype: str):
+        self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype)
+
+
+@register_df_node
+class ShapePattern(DFPattern):
+    """A pattern that matches another pattern with a certain tensor shape
+
+    Parameters
+    ----------
+    pattern: tvm.relay.dataflow_pattern.DFPattern
+        The input pattern that needs type annotation.
+
+    shape: List[tvm.ir.PrimExpr]
+        The shape to match.
+    """
+
+    def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]):
+        self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape)
+
+
+@register_df_node
 class AttrPattern(DFPattern):
     """Get match an expression with a certain attributes.
     Currently only supports Op Attributes, not call Attributes.
index e9543e3..d01a1e7 100644 (file)
@@ -50,13 +50,15 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
   bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
 
   void ClearMap(size_t watermark);
@@ -393,6 +395,22 @@ bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& ex
   return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
 }
 
+bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+    return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr);
+  }
+  return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) {
+  auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+  if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+    return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr);
+  }
+  return false;
+}
+
 bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
   bool matches = false;
   if (const auto* var_node = expr.as<VarNode>()) {
index 2809131..4664e5f 100644 (file)
@@ -187,6 +187,46 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")";
     });
 
+ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
+  ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>();
+  n->pattern = std::move(pattern);
+  n->shape = std::move(shape);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ShapePatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern")
+    .set_body_typed([](DFPattern pattern, Array<PrimExpr> shape) {
+      return ShapePattern(pattern, shape);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ShapePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      auto* node = static_cast<const ShapePatternNode*>(ref.get());
+      p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")";
+    });
+
+DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
+  ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>();
+  n->pattern = std::move(pattern);
+  n->dtype = std::move(dtype);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DataTypePatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern")
+    .set_body_typed([](DFPattern pattern, DataType dtype) {
+      return DataTypePattern(pattern, dtype);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      auto* node = static_cast<const DataTypePatternNode*>(ref.get());
+      p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
+    });
+
 AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
   ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
   n->pattern = std::move(pattern);
index ee44bcb..7e9f828 100644 (file)
@@ -49,6 +49,11 @@ void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) {
     VisitDFPattern(arg);
   }
 }
+
+void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) {
+  VisitDFPattern(op->pattern);
+}
+
 void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {
   VisitDFPattern(op->parent);
   VisitDFPattern(op->path);
@@ -57,6 +62,8 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {
 
 void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}
 
+void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); }
+
 void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
   VisitDFPattern(op->tuple);
 }
index 0d4b90d..456bf02 100644 (file)
@@ -246,6 +246,13 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
         VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
       }
     }
+
+    void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {}
+
+    void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override {
+      VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
+    }
+
     void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override {
       VisitDFPattern(op->parent, graph_.node_map_[GetRef<DFPattern>(op)]);
       VisitDFPattern(op->path, graph_.node_map_[GetRef<DFPattern>(op)]);
@@ -254,6 +261,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
 
     void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}
 
+    void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override {
+      VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
+    }
+
     void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override {
       VisitDFPattern(op->tuple, graph_.node_map_[GetRef<DFPattern>(op)]);
     }
@@ -270,8 +281,6 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
 
     void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {}
 
-    void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {}
-
     void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {}
   };
   return Annotator(Creator().CreateGraph(pattern)).Annotate();
index 8d67db5..9727e53 100644 (file)
@@ -94,6 +94,20 @@ def test_TypePattern():
     assert ty_pat.type == ttype
 
 
+def test_DataTypePattern():
+    dtype = "float16"
+    pattern = has_dtype(dtype)
+    assert isinstance(pattern, DataTypePattern)
+    assert pattern.dtype == dtype
+
+
+def test_ShapePattern():
+    shape = [10, 10]
+    pattern = has_shape(shape)
+    assert isinstance(pattern, ShapePattern)
+    assert tvm.ir.structural_equal(pattern.shape, shape)
+
+
 def test_AttrPattern():
     op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
     assert isinstance(op, AttrPattern)
@@ -265,6 +279,30 @@ def test_no_match_type():
     assert not ty_pat.match(x)
 
 
+def test_match_dtype():
+    x = relay.var('x', shape=(10, 10), dtype="float32")
+    ty_pat = has_dtype("float32")
+    assert ty_pat.match(x)
+
+
+def test_no_match_dtype():
+    x = relay.var('x', shape=(10, 10), dtype="int32")
+    ty_pat = has_dtype("float32")
+    assert not ty_pat.match(x)
+
+
+def test_match_shape():
+    x = relay.var('x', shape=(10, 10), dtype="float32")
+    ty_pat = has_shape((10, 10))
+    assert ty_pat.match(x)
+
+
+def test_no_match_shape():
+    x = relay.var('x', shape=(10, 10), dtype="int32")
+    ty_pat = has_shape((10, 5))
+    assert not ty_pat.match(x)
+
+
 def test_match_op_attr():
     op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
     op_pat = op(wildcard(), wildcard())
@@ -500,6 +538,54 @@ def test_not_match_dominator():
     assert not diamond.match(out)
 
 
+def test_match_typed_dominator():
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32")
+    reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
+    diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+    # Classic Diamond
+    inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32"))
+    weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32"))
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Check
+    assert diamond.match(out)
+
+def test_no_match_typed_dominator():
+    # Classic Diamond
+    inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32"))
+    weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32"))
+    conv2d = relay.op.nn.conv2d(inp, weight)
+    relu = relay.op.nn.relu(conv2d)
+    relu = relay.op.nn.relu(relu)
+    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+    out = relu + leaky_relu
+
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32")
+    reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 1, 10, 10])
+    diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+    # Check
+    assert not diamond.match(out)
+
+    # Pattern
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float16")
+    reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
+    diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+    # Check
+    assert not diamond.match(out)
+
+
 def test_rewrite():
     x = relay.var('x')
     y = relay.var('y')
@@ -1222,6 +1308,8 @@ if __name__ == "__main__":
     test_TupleGetItemPattern()
     test_AltPattern()
     test_TypePattern()
+    test_DataTypePattern()
+    test_ShapePattern()
     test_AttrPattern()
     test_match_op()
     test_no_match_op()
@@ -1237,6 +1325,10 @@ if __name__ == "__main__":
     test_no_match_tuple()
     test_match_type()
     test_no_match_type()
+    test_match_dtype()
+    test_no_match_dtype()
+    test_match_shape()
+    test_no_match_shape()
     test_match_op_attr()
     test_no_match_op_attr()
     test_match_func_attr()