[mlir][shape] Use memref of index in shape lowering
authorStephan Herhut <herhut@google.com>
Thu, 30 Jul 2020 12:38:12 +0000 (14:38 +0200)
committerStephan Herhut <herhut@google.com>
Thu, 30 Jul 2020 13:12:43 +0000 (15:12 +0200)
Now that we can have a memref of index type, we no longer need to materialize shapes in i64 and then index_cast.

Differential Revision: https://reviews.llvm.org/D84938

mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir

index 0101c9e7fdc01da34253379312b484938c37a4a8..a6c667f5641c39376a981e92b060d9b6ab129270 100644 (file)
@@ -186,8 +186,8 @@ ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
   // Allocate stack memory.
   auto loc = op.getLoc();
   Value rank = rewriter.create<mlir::RankOp>(loc, arg);
-  Type i64Ty = rewriter.getI64Type();
-  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
+  Type indexTy = rewriter.getIndexType();
+  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
   Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
 
   // Copy shape extents to stack-allocated memory.
@@ -197,15 +197,12 @@ ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
       loc, zero, rank, one, llvm::None,
       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
         Value dim = rewriter.create<DimOp>(loc, arg, iv);
-        Value dimInt = rewriter.create<IndexCastOp>(loc, dim, i64Ty);
-        rewriter.create<StoreOp>(loc, dimInt, mem, ValueRange{iv});
+        rewriter.create<StoreOp>(loc, dim, mem, ValueRange{iv});
         rewriter.create<scf::YieldOp>(loc);
       });
 
   // Load extents to tensor value.
-  Value extentTensorInt = rewriter.create<TensorLoadOp>(loc, mem);
-  rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), extentTensorInt,
-                                           op.getType());
+  rewriter.replaceOpWithNewOp<TensorLoadOp>(op.getOperation(), mem);
   return success();
 }
 
index 97d2bce5a0948cea5227f00fbf5aded5274dfc6a..768a627208b8e1e0936ab07819b35cbe090c51b6 100644 (file)
@@ -40,16 +40,14 @@ func @shape_of(%arg : tensor<*xf32>) {
 // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
 func @shape_of_unranked(%arg : tensor<*xf32>) {
   // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
-  // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
+  // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xindex>
   // CHECK: %[[C0:.*]] = constant 0 : index
   // CHECK: %[[C1:.*]] = constant 1 : index
   // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
   // CHECK:   %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
-  // CHECK:   %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
-  // CHECK:   store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
+  // CHECK:   store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref<?xindex>
   // CHECK: }
-  // CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
-  // CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
+  // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xindex>
   %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
   return
 }