From 85b46314c047532e5d38b3c417cb65014c97033e Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 8 Oct 2019 19:36:37 -0700 Subject: [PATCH] Allow dynamic but ranked types in ops with SameOperandsAndResultShape and SameOperandsAndResultType traits Currently SameOperandsAndResultShape trait allows operands to have tensor<*xf32> and tensor<2xf32> but doesn't allow tensor and tensor<10xf32>. Also, use the updated shape compatibility helper function in TensorCastOp::areCastCompatible method. PiperOrigin-RevId: 273658336 --- mlir/include/mlir/IR/TypeUtilities.h | 7 +++++++ mlir/lib/Dialect/StandardOps/Ops.cpp | 19 +------------------ mlir/lib/IR/Operation.cpp | 33 ++++++++------------------------- mlir/lib/IR/TypeUtilities.cpp | 31 +++++++++++++++++++++++++++++++ mlir/test/IR/invalid-ops.mlir | 8 -------- mlir/test/IR/traits.mlir | 21 ++++++++++++++++++++- mlir/test/lib/TestDialect/TestOps.td | 6 ++++++ 7 files changed, 73 insertions(+), 52 deletions(-) diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h index ce0169f..49d57e8 100644 --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -52,6 +52,13 @@ SmallVector getFlattenedTypes(TupleType t); /// dialect and typeData. bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData); +/// Returns success if the given two types have compatible shape. That is, +/// they are both scalars (not shaped), or they are both shaped types and at +/// least one is unranked or they have compatible dimensions. Dimensions are +/// compatible if at least one is dynamic or both are equal. The element type +/// does not matter. +LogicalResult verifyCompatibleShape(Type type1, Type type2); + //===----------------------------------------------------------------------===// // Utility Iterators //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index f71fde7..5cbdb67 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2215,24 +2215,7 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) { if (aT.getElementType() != bT.getElementType()) return false; - // If the either are unranked, then the cast is valid. - auto aRType = aT.dyn_cast(); - auto bRType = bT.dyn_cast(); - if (!aRType || !bRType) - return true; - - // If they are both ranked, they have to have the same rank, and any specified - // dimensions must match. - if (aRType.getRank() != bRType.getRank()) - return false; - - for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) { - int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i); - if (aDim != -1 && bDim != -1 && aDim != bDim) - return false; - } - - return true; + return succeeded(verifyCompatibleShape(aT, bT)); } OpFoldResult TensorCastOp::fold(ArrayRef operands) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 23983bc..adf38ca 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -748,33 +748,13 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, return success(); } -/// Returns success if the given two types have the same shape. That is, -/// they are both scalars (not shaped), or they are both shaped types and at -/// least one is unranked or they have the same shape. The element type does not -/// matter. -static LogicalResult verifyShapeMatch(Type type1, Type type2) { - auto sType1 = type1.dyn_cast(); - auto sType2 = type2.dyn_cast(); - - // Either both or neither type should be shaped. - if (!sType1) - return success(!sType2); - if (!sType2) - return failure(); - - if (!sType1.hasRank() || !sType2.hasRank()) - return success(); - - return success(sType1.getShape() == sType2.getShape()); -} - LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { if (failed(verifyAtLeastNOperands(op, 1))) return failure(); auto type = op->getOperand(0)->getType(); for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { - if (failed(verifyShapeMatch(opType, type))) + if (failed(verifyCompatibleShape(opType, type))) return op->emitOpError() << "requires the same shape for all operands"; } return success(); @@ -787,12 +767,12 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { auto type = op->getOperand(0)->getType(); for (auto resultType : op->getResultTypes()) { - if (failed(verifyShapeMatch(resultType, type))) + if (failed(verifyCompatibleShape(resultType, type))) return op->emitOpError() << "requires the same shape for all operands and results"; } for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { - if (failed(verifyShapeMatch(opType, type))) + if (failed(verifyCompatibleShape(opType, type))) return op->emitOpError() << "requires the same shape for all operands and results"; } @@ -843,13 +823,16 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { return failure(); auto type = op->getResult(0)->getType(); + auto elementType = getElementTypeOrSelf(type); for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) { - if (resultType != type) + if (getElementTypeOrSelf(resultType) != elementType || + failed(verifyCompatibleShape(resultType, type))) return op->emitOpError() << "requires the same type for all operands and results"; } for (auto opType : op->getOperandTypes()) { - if (opType != type) + if (getElementTypeOrSelf(opType) != elementType || + failed(verifyCompatibleShape(opType, type))) return op->emitOpError() << "requires the same type for all operands and results"; } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index 95895af..a963a8d 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -61,6 +61,37 @@ bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, return false; } +/// Returns success if the given two types have compatible shape. That is, +/// they are both scalars (not shaped), or they are both shaped types and at +/// least one is unranked or they have compatible dimensions. Dimensions are +/// compatible if at least one is dynamic or both are equal. The element type +/// does not matter. +LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { + auto sType1 = type1.dyn_cast(); + auto sType2 = type2.dyn_cast(); + + // Either both or neither type should be shaped. + if (!sType1) + return success(!sType2); + if (!sType2) + return failure(); + + if (!sType1.hasRank() || !sType2.hasRank()) + return success(); + + if (sType1.getRank() != sType2.getRank()) + return failure(); + + for (const auto &dims : llvm::zip(sType1.getShape(), sType2.getShape())) { + int64_t dim1 = std::get<0>(dims); + int64_t dim2 = std::get<1>(dims); + if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && + dim1 != dim2) + return failure(); + } + return success(); +} + OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it) : llvm::mapped_iterator(it, &unwrap) {} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index d28200b..be44a6b 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -297,14 +297,6 @@ func @func_with_ops(i1, tensor<42xi32>, tensor) { // ----- -func @func_with_ops(tensor, tensor<42xi32>, tensor<42xi32>) { -^bb0(%cond : tensor, %t : tensor<42xi32>, %f : tensor<42xi32>): - // expected-error@+1 {{requires the same shape for all operands and results}} - %r = "std.select"(%cond, %t, %f) : (tensor, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> -} - -// ----- - func @test_vector.transfer_read(memref) { ^bb0(%arg0: memref): %c3 = constant 3 : index diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 926547c..6c7fddb 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -113,10 +113,11 @@ func @failedSameOperandShape_no_operands() { // ----- // CHECK: succeededSameOperandAndResultShape -func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) { +func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>, %t1d: tensor) { %0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> %1 = "test.same_operand_and_result_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> %2 = "test.same_operand_and_result_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32> + %3 = "test.same_operand_and_result_shape"(%t1, %t1d) : (tensor<1xf32>, tensor) -> tensor<1xf32> return } @@ -143,6 +144,24 @@ func @failedSameOperandAndResultShape_no_operands(%t1: tensor<1xf32>) { // ----- +// CHECK: succeededSameOperandAndResultType +func @succeededSameOperandAndResultType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>, %t1d: tensor) { + %0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "test.same_operand_and_result_type"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %2 = "test.same_operand_and_result_type"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32> + %3 = "test.same_operand_and_result_type"(%t1, %t1d) : (tensor<1xf32>, tensor) -> tensor<1xf32> + return +} + +// ----- + +func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<20xf32>) { + // expected-error@+1 {{requires the same type for all operands and results}} + %0 = "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xf32> +} + +// ----- + func @failedHasParent_wrong_parent() { "some.op"() ({ // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index dd620de..68bae70 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -257,6 +257,12 @@ def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape", let results = (outs Variadic); } +def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type", + [SameOperandsAndResultType]> { + let arguments = (ins Variadic); + let results = (outs Variadic); +} + def ArgAndResHaveFixedElementTypesOp : TEST_Op<"arg_and_res_have_fixed_element_types", [PredOpTrait<"fixed type combination", -- 2.7.4