improve deep equality check in alias annotation test (#15031)
authorMichael Suo <suo@fb.com>
Tue, 11 Dec 2018 19:44:27 +0000 (11:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 20:14:00 +0000 (12:14 -0800)
Summary:
Previously we were returning true if either IValue wasn't a tensor, which…is bad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15031

Differential Revision: D13409759

Pulled By: suo

fbshipit-source-id: f8bdcd05d334c1276ce46f55812065d358c1ff5d

aten/src/ATen/core/ivalue.cpp
aten/src/ATen/core/ivalue.h
torch/csrc/jit/passes/utils/check_alias_annotation.cpp

index ac06360..73aaf3c 100644 (file)
@@ -80,4 +80,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
 
 #undef TORCH_FORALL_TAGS
 
+void IValue::dump() const {
+  std::cout << *this << "\n";
+}
+
 } // namespace c10
index 2191cc0..1c76d61 100644 (file)
@@ -134,6 +134,8 @@ struct CAFFE2_API IValue final {
     return *this;
   }
 
+  void dump() const;
+
   bool isAliasOf(const IValue& rhs) const {
     if (this->tag != rhs.tag) {
       // Trivially don't alias if the type is different
index 1602ac2..f81837d 100644 (file)
@@ -56,12 +56,19 @@ Stack deepCopy(const Stack& stack) {
 }
 
 bool deepEquals(const IValue& lhs, const IValue& rhs) {
-  // only check tensors for now
-  if (!lhs.isTensor() || !rhs.isTensor()) {
+  if (lhs.isInt() && rhs.isInt()) {
+    return lhs.toInt() == rhs.toInt();
+  } else if (lhs.isDouble() && rhs.isDouble()) {
+    return lhs.toDouble() == rhs.toDouble();
+  } else if (lhs.isNone() && rhs.isNone()) {
     return true;
+  } else if (lhs.isIntList() && rhs.isIntList()) {
+    return lhs.toIntList()->elements() == rhs.toIntList()->elements();
+  } else if (lhs.isTensor() && rhs.isTensor()) {
+    return lhs.toTensor().equal(rhs.toTensor());
   }
 
-  return lhs.toTensor().equal(rhs.toTensor());
+  throw std::runtime_error("Deep equals not implemented for type");
 }
 
 struct AliasAndIValue {
@@ -70,8 +77,8 @@ struct AliasAndIValue {
       const IValue& iValue)
       : aliasInfo(aliasInfo), iValue(iValue) {}
 
-  const c10::optional<at::AliasInfo>& aliasInfo;
-  const IValue& iValue;
+  const c10::optional<at::AliasInfo> aliasInfo;
+  const IValue iValue;
 };
 
 // No inputs should alias each other