// Operator: reduce_all
//===----------------------------------------------------------------------===//
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce All operator";
let description = [{
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_any
//===----------------------------------------------------------------------===//
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Any operator";
let description = [{
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_max
//===----------------------------------------------------------------------===//
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Max operator";
let description = [{
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_min
//===----------------------------------------------------------------------===//
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Min operator";
let description = [{
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_prod
//===----------------------------------------------------------------------===//
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Prod operator";
let description = [{
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
// Operator: reduce_sum
//===----------------------------------------------------------------------===//
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- Pure]> {
+ InferTensorType, Pure]> {
let summary = "Reduce Sum operator";
let description = [{
);
let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
return success();
}
-bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
- if (l.size() != r.size() || l.size() != 1)
- return false;
- if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))
- return false;
- return succeeded(verifyCompatibleShape(l[0], r[0]));
-}
-
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
}
static LogicalResult ReduceInferReturnTypes(
- ShapeAdaptor operandShape, IntegerAttr axis,
+ ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
if (!operandShape.hasRank()) {
- inferredReturnShapes.push_back(ShapedTypeComponents());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
operandShape.getDims(outputShape);
int64_t axisVal = axis.getValue().getSExtValue();
outputShape[axisVal] = 1;
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
+#define COMPATIBLE_RETURN_TYPES(OP) \
+ bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
+ if (l.size() != r.size() || l.size() != 1) \
+ return false; \
+ if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
+ return false; \
+ return succeeded(verifyCompatibleShape(l[0], r[0])); \
+ }
+
#define REDUCE_SHAPE_INFER(OP) \
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
ValueShapeRange operands, DictionaryAttr attributes, \
RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
- return ReduceInferReturnTypes(operands.getShape(0), \
+ Type inputType = \
+ operands.getType()[0].cast<TensorType>().getElementType(); \
+ return ReduceInferReturnTypes(operands.getShape(0), inputType, \
attributes.get("axis").cast<IntegerAttr>(), \
inferredReturnShapes); \
- }
+ } \
+ COMPATIBLE_RETURN_TYPES(OP)
REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
#undef REDUCE_SHAPE_INFER
+COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
+#undef COMPATIBLE_RETURN_TYPES
static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
%2 = "tosa.fully_connected"(%1, %0, %arg1) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
return %2 : tensor<273x2xf32>
}
+
+// -----
+
+func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error@+1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
+ %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error@+1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}}
+ %0 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error@+1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}}
+ %0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32>
+ return
+}
+
+// -----
+
+func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+ // expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
+ %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
+ return
+}