@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
- def test_fuse_batch_norm(self):
-
+ def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):
- def __init__(self, optimize=True):
+ def __init__(self, norm_module, optimize=True):
super(ResLike, self).__init__(optimize)
- self.bn = nn.BatchNorm2d(16)
+ self.nm = norm_module
@torch.jit.script_method
def forward(self, x, y):
- return y + torch.relu(self.bn(x))
-
- model = ResLike().cuda()
- model_noopt = ResLike(optimize=False).cuda()
- model_noopt.load_state_dict(model.state_dict())
- x = torch.randn(2, 16, 8, 8, device='cuda')
- y = torch.randn(2, 16, 8, 8, device='cuda')
- # FIXME: We need differentiation for CNNs for this optimization to trigger
- with torch.no_grad():
- out = model(x, y)
- graph = model.graph_for(x, y)
- rep = str(graph)
-
- out_noopt = model_noopt(x, y)
- rep_noopt = str(model_noopt.graph_for(x, y))
- self.assertEqual(out, out_noopt, prec=3e-5)
-
- # Check that batch_norm has really been decomposed
- self.assertIn('aten::batch_norm_update_stats', rep)
- self.assertNotIn('aten::batch_norm(', rep)
- self.assertIn('aten::batch_norm(', rep_noopt)
-
- # Make sure the fusion group is big, and contains aten::sqrt, which could
- # originate only from decomposing batch_norm in this case
- fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
- self.assertEqual(len(fusion_groups), 1)
- fused_graph = fusion_groups[0].g('Subgraph')
- self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
+ return y + torch.relu(self.nm(x))
+
+ def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
+ model = ResLike(nm).cuda()
+ model_noopt = ResLike(nm, optimize=False).cuda()
+ model_noopt.load_state_dict(model.state_dict())
+ x = torch.randn(2, 16, 8, 8, device='cuda')
+ y = torch.randn(2, 16, 8, 8, device='cuda')
+
+ # FIXME: We need differentiation for CNNs for this optimization to trigger
+ with torch.no_grad():
+ out = model(x, y)
+ graph = model.graph_for(x, y)
+ rep = str(graph)
+
+ out_noopt = model_noopt(x, y)
+ rep_noopt = str(model_noopt.graph_for(x, y))
+ self.assertEqual(out, out_noopt, prec=3e-5)
+
+ # Check that normalization op has really been decomposed
+ for node_in_graph in in_opt_graph:
+ self.assertIn(node_in_graph, rep)
+
+ for node_not_in_graph in not_in_opt_graph:
+ self.assertNotIn(node_not_in_graph, rep)
+ self.assertIn(node_not_in_graph, rep_noopt)
+
+ fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
+ self.assertEqual(len(fusion_groups), 1)
+ fused_graph = str(fusion_groups[0].g('Subgraph'))
+ for node_in_fusegraph in in_fusegraph:
+ self.assertIn(node_in_fusegraph, fused_graph)
+
+ # test for batchnorm decompose
+ bm = nn.BatchNorm2d(16)
+ test_norm_decompose(bm, ['aten::batch_norm_update_stats'],
+ ['aten::batch_norm('], ['aten::sqrt'])
+
+ # test for layernorm decompose
+ lm = nn.LayerNorm(8)
+ test_norm_decompose(lm, ['aten::batch_norm_stats'],
+ ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::addcmul'])
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
return false;
}
for (Value* input : node->inputs()) {
- if (input->type()->isSubtypeOf(TensorType::get()) || input->type()->isSubtypeOf(FloatType::get())) {
+ if (input->type()->isSubtypeOf(TensorType::get()) ||
+ input->type()->isSubtypeOf(FloatType::get())) {
continue;
}
if (input->node()->kind() != prim::Constant) {
};
})});
+RegisterOperators reg_ln_view({Operator(
+ "aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor",
+ [](const Node* node) {
+ return [](Stack& stack) {
+ const int64_t normalized_ndim = pop(stack).toInt();
+ auto input_shape = pop(stack).toIntListRef();
+ auto self = pop(stack).toTensor();
+ const int64_t input_ndim = input_shape.size();
+ c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
+ for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
+ sizes.at(i) = input_shape[i];
+ }
+ push(stack, self.reshape(sizes));
+ return 0;
+ };
+ })});
+
// Yes, no, or no value if we can't tell
c10::optional<bool> isDefined(Value* tensor) {
if (tensor->type()->isSubtypeOf(TensorType::get())) {
return {};
}
-bool isFusableBatchNorm(Node* batch_norm) {
- if (!batch_norm->matches(
- "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
- return false;
+bool isFusableNorm(Node* normalize_op) {
+ static const OperatorSet decomposable_normalization_ops = {
+ "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
+ "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
+ };
+
+ if (decomposable_normalization_ops.find(normalize_op)) {
+ // If we can't determine if weight and bias is defined statically there's
+ // really no point in decomposing normalization into simpler ops, since it
+ // won't get fused into a single kernel.
+ return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
+ isDefined(normalize_op->namedInput(attr::bias)).has_value();
}
- // If we can't determine if weight and bias is defined statically there's
- // really no point in decomposing batch norm into simpler ops, since it won't
- // get fused into a single kernel.
- return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
- isDefined(batch_norm->namedInput(attr::bias)).has_value();
+ return false;
}
Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
}
bool isFusable(Node* node) {
- return isFusableMap(node) || isFusableBatchNorm(node);
+ return isFusableMap(node) || isFusableNorm(node);
}
bool isFusableMap(Node* node) {
return *n->g(attr::Subgraph);
}
- void decomposeBatchNorm(Node* batch_norm) {
- static std::shared_ptr<Graph> bn_graph;
- static std::once_flag flag;
+ Value* decomposeCommonNormalization(
+ Node* normalization_op,
+ const char* source,
+ const std::string& method_name,
+ const std::vector<Value*>& inputs) {
+ std::shared_ptr<Graph> nm_graph;
+ std::once_flag flag;
std::call_once(
flag,
- [](std::shared_ptr<Graph>* graph_ptr) {
- static const char* source = R"SCRIPT(
+ [](std::shared_ptr<Graph>* graph_ptr,
+ const char* source,
+ const std::string& method_name) {
+ script::CompilationUnit cu;
+ cu.define(source, script::nativeResolver, nullptr);
+ *graph_ptr = cu.get_function(method_name).graph();
+ },
+ &nm_graph,
+ source,
+ method_name);
+
+ AT_ASSERT(isFusableNorm(normalization_op));
+ WithInsertPoint insert_guard{normalization_op};
+ Value* new_output =
+ SubgraphUtils::inlineGraph(nm_graph, inputs, normalization_op).at(0);
+ return new_output;
+ }
+
+ void decomposeNormalizationOps(Node* normalization_op) {
+ static const char* bm_source = R"SCRIPT(
def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
if training:
norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
norm_var = torch._unwrap_optional(running_var)
norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
- norm_invstd = 1 / (eps + torch.sqrt(norm_var))
+ norm_invstd = 1 / (torch.sqrt(norm_var + eps))
return ((input - norm_mean) * norm_invstd)
)SCRIPT";
- script::CompilationUnit cu;
- cu.define(source, script::nativeResolver, nullptr);
- *graph_ptr = cu.get_function("batch_norm").graph();
- },
- &bn_graph);
-
- AT_ASSERT(isFusableBatchNorm(batch_norm));
- WithInsertPoint insert_guard{batch_norm};
- Value* input = batch_norm->namedInput(attr::input);
- Value* input_dim = graph_->insert(aten::dim, {input});
- std::vector<Value*> inputs{input,
- batch_norm->namedInput(attr::running_mean),
- batch_norm->namedInput(attr::running_var),
- batch_norm->namedInput(attr::training),
- batch_norm->namedInput(attr::momentum),
- batch_norm->namedInput(attr::eps)};
- Value* new_output =
- SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
- auto weight = batch_norm->namedInput(attr::weight);
- auto bias = batch_norm->namedInput(attr::bias);
- if (isDefined(weight).value()) {
- Value* expanded_weight =
- graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
- new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
- }
- if (isDefined(bias).value()) {
- Value* expanded_bias =
- graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
- new_output = graph_->insert(aten::add, {new_output, expanded_bias});
+ static const char* lm_source = R"SCRIPT(
+ def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
+ input_ndim = input.dim()
+ normalized_ndim = len(normalized_shape)
+ n = 1
+ for i in range(input_ndim - normalized_ndim):
+ n *= input.size(i)
+ input_reshape = input.contiguous().view(1, n, -1)
+ mean, invstd = torch.batch_norm_stats(input_reshape, eps)
+ input_shape = input.size()
+ mean = torch._ncf_view(mean, input_shape, normalized_ndim)
+ invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
+
+ return (input - mean) * invstd
+ )SCRIPT";
+ Value* input = normalization_op->namedInput(attr::input);
+ if (normalization_op->kind() == aten::batch_norm) {
+ Value* input_dim = graph_->insert(aten::dim, {input});
+ std::vector<Value*> inputs{
+ input,
+ normalization_op->namedInput(attr::running_mean),
+ normalization_op->namedInput(attr::running_var),
+ normalization_op->namedInput(attr::training),
+ normalization_op->namedInput(attr::momentum),
+ normalization_op->namedInput(attr::eps)};
+
+ Value* new_output = decomposeCommonNormalization(
+ normalization_op, bm_source, "batch_norm", inputs);
+ auto weight = normalization_op->namedInput(attr::weight);
+ auto bias = normalization_op->namedInput(attr::bias);
+ if (isDefined(weight).value()) {
+ Value* expanded_weight =
+ graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
+ new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
+ }
+ if (isDefined(bias).value()) {
+ Value* expanded_bias =
+ graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
+ new_output = graph_->insert(aten::add, {new_output, expanded_bias});
+ }
+ normalization_op->output()->replaceAllUsesWith(new_output);
+ normalization_op->destroy();
+
+ } else if (normalization_op->kind() == aten::layer_norm) {
+ std::vector<Value*> inputs{
+ input,
+ normalization_op->namedInput(attr::normalized_shape),
+ normalization_op->namedInput(attr::eps),
+ normalization_op->namedInput(attr::cudnn_enable)};
+ Value* new_output = decomposeCommonNormalization(
+ normalization_op, lm_source, "layer_norm", inputs);
+ auto weight = normalization_op->namedInput(attr::weight);
+ auto bias = normalization_op->namedInput(attr::bias);
+ auto weight_defined = isDefined(weight).value();
+ auto bias_defined = isDefined(bias).value();
+ if (weight_defined && bias_defined) {
+ new_output = graph_->insert(aten::addcmul, {bias, new_output, weight});
+ } else if (weight_defined) {
+ new_output = graph_->insert(aten::mul, {new_output, weight});
+ } else if (bias_defined) {
+ new_output = graph_->insert(aten::add, {new_output, bias});
+ }
+ normalization_op->output()->replaceAllUsesWith(new_output);
+ normalization_op->destroy();
}
- batch_norm->output()->replaceAllUsesWith(new_output);
- batch_norm->destroy();
}
void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
group->insertInput(tensor_insert_idx, input);
tensor_insert_idx++;
} else if (
- (input->type()->isSubtypeOf(FloatType::get()) && input->node()->kind() != prim::Constant) ||
- (n->kind() == aten::_grad_sum_to_size &&
- input->type()->isSubtypeOf(ListType::ofInts()))) {
+ (input->type()->isSubtypeOf(FloatType::get()) &&
+ input->node()->kind() != prim::Constant) ||
+ (n->kind() == aten::_grad_sum_to_size &&
+ input->type()->isSubtypeOf(ListType::ofInts()))) {
auto in_group = subgraph.addInput();
in_group->setType(input->type());
inputs_map[input] = in_group;
return group;
}
- // TODO: remove this and use WithInsertPoint instead
- void insertAt(Node** insertion_point, Node* n) {
- n->insertAfter(*insertion_point);
- *insertion_point = n;
- }
-
at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
// this handles cases where producer can be moved _into_ the fusion group of
// consumer.
group = createSingletonFusionGroup(consumer);
}
if (producer->node()->matches(
+ "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor") ||
+ producer->node()->matches(
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
- // We don't do any fusions in here, but simply decompose the batch norm
- // into a kernel that computes the stats + pointwise ops which will be
+ // We don't do any fusions in here, but simply decompose the normalization
+ // ops into a kernel that computes the stats + pointwise ops which will be
// considered in this fusion next.
- decomposeBatchNorm(producer->node());
+ decomposeNormalizationOps(producer->node());
return group;
}
+
if (producer->node()->kind() == prim::FusionGroup) {
mergeFusionGroups(group, producer->node());
return group;