From cbfec02007775d96139d8a1b9d9f8a44fcede31c Mon Sep 17 00:00:00 2001 From: Don Jang Date: Thu, 26 Aug 2021 12:58:05 -0700 Subject: [PATCH] [Static Runtime] Add native op for aten::expand_as (#64024) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64024 `aten::expand_as` creates a view of the input tensor. This change adds its native op implementation for the static runtime. Test Plan: - Added `StaticRuntime.IndividualOps_ExpandAs` Reviewed By: hlu1 Differential Revision: D30546851 fbshipit-source-id: e53483048af890bc41b6192a1ab0c5ba0ee2bdc0 --- benchmarks/static_runtime/test_scripts.h | 6 ++++++ benchmarks/static_runtime/test_static_runtime.cc | 11 +++++++++++ torch/csrc/jit/runtime/static/native_ops.cpp | 16 ++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 90f93b2..ecdd491 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -349,6 +349,12 @@ const std::string embedding_bag_max_last_offset = R"JIT( return torch.embedding_bag(a, b, c, False, 2, False, None, True) )JIT"; +const auto expand_as_script = R"JIT( + def forward(self, input: Tensor, other:Tensor): + a = input.expand_as(other) + return a.clone() +)JIT"; + const auto sign_tensor = R"JIT( def forward(self, input: Tensor): return torch.sign(input).clone() diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index f6ec677..4441b7d 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -610,6 +610,17 @@ TEST(StaticRuntime, IndividualOps_Detach) { testStaticRuntime(detach_script_1, args, args2); } +TEST(StaticRuntime, IndividualOps_ExpandAs) { + auto a = at::randn({3,1}); + auto b = at::randn({3,2}); + auto c = at::randn({4,1}); + auto d = at::randn({4,2}); + std::vector args{a, b}; + std::vector args2{c, d}; + testStaticRuntime(expand_as_script, args); + testStaticRuntime(expand_as_script, args, args2); +} + TEST(StaticRuntime, IndividualOps_Full) { auto dtype = at::ScalarType::Int; auto cpu = at::Device(DeviceType::CPU); diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 61a6554..7a1558d 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -371,6 +371,22 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( }); REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::expand_as, + aten_expand_as, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto& other = p_node->Input(1).toTensor(); + p_node->Output(0) = self.expand(other.sizes()); + }; + }); + +REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::isinstance, prim_isinstance, [](Node* n) -> SROperator { -- 2.7.4