From ac2cf07195b5833a888dc6878a9a3cb377ef59ac Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 21 Oct 2020 13:42:29 -0700 Subject: [PATCH] [spirv] Fix legalize standard to spir-v for transfer ops Forward missing attributes when creating the new transfer op otherwise the builder would use default values. Differential Revision: https://reviews.llvm.org/D89907 --- .../Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp | 6 ++++-- mlir/test/Conversion/StandardToSPIRV/legalization.mlir | 11 ++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp index a2e608d..1cf3a32 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -67,7 +67,8 @@ void LoadOpOfSubViewFolder::replaceOp( vector::TransferReadOp loadOp, SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( - loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices); + loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, + loadOp.permutation_map(), loadOp.padding(), loadOp.maskedAttr()); } template <> @@ -84,7 +85,8 @@ void StoreOpOfSubViewFolder::replaceOp( ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(), - sourceIndices); + sourceIndices, tranferWriteOp.permutation_map(), + tranferWriteOp.maskedAttr()); } } // namespace diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir index acbda35..c5c5961 100644 --- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir @@ -67,16 +67,17 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : // CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> { // CHECK-NOT: subview + // CHECK: [[F1:%.*]] = constant 1.000000e+00 : f32 // CHECK: [[C2:%.*]] = constant 2 : index // CHECK: [[C3:%.*]] = constant 3 : index // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index - // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %f0 = constant 0.0 : f32 + // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}, [[F1]] {masked = [false]} + %f1 = constant 1.0 : f32 %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> - %1 = vector.transfer_read %0[%arg3, %arg4], %f0 : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32> + %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {masked = [false]} : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32> return %1 : vector<4xf32> } @@ -90,9 +91,9 @@ func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index - // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} {masked = [false]} %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> - vector.transfer_write %arg5, %0[%arg3, %arg4] : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]> + vector.transfer_write %arg5, %0[%arg3, %arg4] {masked = [false]} : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]> return } -- 2.7.4