[Bugfix][IR][ATTRS] Fix AttrEqual for Array and StrMap, double (#5054)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 12 Mar 2020 23:29:06 +0000 (16:29 -0700)
committerGitHub <noreply@github.com>
Thu, 12 Mar 2020 23:29:06 +0000 (16:29 -0700)
- Use fuzzy comparison for double.
- Removed the hack for BatchNormAttrs and DictAttr.

Also removed a warning from text printer printing.

include/tvm/ir/attrs.h
src/ir/attrs.cc
src/printer/doc.cc
src/printer/doc.h
src/printer/meta_data.h
src/relay/analysis/alpha_equal.cc
tests/python/unittest/test_ir_attrs.py

index 899db08..4413fc3 100644 (file)
@@ -143,8 +143,13 @@ class AttrsEqualHandler;
 class AttrsEqual {
  public:
   bool operator()(const double& lhs, const double& rhs) const {
-    return lhs == rhs;
+    // fuzzy float pt comparison
+    constexpr double atol = 1e-9;
+    if (lhs == rhs) return true;
+    double diff = lhs - rhs;
+    return diff > -atol && diff < atol;
   }
+
   bool operator()(const int64_t& lhs, const int64_t& rhs) const {
     return lhs == rhs;
   }
index 4c4c997..868fec6 100644 (file)
@@ -79,7 +79,8 @@ using namespace tir;
 // Equal handler.
 bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
   if (lhs.same_as(rhs)) return true;
-  if (!lhs.defined() || !rhs.defined()) return false;
+  if (!lhs.defined() && rhs.defined()) return false;
+  if (!rhs.defined() && lhs.defined()) return false;
   return this->VisitAttr(lhs, rhs);
 }
 
@@ -96,22 +97,25 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot
 bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
   if (const auto* rhs = other.as<IntImmNode>()) {
     return lhs->value == rhs->value;
+  } else {
+    return false;
   }
-  return false;
 }
 
 bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
   if (const auto* rhs = other.as<FloatImmNode>()) {
     return lhs->value == rhs->value;
+  } else {
+    return false;
   }
-  return false;
 }
 
 bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
   if (const auto* rhs = other.as<StringImmNode>()) {
     return lhs->value == rhs->value;
+  } else {
+    return false;
   }
-  return false;
 }
 
 bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
@@ -120,8 +124,10 @@ bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other)
     for (size_t i = 0; i < lhs->data.size(); ++i) {
       if (!Equal(lhs->data[i], rhs->data[i])) return false;
     }
+    return true;
+  } else {
+    return false;
   }
-  return true;
 }
 
 bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) {
@@ -132,8 +138,10 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other
       if (it == rhs->data.end()) return false;
       if (!Equal(kv.second, it->second)) return false;
     }
+    return true;
+  } else {
+    return false;
   }
-  return true;
 }
 
 #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName)                          \
@@ -340,8 +348,13 @@ bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
 }
 
 TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = args[0].operator Attrs()->ListFieldInfo();
+.set_body_typed([](Attrs attrs) {
+  return attrs->ListFieldInfo();
+});
+
+TVM_REGISTER_GLOBAL("ir.AttrsEqual")
+.set_body_typed([](ObjectRef lhs, ObjectRef rhs) {
+  return AttrsEqual()(lhs, rhs);
 });
 
 }  // namespace tvm
index c5595db..ee260f4 100644 (file)
@@ -40,9 +40,6 @@ class DocTextNode : public DocAtomNode {
 
   explicit DocTextNode(std::string str_val)
       : str(str_val) {
-    if (str.find_first_of("\t\n") != str.npos) {
-      LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
-    }
   }
 
   static constexpr const char* _type_key = "printer.DocText";
@@ -54,6 +51,9 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode);
 class DocText : public DocAtom {
  public:
   explicit DocText(std::string str) {
+    if (str.find_first_of("\t\n") != str.npos) {
+      LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
+    }
     data_ = runtime::make_object<DocTextNode>(str);
   }
 
@@ -125,6 +125,10 @@ Doc Doc::Text(std::string text) {
   return Doc() << DocText(text);
 }
 
+Doc Doc::RawText(std::string text) {
+  return Doc() << DocAtom(runtime::make_object<DocTextNode>(text));
+}
+
 Doc Doc::Indent(int indent, Doc doc) {
   for (size_t i = 0; i < doc.stream_.size(); ++i) {
     if (auto* line = doc.stream_[i].as<DocLineNode>()) {
index 34a284b..7d8d72e 100644 (file)
@@ -111,6 +111,11 @@ class Doc {
    */
   static Doc Text(std::string value);
   /*!
+   * \brief Create a doc that represents raw text(can have new lines)
+   * \return The created doc.
+   */
+  static Doc RawText(std::string value);
+  /*!
    * \brief Create a doc that represents a new line.
    * \return The created doc.
    */
index 6c300fd..d390692 100644 (file)
@@ -121,7 +121,7 @@ class TextMetaDataContext {
    */
   Doc GetMetaSection() const {
     if (meta_data_.size() == 0) return Doc();
-    return Doc::Text(
+    return Doc::RawText(
         SaveJSON(Map<std::string, ObjectRef>(meta_data_.begin(), meta_data_.end())));
   }
 
index 8a07a19..726ccbb 100644 (file)
@@ -30,6 +30,8 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/attrs/nn.h>
 #include "../../ir/attr_functor.h"
+
+
 namespace tvm {
 namespace relay {
 
@@ -50,37 +52,7 @@ class AlphaEqualHandler:
    * \return The comparison result.
    */
   bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
-    if (!lhs.defined() || !rhs.defined()) return false;
-    if (lhs.same_as(rhs)) return true;
-    if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
-      if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
-      return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
-    }
-    if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
-      if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
-      return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
-    }
-    if (const auto lhsm = lhs.as<IRModuleNode>()) {
-      auto rhsm = rhs.as<IRModuleNode>();
-      if (!rhsm) return false;
-      if (lhsm->functions.size() != rhsm->functions.size()) return false;
-      for (const auto& p : lhsm->functions) {
-        if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
-      }
-      if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
-      for (const auto& p : lhsm->type_definitions) {
-        if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
-            !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
-          return false;
-        }
-      }
-      return true;
-    }
-    return AttrEqual(lhs, rhs);
-  }
-
-  bool DoubleEqual(double l, double r) {
-    return true;
+    return VisitAttr(lhs, rhs);
   }
   /*!
    * Check equality of two attributes.
@@ -90,25 +62,7 @@ class AlphaEqualHandler:
    */
   bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
     auto compute = [&]() {
-      if (&lhs == &rhs) return true;
-      if (auto lhsd = lhs.as<DictAttrsNode>()) {
-        auto rhsd = rhs.as<DictAttrsNode>();
-        if (!rhsd) return false;
-        if (lhsd->dict.size() != rhsd->dict.size()) return false;
-        for (const auto& k : lhsd->dict) {
-          if (!Equal(k.second, rhsd->dict[k.first])) return false;
-        }
-        return true;
-      }
-      if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
-        auto rhsbn = rhs.as<BatchNormAttrs>();
-        if (!rhsbn) return false;
-        return (lhsbn->axis == rhsbn->axis)
-          && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
-          && (lhsbn->center == rhsbn->center)
-          && (lhsbn->scale == rhsbn->scale);
-      }
-      return AttrsEqualHandler::Equal(lhs, rhs);
+      return VisitAttr(lhs, rhs);
     };
     return Compare(compute(), lhs, rhs);
   }
@@ -164,6 +118,40 @@ class AlphaEqualHandler:
   }
 
  protected:
+  // So that the new definition of equality in relay can be handled directly.
+  // Specifically, if a DictAttr contains a value defined by a relay AST.
+  // We want to able to recursively check the equality in the attr defined by the relay AST.
+  bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final {
+    if (lhs.same_as(rhs)) return true;
+    if (!lhs.defined() && rhs.defined()) return false;
+    if (!rhs.defined() && lhs.defined()) return false;
+    if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
+      if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
+      return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
+    }
+    if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
+      if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
+      return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
+    }
+    if (const auto lhsm = lhs.as<IRModuleNode>()) {
+      auto rhsm = rhs.as<IRModuleNode>();
+      if (!rhsm) return false;
+      if (lhsm->functions.size() != rhsm->functions.size()) return false;
+      for (const auto& p : lhsm->functions) {
+        if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
+      }
+      if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
+      for (const auto& p : lhsm->type_definitions) {
+        if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
+            !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
+          return false;
+        }
+      }
+      return true;
+    }
+    // Fall back to the object equal case.
+    return AttrsEqualHandler::VisitAttr(lhs, rhs);
+  }
   /*!
    * \brief Check if data type equals each other.
    * \param lhs The left hand operand.
index a2be2b7..f4148ca 100644 (file)
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te
+import tvm.ir._ffi_api
 
 def test_make_attrs():
     try:
@@ -50,6 +50,19 @@ def test_dict_attrs():
     assert len(dattr.items()) == 4
 
 
+def test_attrs_equal():
+    attr_equal = tvm.ir._ffi_api.AttrsEqual
+    dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
+    dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
+    dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
+    assert attr_equal(dattr0, dattr1)
+    assert not attr_equal(dattr0, dattr2)
+    assert not attr_equal({"x": 1}, tvm.runtime.convert(1))
+    assert not attr_equal([1, 2], tvm.runtime.convert(1))
+
+
+
 if __name__ == "__main__":
     test_make_attrs()
     test_dict_attrs()
+    test_attrs_equal()