* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
CHECK_EQ(call_node->args.size(), 1U)
<< "Only allow function with a single tuple input";
}
+
+ // Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
+ // Int32. Following code ensures the same for the output as well.
+ // TODO(@icemelon): Support recursive tuple
+ Type call_node_type = call_node->checked_type();
+ if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
+ call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
+ } else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
+ std::vector<Type> new_fields;
+ for (auto field : tuple_t->fields) {
+ if (const auto* tt = field.as<TensorTypeNode>()) {
+ new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
+ } else {
+ new_fields.push_back(field);
+ }
+ }
+ call_node_type = TupleTypeNode::make(new_fields);
+ }
+
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
- call_node->checked_type(), target_);
+ call_node_type, target_);
}
int op_pattern = fpattern[op];
relay.build(relay.Module.from_expr(f), 'llvm')
+def test_compile_full():
+ # Shape calculations can happen in int64. The test checks that full operator
+ # can handle when shapes are not int32
+ shape = (tvm.expr.IntImm('int32', 1),
+ tvm.expr.IntImm('int64', 16),
+ tvm.expr.IntImm('int64', 16),
+ tvm.expr.IntImm('int32', 64))
+ output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
+ f = relay.Function([], output)
+ mod = relay.Module.from_expr(f)
+ mod = relay.qnn.transform.CanonicalizeOps()(mod)
+ relay.build(mod, 'llvm')
+
+
if __name__ == "__main__":
test_compile_engine()
test_compile_placeholder_bypass()
test_compile_injective_with_tuple()
test_compile_tuple_dup()
+ test_compile_full()