[mlir] Fix SameOperandsAndResultType to check encoding.
authorJacques Pienaar <jpienaar@google.com>
Wed, 21 Dec 2022 17:49:18 +0000 (09:49 -0800)
committerJacques Pienaar <jpienaar@google.com>
Wed, 21 Dec 2022 17:49:18 +0000 (09:49 -0800)
Encoding was accidentally left out here even though it forms part of the type.
This is small tightening step and I'll look at follow on to tighten more.

Differential Revision: https://reviews.llvm.org/D140445

mlir/lib/IR/Operation.cpp
mlir/test/IR/traits.mlir

index 219f1e2..d44d0b1 100644 (file)
@@ -893,17 +893,30 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
 
   auto type = op->getResult(0).getType();
   auto elementType = getElementTypeOrSelf(type);
+  Attribute encoding = nullptr;
+  if (auto rankedType = dyn_cast<RankedTensorType>(type))
+    encoding = rankedType.getEncoding();
   for (auto resultType : llvm::drop_begin(op->getResultTypes())) {
     if (getElementTypeOrSelf(resultType) != elementType ||
         failed(verifyCompatibleShape(resultType, type)))
       return op->emitOpError()
              << "requires the same type for all operands and results";
+    if (encoding)
+      if (auto rankedType = dyn_cast<RankedTensorType>(resultType);
+          encoding != rankedType.getEncoding())
+        return op->emitOpError()
+               << "requires the same encoding for all operands and results";
   }
   for (auto opType : op->getOperandTypes()) {
     if (getElementTypeOrSelf(opType) != elementType ||
         failed(verifyCompatibleShape(opType, type)))
       return op->emitOpError()
              << "requires the same type for all operands and results";
+    if (encoding)
+      if (auto rankedType = dyn_cast<RankedTensorType>(opType);
+          encoding != rankedType.getEncoding())
+        return op->emitOpError()
+               << "requires the same encoding for all operands and results";
   }
   return success();
 }
index 80e0d4c..ddba117 100644 (file)
@@ -174,6 +174,14 @@ func.func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<
 
 // -----
 
+func.func @failedSameOperandAndResultType_encoding_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<10xf32>) {
+  // expected-error@+1 {{requires the same encoding for all operands and results}}
+  "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32, "enc">
+  return
+}
+
+// -----
+
 func.func @failedElementwiseMappable_different_rankedness(%arg0: tensor<?xf32>, %arg1: tensor<*xf32>) {
   // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<*xf32>) -> tensor<*xf32>