From a971d519327bbaf9a00b5bf5ec7ce385f6235e85 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Tue, 29 Nov 2022 11:38:15 -0800 Subject: [PATCH] [mlir][tensor] Enhance the verifier of pack and unpack op. The outer_dims_perm must be a permutation or empty. Reviewed By: chelini Differential Revision: https://reviews.llvm.org/D138936 --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 ++ mlir/test/Dialect/Tensor/invalid.mlir | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 2910191..23bfb23 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3170,6 +3170,8 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("invalid inner_dims_pos vector"); if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank)) return op->emitError("invalid outer_dims_perm vector"); + if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank) + return op->emitError("outer_dims_perm must be a permutation or empty"); // Tiling factors must be less than or equal to the input rank for pack (or // output rank for unpack), and must match the number of `inner_dims_pos`. diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index af26648..36c4dfe 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -587,6 +587,22 @@ func.func @unpack_invalid_out_of_bound_outer_perm(%input: tensor<256x128xf32>, % // ----- +func.func @pack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<16x4x32x16xf32> { + // expected-error@+1 {{outer_dims_perm must be a permutation or empty}} + %0 = tensor.pack %source outer_dims_perm = [0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<16x4x32x16xf32> + return %0 : tensor<16x4x32x16xf32> +} + +// ----- + +func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: tensor<16x4x32x16xf32>) -> tensor<128x256xf32> { + // expected-error@+1 {{outer_dims_perm must be a permutation or empty}} + %0 = tensor.unpack %dest outer_dims_perm = [1] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<16x4x32x16xf32> -> tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} + +// ----- + func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}} %0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32> -- 2.7.4