[Static Runtime] Add native op for aten::expand_as (#64024)
authorDon Jang <djang@fb.com>
Thu, 26 Aug 2021 19:58:05 +0000 (12:58 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 20:05:53 +0000 (13:05 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64024

`aten::expand_as` creates a view of the input tensor. This change adds its native op implementation for the static runtime.

Test Plan: - Added `StaticRuntime.IndividualOps_ExpandAs`

Reviewed By: hlu1

Differential Revision: D30546851

fbshipit-source-id: e53483048af890bc41b6192a1ab0c5ba0ee2bdc0

benchmarks/static_runtime/test_scripts.h
benchmarks/static_runtime/test_static_runtime.cc
torch/csrc/jit/runtime/static/native_ops.cpp

index 90f93b2..ecdd491 100644 (file)
@@ -349,6 +349,12 @@ const std::string embedding_bag_max_last_offset = R"JIT(
       return torch.embedding_bag(a, b, c, False, 2, False, None, True)
 )JIT";
 
+const auto expand_as_script = R"JIT(
+  def forward(self, input: Tensor, other:Tensor):
+      a = input.expand_as(other)
+      return a.clone()
+)JIT";
+
 const auto sign_tensor = R"JIT(
   def forward(self, input: Tensor):
       return torch.sign(input).clone()
index f6ec677..4441b7d 100644 (file)
@@ -610,6 +610,17 @@ TEST(StaticRuntime, IndividualOps_Detach) {
   testStaticRuntime(detach_script_1, args, args2);
 }
 
+TEST(StaticRuntime, IndividualOps_ExpandAs) {
+  auto a = at::randn({3,1});
+  auto b = at::randn({3,2});
+  auto c = at::randn({4,1});
+  auto d = at::randn({4,2});
+  std::vector<IValue> args{a, b};
+  std::vector<IValue> args2{c, d};
+  testStaticRuntime(expand_as_script, args);
+  testStaticRuntime(expand_as_script, args, args2);
+}
+
 TEST(StaticRuntime, IndividualOps_Full) {
   auto dtype = at::ScalarType::Int;
   auto cpu = at::Device(DeviceType::CPU);
index 61a6554..7a1558d 100644 (file)
@@ -371,6 +371,22 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
     });
 
 REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    aten::expand_as,
+    aten_expand_as,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& self = p_node->Input(0).toTensor();
+        const auto& other = p_node->Input(1).toTensor();
+        p_node->Output(0) = self.expand(other.sizes());
+      };
+    });
+
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
     prim::isinstance,
     prim_isinstance,
     [](Node* n) -> SROperator {