return success();
}
-/// Checks if two ShapedTypes are the same, ignoring the element type.
-static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
- if (a.getTypeID() != b.getTypeID())
- return false;
- if (!a.hasRank())
- return !b.hasRank();
- return a.getShape() == b.getShape();
-}
-
LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
auto isMappableType = [](Type type) {
return type.isa<VectorType, TensorType>();
return op->emitOpError(
"if an operand is non-scalar, then all results must be non-scalar");
- auto mustMatchType = operandMappableTypes[0].cast<ShapedType>();
- for (auto type :
- llvm::concat<Type>(resultMappableTypes, operandMappableTypes)) {
- if (!areSameShapedTypeIgnoringElementType(type.cast<ShapedType>(),
- mustMatchType)) {
- return op->emitOpError() << "all non-scalar operands/results must have "
- "the same shape and base type: found "
- << type << " and " << mustMatchType;
- }
+ SmallVector<Type, 4> types = llvm::to_vector<2>(
+ llvm::concat<Type>(operandMappableTypes, resultMappableTypes));
+ TypeID expectedBaseTy = types.front().getTypeID();
+ if (!llvm::all_of(types,
+ [&](Type t) { return t.getTypeID() == expectedBaseTy; }) ||
+ failed(verifyCompatibleShapes(types))) {
+ return op->emitOpError() << "all non-scalar operands/results must have the "
+ "same shape and base type";
}
return success();
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
- // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xi64>' and 'tensor<index>'}}
+ // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type}}
%0 = index_cast %arg0 : tensor<index> to tensor<2xi64>
return %0 : tensor<2xi64>
}
func @func_with_ops() {
^bb0:
%c = constant dense<0> : vector<42 x i32>
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xi32>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%r = "std.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
}
func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
}
func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
}
// -----
func @cmpf_result_shape_mismatch(%a : vector<42xf32>) {
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xf32>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%r = "std.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1>
}
// -----
func @fpext_vec(%arg0 : vector<2xf16>) {
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%0 = fpext %arg0 : vector<2xf16> to vector<3xf32>
return
}
// -----
func @fptrunc_vec(%arg0 : vector<2xf16>) {
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32>
return
}
func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<20xf32>) {
// expected-error@+1 {{requires the same type for all operands and results}}
"test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xf32>
+ return
}
// -----
func @failedElementwiseMappable_different_rankedness(%arg0: tensor<?xf32>, %arg1: tensor<*xf32>) {
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor<?xf32>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<*xf32>) -> tensor<*xf32>
+ return
}
// -----
func @failedElementwiseMappable_different_rank(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) {
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<?x?xf32>' and 'tensor<?xf32>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return
}
// -----
-func @failedElementwiseMappable_different_shape(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor<?xf32>'}}
- %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<5xf32>) -> tensor<?xf32>
+func @elementwiseMappable_dynamic_shapes(%arg0: tensor<?xf32>,
+ %arg1: tensor<5xf32>) {
+ %0 = "test.elementwise_mappable"(%arg0, %arg1) :
+ (tensor<?xf32>, tensor<5xf32>) -> tensor<?xf32>
+ return
}
// -----
func @failedElementwiseMappable_different_base_type(%arg0: vector<2xf32>, %arg1: tensor<2xf32>) {
- // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xf32>' and 'vector<2xf32>'}}
+ // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
%0 = "test.elementwise_mappable"(%arg0, %arg1) : (vector<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return
}