[RELAY] Fix segfault in pretty print when ObjectRef is null (#5681)
authorlhutton1 <35535092+lhutton1@users.noreply.github.com>
Fri, 29 May 2020 14:51:30 +0000 (15:51 +0100)
committerGitHub <noreply@github.com>
Fri, 29 May 2020 14:51:30 +0000 (07:51 -0700)
* [RELAY] Fix segfault in pretty print when ObjectRef is null

Encountered when pretty printing module with function attribute equal to NullValue<ObjectRef>().

Change-Id: I2e7b304859f03038730ba9c3b9db41ebd3e1fbb5

* Add test case

Change-Id: I579b20da3f5d49054823392be80aaf78a055f596

src/printer/relay_text_printer.cc
tests/python/relay/test_ir_text_printer.py

index 076339d..ad16f86 100644 (file)
@@ -91,7 +91,8 @@ Doc RelayTextPrinter::PrintScope(const ObjectRef& node) {
 }
 
 Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
-  if (node->IsInstance<BaseFuncNode>() && !node->IsInstance<relay::FunctionNode>()) {
+  if (node.defined() && node->IsInstance<BaseFuncNode>() &&
+      !node->IsInstance<relay::FunctionNode>()) {
     // Temporarily skip non-relay functions.
     // TODO(tvm-team) enhance the code to work for all functions
   } else if (node.as<ExprNode>()) {
@@ -105,8 +106,8 @@ Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
 }
 
 Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
-  bool is_non_relay_func =
-      node->IsInstance<BaseFuncNode>() && !node->IsInstance<relay::FunctionNode>();
+  bool is_non_relay_func = node.defined() && node->IsInstance<BaseFuncNode>() &&
+                           !node->IsInstance<relay::FunctionNode>();
   if (node.as<ExprNode>() && !is_non_relay_func) {
     return PrintExpr(Downcast<Expr>(node), meta, try_inline);
   } else if (node.as<TypeNode>()) {
index 61dbca3..2a88c0c 100644 (file)
@@ -240,6 +240,15 @@ def @main[A]() -> fn (A, List[A]) -> List[A] {
     assert main_def_str.strip() in mod_str
 
 
+def test_null_attribute():
+    x = relay.var("x")
+    y = relay.var("y")
+    z = relay.Function([x], y)
+    z = z.with_attr("TestAttribute", None)
+    txt = astext(z)
+    assert "TestAttribute=(nullptr)" in txt
+
+
 if __name__ == "__main__":
     do_print[0] = True
     test_lstm()
@@ -262,3 +271,4 @@ if __name__ == "__main__":
     test_variable_name()
     test_call_node_order()
     test_unapplied_constructor()
+    test_null_attribute()