[MLIR][TOSA] Tosa elementwise broadcasting
authorRob Suderman <rob.suderman@gmail.com>
Sat, 6 Feb 2021 00:45:36 +0000 (16:45 -0800)
committerRob Suderman <rob.suderman@gmail.com>
Wed, 10 Feb 2021 23:28:18 +0000 (15:28 -0800)
Added support for broadcasting size-1 dimensions for TOSA elemtnwise
operations.

Differential Revision: https://reviews.llvm.org/D96190

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index 8ecd15d..fcc5a52 100644 (file)
@@ -152,23 +152,8 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
     return rewriter.notifyMatchFailure(operation,
                                        "All results must be a shaped type");
 
-  // For now require no broadcasting. Consider making it support broadcasting
-  // operations.
-  Type uniqueInTy = operation->getOperand(0).getType();
-  bool allInputTypesEqual =
-      llvm::all_of(operation->getOperandTypes(),
-                   [&](Type operandTy) { return operandTy == uniqueInTy; });
-  if (!allInputTypesEqual)
-    return rewriter.notifyMatchFailure(operation,
-                                       "All operands must have the same type");
-  bool resultAndInputShapeEqual =
-      llvm::all_of(operation->getResultTypes(), [&](Type resultTy) {
-        return resultTy.cast<ShapedType>().getShape() == t0.getShape();
-      });
-
-  if (!resultAndInputShapeEqual)
-    return rewriter.notifyMatchFailure(
-        operation, "All results must have the same shape as the input");
+  assert(operation->getNumResults() == 1 &&
+         "All TOSA elementwise ops should only return a single result.");
 
   // Construct the indexing maps needed for linalg.generic ops.
   SmallVector<Type> bodyArgTypes;
@@ -194,12 +179,30 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
       initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
 
-  // Supports only non-broadcasted operation. Shoudl consider update indexing
-  // map to be multidimensional.
   unsigned nloops = t0.getRank();
-  AffineMap commonIndexingMap = rewriter.getMultiDimIdentityMap(nloops);
-  SmallVector<AffineMap, 2> indexingMaps(
-      operation->getNumOperands() + bodyResultTypes.size(), commonIndexingMap);
+  SmallVector<AffineMap, 2> indexingMaps;
+  indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
+
+  // Input indexing maps may be broadcasted.
+  for (Type types : operation->getOperandTypes()) {
+    auto shape = types.cast<ShapedType>().getShape();
+    SmallVector<AffineExpr, 4> dimExprs;
+    dimExprs.reserve(nloops);
+    for (unsigned i = 0; i < nloops; ++i) {
+      // If the dimension is one we can broadcast the input with a constant
+      // affine expression.
+      if (shape[i] == 1)
+        dimExprs.push_back(rewriter.getAffineConstantExpr(0));
+      else
+        dimExprs.push_back(rewriter.getAffineDimExpr(i));
+    }
+    indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops,
+                                          /*symbolCount=*/0, dimExprs,
+                                          rewriter.getContext()));
+  }
+
+  indexingMaps.append(operation->getNumResults(),
+                      rewriter.getMultiDimIdentityMap(nloops));
 
   bool didEncounterError = false;
   auto linalgOp = rewriter.create<linalg::GenericOp>(
index e416246..8963544 100644 (file)
@@ -1,11 +1,11 @@
 // RUN: mlir-opt --split-input-file --tosa-to-linalg-on-tensors %s -verify-diagnostics -o -| FileCheck %s
 
-// CHECK: #map = affine_map<() -> ()>
+// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
 
 // CHECK-LABEL: @test_abs
 func @test_abs(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [] : tensor<f32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor<f32>) outs([[INIT]] : tensor<f32>) {
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%arg0 : tensor<f32>) outs([[INIT]] : tensor<f32>) {
   // CHECK: ^bb0(%arg1: f32, %arg2: f32):
   // CHECK:   [[ELEMENT:%.+]] = absf %arg1
   // CHECK:   linalg.yield [[ELEMENT]] : f32
@@ -19,54 +19,73 @@ func @test_abs(%arg0: tensor<f32>) -> tensor<f32> {
 
 // -----
 
-// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: @test_abs
-func @test_abs(%arg0: tensor<1xf32>) -> tensor<1xf32> {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] : tensor<1xf32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs([[INIT]] : tensor<1xf32>) {
+func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
   // CHECK: ^bb0(%arg1: f32, %arg2: f32):
   // CHECK:   [[ELEMENT:%.+]] = absf %arg1
   // CHECK:   linalg.yield [[ELEMENT]] : f32
-  // CHECK: } -> tensor<1xf32>
-  %0 = "tosa.abs"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+  // CHECK: } -> tensor<2xf32>
+  %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
 
   // CHECK: return [[GENERIC]]
-  return %0 : tensor<1xf32>
+  return %0 : tensor<2xf32>
 }
 
 // -----
 
-// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: @test_abs
-func @test_abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2] : tensor<1x2xf32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<1x2xf32>) outs([[INIT]] : tensor<1x2xf32>) {
+func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) {
   // CHECK: ^bb0(%arg1: f32, %arg2: f32):
   // CHECK:   [[ELEMENT:%.+]] = absf %arg1
   // CHECK:   linalg.yield [[ELEMENT]] : f32
-  // CHECK: } -> tensor<1x2xf32>
-  %0 = "tosa.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32>
+  // CHECK: } -> tensor<2x3xf32>
+  %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
 
   // CHECK: return [[GENERIC]]
-  return %0 : tensor<1x2xf32>
+  return %0 : tensor<2x3xf32>
 }
 
 // -----
 
-func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
-  // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @test_broadcast
+func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
+  // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+  // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
+  // CHECK:   linalg.yield [[ELEMENT]] : f32
+  // CHECK: } -> tensor<2xf32>
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32>
   return %0 : tensor<2xf32>
 }
 
 // -----
 
-func @test_add(%arg0: tensor<1xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
-  // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
-  return %0 : tensor<1xf32>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @test_multibroadcast
+func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x3xf32>, tensor<2x1xf32>) outs([[INIT]] : tensor<2x3xf32>) {
+  // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
+  // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
+  // CHECK:   linalg.yield [[ELEMENT]] : f32
+  // CHECK: } -> tensor<2x3xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
+  return %0 : tensor<2x3xf32>
 }
 
 // -----