// Operator: concat
//===----------------------------------------------------------------------===//
def Tosa_ConcatOp : Tosa_Op<"concat", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
+ InferTensorType,
Pure]> {
let summary = "Concatenates tensors along one dimension.";
);
let hasCanonicalizer = 1;
+
+ let extraClassDeclaration = [{
+ /// Returns true when two result types are compatible for this op;
+ /// Method used by InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
return success();
}
+bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != r.size() || l.size() != 1)
+ return false;
+ return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
if (outputShape[i] == ShapedType::kDynamic)
outputShape[i] = operandShape.getDimSize(i);
if (outputShape[i] != operandShape.getDimSize(i))
- return failure();
+ return emitOptionalError(location,
+ "Cannot concat tensors with different sizes"
+ " on the non-axis dimension ",
+ i);
}
hasRankedInput = true;
}
-
+ Type inputType = operands.getType()[0].cast<TensorType>().getElementType();
if (!hasRankedInput) {
- inferredReturnShapes.push_back(ShapedTypeComponents());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}
outputShape[axis] = concatDimSize;
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
return %0 : tensor<1x27x27x16xi8>
}
+// -----
+func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
+ // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
// -----
-// CHECK-LABEL: @test_concat_failure
-func.func @test_concat_failure(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () {
- // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
- %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-
- return
-}
-
-// -----
-
// CHECK-LABEL: @test_padding_no_const
func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
// CHECK: "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>