From: Frederik Gossen Date: Mon, 15 Mar 2021 08:47:00 +0000 (+0100) Subject: [MLIR] Allow compatible shapes in `Elementwise` operations X-Git-Tag: llvmorg-14-init~12388 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2a71f95767490f5ac65d42bf55ad571e6fbd1123;p=platform%2Fupstream%2Fllvm.git [MLIR] Allow compatible shapes in `Elementwise` operations Differential Revision: https://reviews.llvm.org/D98186 --- diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 0614a7b..f427e10 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1051,15 +1051,6 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { return success(); } -/// Checks if two ShapedTypes are the same, ignoring the element type. -static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) { - if (a.getTypeID() != b.getTypeID()) - return false; - if (!a.hasRank()) - return !b.hasRank(); - return a.getShape() == b.getShape(); -} - LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { auto isMappableType = [](Type type) { return type.isa(); @@ -1088,15 +1079,14 @@ LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { return op->emitOpError( "if an operand is non-scalar, then all results must be non-scalar"); - auto mustMatchType = operandMappableTypes[0].cast(); - for (auto type : - llvm::concat(resultMappableTypes, operandMappableTypes)) { - if (!areSameShapedTypeIgnoringElementType(type.cast(), - mustMatchType)) { - return op->emitOpError() << "all non-scalar operands/results must have " - "the same shape and base type: found " - << type << " and " << mustMatchType; - } + SmallVector types = llvm::to_vector<2>( + llvm::concat(operandMappableTypes, resultMappableTypes)); + TypeID expectedBaseTy = types.front().getTypeID(); + if (!llvm::all_of(types, + [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || + failed(verifyCompatibleShapes(types))) { + return op->emitOpError() << "all non-scalar operands/results must have the " + "same shape and base type"; } return success(); diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir index 9b986e5..c6f11a8 100644 --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file %s -verify-diagnostics func @test_index_cast_shape_error(%arg0 : tensor) -> tensor<2xi64> { - // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xi64>' and 'tensor'}} + // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = index_cast %arg0 : tensor to tensor<2xi64> return %0 : tensor<2xi64> } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index b5ee968..797c1d4 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -236,7 +236,7 @@ func @func_with_ops(i32, i32) { func @func_with_ops() { ^bb0: %c = constant dense<0> : vector<42 x i32> - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xi32>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %r = "std.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1> } @@ -269,7 +269,7 @@ func @func_with_ops(i1, i32, i64) { func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } @@ -277,7 +277,7 @@ func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } @@ -514,7 +514,7 @@ func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 { // ----- func @cmpf_result_shape_mismatch(%a : vector<42xf32>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xf32>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %r = "std.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1> } @@ -614,7 +614,7 @@ func @fpext_f32_to_i32(%arg0 : f32) { // ----- func @fpext_vec(%arg0 : vector<2xf16>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = fpext %arg0 : vector<2xf16> to vector<3xf32> return } @@ -686,7 +686,7 @@ func @fptrunc_f32_to_i32(%arg0 : f32) { // ----- func @fptrunc_vec(%arg0 : vector<2xf16>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32> return } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 858f601..dc9e510 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -169,33 +169,38 @@ func @succeededSameOperandAndResultType(%t10x10 : tensor<10x10xf32>, %t1: tensor func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<20xf32>) { // expected-error@+1 {{requires the same type for all operands and results}} "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xf32> + return } // ----- func @failedElementwiseMappable_different_rankedness(%arg0: tensor, %arg1: tensor<*xf32>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<*xf32>) -> tensor<*xf32> + return } // ----- func @failedElementwiseMappable_different_rank(%arg0: tensor, %arg1: tensor) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor' and 'tensor'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor) -> tensor + return } // ----- -func @failedElementwiseMappable_different_shape(%arg0: tensor, %arg1: tensor<5xf32>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor'}} - %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<5xf32>) -> tensor +func @elementwiseMappable_dynamic_shapes(%arg0: tensor, + %arg1: tensor<5xf32>) { + %0 = "test.elementwise_mappable"(%arg0, %arg1) : + (tensor, tensor<5xf32>) -> tensor + return } // ----- func @failedElementwiseMappable_different_base_type(%arg0: vector<2xf32>, %arg1: tensor<2xf32>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xf32>' and 'vector<2xf32>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (vector<2xf32>, tensor<2xf32>) -> tensor<2xf32> return }