}
+bool hasMutableOperators(Block* block) {
+ for (auto n : block->nodes()) {
+ if (n->kind().is_aten() && n->schema().is_mutable())
+ return true;
+ for (auto b : n->blocks()) {
+ if (hasMutableOperators(b))
+ return true;
+ }
+ }
+ return false;
+}
+
void BatchMM(std::shared_ptr<Graph>& graph) {
+ if (hasMutableOperators(graph->block())) {
+ // TODO(suo): make BatchMM mutability-safe
+ return;
+ }
const auto alias_db = AliasAnalysis(graph);
BatchMMTreeReduce(graph->block());
BatchMMSide(graph->block(), alias_db);