From 21bb63893e8557df7a7ab690ad98cb5979099186 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 25 Feb 2021 17:20:25 -0800 Subject: [PATCH] [MLIR][linalg] Make integer matmul ops cast before multiplying Right now they multiply before casting which means they would frequently overflow. There are various reasonable ways to do this, but until we have robust op description infra, this is a simple and safe default. More careful treatments are likely to be hardware specific, as well (e.g. using an i8*i8->i16 mul instruction). Reviewed By: nicolasvasilache, mravishankar Differential Revision: https://reviews.llvm.org/D97505 --- .../Linalg/IR/LinalgNamedStructuredOpsSpec.tc | 20 ++++++++++---------- mlir/test/Dialect/Linalg/vectorization.mlir | 11 ++++++----- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc index 3ebbb48..338cc6e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -15,13 +15,13 @@ implements_interface : def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { // TODO: ideally something closer to // C(m, n) += cast(A(m, k)) * cast(B(k, n)) - C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); } ods_def implements_interface : def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) { - C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); } ods_def @@ -39,13 +39,13 @@ def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { ods_def implements_interface : def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); } ods_def implements_interface : def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); } ods_def @@ -63,13 +63,13 @@ def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) { ods_def implements_interface : def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); } ods_def implements_interface : def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); } ods_def @@ -87,13 +87,13 @@ def dot(A: f32(M), B: f32(M)) -> (C: f32()) { ods_def implements_interface : def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) { - C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); } ods_def implements_interface : def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) { - C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); } ods_def @@ -112,14 +112,14 @@ ods_def implements_interface : def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); } ods_def implements_interface : def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 13d2e18..436d513 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -373,17 +373,18 @@ func @matmul_tensors( // CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32> func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index - // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8> + // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32> // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8> // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8> // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32> + // CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32> + // CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32> // // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. // a later canonicalization fuses the add into vector.contract. - // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[V0]], %[[V1]], %[[VEC_C0]] - // CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8> - // CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32> - // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32> + // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[V0_32]], %[[V1_32]], %[[VEC_C0]] + // CHECK-SAME: vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32> + // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32> // CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} // CHECK-SAME: vector<4x12xi32>, memref<4x12xi32> linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>) -- 2.7.4