shape_func_op_(Op::Get("vm.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
alloc_storage_op_(Op::Get("memory.alloc_storage")),
- cast_op_(Op::Get("cast")) {}
+ cast_op_(Op::Get("cast")),
+ ndarray_size_op_(Op::Get("ndarray_size")) {}
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
return EvaluateShapeOf(res, origin_args, call->attrs);
}
+ if (call->op == ndarray_size_op_) {
+ return EvaluateNdarraySize(res, origin_args, call->attrs);
+ }
+
// We should think about potentially constant evaluation over these ops too.
if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ ||
call->op == alloc_storage_op_) {
const Op& alloc_tensor_op_;
const Op& alloc_storage_op_;
const Op& cast_op_;
+ const Op& ndarray_size_op_;
// Convert value to expression.
Expr ObjectToExpr(const ObjectRef& value) {
CHECK(param != nullptr);
tvm::Array<IndexExpr> ishape;
- if (const ConstantNode* op = input.as<ConstantNode>()) {
- ishape = op->tensor_type()->shape;
- } else if (input->checked_type_.defined()) {
- ishape = input->checked_type().as<TensorTypeNode>()->shape;
+ if (auto opt = GetConstantShape(input)) {
+ ishape = opt.value();
} else {
return expr;
}
shape = Constant(ndarray);
}
+ return CastValue(shape, param->dtype);
+ }
+
+ // Evaluate a call to the ndarray_size operator for tensors with constant
+ // shapes.
+ Expr EvaluateNdarraySize(Expr expr, Array<Expr> args, Attrs attrs) {
+ Expr input = args[0];
+ const auto* param = attrs.as<NdarraySizeAttrs>();
+ CHECK(param != nullptr);
+
+ tvm::Array<IndexExpr> ishape;
+ if (auto opt = GetConstantShape(input)) {
+ ishape = opt.value();
+ } else {
+ return expr;
+ }
+
+ // Get the constant size
+ DLContext ctx;
+ ctx.device_type = kDLCPU;
+ ctx.device_id = 0;
+ runtime::NDArray value;
+ DLDataType cdtype = DataType::Int(32);
+ value = runtime::NDArray::Empty({1}, cdtype, ctx);
+ int32_t* data = static_cast<int32_t*>(value->data);
+ if (ishape.size() == 0) {
+ *data = 0;
+ } else {
+ *data = 1;
+ using ::tvm::tir::IntImmNode;
+ for (size_t i = 0; i < ishape.size(); ++i) {
+ if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
+ *data *= dim->value;
+ } else {
+ return expr;
+ }
+ }
+ }
+
+ Constant size = Downcast<Constant>(ObjectToExpr(value));
+ return CastValue(size, param->dtype);
+ }
+
+ Expr CastValue(const Expr& value, DataType dtype) {
// Cast the constant into correct dtype
auto cast_attrs = make_object<CastAttrs>();
- cast_attrs->dtype = param->dtype;
- Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {});
+ cast_attrs->dtype = dtype;
+ Expr ret = Call(cast_op_, {value}, Attrs(cast_attrs), {});
return ConstEvaluate(ret);
}
+
+ Optional<tvm::Array<IndexExpr>> GetConstantShape(const Expr& input) {
+ tvm::Array<IndexExpr> ishape;
+ if (const ConstantNode* op = input.as<ConstantNode>()) {
+ ishape = op->tensor_type()->shape;
+ } else if (input->checked_type_.defined()) {
+ ishape = input->checked_type().as<TensorTypeNode>()->shape;
+ } else {
+ return Optional<tvm::Array<IndexExpr>>(nullptr);
+ }
+
+ return Optional<tvm::Array<IndexExpr>>(ishape);
+ }
};
Expr FoldConstant(const Expr& expr, const IRModule& mod) {
assert tvm.ir.structural_equal(zz, zexpected)
+def test_fold_ndarray_size():
+ c_shape = (8, 9, 10)
+ def before(dtype):
+ x = relay.var("x", shape=c_shape, dtype="float32")
+ y = relay.var("y", shape=c_shape, dtype="float32")
+ z = relay.ndarray_size(x + y, dtype)
+ return relay.Function([x, y], z)
+
+ def expected(dtype):
+ x = relay.var("x", shape=c_shape, dtype="float32")
+ y = relay.var("y", shape=c_shape, dtype="float32")
+ z = relay.const([np.size(np.zeros(c_shape))], dtype=dtype)
+ func = relay.Function([x, y], z)
+ return func
+
+ for dtype in ["int32", "float32"]:
+ zz = run_opt_pass(before(dtype), transform.FoldConstant())
+ zexpected = run_opt_pass(expected(dtype), transform.InferType())
+ assert tvm.ir.structural_equal(zz, zexpected)
+
+
def test_fold_full():
c_shape = (8, 9, 10)
def before():
test_fold_shape_of()
test_fold_full()
test_fold_batch_norm()
+ test_fold_ndarray_size()