[MLIR][linalg] Make integer matmul ops cast before multiplying
authorGeoffrey Martin-Noble <gcmn@google.com>
Fri, 26 Feb 2021 01:20:25 +0000 (17:20 -0800)
committerGeoffrey Martin-Noble <gcmn@google.com>
Fri, 26 Feb 2021 16:36:31 +0000 (08:36 -0800)
Right now they multiply before casting which means they would frequently
overflow. There are various reasonable ways to do this, but until we
have robust op description infra, this is a simple and safe default. More
careful treatments are likely to be hardware specific, as well (e.g.
using an i8*i8->i16 mul instruction).

Reviewed By: nicolasvasilache, mravishankar

Differential Revision: https://reviews.llvm.org/D97505

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/test/Dialect/Linalg/vectorization.mlir

index 3ebbb48..338cc6e 100644 (file)
@@ -15,13 +15,13 @@ implements_interface<LinalgContractionOpInterface> :
 def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
   // TODO: ideally something closer to
   //   C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
-  C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
+  C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
 }
 
 ods_def<MatmulI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
-  C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
+  C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
 }
 
 ods_def<MatmulI32I32I32Op>
@@ -39,13 +39,13 @@ def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
 ods_def<MatvecI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
-  x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
+  x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
 }
 
 ods_def<MatvecI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
-  x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
+  x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
 }
 
 ods_def<MatvecI32I32I32Op>
@@ -63,13 +63,13 @@ def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
 ods_def<VecmatI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
-  x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
+  x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
 }
 
 ods_def<VecmatI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
-  x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
+  x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
 }
 
 ods_def<VecmatI32I32I32Op>
@@ -87,13 +87,13 @@ def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
 ods_def<DotI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
-  C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
+  C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
 }
 
 ods_def<DotI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
-  C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
+  C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
 }
 
 ods_def<DotI32I32I32Op>
@@ -112,14 +112,14 @@ ods_def<BatchMatmulI8I8I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
   C(b, m, n) =
-      std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+      std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
 }
 
 ods_def<BatchMatmulI16I16I32Op>
 implements_interface<LinalgContractionOpInterface> :
 def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
   C(b, m, n) =
-      std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+      std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
 }
 
 
index 13d2e18..436d513 100644 (file)
@@ -373,17 +373,18 @@ func @matmul_tensors(
 //  CHECK-SAME:  %[[ARG2:[a-z0-9]+]]: memref<4x12xi32>
 func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) {
   //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
-  //   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8>
+  //   CHECK-DAG:   %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32>
   //   CHECK-DAG:   %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
   //   CHECK-DAG:   %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
   //   CHECK-DAG:   %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
+  //   CHECK-DAG:   %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32>
+  //   CHECK-DAG:   %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32>
   //
   // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
   // a later canonicalization fuses the add into vector.contract.
-  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]]
-  //  CHECK-SAME:     vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
-  //       CHECK:   %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
-  //       CHECK:   %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>
+  //       CHECK:   %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0_32]], %[[V1_32]], %[[VEC_C0]]
+  //  CHECK-SAME:     vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32>
+  //       CHECK:   %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32>
   //       CHECK:   vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]}
   //  CHECK-SAME:     vector<4x12xi32>, memref<4x12xi32>
   linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)