From 764ad3b3fafbf57ca916715625fffb7df5dbeb92 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 11 May 2021 13:40:03 -0700 Subject: [PATCH] [mlir][tosa] Tosa elementwise broadcasting had some minor bugs Updated tests to include broadcast of left and right. Includes bypass if in-type and out-type match shape (no broadcasting). Differential Revision: https://reviews.llvm.org/D102276 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 18 +++++++++++++----- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 4bf2dc7..3718a56b 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -522,8 +522,12 @@ static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter) { auto loc = operation->getLoc(); + + assert(operation->getNumResults() == 1 && + "All TOSA elementwise ops should only return a single result."); + auto results = operation->getResults(); - auto resultTy = operation->getOperand(0).getType().dyn_cast(); + auto resultTy = operation->getResult(0).getType().dyn_cast(); if (!resultTy) return rewriter.notifyMatchFailure(operation, @@ -531,9 +535,6 @@ elementwiseMatchAndRewriteHelper(Operation *operation, unsigned rank = resultTy.getRank(); - assert(operation->getNumResults() == 1 && - "All TOSA elementwise ops should only return a single result."); - // Construct the indexing maps needed for linalg.generic ops. SmallVector bodyArgTypes; @@ -565,11 +566,18 @@ elementwiseMatchAndRewriteHelper(Operation *operation, // Input indexing maps may be broadcasted. for (Value operand : operation->getOperands()) { ShapedType type = operand.getType().cast(); + + if (type.getShape() == resultTy.getShape()) { + operands.push_back(operand); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); + continue; + } + SmallVector newShape; SmallVector affineExprs; newShape.reserve(type.getRank()); for (auto it : llvm::enumerate(type.getShape())) { - if (it.value() != 1) { + if (it.value() == resultTy.getDimSize(it.index())) { newShape.push_back(it.value()); affineExprs.push_back( mlir::getAffineDimExpr(it.index(), rewriter.getContext())); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index dbd4f90..4916c70 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -73,6 +73,24 @@ func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32 // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> ()> + +// CHECK-LABEL: @test_broadcast_swapped_args +func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32> + // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg1 + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor) outs([[INIT]] : tensor<2xf32>) { + // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + // CHECK: [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32 + // CHECK: linalg.yield [[ELEMENT]] : f32 + // CHECK: } -> tensor<2xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> -- 2.7.4