From 29a3d3a66f31235eb644b38d9e03c156fa5fde7f Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 12 Mar 2020 16:29:06 -0700 Subject: [PATCH] [Bugfix][IR][ATTRS] Fix AttrEqual for Array and StrMap, double (#5054) - Use fuzzy comparison for double. - Removed the hack for BatchNormAttrs and DictAttr. Also removed a warning from text printer printing. --- include/tvm/ir/attrs.h | 7 ++- src/ir/attrs.cc | 29 +++++++---- src/printer/doc.cc | 10 ++-- src/printer/doc.h | 5 ++ src/printer/meta_data.h | 2 +- src/relay/analysis/alpha_equal.cc | 88 +++++++++++++++------------------- tests/python/unittest/test_ir_attrs.py | 15 +++++- 7 files changed, 92 insertions(+), 64 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 899db08..4413fc3 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -143,8 +143,13 @@ class AttrsEqualHandler; class AttrsEqual { public: bool operator()(const double& lhs, const double& rhs) const { - return lhs == rhs; + // fuzzy float pt comparison + constexpr double atol = 1e-9; + if (lhs == rhs) return true; + double diff = lhs - rhs; + return diff > -atol && diff < atol; } + bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 4c4c997..868fec6 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -79,7 +79,8 @@ using namespace tir; // Equal handler. bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; - if (!lhs.defined() || !rhs.defined()) return false; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; return this->VisitAttr(lhs, rhs); } @@ -96,22 +97,25 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; + } else { + return false; } - return false; } bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; + } else { + return false; } - return false; } bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; + } else { + return false; } - return false; } bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) { @@ -120,8 +124,10 @@ bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) for (size_t i = 0; i < lhs->data.size(); ++i) { if (!Equal(lhs->data[i], rhs->data[i])) return false; } + return true; + } else { + return false; } - return true; } bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) { @@ -132,8 +138,10 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other if (it == rhs->data.end()) return false; if (!Equal(kv.second, it->second)) return false; } + return true; + } else { + return false; } - return true; } #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ @@ -340,8 +348,13 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const { } TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Attrs()->ListFieldInfo(); +.set_body_typed([](Attrs attrs) { + return attrs->ListFieldInfo(); +}); + +TVM_REGISTER_GLOBAL("ir.AttrsEqual") +.set_body_typed([](ObjectRef lhs, ObjectRef rhs) { + return AttrsEqual()(lhs, rhs); }); } // namespace tvm diff --git a/src/printer/doc.cc b/src/printer/doc.cc index c5595db..ee260f4 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -40,9 +40,6 @@ class DocTextNode : public DocAtomNode { explicit DocTextNode(std::string str_val) : str(str_val) { - if (str.find_first_of("\t\n") != str.npos) { - LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; - } } static constexpr const char* _type_key = "printer.DocText"; @@ -54,6 +51,9 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode); class DocText : public DocAtom { public: explicit DocText(std::string str) { + if (str.find_first_of("\t\n") != str.npos) { + LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; + } data_ = runtime::make_object(str); } @@ -125,6 +125,10 @@ Doc Doc::Text(std::string text) { return Doc() << DocText(text); } +Doc Doc::RawText(std::string text) { + return Doc() << DocAtom(runtime::make_object(text)); +} + Doc Doc::Indent(int indent, Doc doc) { for (size_t i = 0; i < doc.stream_.size(); ++i) { if (auto* line = doc.stream_[i].as()) { diff --git a/src/printer/doc.h b/src/printer/doc.h index 34a284b..7d8d72e 100644 --- a/src/printer/doc.h +++ b/src/printer/doc.h @@ -111,6 +111,11 @@ class Doc { */ static Doc Text(std::string value); /*! + * \brief Create a doc that represents raw text(can have new lines) + * \return The created doc. + */ + static Doc RawText(std::string value); + /*! * \brief Create a doc that represents a new line. * \return The created doc. */ diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index 6c300fd..d390692 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -121,7 +121,7 @@ class TextMetaDataContext { */ Doc GetMetaSection() const { if (meta_data_.size() == 0) return Doc(); - return Doc::Text( + return Doc::RawText( SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); } diff --git a/src/relay/analysis/alpha_equal.cc b/src/relay/analysis/alpha_equal.cc index 8a07a19..726ccbb 100644 --- a/src/relay/analysis/alpha_equal.cc +++ b/src/relay/analysis/alpha_equal.cc @@ -30,6 +30,8 @@ #include #include #include "../../ir/attr_functor.h" + + namespace tvm { namespace relay { @@ -50,37 +52,7 @@ class AlphaEqualHandler: * \return The comparison result. */ bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) { - if (!lhs.defined() || !rhs.defined()) return false; - if (lhs.same_as(rhs)) return true; - if (lhs->IsInstance() || rhs->IsInstance()) { - if (!rhs->IsInstance() || !lhs->IsInstance()) return false; - return TypeEqual(Downcast(lhs), Downcast(rhs)); - } - if (lhs->IsInstance() || rhs->IsInstance()) { - if (!rhs->IsInstance() || !lhs->IsInstance()) return false; - return ExprEqual(Downcast(lhs), Downcast(rhs)); - } - if (const auto lhsm = lhs.as()) { - auto rhsm = rhs.as(); - if (!rhsm) return false; - if (lhsm->functions.size() != rhsm->functions.size()) return false; - for (const auto& p : lhsm->functions) { - if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false; - } - if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; - for (const auto& p : lhsm->type_definitions) { - if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) || - !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) { - return false; - } - } - return true; - } - return AttrEqual(lhs, rhs); - } - - bool DoubleEqual(double l, double r) { - return true; + return VisitAttr(lhs, rhs); } /*! * Check equality of two attributes. @@ -90,25 +62,7 @@ class AlphaEqualHandler: */ bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) { auto compute = [&]() { - if (&lhs == &rhs) return true; - if (auto lhsd = lhs.as()) { - auto rhsd = rhs.as(); - if (!rhsd) return false; - if (lhsd->dict.size() != rhsd->dict.size()) return false; - for (const auto& k : lhsd->dict) { - if (!Equal(k.second, rhsd->dict[k.first])) return false; - } - return true; - } - if (auto lhsbn = lhs.as()) { - auto rhsbn = rhs.as(); - if (!rhsbn) return false; - return (lhsbn->axis == rhsbn->axis) - && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon) - && (lhsbn->center == rhsbn->center) - && (lhsbn->scale == rhsbn->scale); - } - return AttrsEqualHandler::Equal(lhs, rhs); + return VisitAttr(lhs, rhs); }; return Compare(compute(), lhs, rhs); } @@ -164,6 +118,40 @@ class AlphaEqualHandler: } protected: + // So that the new definition of equality in relay can be handled directly. + // Specifically, if a DictAttr contains a value defined by a relay AST. + // We want to able to recursively check the equality in the attr defined by the relay AST. + bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final { + if (lhs.same_as(rhs)) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; + if (lhs->IsInstance() || rhs->IsInstance()) { + if (!rhs->IsInstance() || !lhs->IsInstance()) return false; + return TypeEqual(Downcast(lhs), Downcast(rhs)); + } + if (lhs->IsInstance() || rhs->IsInstance()) { + if (!rhs->IsInstance() || !lhs->IsInstance()) return false; + return ExprEqual(Downcast(lhs), Downcast(rhs)); + } + if (const auto lhsm = lhs.as()) { + auto rhsm = rhs.as(); + if (!rhsm) return false; + if (lhsm->functions.size() != rhsm->functions.size()) return false; + for (const auto& p : lhsm->functions) { + if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false; + } + if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; + for (const auto& p : lhsm->type_definitions) { + if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) || + !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) { + return false; + } + } + return true; + } + // Fall back to the object equal case. + return AttrsEqualHandler::VisitAttr(lhs, rhs); + } /*! * \brief Check if data type equals each other. * \param lhs The left hand operand. diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py index a2be2b7..f4148ca 100644 --- a/tests/python/unittest/test_ir_attrs.py +++ b/tests/python/unittest/test_ir_attrs.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te +import tvm.ir._ffi_api def test_make_attrs(): try: @@ -50,6 +50,19 @@ def test_dict_attrs(): assert len(dattr.items()) == 4 +def test_attrs_equal(): + attr_equal = tvm.ir._ffi_api.AttrsEqual + dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20]) + dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1) + dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None) + assert attr_equal(dattr0, dattr1) + assert not attr_equal(dattr0, dattr2) + assert not attr_equal({"x": 1}, tvm.runtime.convert(1)) + assert not attr_equal([1, 2], tvm.runtime.convert(1)) + + + if __name__ == "__main__": test_make_attrs() test_dict_attrs() + test_attrs_equal() -- 2.7.4