From 53a9d4f3121f634d92336ea7d26122fccc7c1418 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Tue, 4 Dec 2018 14:28:10 -0800 Subject: [PATCH] disable batch mm if we have mutable ops (#14771) Summary: Just to be safe, disable batch mm for mutable ops. We don't lose much for doing this, and we can go back at a calmer time to re-enable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14771 Reviewed By: eellison Differential Revision: D13327641 Pulled By: suo fbshipit-source-id: 96611e21ed3cb8492a2cd040f7d33fb58c52bd5e --- torch/csrc/jit/passes/batch_mm.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 4b992f8..3ee75de 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -391,7 +391,23 @@ void BatchMMSide(Block * block, const AliasDb& alias_db) { } +bool hasMutableOperators(Block* block) { + for (auto n : block->nodes()) { + if (n->kind().is_aten() && n->schema().is_mutable()) + return true; + for (auto b : n->blocks()) { + if (hasMutableOperators(b)) + return true; + } + } + return false; +} + void BatchMM(std::shared_ptr& graph) { + if (hasMutableOperators(graph->block())) { + // TODO(suo): make BatchMM mutability-safe + return; + } const auto alias_db = AliasAnalysis(graph); BatchMMTreeReduce(graph->block()); BatchMMSide(graph->block(), alias_db); -- 2.7.4