From 451b1ff37639e8a6194940fbf5ffdc23b8b4af28 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 25 Oct 2022 17:05:40 +0200 Subject: [PATCH] [mlir] Add lower-to-loops tests for linalg.map/reduce/transpose. Differential Revision: https://reviews.llvm.org/D136691 --- .../lower-to-loops-using-interface.mlir | 81 ++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir index 5193649..f9dd94e 100644 --- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -159,3 +159,84 @@ func.func @pool_strides_and_dilation(%arg0 : memref, %arg1 : memref // CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK: %[[T10:.+]] = arith.maxf %[[T9]], %[[T8]] // CHECK: memref.store %[[T10]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] + +// ----- + +func.func @map(%lhs: memref<64xf32>, + %rhs: memref<64xf32>, %out: memref<64xf32>) { + linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>) + outs(%out : memref<64xf32>) + (%in: f32, %in_0: f32) { + %0 = arith.addf %in, %in_0 : f32 + linalg.yield %0 : f32 + } + return +} +// CHECK-LABEL: func.func @map( +// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>, +// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>, +// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<64xf32>) { + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] { +// CHECK: %[[LHS_ELEM:.*]] = memref.load %[[LHS]][%[[I]]] +// CHECK: %[[RHS_ELEM:.*]] = memref.load %[[RHS]][%[[I]]] +// CHECK: %[[ADD:.*]] = arith.addf %[[LHS_ELEM]], %[[RHS_ELEM]] +// CHECK: memref.store %[[ADD]], %[[OUT]][%[[I]]] + +// ----- + +func.func @transpose(%arg0: memref<16x32x64xf32>, + %arg1: memref<32x64x16xf32>) { + linalg.transpose ins(%arg0 : memref<16x32x64xf32>) + outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0] + return +} +// CHECK-LABEL: func.func @transpose( +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>, +// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>) + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C16]] step %[[C1]] { +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C32]] step %[[C1]] { +// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C64]] step %[[C1]] { +// CHECK: %[[ELEM:.*]] = memref.load %[[OUT]][%[[J]], %[[K]], %[[I]]] +// CHECK: memref.store %[[ELEM]], %[[OUT]][%[[J]], %[[K]], %[[I]]] + +// ----- + +func.func @reduce(%arg0: memref<16x32x64xf32>, + %arg1: memref<16x64xf32>) { + linalg.reduce ins(%arg0 : memref<16x32x64xf32>) + outs(%arg1 : memref<16x64xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %0 = arith.addf %in, %init : f32 + linalg.yield %0 : f32 + } + return +} +// CHECK-LABEL: func.func @reduce( +// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>, +// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32> + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C16]] step %[[C1]] { +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C32]] step %[[C1]] { +// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C64]] step %[[C1]] { +// CHECK: %[[IN_ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[J]], %[[K]]] +// CHECK: %[[OUT_ELEM:.*]] = memref.load %[[OUT]][%[[I]], %[[K]]] +// CHECK: %[[ADD:.*]] = arith.addf %[[IN_ELEM]], %[[OUT_ELEM]] +// CHECK: memref.store %[[ADD]], %[[OUT]][%[[I]], %[[K]]] -- 2.7.4