[mlir][Vector] Vectorize integer matmuls
authorBenjamin Kramer <benny.kra@googlemail.com>
Wed, 22 Jul 2020 16:18:50 +0000 (18:18 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Wed, 22 Jul 2020 17:39:56 +0000 (19:39 +0200)
The underlying infrastructure supports this already, just add the
pattern matching for linalg.generic.

Differential Revision: https://reviews.llvm.org/D84335

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir

index d923ea1..8e5da6a 100644 (file)
@@ -52,9 +52,17 @@ static bool hasMultiplyAddBody(Region &r) {
   auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
   auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
   auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
+  auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
+  auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
+  auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
+  auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
   return pattern1.match(&r.front().back()) ||
          pattern2.match(&r.front().back()) ||
-         pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
+         pattern3.match(&r.front().back()) ||
+         pattern4.match(&r.front().back()) ||
+         pattern5.match(&r.front().back()) ||
+         pattern6.match(&r.front().back()) ||
+         pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
 }
 
 // TODO: Should be Tablegen'd from a single source that generates the op itself.
index 9eedc31..819b3b7 100644 (file)
@@ -118,6 +118,23 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
 //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
 //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
 
+func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
+                                 %C: memref<8x32xi32>) {
+  linalg.generic #matmul_trait %A, %B, %C {
+    ^bb(%a: i32, %b: i32, %c: i32) :
+      %d = muli %a, %b: i32
+      %e = addi %c, %d: i32
+      linalg.yield %e : i32
+  } : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32>
+  return
+}
+// CHECK-LABEL: func @vectorization_test_integer
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
+
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
   linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :