TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
};
+class ShapePattern;
+/*!
+ * \brief Pattern for Shapes.
+ */
+class ShapePatternNode : public DFPatternNode {
+ public:
+ /*! \brief The pattern. */
+ DFPattern pattern;
+ /*! \brief The type to match */
+ Array<PrimExpr> shape;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("pattern", &pattern);
+ v->Visit("shape", &shape);
+ }
+
+ static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class ShapePattern : public DFPattern {
+ public:
+ TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type);
+ TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode);
+};
+
+class DataTypePattern;
+/*!
+ * \brief Pattern for Types.
+ */
+class DataTypePatternNode : public DFPatternNode {
+ public:
+ /*! \brief The pattern. */
+ DFPattern pattern;
+ /*! \brief The type to match */
+ DataType dtype;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("pattern", &pattern);
+ v->Visit("dtype", &dtype);
+ }
+
+ static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief A pattern which matches a type in another pattern
+ */
+class DataTypePattern : public DFPattern {
+ public:
+ TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
+ TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode);
+};
+
class AttrPattern;
/*!
* \brief Pattern for Attributes.
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
+ RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
+ RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+ RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
- RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
return vtable;
}
void VisitDFPattern_(const AltPatternNode* op) override;
void VisitDFPattern_(const AttrPatternNode* op) override;
void VisitDFPattern_(const CallPatternNode* op) override;
+ void VisitDFPattern_(const ConstantPatternNode* op) override;
+ void VisitDFPattern_(const DataTypePatternNode* op) override;
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
+ void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
- void VisitDFPattern_(const ConstantPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
protected:
"""
return has_type(ttype, self)
+ def has_dtype(self, dtype: str):
+ """
+ Add a type constraint to this pattern
+
+ Parameters
+ ----------
+ dtype: str
+ The dtype to match
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting DataTypePattern
+ """
+ return has_dtype(dtype, self)
+
+ def has_shape(self, shape: List[tvm.ir.PrimExpr]):
+ """
+ Add a type constraint to this pattern
+
+ Parameters
+ ----------
+ shape: List[tvm.ir.PrimExpr]
+ The shape to match
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting ShapePattern
+ """
+ return has_shape(shape, self)
+
def match(self, expr: Expr) -> bool:
"""
Match this pattern to an expression
return WildcardPattern()
-def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
+def has_type(ttype: tvm.ir.type.Type, pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a TypePattern
Parameters
----------
- pattern: tvm.relay.dataflow_pattern.DFPattern
- The pattern that needs type annotation
-
ttype: tvm.ir.type.Type
The type to match
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The pattern that needs type annotation
+
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
return TypePattern(pattern, ttype)
+def has_dtype(dtype: str, pattern: "DFPattern" = None) -> "DFPattern":
+ """
+ Syntatic sugar for creating a DataTypePattern
+
+ Parameters
+ ----------
+ dtype: str
+ The dtype to match
+
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The pattern that needs type annotation
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting DataTypePattern
+ """
+ if pattern is None:
+ pattern = wildcard()
+ return DataTypePattern(pattern, dtype)
+
+
+def has_shape(shape: List[tvm.ir.PrimExpr], pattern: "DFPattern" = None) -> "DFPattern":
+ """
+ Syntatic sugar for creating a ShapePattern
+
+ Parameters
+ ----------
+ shape: List[tvm.ir.PrimExpr]
+ The shape to match
+
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The pattern that needs type annotation
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting ShapePattern
+ """
+ if pattern is None:
+ pattern = wildcard()
+ return ShapePattern(pattern, shape)
+
+
def has_attr(attrs, pattern=None) -> "DFPattern":
"""
Syntatic sugar for creating an AttrPattern
@register_df_node
class TypePattern(DFPattern):
- """Get index-th item from a TuplePattern.
+ """A pattern that matches another pattern with a certain type annotation.
Parameters
----------
@register_df_node
+class DataTypePattern(DFPattern):
+ """A pattern that matches another pattern with certain data type
+
+ Parameters
+ ----------
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The input pattern that needs type annotation.
+
+ dtype: str
+ The dtype to match.
+ """
+
+ def __init__(self, pattern: "DFPattern", dtype: str):
+ self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype)
+
+
+@register_df_node
+class ShapePattern(DFPattern):
+ """A pattern that matches another pattern with a certain tensor shape
+
+ Parameters
+ ----------
+ pattern: tvm.relay.dataflow_pattern.DFPattern
+ The input pattern that needs type annotation.
+
+ shape: List[tvm.ir.PrimExpr]
+ The shape to match.
+ """
+
+ def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]):
+ self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape)
+
+
+@register_df_node
class AttrPattern(DFPattern):
"""Get match an expression with a certain attributes.
Currently only supports Op Attributes, not call Attributes.
bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
void ClearMap(size_t watermark);
return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
}
+bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) {
+ auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+ if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+ return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr);
+ }
+ return false;
+}
+
+bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) {
+ auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
+ if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+ return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr);
+ }
+ return false;
+}
+
bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* var_node = expr.as<VarNode>()) {
p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")";
});
+ShapePattern::ShapePattern(DFPattern pattern, Array<PrimExpr> shape) {
+ ObjectPtr<ShapePatternNode> n = make_object<ShapePatternNode>();
+ n->pattern = std::move(pattern);
+ n->shape = std::move(shape);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ShapePatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern")
+ .set_body_typed([](DFPattern pattern, Array<PrimExpr> shape) {
+ return ShapePattern(pattern, shape);
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<ShapePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const ShapePatternNode*>(ref.get());
+ p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")";
+ });
+
+DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) {
+ ObjectPtr<DataTypePatternNode> n = make_object<DataTypePatternNode>();
+ n->pattern = std::move(pattern);
+ n->dtype = std::move(dtype);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DataTypePatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern")
+ .set_body_typed([](DFPattern pattern, DataType dtype) {
+ return DataTypePattern(pattern, dtype);
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<DataTypePatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const DataTypePatternNode*>(ref.get());
+ p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
+ });
+
AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) {
ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
n->pattern = std::move(pattern);
VisitDFPattern(arg);
}
}
+
+void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) {
+ VisitDFPattern(op->pattern);
+}
+
void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) {
VisitDFPattern(op->parent);
VisitDFPattern(op->path);
void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}
+void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); }
+
void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
VisitDFPattern(op->tuple);
}
VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
}
}
+
+ void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {}
+
+ void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override {
+ VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
+ }
+
void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->parent, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->path, graph_.node_map_[GetRef<DFPattern>(op)]);
void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}
+ void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override {
+ VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
+ }
+
void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->tuple, graph_.node_map_[GetRef<DFPattern>(op)]);
}
void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {}
- void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {}
-
void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {}
};
return Annotator(Creator().CreateGraph(pattern)).Annotate();
assert ty_pat.type == ttype
+def test_DataTypePattern():
+ dtype = "float16"
+ pattern = has_dtype(dtype)
+ assert isinstance(pattern, DataTypePattern)
+ assert pattern.dtype == dtype
+
+
+def test_ShapePattern():
+ shape = [10, 10]
+ pattern = has_shape(shape)
+ assert isinstance(pattern, ShapePattern)
+ assert tvm.ir.structural_equal(pattern.shape, shape)
+
+
def test_AttrPattern():
op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
assert isinstance(op, AttrPattern)
assert not ty_pat.match(x)
+def test_match_dtype():
+ x = relay.var('x', shape=(10, 10), dtype="float32")
+ ty_pat = has_dtype("float32")
+ assert ty_pat.match(x)
+
+
+def test_no_match_dtype():
+ x = relay.var('x', shape=(10, 10), dtype="int32")
+ ty_pat = has_dtype("float32")
+ assert not ty_pat.match(x)
+
+
+def test_match_shape():
+ x = relay.var('x', shape=(10, 10), dtype="float32")
+ ty_pat = has_shape((10, 10))
+ assert ty_pat.match(x)
+
+
+def test_no_match_shape():
+ x = relay.var('x', shape=(10, 10), dtype="int32")
+ ty_pat = has_shape((10, 5))
+ assert not ty_pat.match(x)
+
+
def test_match_op_attr():
op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
assert not diamond.match(out)
+def test_match_typed_dominator():
+ # Pattern
+ is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32")
+ reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
+ diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+ # Classic Diamond
+ inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32"))
+ weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32"))
+ conv2d = relay.op.nn.conv2d(inp, weight)
+ relu = relay.op.nn.relu(conv2d)
+ relu = relay.op.nn.relu(relu)
+ leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+ out = relu + leaky_relu
+
+ # Check
+ assert diamond.match(out)
+
+def test_no_match_typed_dominator():
+ # Classic Diamond
+ inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32"))
+ weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32"))
+ conv2d = relay.op.nn.conv2d(inp, weight)
+ relu = relay.op.nn.relu(conv2d)
+ relu = relay.op.nn.relu(relu)
+ leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
+ out = relu + leaky_relu
+
+ # Pattern
+ is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32")
+ reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 1, 10, 10])
+ diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+ # Check
+ assert not diamond.match(out)
+
+ # Pattern
+ is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float16")
+ reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10])
+ diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
+
+ # Check
+ assert not diamond.match(out)
+
+
def test_rewrite():
x = relay.var('x')
y = relay.var('y')
test_TupleGetItemPattern()
test_AltPattern()
test_TypePattern()
+ test_DataTypePattern()
+ test_ShapePattern()
test_AttrPattern()
test_match_op()
test_no_match_op()
test_no_match_tuple()
test_match_type()
test_no_match_type()
+ test_match_dtype()
+ test_no_match_dtype()
+ test_match_shape()
+ test_no_match_shape()
test_match_op_attr()
test_no_match_op_attr()
test_match_func_attr()