%res: Tensor = aten::clone(%output, %none)
return (%res)
)IR";
+
+const auto linalg_norm_ord_scalar = R"JIT(
+ def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int):
+ return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
+)JIT";
+
+const auto linalg_norm_ord_str = R"JIT(
+ def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
+ return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
+)JIT";
testStaticRuntime(embedding_bag_byte_prepack_script, {a});
testStaticRuntime(embedding_bag_byte_prepack_script, {a},{b});
}
+
+TEST(StaticRuntime, IndividualOps_LinalgNorm_ScalarOrd) {
+ auto a = at::randn({2, 3});
+ auto dim = std::vector<int64_t>({1});
+ auto dtype = at::ScalarType::Float;
+
+ std::vector<IValue> args0{a, 4, dim, true, dtype};
+ testStaticRuntime(linalg_norm_ord_scalar, args0);
+
+ auto b = at::randn({4, 5});
+ std::vector<IValue> args1{b, 4, dim, true, dtype};
+ testStaticRuntime(linalg_norm_ord_scalar, args0, args1);
+}
+
+TEST(StaticRuntime, IndividualOps_LinalgNorm_StringOrd) {
+ auto a = at::randn({2, 3});
+ auto dim = std::vector<int64_t>({0, 1});
+ auto dtype = at::ScalarType::Float;
+
+ std::vector<IValue> args0{a, "fro", dim, true, dtype};
+ testStaticRuntime(linalg_norm_ord_str, args0);
+
+ auto b = at::randn({4, 5});
+ std::vector<IValue> args1{b, "fro", dim, true, dtype};
+ testStaticRuntime(linalg_norm_ord_str, args0, args1);
+}
};
});
+REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor")) &&
+ !n->matches(torch::schema(
+ "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& input = p_node->Input(0).toTensor();
+ const auto dim = p_node->Input(2).toIntVector();
+ const auto keepdim = p_node->Input(3).toBool();
+ const auto dtype = p_node->Input(4).toOptional<c10::ScalarType>();
+
+ if (p_node->Output(0).isNone()) {
+ if (p_node->Input(1).isScalar()) {
+ p_node->Output(0) = at::native::linalg_norm(
+ input,
+ p_node->Input(1).toOptional<at::Scalar>(),
+ dim,
+ keepdim,
+ dtype);
+ } else {
+ p_node->Output(0) = at::native::linalg_norm(
+ input, p_node->Input(1).toStringView(), dim, keepdim, dtype);
+ }
+ return;
+ }
+
+ auto& output = p_node->Output(0).toTensor();
+ fastResizeToZero(output);
+
+ if (p_node->Input(1).isScalar()) {
+ at::native::linalg_norm_out(
+ input,
+ p_node->Input(1).toOptional<at::Scalar>(),
+ dim,
+ keepdim,
+ dtype,
+ output);
+ } else {
+ at::native::linalg_norm_out(
+ input, p_node->Input(1).toStringRef(), dim, keepdim, dtype, output);
+ }
+ };
+});
+
namespace {
void check_cat_no_zero_dim(const std::vector<at::Tensor>& tensors) {