[Relay][Compile_engine] Int64 shape handling for outputs. (#4031)
authorAnimesh Jain <anijain@umich.edu>
Mon, 30 Sep 2019 17:06:35 +0000 (10:06 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 30 Sep 2019 17:06:35 +0000 (10:06 -0700)
src/relay/backend/compile_engine.cc
tests/python/relay/test_backend_compile_engine.py

index c88703e..a75cdb2 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -219,6 +219,25 @@ class ScheduleGetter :
       CHECK_EQ(call_node->args.size(), 1U)
           << "Only allow function with a single tuple input";
     }
+
+    // Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
+    // Int32. Following code ensures the same for the output as well.
+    // TODO(@icemelon): Support recursive tuple
+    Type call_node_type = call_node->checked_type();
+    if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
+      call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
+    } else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
+      std::vector<Type> new_fields;
+      for (auto field : tuple_t->fields) {
+        if (const auto* tt = field.as<TensorTypeNode>()) {
+          new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
+        } else {
+          new_fields.push_back(field);
+        }
+      }
+      call_node_type = TupleTypeNode::make(new_fields);
+    }
+
     CHECK(call_node->op.as<OpNode>())
         << "Primitive function only allows call into primitive ops";
     Op op = Downcast<Op>(call_node->op);
@@ -232,7 +251,7 @@ class ScheduleGetter :
                                          Operation(), 0));
     } else {
       outputs = fcompute[op](call_node->attrs, inputs,
-                             call_node->checked_type(), target_);
+                             call_node_type, target_);
     }
 
     int op_pattern = fpattern[op];
index ea16a8d..b1f41a4 100644 (file)
@@ -79,8 +79,23 @@ def test_compile_tuple_dup():
     relay.build(relay.Module.from_expr(f), 'llvm')
 
 
+def test_compile_full():
+    # Shape calculations can happen in int64. The test checks that full operator
+    # can handle when shapes are not int32
+    shape = (tvm.expr.IntImm('int32', 1),
+             tvm.expr.IntImm('int64', 16),
+             tvm.expr.IntImm('int64', 16),
+             tvm.expr.IntImm('int32', 64))
+    output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
+    f = relay.Function([], output)
+    mod = relay.Module.from_expr(f)
+    mod = relay.qnn.transform.CanonicalizeOps()(mod)
+    relay.build(mod, 'llvm')
+
+
 if __name__ == "__main__":
     test_compile_engine()
     test_compile_placeholder_bypass()
     test_compile_injective_with_tuple()
     test_compile_tuple_dup()
+    test_compile_full()