def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
// TODO: ideally something closer to
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
- C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
}
ods_def<MatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
- C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n))));
}
ods_def<MatmulI32I32I32Op>
ods_def<MatvecI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
}
ods_def<MatvecI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n))));
}
ods_def<MatvecI32I32I32Op>
ods_def<VecmatI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
}
ods_def<VecmatI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n))));
}
ods_def<VecmatI32I32I32Op>
ods_def<DotI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
- C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
}
ods_def<DotI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
- C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m))));
}
ods_def<DotI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
C(b, m, n) =
- std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
}
ods_def<BatchMatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
C(b, m, n) =
- std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ std_addi<k>(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n))));
}
// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32>
func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) {
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
- // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8>
+ // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32>
// CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8>
// CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8>
// CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32>
+ // CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32>
+ // CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32>
//
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
// a later canonicalization fuses the add into vector.contract.
- // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]]
- // CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
- // CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
- // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>
+ // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0_32]], %[[V1_32]], %[[VEC_C0]]
+ // CHECK-SAME: vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32>
+ // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32>
// CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]}
// CHECK-SAME: vector<4x12xi32>, memref<4x12xi32>
linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)