[MLIR][Shape] Fix `shape.broadcast` to standard lowering
authorFrederik Gossen <frgossen@google.com>
Thu, 29 Apr 2021 08:07:20 +0000 (10:07 +0200)
committerFrederik Gossen <frgossen@google.com>
Thu, 29 Apr 2021 08:09:15 +0000 (10:09 +0200)
Differential Revision: https://reviews.llvm.org/D101456

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index e0342f6..9e0020a 100644 (file)
@@ -155,17 +155,18 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
                        return lb.create<SubIOp>(indexTy, maxRank, v);
                      }));
 
-  rewriter.replaceOp(
-      op, lb.create<tensor::GenerateOp>(
-                getExtentTensorType(lb.getContext()), ValueRange{maxRank},
-                [&](OpBuilder &b, Location loc, ValueRange args) {
-                  Value broadcastedDim = getBroadcastedDim(
-                      ImplicitLocOpBuilder(loc, b), transformed.shapes(),
-                      rankDiffs, args[0]);
-
-                  b.create<tensor::YieldOp>(loc, broadcastedDim);
-                })
-              ->getResults());
+  Value replacement = lb.create<tensor::GenerateOp>(
+      getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value broadcastedDim =
+            getBroadcastedDim(ImplicitLocOpBuilder(loc, b),
+                              transformed.shapes(), rankDiffs, args[0]);
+
+        b.create<tensor::YieldOp>(loc, broadcastedDim);
+      });
+  if (replacement.getType() != op.getType())
+    replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
+  rewriter.replaceOp(op, replacement);
   return success();
 }
 
index 751f500..9800044 100644 (file)
@@ -593,6 +593,17 @@ func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
   return
 }
 
+// ----
+
+// CHECK-LABEL: @broadcast_to_known_rank
+func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>)
+    -> tensor<3xindex> {
+  // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+  // CHECK: return %[[RES]] : tensor<3xindex>
+  %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex>
+  return %0 : tensor<3xindex>
+}
+
 // -----
 
 // Lower `split_at`