[MLIR][Shape] Lower `shape_of` for unranked tensors
authorFrederik Gossen <frgossen@google.com>
Thu, 25 Jun 2020 08:50:02 +0000 (08:50 +0000)
committerFrederik Gossen <frgossen@google.com>
Thu, 25 Jun 2020 08:50:45 +0000 (08:50 +0000)
Lower `shape_of` for unranked tensors.
Materializes shape in stack-allocated memory.

Differential Revision: https://reviews.llvm.org/D82196

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir

index d61c8af..8440b9b 100644 (file)
@@ -1408,7 +1408,9 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
 
   let builders = [
     OpBuilder<"OpBuilder &builder, OperationState &result, "
-              "Value memrefOrTensor, int64_t index">
+              "Value memrefOrTensor, int64_t index">,
+    OpBuilder<"OpBuilder &builder, OperationState &result, "
+              "Value memrefOrTensor, Value index">
   ];
 
   let extraClassDeclaration = [{
index db7796d..adf046e 100644 (file)
@@ -70,6 +70,58 @@ ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
 }
 
 namespace {
+/// Converts `shape_of` to for loop for unranked tensors.
+class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
+public:
+  using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+  ShapeOfOp::Adaptor transformed(operands);
+  auto tensorVal = transformed.arg();
+  auto tensorTy = tensorVal.getType();
+
+  // For ranked tensors `shape_of` lowers to `std` and the pattern can be
+  // found in the corresponding pass.
+  if (tensorTy.isa<RankedTensorType>())
+    return failure();
+
+  // Allocate stack memory.
+  auto loc = op.getLoc();
+  auto rankVal = rewriter.create<RankOp>(loc, tensorVal);
+  auto i64Ty = rewriter.getI64Type();
+  auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
+  auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
+
+  // Copy shape extents to stack-allocated memory.
+  auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
+  auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
+  rewriter.create<scf::ForOp>(
+      loc, zeroVal, rankVal, oneVal, ValueRange(),
+      [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
+        auto dimVal = b.create<DimOp>(loc, tensorVal, iVal);
+        auto dimIntVal = b.create<IndexCastOp>(loc, dimVal, i64Ty);
+        b.create<StoreOp>(loc, dimIntVal, memVal, ValueRange({iVal}));
+        b.create<scf::YieldOp>(loc);
+      });
+
+  // Load extents to tensor value.
+  auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
+  auto indexTy = rewriter.getIndexType();
+  auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
+                                           shapeTy);
+  return success();
+}
+
+namespace {
 struct ConvertShapeToSCFPass
     : public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
   void runOnFunction() override;
@@ -79,19 +131,23 @@ struct ConvertShapeToSCFPass
 void ConvertShapeToSCFPass::runOnFunction() {
   MLIRContext &ctx = getContext();
 
+  // Populate conversion patterns.
   OwningRewritePatternList patterns;
   populateShapeToSCFConversionPatterns(patterns, &ctx);
 
+  // Setup target legality.
   ConversionTarget target(getContext());
   target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
-  target.addIllegalOp<ReduceOp>();
-  if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+  target.addIllegalOp<ReduceOp, ShapeOfOp>();
+
+  // Apply conversion.
+  if (failed(applyPartialConversion(getFunction(), target, patterns)))
     signalPassFailure();
 }
 
 void mlir::populateShapeToSCFConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<ReduceOpConverter>(ctx);
+  patterns.insert<ReduceOpConverter, ShapeOfOpConverter>(ctx);
 }
 
 std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
index ca4fe83..6e6ad47 100644 (file)
@@ -1273,8 +1273,13 @@ void DimOp::build(OpBuilder &builder, OperationState &result,
                   Value memrefOrTensor, int64_t index) {
   auto loc = result.location;
   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
+  build(builder, result, memrefOrTensor, indexValue);
+}
+
+void DimOp::build(OpBuilder &builder, OperationState &result,
+                  Value memrefOrTensor, Value index) {
   auto indexTy = builder.getIndexType();
-  build(builder, result, indexTy, memrefOrTensor, indexValue);
+  build(builder, result, indexTy, memrefOrTensor, index);
 }
 
 Optional<int64_t> DimOp::getConstantIndex() {
index b52266c..1c21456 100644 (file)
@@ -26,3 +26,25 @@ func @shape_reduce(%shape : !shape.shape) -> !shape.size {
 // CHECK-NEXT:   scf.yield [[NEW_ACC]] : !shape.size
 // CHECK-NEXT: }
 // CHECK-NEXT: return [[RESULT]] : !shape.size
+
+// -----
+
+// Lower `shape_of` for unranked tensors.
+// CHECK-LABEL: @shape_of_unranked
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of_unranked(%arg : tensor<*xf32>) {
+  // CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+  // CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
+  // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
+  // CHECK-DAG:   %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+  // CHECK-DAG:   %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
+  // CHECK-DAG:   store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
+  // CHECK:     }
+  // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
+  // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
+  %shape = shape.shape_of %arg : tensor<*xf32>
+  return
+}
+