From: Wanchao Liang Date: Fri, 12 Apr 2019 21:24:37 +0000 (-0700) Subject: JIT Layernorm fusion (#18266) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~248 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a3d3008e7377002c4d2015d73f6e047dbf21545c;p=platform%2Fupstream%2Fpytorch.git JIT Layernorm fusion (#18266) Summary: Partially fuse layer_norm by decomposing layer_norm into the batchnorm kernel that computes the stats, and then fusing the affine operations after the reduce operations, this is similar to the batchnorm fusion that apaszke did, it also only works in inference mode now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18266 Differential Revision: D14879877 Pulled By: wanchaol fbshipit-source-id: 0197d8f2a17ec438d3e53f4c411d759c1ae81efe --- diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 9f45947..f19398b 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -428,43 +428,56 @@ class TestFuser(JitTestCase): @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm - def test_fuse_batch_norm(self): - + def test_fuse_decompose_normalization(self): class ResLike(torch.jit.ScriptModule): - def __init__(self, optimize=True): + def __init__(self, norm_module, optimize=True): super(ResLike, self).__init__(optimize) - self.bn = nn.BatchNorm2d(16) + self.nm = norm_module @torch.jit.script_method def forward(self, x, y): - return y + torch.relu(self.bn(x)) - - model = ResLike().cuda() - model_noopt = ResLike(optimize=False).cuda() - model_noopt.load_state_dict(model.state_dict()) - x = torch.randn(2, 16, 8, 8, device='cuda') - y = torch.randn(2, 16, 8, 8, device='cuda') - # FIXME: We need differentiation for CNNs for this optimization to trigger - with torch.no_grad(): - out = model(x, y) - graph = model.graph_for(x, y) - rep = str(graph) - - out_noopt = model_noopt(x, y) - rep_noopt = str(model_noopt.graph_for(x, y)) - self.assertEqual(out, out_noopt, prec=3e-5) - - # Check that batch_norm has really been decomposed - self.assertIn('aten::batch_norm_update_stats', rep) - self.assertNotIn('aten::batch_norm(', rep) - self.assertIn('aten::batch_norm(', rep_noopt) - - # Make sure the fusion group is big, and contains aten::sqrt, which could - # originate only from decomposing batch_norm in this case - fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup'] - self.assertEqual(len(fusion_groups), 1) - fused_graph = fusion_groups[0].g('Subgraph') - self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes())) + return y + torch.relu(self.nm(x)) + + def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph): + model = ResLike(nm).cuda() + model_noopt = ResLike(nm, optimize=False).cuda() + model_noopt.load_state_dict(model.state_dict()) + x = torch.randn(2, 16, 8, 8, device='cuda') + y = torch.randn(2, 16, 8, 8, device='cuda') + + # FIXME: We need differentiation for CNNs for this optimization to trigger + with torch.no_grad(): + out = model(x, y) + graph = model.graph_for(x, y) + rep = str(graph) + + out_noopt = model_noopt(x, y) + rep_noopt = str(model_noopt.graph_for(x, y)) + self.assertEqual(out, out_noopt, prec=3e-5) + + # Check that normalization op has really been decomposed + for node_in_graph in in_opt_graph: + self.assertIn(node_in_graph, rep) + + for node_not_in_graph in not_in_opt_graph: + self.assertNotIn(node_not_in_graph, rep) + self.assertIn(node_not_in_graph, rep_noopt) + + fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup'] + self.assertEqual(len(fusion_groups), 1) + fused_graph = str(fusion_groups[0].g('Subgraph')) + for node_in_fusegraph in in_fusegraph: + self.assertIn(node_in_fusegraph, fused_graph) + + # test for batchnorm decompose + bm = nn.BatchNorm2d(16) + test_norm_decompose(bm, ['aten::batch_norm_update_stats'], + ['aten::batch_norm('], ['aten::sqrt']) + + # test for layernorm decompose + lm = nn.LayerNorm(8) + test_norm_decompose(lm, ['aten::batch_norm_stats'], + ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::addcmul']) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index e7d5f94..20a3373 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -109,7 +109,8 @@ bool isSimpleMap(Node* node) { return false; } for (Value* input : node->inputs()) { - if (input->type()->isSubtypeOf(TensorType::get()) || input->type()->isSubtypeOf(FloatType::get())) { + if (input->type()->isSubtypeOf(TensorType::get()) || + input->type()->isSubtypeOf(FloatType::get())) { continue; } if (input->node()->kind() != prim::Constant) { @@ -133,6 +134,23 @@ RegisterOperators reg_bn_unsqueeze({Operator( }; })}); +RegisterOperators reg_ln_view({Operator( + "aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor", + [](const Node* node) { + return [](Stack& stack) { + const int64_t normalized_ndim = pop(stack).toInt(); + auto input_shape = pop(stack).toIntListRef(); + auto self = pop(stack).toTensor(); + const int64_t input_ndim = input_shape.size(); + c10::SmallVector sizes(input_ndim, 1); + for (int i = 0; i < input_ndim - normalized_ndim; ++i) { + sizes.at(i) = input_shape[i]; + } + push(stack, self.reshape(sizes)); + return 0; + }; + })}); + // Yes, no, or no value if we can't tell c10::optional isDefined(Value* tensor) { if (tensor->type()->isSubtypeOf(TensorType::get())) { @@ -144,16 +162,20 @@ c10::optional isDefined(Value* tensor) { return {}; } -bool isFusableBatchNorm(Node* batch_norm) { - if (!batch_norm->matches( - "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) { - return false; +bool isFusableNorm(Node* normalize_op) { + static const OperatorSet decomposable_normalization_ops = { + "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor", + }; + + if (decomposable_normalization_ops.find(normalize_op)) { + // If we can't determine if weight and bias is defined statically there's + // really no point in decomposing normalization into simpler ops, since it + // won't get fused into a single kernel. + return isDefined(normalize_op->namedInput(attr::weight)).has_value() && + isDefined(normalize_op->namedInput(attr::bias)).has_value(); } - // If we can't determine if weight and bias is defined statically there's - // really no point in decomposing batch norm into simpler ops, since it won't - // get fused into a single kernel. - return isDefined(batch_norm->namedInput(attr::weight)).has_value() && - isDefined(batch_norm->namedInput(attr::bias)).has_value(); + return false; } Value* broadcastSizes(at::ArrayRef sizes) { @@ -187,7 +209,7 @@ struct GraphFuser { } bool isFusable(Node* node) { - return isFusableMap(node) || isFusableBatchNorm(node); + return isFusableMap(node) || isFusableNorm(node); } bool isFusableMap(Node* node) { @@ -249,13 +271,35 @@ struct GraphFuser { return *n->g(attr::Subgraph); } - void decomposeBatchNorm(Node* batch_norm) { - static std::shared_ptr bn_graph; - static std::once_flag flag; + Value* decomposeCommonNormalization( + Node* normalization_op, + const char* source, + const std::string& method_name, + const std::vector& inputs) { + std::shared_ptr nm_graph; + std::once_flag flag; std::call_once( flag, - [](std::shared_ptr* graph_ptr) { - static const char* source = R"SCRIPT( + [](std::shared_ptr* graph_ptr, + const char* source, + const std::string& method_name) { + script::CompilationUnit cu; + cu.define(source, script::nativeResolver, nullptr); + *graph_ptr = cu.get_function(method_name).graph(); + }, + &nm_graph, + source, + method_name); + + AT_ASSERT(isFusableNorm(normalization_op)); + WithInsertPoint insert_guard{normalization_op}; + Value* new_output = + SubgraphUtils::inlineGraph(nm_graph, inputs, normalization_op).at(0); + return new_output; + } + + void decomposeNormalizationOps(Node* normalization_op) { + static const char* bm_source = R"SCRIPT( def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor: if training: norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum) @@ -264,41 +308,74 @@ struct GraphFuser { norm_var = torch._unwrap_optional(running_var) norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim()) norm_var = torch._ncf_unsqueeze(norm_var, input.dim()) - norm_invstd = 1 / (eps + torch.sqrt(norm_var)) + norm_invstd = 1 / (torch.sqrt(norm_var + eps)) return ((input - norm_mean) * norm_invstd) )SCRIPT"; - script::CompilationUnit cu; - cu.define(source, script::nativeResolver, nullptr); - *graph_ptr = cu.get_function("batch_norm").graph(); - }, - &bn_graph); - - AT_ASSERT(isFusableBatchNorm(batch_norm)); - WithInsertPoint insert_guard{batch_norm}; - Value* input = batch_norm->namedInput(attr::input); - Value* input_dim = graph_->insert(aten::dim, {input}); - std::vector inputs{input, - batch_norm->namedInput(attr::running_mean), - batch_norm->namedInput(attr::running_var), - batch_norm->namedInput(attr::training), - batch_norm->namedInput(attr::momentum), - batch_norm->namedInput(attr::eps)}; - Value* new_output = - SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0); - auto weight = batch_norm->namedInput(attr::weight); - auto bias = batch_norm->namedInput(attr::bias); - if (isDefined(weight).value()) { - Value* expanded_weight = - graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim}); - new_output = graph_->insert(aten::mul, {new_output, expanded_weight}); - } - if (isDefined(bias).value()) { - Value* expanded_bias = - graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim}); - new_output = graph_->insert(aten::add, {new_output, expanded_bias}); + static const char* lm_source = R"SCRIPT( + def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor: + input_ndim = input.dim() + normalized_ndim = len(normalized_shape) + n = 1 + for i in range(input_ndim - normalized_ndim): + n *= input.size(i) + input_reshape = input.contiguous().view(1, n, -1) + mean, invstd = torch.batch_norm_stats(input_reshape, eps) + input_shape = input.size() + mean = torch._ncf_view(mean, input_shape, normalized_ndim) + invstd = torch._ncf_view(invstd, input_shape, normalized_ndim) + + return (input - mean) * invstd + )SCRIPT"; + Value* input = normalization_op->namedInput(attr::input); + if (normalization_op->kind() == aten::batch_norm) { + Value* input_dim = graph_->insert(aten::dim, {input}); + std::vector inputs{ + input, + normalization_op->namedInput(attr::running_mean), + normalization_op->namedInput(attr::running_var), + normalization_op->namedInput(attr::training), + normalization_op->namedInput(attr::momentum), + normalization_op->namedInput(attr::eps)}; + + Value* new_output = decomposeCommonNormalization( + normalization_op, bm_source, "batch_norm", inputs); + auto weight = normalization_op->namedInput(attr::weight); + auto bias = normalization_op->namedInput(attr::bias); + if (isDefined(weight).value()) { + Value* expanded_weight = + graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim}); + new_output = graph_->insert(aten::mul, {new_output, expanded_weight}); + } + if (isDefined(bias).value()) { + Value* expanded_bias = + graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim}); + new_output = graph_->insert(aten::add, {new_output, expanded_bias}); + } + normalization_op->output()->replaceAllUsesWith(new_output); + normalization_op->destroy(); + + } else if (normalization_op->kind() == aten::layer_norm) { + std::vector inputs{ + input, + normalization_op->namedInput(attr::normalized_shape), + normalization_op->namedInput(attr::eps), + normalization_op->namedInput(attr::cudnn_enable)}; + Value* new_output = decomposeCommonNormalization( + normalization_op, lm_source, "layer_norm", inputs); + auto weight = normalization_op->namedInput(attr::weight); + auto bias = normalization_op->namedInput(attr::bias); + auto weight_defined = isDefined(weight).value(); + auto bias_defined = isDefined(bias).value(); + if (weight_defined && bias_defined) { + new_output = graph_->insert(aten::addcmul, {bias, new_output, weight}); + } else if (weight_defined) { + new_output = graph_->insert(aten::mul, {new_output, weight}); + } else if (bias_defined) { + new_output = graph_->insert(aten::add, {new_output, bias}); + } + normalization_op->output()->replaceAllUsesWith(new_output); + normalization_op->destroy(); } - batch_norm->output()->replaceAllUsesWith(new_output); - batch_norm->destroy(); } void mergeFusionGroups(Node* consumer_group, Node* producer_group) { @@ -390,9 +467,10 @@ struct GraphFuser { group->insertInput(tensor_insert_idx, input); tensor_insert_idx++; } else if ( - (input->type()->isSubtypeOf(FloatType::get()) && input->node()->kind() != prim::Constant) || - (n->kind() == aten::_grad_sum_to_size && - input->type()->isSubtypeOf(ListType::ofInts()))) { + (input->type()->isSubtypeOf(FloatType::get()) && + input->node()->kind() != prim::Constant) || + (n->kind() == aten::_grad_sum_to_size && + input->type()->isSubtypeOf(ListType::ofInts()))) { auto in_group = subgraph.addInput(); in_group->setType(input->type()); inputs_map[input] = in_group; @@ -453,12 +531,6 @@ struct GraphFuser { return group; } - // TODO: remove this and use WithInsertPoint instead - void insertAt(Node** insertion_point, Node* n) { - n->insertAfter(*insertion_point); - *insertion_point = n; - } - at::optional tryFuse(Node* consumer, Value* producer) { // this handles cases where producer can be moved _into_ the fusion group of // consumer. @@ -506,13 +578,16 @@ struct GraphFuser { group = createSingletonFusionGroup(consumer); } if (producer->node()->matches( + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor") || + producer->node()->matches( "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) { - // We don't do any fusions in here, but simply decompose the batch norm - // into a kernel that computes the stats + pointwise ops which will be + // We don't do any fusions in here, but simply decompose the normalization + // ops into a kernel that computes the stats + pointwise ops which will be // considered in this fusion next. - decomposeBatchNorm(producer->node()); + decomposeNormalizationOps(producer->node()); return group; } + if (producer->node()->kind() == prim::FusionGroup) { mergeFusionGroups(group, producer->node()); return group;