PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(IRModule mod, DLContext context, Target target)
- : mod_(mod),
- context_(context),
- target_(target),
- debug_op_(Op::Get("debug")),
- shape_of_op_(Op::Get("shape_of")) {
+ : mod_(mod), context_(context), target_(target), debug_op_(Op::Get("debug")) {
engine_ = CompileEngine::Global();
}
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
- bool is_dyn = IsDynamic(func->checked_type());
- if (call_node->op == shape_of_op_) {
- // The output shape of shape_of must be static since Relay doesn't support
- // dynamic rank tensors.
- is_dyn = false;
- }
+ bool is_dyn = IsDynamic(ret_type);
if (is_dyn) {
CHECK(func->HasNonzeroAttr(attr::kPrimitive));
CompileEngine engine_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
- const Op& shape_of_op_;
};
TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLContext context, Target target) {
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))
-def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
+def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
layout, static_boxes, static_box_indices_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
static_data_shape=(1, 256, 232, 232),
ref_out_shape=(1, 256, 234, 234))
+def verify_any_ndarray_size(data_np_shape):
+ v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32')
+ n = relay.ndarray_size(v, dtype='int32')
+ mod = tvm.IRModule()
+ mod['main'] = relay.Function([v], n)
+ np_data = np.zeros(data_np_shape, dtype='float32')
+ ref_res = np.size(np_data)
+
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(np_data)
+ tvm.testing.assert_allclose(result.asnumpy(), ref_res)
+
+def test_any_ndarray_size():
+ verify_any_ndarray_size((2,))
+ verify_any_ndarray_size((2, 2))
+ verify_any_ndarray_size((1, 2, 3, 4))
+
if __name__ == "__main__":
test_any_full()
test_any_full_like()
test_mixed_input_type()
test_any_crop_and_resize()
test_any_mirror_pad()
-
+ test_any_ndarray_size()