[mlir] Handle unknown ops in dynamic_tensor_from_elements bufferization
authorSean Silva <silvasean@google.com>
Tue, 15 Dec 2020 02:22:58 +0000 (18:22 -0800)
committerSean Silva <silvasean@google.com>
Tue, 15 Dec 2020 20:50:56 +0000 (12:50 -0800)
Due to how the conversion infra works, the "clone" call that this
pattern was using required all the cloned ops to be immediately
legalized as part of this dialect conversion invocation.

That was previously working due to a couple factors:

- In the test case, there was scf.if, which we happen to mark as legal
  as part of marking the entire SCF dialect as legal for the scf.parallel
  we generate here.

- Originally, this test case had std.extract_element in the body, which
  we happened to have a pattern for in this pass. After I migrated that to
  `tensor.extract` (which removed the tensor.extract bufferization from
  here), I hacked this up to use `std.dim` which we still have patterns
  for in this pass.

This patch updates the test case to use a truly opaque op `test.source`
that properly stresses this aspect of the pattern.

(this also removes a stray dependency on the `tensor` dialect that I
must have left behind as part of my hacking this pass up when migrating
to `tensor.extract`)

Differential Revision: https://reviews.llvm.org/D93262

mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir

index 6691355..a84934b 100644 (file)
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -70,18 +69,29 @@ public:
       upperBounds.push_back(upperBound);
     }
 
-    // Generate tensor elements with a parallel loop.
-    rewriter.create<scf::ParallelOp>(
-        loc, lowerBounds, upperBounds, steps,
-        [&](OpBuilder &b, Location loc, ValueRange ivs) {
-          BlockAndValueMapping mapping;
-          mapping.map(op.body().getArguments(), ivs);
-          for (auto &nestedOp : op.getBody()->without_terminator())
-            b.clone(nestedOp, mapping);
-          auto yieldOp = cast<YieldOp>(op.getBody()->getTerminator());
-          b.create<StoreOp>(loc, mapping.lookup(yieldOp.value()), result, ivs);
-          b.create<scf::YieldOp>(loc);
-        });
+    // Generate tensor elements with a parallel loop that stores into
+    // each element of the resulting memref.
+    //
+    // This is a bit tricky. We cannot simply clone the ops because when an op
+    // is cloned, it must be legalized. However, we want to allow arbitrary ops
+    // in the body that we don't necessarily have legalization patterns for as
+    // part of this dialect conversion invocation.
+    //
+    // To accomplish this, we use mergeBlockBefore to "move" this op's body
+    // into the scf.parallel's body.
+    auto parallel =
+        rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
+    Block *parallelBody = parallel.getBody();
+    rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
+                              parallelBody->getArguments());
+    // Replace the inlined yield op with a store op. The scf.parallel's builder
+    // already populated an scf.yield at the end, so we don't need to worry
+    // about creating that.
+    Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
+    rewriter.setInsertionPointAfter(elementYield);
+    rewriter.replaceOpWithNewOp<StoreOp>(elementYield,
+                                         elementYield->getOperands()[0], result,
+                                         parallelBody->getArguments());
 
     rewriter.replaceOp(op, {result});
     return success();
@@ -168,7 +178,6 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
 
     target.addLegalDialect<StandardOpsDialect>();
     target.addLegalDialect<scf::SCFDialect>();
-    target.addLegalDialect<tensor::TensorDialect>();
 
     populateStdBufferizePatterns(context, typeConverter, patterns);
     target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
index 75ff2a9..8ae10cc 100644 (file)
@@ -123,20 +123,20 @@ func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
   return %0 : tensor<2xindex>
 }
 
-// The dynamic_tensor_from_elements op clones each op in its body.
-// Make sure that regions nested within such ops are recursively converted.
-// CHECK-LABEL: func @recursively_convert_cloned_regions
-func @recursively_convert_cloned_regions(%arg0: tensor<*xf32>, %arg1: index, %arg2: i1) -> tensor<?xindex> {
-  %tensor = dynamic_tensor_from_elements %arg1 {
+// The dynamic_tensor_from_elements op needs to put its body into the
+// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
+// the body because that would require the cloned ops to be legalized
+// immediately, which is usually not possible since they might be from various
+// other dialects.
+//
+// CHECK-LABEL: func @unknown_ops_in_body
+func @unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
+  // CHECK-NOT: dynamic_tensor_from_elements
+  %tensor = dynamic_tensor_from_elements %arg0 {
   ^bb0(%iv: index):
-    %48 = scf.if %arg2 -> (index) {
-      scf.yield %iv : index
-    } else {
-      // CHECK-NOT: dim{{.*}}tensor
-      %50 = dim %arg0, %iv : tensor<*xf32>
-      scf.yield %50 : index
-    }
-    yield %48 : index
+    // CHECK: test.source
+    %0 = "test.source"() : () -> index
+    yield %0 : index
   } : tensor<?xindex>
   return %tensor : tensor<?xindex>
 }