[mlir][tosa][fix] Add proper type checking trait for tosa mul
authorTatWai Chong <tatwai.chong@arm.com>
Fri, 21 Jul 2023 23:28:45 +0000 (16:28 -0700)
committerEric Kunze <eric.kunze@arm.com>
Fri, 21 Jul 2023 23:29:05 +0000 (23:29 +0000)
when operating integer type tensors, tosa elementwise multiplication
requires the element type of result to be a 32-bit integer rather
than the same type as inputs.

Change-Id: Ifd3d7ebd879be5c6b2c8e23aa6d7ef41f39c6d41

Reviewed By: mgehre-amd

Differential Revision: https://reviews.llvm.org/D154988

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Tosa/constant-op-fold.mlir
mlir/test/Dialect/Tosa/ops.mlir

index 4447247..555d9be 100644 (file)
@@ -15,7 +15,9 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Traits.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -35,6 +37,49 @@ namespace tosa {
 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
 
 } // namespace tosa
+
+namespace OpTrait {
+namespace tosa {
+
+// This trait verifies if the element type amoung operands and result
+// of multiplication match tosa specification.
+template <typename ConcreteType>
+class MulOperandsAndResultElementType
+    : public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    auto resElemType = getElementTypeOrSelf(op->getResult(0));
+
+    // In cases of floating point type, op requires the same element
+    // type for all operands and result.
+    if (llvm::isa<FloatType>(resElemType))
+      return impl::verifySameOperandsAndResultElementType(op);
+
+    if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
+      IntegerType lhsIntType =
+          getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
+      IntegerType rhsIntType =
+          getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
+      if (lhsIntType != rhsIntType)
+        return op->emitOpError(
+            "requires the same element type for all operands");
+
+      // Though the spec requires the element type of result to be i32, a more
+      // relaxed way is provided at dialect level for easier cooperating with
+      // other dialects.
+      if (lhsIntType.getWidth() > resIntType.getWidth())
+        return op->emitOpError("invalid data type size for operands or result");
+
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+} // namespace tosa
+} // namespace OpTrait
+
 } // namespace mlir
 
 #define GET_ATTRDEF_CLASSES
index 812db60..3e3c070 100644 (file)
@@ -747,12 +747,17 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
   );
 }
 
+def MulOperandsAndResultElementType :
+  NativeOpTrait<"MulOperandsAndResultElementType"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: mul
 //===----------------------------------------------------------------------===//
 def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
     Commutative,
-    SameOperandsAndResultElementType]> {
+    MulOperandsAndResultElementType]> {
   let summary = "Multiplication operator";
 
   let description = [{
index 1055d6f..29d57f2 100644 (file)
@@ -538,8 +538,10 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
 // CHECK-LABEL: @test_simple_i16
 func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: linalg.generic
+  // CHECK: arith.extsi
+  // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16>
+  %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
 
   return
 }
index ec4d8bd..e4762de 100644 (file)
@@ -294,13 +294,13 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // -----
 
 // CHECK-LABEL: @fold_mul_splat_i8
-func.func @fold_mul_splat_i8() -> tensor<10xi8> {
+func.func @fold_mul_splat_i8() -> tensor<10xi32> {
   %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
   %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
-  %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8>
-  // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>}
+  %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
   // CHECK: return %[[THREE]]
-  return %mul : tensor<10xi8>
+  return %mul : tensor<10xi32>
 }
 
 // -----
index 0ad53bd..f0ff06a 100644 (file)
@@ -230,6 +230,13 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te
 }
 
 // -----
+// CHECK-LABEL: mul
+func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
+  %0 = "tosa.mul"(%arg0, %arg1)  { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
 // CHECK-LABEL: pow
 func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
   %0 = "tosa.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>