disable batch mm if we have mutable ops (#14771)
authorMichael Suo <suo@fb.com>
Tue, 4 Dec 2018 22:28:10 +0000 (14:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 22:34:57 +0000 (14:34 -0800)
Summary:
Just to be safe, disable batch mm for mutable ops. We don't lose much for doing this, and we can go back at a calmer time to re-enable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14771

Reviewed By: eellison

Differential Revision: D13327641

Pulled By: suo

fbshipit-source-id: 96611e21ed3cb8492a2cd040f7d33fb58c52bd5e

torch/csrc/jit/passes/batch_mm.cpp

index 4b992f8..3ee75de 100644 (file)
@@ -391,7 +391,23 @@ void BatchMMSide(Block * block, const AliasDb& alias_db) {
 
 }
 
+bool hasMutableOperators(Block* block) {
+  for (auto n : block->nodes()) {
+    if (n->kind().is_aten() && n->schema().is_mutable())
+      return true;
+    for (auto b : n->blocks()) {
+      if (hasMutableOperators(b))
+        return true;
+    }
+  }
+  return false;
+}
+
 void BatchMM(std::shared_ptr<Graph>& graph) {
+  if (hasMutableOperators(graph->block())) {
+    // TODO(suo): make BatchMM mutability-safe
+    return;
+  }
   const auto alias_db = AliasAnalysis(graph);
   BatchMMTreeReduce(graph->block());
   BatchMMSide(graph->block(), alias_db);