[mlir][Vector] Add a folder for vector.broadcast
authorHanhan Wang <hanchung@google.com>
Thu, 17 Sep 2020 15:54:16 +0000 (08:54 -0700)
committerHanhan Wang <hanchung@google.com>
Thu, 17 Sep 2020 15:54:51 +0000 (08:54 -0700)
Fold the operation if the source is a scalar constant or splat constant.

Update transform-patterns-matmul-to-vector.mlir because the broadcast ops are folded in the conversion.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D87703

mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
mlir/test/Dialect/Vector/canonicalize.mlir

index 3cb1265..04aa18c 100644 (file)
@@ -270,6 +270,7 @@ def Vector_BroadcastOp :
     }
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
+  let hasFolder = 1;
 }
 
 def Vector_ShuffleOp :
index c2b6f31..c2cfaa5 100644 (file)
@@ -929,6 +929,17 @@ static LogicalResult verify(BroadcastOp op) {
   return success();
 }
 
+OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0])
+    return {};
+  auto vectorType = getVectorType();
+  if (operands[0].getType().isIntOrIndexOrFloat())
+    return DenseElementsAttr::get(vectorType, operands[0]);
+  if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
+    return DenseElementsAttr::get(vectorType, attr.getSplatValue());
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ShuffleOp
 //===----------------------------------------------------------------------===//
index 83e9461..683aeb2 100644 (file)
@@ -13,13 +13,8 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 }
 
 // CHECK-LABEL:func @matmul
-//      CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
 //      CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
-//
-//      CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
 //      CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
-//
-//      CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
 //      CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
 //
 //      CHECK: linalg.copy
index 1b1362f..9c36f76 100644 (file)
@@ -385,3 +385,28 @@ func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf
   %2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32>
   return %0, %2 : vector<4x8xf32>, vector<2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: broadcast_folding1
+//       CHECK: %[[CST:.*]] = constant dense<42> : vector<4xi32>
+//   CHECK-NOT: vector.broadcast
+//       CHECK: return %[[CST]]
+func @broadcast_folding1() -> vector<4xi32> {
+  %0 = constant 42 : i32
+  %1 = vector.broadcast %0 : i32 to vector<4xi32>
+  return %1 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_folding2
+//       CHECK: %[[CST:.*]] = constant dense<42> : vector<4x16xi32>
+//   CHECK-NOT: vector.broadcast
+//       CHECK: return %[[CST]]
+func @broadcast_folding2() -> vector<4x16xi32> {
+  %0 = constant 42 : i32
+  %1 = vector.broadcast %0 : i32 to vector<16xi32>
+  %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
+  return %2 : vector<4x16xi32>
+}