From 001d601ac4fb1ee02d4bb3990f2f5a8afacd4932 Mon Sep 17 00:00:00 2001 From: Javier Setoain Date: Wed, 5 May 2021 09:38:50 +0200 Subject: [PATCH] [mlir][ArmSVE] Add basic arithmetic operations While we figure out how to best add Standard support for scalable vectors, these instructions provide a workaround for basic arithmetic between scalable vectors. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D100837 --- mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td | 55 ++++++++++++++++++++++ .../ArmSVE/Transforms/LegalizeForLLVMExport.cpp | 47 ++++++++++++++---- mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir | 34 +++++++++++++ mlir/test/Dialect/ArmSVE/roundtrip.mlir | 20 ++++++++ mlir/test/Target/LLVMIR/arm-sve.mlir | 24 ++++++++++ 5 files changed, 170 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td index 4e75a56..33c60ba 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -122,6 +122,42 @@ class ArmSVE_IntrBinaryOverloadedOp traits=*/traits, /*int numResults=*/1>; +class ScalableFOp traits = []> : + ArmSVE_Op])> { + let summary = op_description # " for scalable vectors of floats"; + let description = [{ + The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and + returns one scalable vector with the result of the }] # op_description # [{. + }]; + let arguments = (ins + ScalableVectorOf<[AnyFloat]>:$src1, + ScalableVectorOf<[AnyFloat]>:$src2 + ); + let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + +class ScalableIOp traits = []> : + ArmSVE_Op])> { + let summary = op_description # " for scalable vectors of integers"; + let description = [{ + The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and + returns one scalable vector with the result of the }] # op_description # [{. + }]; + let arguments = (ins + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); + let assemblyFormat = + "$src1 `,` $src2 attr-dict `:` type($src1)"; +} + def SdotOp : ArmSVE_Op<"sdot", [NoSideEffect, AllTypesMatch<["src1", "src2"]>, @@ -266,6 +302,25 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale", "attr-dict `:` type($res)"; } + +def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>; + +def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>; + +def ScalableSubIOp : ScalableIOp<"subi", "subtraction">; + +def ScalableSubFOp : ScalableFOp<"subf", "subtraction">; + +def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>; + +def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>; + +def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">; + +def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">; + +def ScalableDivFOp : ScalableFOp<"divf", "division">; + def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index b0197cb..b258f2a 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -84,6 +84,38 @@ using UmmlaOpLowering = OneToOneConvertToLLVMPattern; using VectorScaleOpLowering = OneToOneConvertToLLVMPattern; +static void +populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + // clang-format off + patterns.add, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern, + OneToOneConvertToLLVMPattern + >(converter); + // clang-format on +} + +static void +configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { + // clang-format off + target.addIllegalOp(); + // clang-format on +} + /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { @@ -106,20 +138,14 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( UmmlaOpLowering, VectorScaleOpLowering>(converter); // clang-format on + populateBasicSVEArithmeticExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); auto hasScalableVectorType = [](TypeRange types) { for (Type type : types) if (type.isa()) @@ -135,4 +161,5 @@ void mlir::configureArmSVELegalizeForExportTarget( return !hasScalableVectorType(op->getOperandTypes()) && !hasScalableVectorType(op->getResultTypes()); }); + configureBasicSVEArithmeticLegalizations(target); } diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index 247b53e..f81196f 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -40,6 +40,40 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, return %0 : !arm_sve.vector<4xi32> } +func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>, + %d: !arm_sve.vector<4xi32>, + %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: llvm.mul {{.*}}: !llvm.vec + %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> + // CHECK: llvm.add {{.*}}: !llvm.vec + %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> + // CHECK: llvm.sub {{.*}}: !llvm.vec + %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32> + // CHECK: llvm.sdiv {{.*}}: !llvm.vec + %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32> + // CHECK: llvm.udiv {{.*}}: !llvm.vec + %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32> + return %3 : !arm_sve.vector<4xi32> +} + +func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>, + %d: !arm_sve.vector<4xf32>, + %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { + // CHECK: llvm.fmul {{.*}}: !llvm.vec + %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> + // CHECK: llvm.fadd {{.*}}: !llvm.vec + %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> + // CHECK: llvm.fsub {{.*}}: !llvm.vec + %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32> + // CHECK: llvm.fdiv {{.*}}: !llvm.vec + %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32> + return %3 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vscale %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index 8834ef8..44cc2fa 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -36,6 +36,26 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, return %0 : !arm_sve.vector<4xi32> } +func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32> + %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> + // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32> + %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> + return %1 : !arm_sve.vector<4xi32> +} + +func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { + // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32> + %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> + // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32> + %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> + return %1 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vector_scale : index %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index 46fc884..71d4b0a 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -48,6 +48,30 @@ llvm.func @arm_sve_ummla(%arg0: !llvm.vec, llvm.return %0 : !llvm.vec } +// CHECK-LABEL: define @arm_sve_arithi +llvm.func @arm_sve_arithi(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: mul + %0 = llvm.mul %arg0, %arg1 : !llvm.vec + // CHECK: add + %1 = llvm.add %0, %arg2 : !llvm.vec + llvm.return %1 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_arithf +llvm.func @arm_sve_arithf(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: fmul + %0 = llvm.fmul %arg0, %arg1 : !llvm.vec + // CHECK: fadd + %1 = llvm.fadd %0, %arg2 : !llvm.vec + llvm.return %1 : !llvm.vec +} + // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64() -- 2.7.4