Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64024
`aten::expand_as` creates a view of the input tensor. This change adds its native op implementation for the static runtime.
Test Plan: - Added `StaticRuntime.IndividualOps_ExpandAs`
Reviewed By: hlu1
Differential Revision:
D30546851
fbshipit-source-id:
e53483048af890bc41b6192a1ab0c5ba0ee2bdc0
return torch.embedding_bag(a, b, c, False, 2, False, None, True)
)JIT";
+const auto expand_as_script = R"JIT(
+ def forward(self, input: Tensor, other:Tensor):
+ a = input.expand_as(other)
+ return a.clone()
+)JIT";
+
const auto sign_tensor = R"JIT(
def forward(self, input: Tensor):
return torch.sign(input).clone()
testStaticRuntime(detach_script_1, args, args2);
}
+TEST(StaticRuntime, IndividualOps_ExpandAs) {
+ auto a = at::randn({3,1});
+ auto b = at::randn({3,2});
+ auto c = at::randn({4,1});
+ auto d = at::randn({4,2});
+ std::vector<IValue> args{a, b};
+ std::vector<IValue> args2{c, d};
+ testStaticRuntime(expand_as_script, args);
+ testStaticRuntime(expand_as_script, args, args2);
+}
+
TEST(StaticRuntime, IndividualOps_Full) {
auto dtype = at::ScalarType::Int;
auto cpu = at::Device(DeviceType::CPU);
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ aten::expand_as,
+ aten_expand_as,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& self = p_node->Input(0).toTensor();
+ const auto& other = p_node->Input(1).toTensor();
+ p_node->Output(0) = self.expand(other.sizes());
+ };
+ });
+
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::isinstance,
prim_isinstance,
[](Node* n) -> SROperator {