From a5f0b237be7f2d50efdccd8d6a95edd05c8fd52f Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Fri, 21 Jul 2023 16:28:45 -0700 Subject: [PATCH] [mlir][tosa][fix] Add proper type checking trait for tosa mul when operating integer type tensors, tosa elementwise multiplication requires the element type of result to be a 32-bit integer rather than the same type as inputs. Change-Id: Ifd3d7ebd879be5c6b2c8e23aa6d7ef41f39c6d41 Reviewed By: mgehre-amd Differential Revision: https://reviews.llvm.org/D154988 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 45 ++++++++++++++++++++++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 7 +++- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 4 +- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 8 ++-- mlir/test/Dialect/Tosa/ops.mlir | 7 ++++ 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 4447247..555d9be 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,7 +15,9 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -35,6 +37,49 @@ namespace tosa { #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" } // namespace tosa + +namespace OpTrait { +namespace tosa { + +// This trait verifies if the element type amoung operands and result +// of multiplication match tosa specification. +template +class MulOperandsAndResultElementType + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + auto resElemType = getElementTypeOrSelf(op->getResult(0)); + + // In cases of floating point type, op requires the same element + // type for all operands and result. + if (llvm::isa(resElemType)) + return impl::verifySameOperandsAndResultElementType(op); + + if (auto resIntType = resElemType.dyn_cast()) { + IntegerType lhsIntType = + getElementTypeOrSelf(op->getOperand(0)).cast(); + IntegerType rhsIntType = + getElementTypeOrSelf(op->getOperand(1)).cast(); + if (lhsIntType != rhsIntType) + return op->emitOpError( + "requires the same element type for all operands"); + + // Though the spec requires the element type of result to be i32, a more + // relaxed way is provided at dialect level for easier cooperating with + // other dialects. + if (lhsIntType.getWidth() > resIntType.getWidth()) + return op->emitOpError("invalid data type size for operands or result"); + + return success(); + } + + return failure(); + } +}; + +} // namespace tosa +} // namespace OpTrait + } // namespace mlir #define GET_ATTRDEF_CLASSES diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 812db60..3e3c070 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -747,12 +747,17 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [ ); } +def MulOperandsAndResultElementType : + NativeOpTrait<"MulOperandsAndResultElementType"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [ Commutative, - SameOperandsAndResultElementType]> { + MulOperandsAndResultElementType]> { let summary = "Multiplication operator"; let description = [{ diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 1055d6f..29d57f2 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -538,8 +538,10 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () { // CHECK-LABEL: @test_simple_i16 func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () { // CHECK: linalg.generic + // CHECK: arith.extsi + // CHECK: arith.extsi // CHECK: arith.muli - %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16> + %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index ec4d8bd..e4762de 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -294,13 +294,13 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { // ----- // CHECK-LABEL: @fold_mul_splat_i8 -func.func @fold_mul_splat_i8() -> tensor<10xi8> { +func.func @fold_mul_splat_i8() -> tensor<10xi32> { %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> - %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>} + %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>} // CHECK: return %[[THREE]] - return %mul : tensor<10xi8> + return %mul : tensor<10xi32> } // ----- diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 0ad53bd..f0ff06a 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -230,6 +230,13 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te } // ----- +// CHECK-LABEL: mul +func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> { + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16> + return %0 : tensor<13x21x3xi16> +} + +// ----- // CHECK-LABEL: pow func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { %0 = "tosa.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> -- 2.7.4