From: Jacques Pienaar Date: Wed, 21 Dec 2022 17:49:18 +0000 (-0800) Subject: [mlir] Fix SameOperandsAndResultType to check encoding. X-Git-Tag: upstream/17.0.6~22941 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=829733af4ac2895543797443c82f1f1709472c4f;p=platform%2Fupstream%2Fllvm.git [mlir] Fix SameOperandsAndResultType to check encoding. Encoding was accidentally left out here even though it forms part of the type. This is small tightening step and I'll look at follow on to tighten more. Differential Revision: https://reviews.llvm.org/D140445 --- diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 219f1e2..d44d0b1 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -893,17 +893,30 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { auto type = op->getResult(0).getType(); auto elementType = getElementTypeOrSelf(type); + Attribute encoding = nullptr; + if (auto rankedType = dyn_cast(type)) + encoding = rankedType.getEncoding(); for (auto resultType : llvm::drop_begin(op->getResultTypes())) { if (getElementTypeOrSelf(resultType) != elementType || failed(verifyCompatibleShape(resultType, type))) return op->emitOpError() << "requires the same type for all operands and results"; + if (encoding) + if (auto rankedType = dyn_cast(resultType); + encoding != rankedType.getEncoding()) + return op->emitOpError() + << "requires the same encoding for all operands and results"; } for (auto opType : op->getOperandTypes()) { if (getElementTypeOrSelf(opType) != elementType || failed(verifyCompatibleShape(opType, type))) return op->emitOpError() << "requires the same type for all operands and results"; + if (encoding) + if (auto rankedType = dyn_cast(opType); + encoding != rankedType.getEncoding()) + return op->emitOpError() + << "requires the same encoding for all operands and results"; } return success(); } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 80e0d4c..ddba117 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -174,6 +174,14 @@ func.func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor< // ----- +func.func @failedSameOperandAndResultType_encoding_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<10xf32>) { + // expected-error@+1 {{requires the same encoding for all operands and results}} + "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32, "enc"> + return +} + +// ----- + func.func @failedElementwiseMappable_different_rankedness(%arg0: tensor, %arg1: tensor<*xf32>) { // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<*xf32>) -> tensor<*xf32>