return lb.create<SubIOp>(indexTy, maxRank, v);
}));
- rewriter.replaceOp(
- op, lb.create<tensor::GenerateOp>(
- getExtentTensorType(lb.getContext()), ValueRange{maxRank},
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value broadcastedDim = getBroadcastedDim(
- ImplicitLocOpBuilder(loc, b), transformed.shapes(),
- rankDiffs, args[0]);
-
- b.create<tensor::YieldOp>(loc, broadcastedDim);
- })
- ->getResults());
+ Value replacement = lb.create<tensor::GenerateOp>(
+ getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value broadcastedDim =
+ getBroadcastedDim(ImplicitLocOpBuilder(loc, b),
+ transformed.shapes(), rankDiffs, args[0]);
+
+ b.create<tensor::YieldOp>(loc, broadcastedDim);
+ });
+ if (replacement.getType() != op.getType())
+ replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
+ rewriter.replaceOp(op, replacement);
return success();
}
return
}
+// ----
+
+// CHECK-LABEL: @broadcast_to_known_rank
+func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>)
+ -> tensor<3xindex> {
+ // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+ // CHECK: return %[[RES]] : tensor<3xindex>
+ %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex>
+ return %0 : tensor<3xindex>
+}
+
// -----
// Lower `split_at`