From ea0638886d288a00450659134207446c06899d7e Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Wed, 8 Apr 2020 14:20:16 -0700 Subject: [PATCH] [BUGFIX][IR] Fix String SEqual (#5275) * fix String SEqual * retrigger ci --- src/node/container.cc | 4 ++-- tests/python/relay/test_ir_structural_equal_hash.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/node/container.cc b/src/node/container.cc index 0cd093e..8fff151 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -43,8 +43,8 @@ struct StringObjTrait { SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->size != rhs->size) return false; - if (lhs->data != rhs->data) return true; - return std::memcmp(lhs->data, rhs->data, lhs->size) != 0; + if (lhs->data == rhs->data) return true; + return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; } }; diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index 5295e17..271960e 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -356,7 +356,7 @@ def test_function_attr(): p00 = relay.subtract(z00, w01) q00 = relay.multiply(p00, w02) func0 = relay.Function([x0, w00, w01, w02], q00) - func0 = func0.with_attr("FuncName", tvm.tir.StringImm("a")) + func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a")) x1 = relay.var('x1', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10)) @@ -366,7 +366,7 @@ def test_function_attr(): p10 = relay.subtract(z10, w11) q10 = relay.multiply(p10, w12) func1 = relay.Function([x1, w10, w11, w12], q10) - func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b")) + func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b")) assert not consistent_equal(func0, func1) @@ -698,7 +698,7 @@ def test_fn_attribute(): d = relay.var('d', shape=(10, 10)) add_1 = relay.add(c, d) add_1_fn = relay.Function([c, d], add_1) - add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test")) + add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test")) add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) assert not consistent_equal(add_1_fn, add_fn) -- 2.7.4