#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
} // namespace tosa
+
+namespace OpTrait {
+namespace tosa {
+
+// This trait verifies if the element type amoung operands and result
+// of multiplication match tosa specification.
+template <typename ConcreteType>
+class MulOperandsAndResultElementType
+ : public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ auto resElemType = getElementTypeOrSelf(op->getResult(0));
+
+ // In cases of floating point type, op requires the same element
+ // type for all operands and result.
+ if (llvm::isa<FloatType>(resElemType))
+ return impl::verifySameOperandsAndResultElementType(op);
+
+ if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
+ IntegerType lhsIntType =
+ getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
+ IntegerType rhsIntType =
+ getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
+ if (lhsIntType != rhsIntType)
+ return op->emitOpError(
+ "requires the same element type for all operands");
+
+ // Though the spec requires the element type of result to be i32, a more
+ // relaxed way is provided at dialect level for easier cooperating with
+ // other dialects.
+ if (lhsIntType.getWidth() > resIntType.getWidth())
+ return op->emitOpError("invalid data type size for operands or result");
+
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+} // namespace tosa
+} // namespace OpTrait
+
} // namespace mlir
#define GET_ATTRDEF_CLASSES
);
}
+def MulOperandsAndResultElementType :
+ NativeOpTrait<"MulOperandsAndResultElementType"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
//===----------------------------------------------------------------------===//
// Operator: mul
//===----------------------------------------------------------------------===//
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
Commutative,
- SameOperandsAndResultElementType]> {
+ MulOperandsAndResultElementType]> {
let summary = "Multiplication operator";
let description = [{
// -----
// CHECK-LABEL: @fold_mul_splat_i8
-func.func @fold_mul_splat_i8() -> tensor<10xi8> {
+func.func @fold_mul_splat_i8() -> tensor<10xi32> {
%one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
%two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
- %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8>
- // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>}
+ %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
+ // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
// CHECK: return %[[THREE]]
- return %mul : tensor<10xi8>
+ return %mul : tensor<10xi32>
}
// -----
}
// -----
+// CHECK-LABEL: mul
+func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
+ %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
%0 = "tosa.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>