return rewriter.notifyMatchFailure(operation,
"All results must be a shaped type");
- // For now require no broadcasting. Consider making it support broadcasting
- // operations.
- Type uniqueInTy = operation->getOperand(0).getType();
- bool allInputTypesEqual =
- llvm::all_of(operation->getOperandTypes(),
- [&](Type operandTy) { return operandTy == uniqueInTy; });
- if (!allInputTypesEqual)
- return rewriter.notifyMatchFailure(operation,
- "All operands must have the same type");
- bool resultAndInputShapeEqual =
- llvm::all_of(operation->getResultTypes(), [&](Type resultTy) {
- return resultTy.cast<ShapedType>().getShape() == t0.getShape();
- });
-
- if (!resultAndInputShapeEqual)
- return rewriter.notifyMatchFailure(
- operation, "All results must have the same shape as the input");
+ assert(operation->getNumResults() == 1 &&
+ "All TOSA elementwise ops should only return a single result.");
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type> bodyArgTypes;
auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
- // Supports only non-broadcasted operation. Shoudl consider update indexing
- // map to be multidimensional.
unsigned nloops = t0.getRank();
- AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops);
- SmallVector<AffineMap, 2> indexingMaps(
- operation->getNumOperands() + bodyResultTypes.size(), commonIndexingMap);
+ SmallVector<AffineMap, 2> indexingMaps;
+ indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
+
+ // Input indexing maps may be broadcasted.
+ for (Type types : operation->getOperandTypes()) {
+ auto shape = types.cast<ShapedType>().getShape();
+ SmallVector<AffineExpr, 4> dimExprs;
+ dimExprs.reserve(nloops);
+ for (unsigned i = 0; i < nloops; ++i) {
+ // If the dimension is one we can broadcast the input with a constant
+ // affine expression.
+ if (shape[i] == 1)
+ dimExprs.push_back(rewriter.getAffineConstantExpr(0));
+ else
+ dimExprs.push_back(rewriter.getAffineDimExpr(i));
+ }
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops,
+ /*symbolCount=*/0, dimExprs,
+ rewriter.getContext()));
+ }
+
+ indexingMaps.append(operation->getNumResults(),
+ rewriter.getMultiDimIdentityMap(nloops));
bool didEncounterError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
// RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s
-// CHECK: #map = affine_map<() -> ()>
+// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
// CHECK-LABEL: @test_abs
func @test_abs(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor<f32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor<f32>) outs([[INIT]] : tensor<f32>) {
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%arg0 : tensor<f32>) outs([[INIT]] : tensor<f32>) {
// CHECK: ^bb0(%arg1: f32, %arg2: f32):
// CHECK: [[ELEMENT:%.+]] = absf %arg1
// CHECK: linalg.yield [[ELEMENT]] : f32
// -----
-// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @test_abs
-func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> {
- // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) {
+func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
// CHECK: ^bb0(%arg1: f32, %arg2: f32):
// CHECK: [[ELEMENT:%.+]] = absf %arg1
// CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<1xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: } -> tensor<2xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: return [[GENERIC]]
- return %0 : tensor<1xf32>
+ return %0 : tensor<2xf32>
}
// -----
-// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @test_abs
-func @test_abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
- // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2] : tensor<1x2xf32>
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x2xf32>) outs([[INIT]] : tensor<1x2xf32>) {
+func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) {
// CHECK: ^bb0(%arg1: f32, %arg2: f32):
// CHECK: [[ELEMENT:%.+]] = absf %arg1
// CHECK: linalg.yield [[ELEMENT]] : f32
- // CHECK: } -> tensor<1x2xf32>
- %0 = "tosa.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32>
+ // CHECK: } -> tensor<2x3xf32>
+ %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK: return [[GENERIC]]
- return %0 : tensor<1x2xf32>
+ return %0 : tensor<2x3xf32>
}
// -----
-func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
- // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @test_broadcast
+func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
+ // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+ // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<2xf32>
%0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
-func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
- // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
- %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
- return %0 : tensor<1xf32>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_multibroadcast
+func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x3xf32>, tensor<2x1xf32>) outs([[INIT]] : tensor<2x3xf32>) {
+ // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+ // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
+ // CHECK: linalg.yield [[ELEMENT]] : f32
+ // CHECK: } -> tensor<2x3xf32>
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
+ return %0 : tensor<2x3xf32>
}
// -----