};
});
-REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROperator {
- if (!n->matches(torch::schema(
- "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- // ignore Input(5): `bool cudnn_enable=True`
- const auto& input = p_node->Input(0).toTensor();
- const auto normalized_shape = p_node->Input(1).toIntVector();
- auto weight_opt = p_node->Input(2).toOptional<at::Tensor>();
- auto bias_opt = p_node->Input(3).toOptional<at::Tensor>();
- float eps = p_node->Input(4).toDouble();
-
- c10::MaybeOwned<at::Tensor> weight_maybe_owned =
- at::borrow_from_optional_tensor(weight_opt);
- const at::Tensor& weight = *weight_maybe_owned;
- c10::MaybeOwned<at::Tensor> bias_maybe_owned =
- at::borrow_from_optional_tensor(bias_opt);
- const at::Tensor& bias = *bias_maybe_owned;
-
- auto M_N = at::native::_check_layer_norm_inputs(
- input, normalized_shape, weight, bias);
- auto M = M_N.first;
- auto N = M_N.second;
- auto X = input.expect_contiguous();
- auto gamma = weight.expect_contiguous();
- auto beta = bias.expect_contiguous();
-
- if (p_node->Output(0).isNone()) {
- p_node->Output(0) = at::native::empty_like(
- *X,
- c10::nullopt /* dtype */,
- c10::nullopt /* layout */,
- c10::nullopt /* device */,
- c10::nullopt /* pin_memory */,
- at::MemoryFormat::Contiguous);
- } else {
- at::native::resize_(
- p_node->Output(0).toTensor(), X->sizes(), c10::nullopt);
- }
- at::Tensor& output = p_node->Output(0).toTensor();
- at::Tensor mean = create_empty_from({M}, *X);
- at::Tensor rstd = create_empty_from({M}, *X);
+REGISTER_OPERATOR_FUNCTOR(
+ static_runtime::layer_norm,
+ aten_layer_norm,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor,Tensor,Tensor)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ // ignore Input(5): `bool cudnn_enable=True`
+ const auto& input = p_node->Input(0).toTensor();
+ const auto normalized_shape = p_node->Input(1).toIntVector();
+ auto weight_opt = p_node->Input(2).toOptional<at::Tensor>();
+ auto bias_opt = p_node->Input(3).toOptional<at::Tensor>();
+ float eps = p_node->Input(4).toDouble();
+
+ c10::MaybeOwned<at::Tensor> weight_maybe_owned =
+ at::borrow_from_optional_tensor(weight_opt);
+ const at::Tensor& weight = *weight_maybe_owned;
+ c10::MaybeOwned<at::Tensor> bias_maybe_owned =
+ at::borrow_from_optional_tensor(bias_opt);
+ const at::Tensor& bias = *bias_maybe_owned;
+
+ auto M_N = at::native::_check_layer_norm_inputs(
+ input, normalized_shape, weight, bias);
+ auto M = M_N.first;
+ auto N = M_N.second;
+ auto X = input.expect_contiguous();
+ auto gamma = weight.expect_contiguous();
+ auto beta = bias.expect_contiguous();
- at::native::layer_norm_cpu_out(
- output, mean, rstd, input, normalized_shape, *gamma, *beta, eps, M, N);
- };
-});
+ if (p_node->Output(0).isNone()) {
+ p_node->Output(0) = at::native::empty_like(
+ *X,
+ c10::nullopt /* dtype */,
+ c10::nullopt /* layout */,
+ c10::nullopt /* device */,
+ c10::nullopt /* pin_memory */,
+ at::MemoryFormat::Contiguous);
+ } else {
+ at::native::resize_(
+ p_node->Output(0).toTensor(), X->sizes(), c10::nullopt);
+ }
+ if (p_node->Output(1).isNone()) {
+ p_node->Output(1) = create_empty_from({M}, *X);
+ } else {
+ at::native::resize_(p_node->Output(1).toTensor(), {M}, c10::nullopt);
+ }
+ if (p_node->Output(2).isNone()) {
+ p_node->Output(2) = create_empty_from({M}, *X);
+ } else {
+ at::native::resize_(p_node->Output(2).toTensor(), {M}, c10::nullopt);
+ }
+ at::Tensor& output = p_node->Output(0).toTensor();
+ at::Tensor mean = p_node->Output(1).toTensor();
+ at::Tensor rstd = p_node->Output(2).toTensor();
+ at::native::layer_norm_cpu_out(
+ output,
+ mean,
+ rstd,
+ input,
+ normalized_shape,
+ *gamma,
+ *beta,
+ eps,
+ M,
+ N);
+ };
+ });
REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
m.def(
"static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
+ m.def(torch::schema(
+ "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)",
+ c10::AliasAnalysisKind::PURE_FUNCTION));
}
bool HasInplaceOp(std::shared_ptr<Graph>& graph, const AliasDb& alias_db) {
#endif
}
+void EnableStaticRuntimeLayerNorm(std::shared_ptr<torch::jit::Graph>& graph) {
+ const c10::Symbol static_runtime_layer_norm_symbol =
+ c10::Symbol::fromQualString("static_runtime::layer_norm");
+ auto nodes = graph->nodes();
+ std::vector<std::pair<Node*, Node*>> replacement;
+ for (auto it = nodes.begin(); it != nodes.end(); ++it) {
+ Node* old_node = *it;
+ if (!old_node->matches(torch::schema(
+ "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) {
+ continue;
+ }
+ TORCH_CHECK(old_node->outputs().size() == 1);
+ auto* new_node = graph->create(
+ static_runtime_layer_norm_symbol,
+ /*layer_norm*/ 1 + /*mean*/ 1 + /*rst=*/1);
+ new_node->insertBefore(old_node);
+ for (auto* input : old_node->inputs()) {
+ new_node->addInput(input);
+ }
+ replacement.emplace_back(old_node, new_node);
+ }
+ for (const auto& p : replacement) {
+ auto* old_node = p.first;
+ auto* new_node = p.second;
+ new_node->output(0)->copyMetadata(old_node->output(0));
+ old_node->output(0)->replaceAllUsesWith(new_node->output(0));
+ old_node->destroy();
+ }
+}
+
} // namespace jit
} // namespace torch