From 92dc127ab347cbc4cb4b93db04109eea2a3c13e3 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 9 Aug 2019 05:24:47 -0700 Subject: [PATCH] Add support for vector ops in the LLVM dialect This CL is step 1/n towards building a simple, programmable and portable vector abstraction in MLIR that can go all the way down to generating assembly vector code via LLVM's opt and llc tools. This CL adds the 3 instructions `llvm.extractelement`, `llvm.insertelement` and `llvm.shufflevector` as documented in the LLVM LangRef "Vector Instructions" section. The "Experimental Vector Reduction Intrinsics" are left out for now and can be added in the future on a per-need basis. Appropriate roundtrip and LLVM Target tests are added. PiperOrigin-RevId: 262542095 --- mlir/include/mlir/LLVMIR/LLVMDialect.h | 3 + mlir/include/mlir/LLVMIR/LLVMOps.td | 42 +++++++++ mlir/lib/LLVMIR/IR/LLVMDialect.cpp | 151 +++++++++++++++++++++++++++++++++ mlir/test/LLVMIR/invalid.mlir | 24 ++++++ mlir/test/LLVMIR/roundtrip.mlir | 13 ++- mlir/test/Target/llvmir.mlir | 11 +++ 6 files changed, 243 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index 55479f2..ec7ea8f 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -67,6 +67,9 @@ public: /// Array type utilities. LLVMType getArrayElementType(); + /// Vector type utilities. + LLVMType getVectorElementType(); + /// Function type utilities. LLVMType getFunctionParamType(unsigned argIdx); unsigned getFunctionNumParams(); diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index 32b17d0..b626836 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -282,6 +282,18 @@ def LLVM_CallOp : LLVM_Op<"call">, let parser = [{ return parseCallOp(parser, result); }]; let printer = [{ printCallOp(p, *this); }]; } +def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$vector, + LLVM_Type:$position)> { + string llvmBuilder = [{ + $res = builder.CreateExtractElement($vector, $position); + }]; + let builders = [OpBuilder< + "Builder *b, OperationState *result, Value *vector, Value *position," + "ArrayRef attrs = {}">]; + let parser = [{ return parseExtractElementOp(parser, result); }]; + let printer = [{ printExtractElementOp(p, *this); }]; +} def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, Arguments<(ins LLVM_Type:$container, ArrayAttr:$position)> { @@ -291,6 +303,15 @@ def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, let parser = [{ return parseExtractValueOp(parser, result); }]; let printer = [{ printExtractValueOp(p, *this); }]; } +def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$vector, LLVM_Type:$value, + LLVM_Type:$position)> { + string llvmBuilder = [{ + $res = builder.CreateInsertElement($vector, $value, $position); + }]; + let parser = [{ return parseInsertElementOp(parser, result); }]; + let printer = [{ printInsertElementOp(p, *this); }]; +} def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, Arguments<(ins LLVM_Type:$container, LLVM_Type:$value, ArrayAttr:$position)> { @@ -301,6 +322,27 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, let parser = [{ return parseInsertValueOp(parser, result); }]; let printer = [{ printInsertValueOp(p, *this); }]; } +def LLVM_ShuffleVectorOp + : LLVM_OneResultOp<"shufflevector", [NoSideEffect]>, + Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, I32ArrayAttr:$mask)>, + LLVM_Builder< + "$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> { + let builders = [OpBuilder< + "Builder *b, OperationState *result, Value *v1, Value *v2, " + "ArrayAttr mask, ArrayRef attrs = {}">]; + let verifier = [{ + auto wrappedVectorType1 = v1()->getType().cast(); + auto wrappedVectorType2 = v2()->getType().cast(); + if (!wrappedVectorType2.getUnderlyingType()->isVectorTy()) + return emitOpError("expected LLVM IR Dialect vector type for operand #2"); + if (wrappedVectorType1.getVectorElementType() != + wrappedVectorType2.getVectorElementType()) + return emitOpError("expected matching LLVM IR Dialect element types"); + return success(); + }]; + let parser = [{ return parseShuffleVectorOp(parser, result); }]; + let printer = [{ printShuffleVectorOp(p, *this); }]; +} // Misc operations. def LLVM_SelectOp diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 30c9eb5..c50a14f 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -410,6 +410,52 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { } //===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::ExtractElementOp. +//===----------------------------------------------------------------------===// +// Expects vector to be of wrapped LLVM vector type and position to be of +// wrapped LLVM i32 type. +void LLVM::ExtractElementOp::build(Builder *b, OperationState *result, + Value *vector, Value *position, + ArrayRef attrs) { + auto wrappedVectorType = vector->getType().cast(); + auto llvmType = wrappedVectorType.getVectorElementType(); + build(b, result, llvmType, vector, position); + result->addAttributes(attrs); +} + +static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp &op) { + *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.vector()->getType(); +} + +// ::= `llvm.extractelement` ssa-use `, ` ssa-use +// attribute-dict? `:` type +static ParseResult parseExtractElementOp(OpAsmParser *parser, + OperationState *result) { + llvm::SMLoc loc; + OpAsmParser::OperandType vector, position; + auto *llvmDialect = parser->getBuilder() + .getContext() + ->getRegisteredDialect(); + Type type, i32Type = LLVMType::getInt32Ty(llvmDialect); + if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || + parser->parseComma() || parser->parseOperand(position) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(vector, type, result->operands) || + parser->resolveOperand(position, i32Type, result->operands)) + return failure(); + auto wrappedVectorType = type.dyn_cast(); + if (!wrappedVectorType || + !wrappedVectorType.getUnderlyingType()->isVectorTy()) + return parser->emitError( + loc, "expected LLVM IR dialect vector type for operand #1"); + result->addTypes(wrappedVectorType.getVectorElementType()); + return success(); +} + +//===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ExtractValueOp. //===----------------------------------------------------------------------===// @@ -502,6 +548,52 @@ static ParseResult parseExtractValueOp(OpAsmParser *parser, } //===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::InsertElementOp. +//===----------------------------------------------------------------------===// + +static void printInsertElementOp(OpAsmPrinter *p, InsertElementOp &op) { + *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value() + << ", " << *op.position(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.vector()->getType(); +} + +// ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use +// attribute-dict? `:` type +static ParseResult parseInsertElementOp(OpAsmParser *parser, + OperationState *result) { + llvm::SMLoc loc; + OpAsmParser::OperandType vector, value, position; + auto *llvmDialect = parser->getBuilder() + .getContext() + ->getRegisteredDialect(); + Type vectorType, i32Type = LLVMType::getInt32Ty(llvmDialect); + if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || + parser->parseComma() || parser->parseOperand(value) || + parser->parseComma() || parser->parseOperand(position) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(vectorType)) + return failure(); + + auto wrappedVectorType = vectorType.dyn_cast(); + if (!wrappedVectorType || + !wrappedVectorType.getUnderlyingType()->isVectorTy()) + return parser->emitError( + loc, "expected LLVM IR dialect vector type for operand #1"); + auto valueType = wrappedVectorType.getVectorElementType(); + if (!valueType) + return failure(); + + if (parser->resolveOperand(vector, vectorType, result->operands) || + parser->resolveOperand(value, valueType, result->operands) || + parser->resolveOperand(position, i32Type, result->operands)) + return failure(); + + result->addTypes(vectorType); + return success(); +} + +//===----------------------------------------------------------------------===// // Printing/parsing for LLVM::InsertValueOp. //===----------------------------------------------------------------------===// @@ -784,6 +876,60 @@ static LogicalResult verify(GlobalOp op) { } //===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::ShuffleVectorOp. +//===----------------------------------------------------------------------===// +// Expects vector to be of wrapped LLVM vector type and position to be of +// wrapped LLVM i32 type. +void LLVM::ShuffleVectorOp::build(Builder *b, OperationState *result, Value *v1, + Value *v2, ArrayAttr mask, + ArrayRef attrs) { + auto wrappedContainerType1 = v1->getType().cast(); + auto vType = LLVMType::getVectorTy( + wrappedContainerType1.getVectorElementType(), mask.size()); + build(b, result, vType, v1, v2, mask); + result->addAttributes(attrs); +} + +static void printShuffleVectorOp(OpAsmPrinter *p, ShuffleVectorOp &op) { + *p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " " + << op.mask(); + p->printOptionalAttrDict(op.getAttrs(), {"mask"}); + *p << " : " << op.v1()->getType() << ", " << op.v2()->getType(); +} + +// ::= `llvm.shufflevector` ssa-use `, ` ssa-use +// `[` integer-literal (`,` integer-literal)* `]` +// attribute-dict? `:` type +static ParseResult parseShuffleVectorOp(OpAsmParser *parser, + OperationState *result) { + llvm::SMLoc loc; + SmallVector attrs; + OpAsmParser::OperandType v1, v2; + Attribute maskAttr; + Type typeV1, typeV2; + if (parser->getCurrentLocation(&loc) || parser->parseOperand(v1) || + parser->parseComma() || parser->parseOperand(v2) || + parser->parseAttribute(maskAttr, "mask", attrs) || + parser->parseOptionalAttributeDict(attrs) || + parser->parseColonType(typeV1) || parser->parseComma() || + parser->parseType(typeV2) || + parser->resolveOperand(v1, typeV1, result->operands) || + parser->resolveOperand(v2, typeV2, result->operands)) + return failure(); + auto wrappedContainerType1 = typeV1.dyn_cast(); + if (!wrappedContainerType1 || + !wrappedContainerType1.getUnderlyingType()->isVectorTy()) + return parser->emitError( + loc, "expected LLVM IR dialect vector type for operand #1"); + auto vType = + LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(), + maskAttr.cast().size()); + result->attributes = attrs; + result->addTypes(vType); + return success(); +} + +//===----------------------------------------------------------------------===// // Builder, printer and verifier for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// @@ -1055,6 +1201,11 @@ LLVMType LLVMType::getArrayElementType() { return get(getContext(), getUnderlyingType()->getArrayElementType()); } +/// Vector type utilities. +LLVMType LLVMType::getVectorElementType() { + return get(getContext(), getUnderlyingType()->getVectorElementType()); +} + /// Function type utilities. LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); diff --git a/mlir/test/LLVMIR/invalid.mlir b/mlir/test/LLVMIR/invalid.mlir index 801c3e2..111d90a 100644 --- a/mlir/test/LLVMIR/invalid.mlir +++ b/mlir/test/LLVMIR/invalid.mlir @@ -224,3 +224,27 @@ func @extractvalue_wrong_nesting() { // expected-error@+1 {{expected wrapped LLVM IR structure/array type}} llvm.extractvalue %b[0,0] : !llvm<"{i32}"> } + +// ----- + +// CHECK-LABEL: @invalid_vector_type_1 +func @invalid_vector_type_1(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) { + // expected-error@+1 {{expected LLVM IR dialect vector type for operand #1}} + %0 = llvm.extractelement %arg2, %arg1 : !llvm.float +} + +// ----- + +// CHECK-LABEL: @invalid_vector_type_2 +func @invalid_vector_type_2(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) { + // expected-error@+1 {{expected LLVM IR dialect vector type for operand #1}} + %0 = llvm.insertelement %arg2, %arg2, %arg1 : !llvm.float +} + +// ----- + +// CHECK-LABEL: @invalid_vector_type_3 +func @invalid_vector_type_3(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) { + // expected-error@+1 {{expected LLVM IR dialect vector type for operand #1}} + %0 = llvm.shufflevector %arg2, %arg2 [0 : i32, 0 : i32, 0 : i32, 0 : i32, 7 : i32] : !llvm.float, !llvm.float +} diff --git a/mlir/test/LLVMIR/roundtrip.mlir b/mlir/test/LLVMIR/roundtrip.mlir index be89407..4348b7b 100644 --- a/mlir/test/LLVMIR/roundtrip.mlir +++ b/mlir/test/LLVMIR/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt %s | mlir-opt | FileCheck %s // CHECK-LABEL: func @ops(%arg0: !llvm.i32, %arg1: !llvm.float) func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) { @@ -167,3 +167,14 @@ func @casts(%arg0: !llvm.i32, %arg1: !llvm.i64, %arg2: !llvm<"<4 x i32>">, %5 = llvm.trunc %arg3 : !llvm<"<4 x i64>"> to !llvm<"<4 x i56>"> llvm.return } + +// CHECK-LABEL: @vect +func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) { +// CHECK-NEXT: = llvm.extractelement {{.*}} : !llvm<"<4 x float>"> + %0 = llvm.extractelement %arg0, %arg1 : !llvm<"<4 x float>"> +// CHECK-NEXT: = llvm.insertelement {{.*}} : !llvm<"<4 x float>"> + %1 = llvm.insertelement %arg0, %arg2, %arg1 : !llvm<"<4 x float>"> +// CHECK-NEXT: = llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 7 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> + %2 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 0 : i32, 7 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> + return +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir index 9b6e0a8..ab714cd 100644 --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -862,3 +862,14 @@ func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) { %13 = llvm.fcmp "uno" %arg0, %arg1 : !llvm.float llvm.return } + +// CHECK-LABEL: @vect +func @vect(%arg0: !llvm<"<4 x float>">, %arg1: !llvm.i32, %arg2: !llvm.float) { + // CHECK-NEXT: extractelement <4 x float> {{.*}}, i32 {{.*}} + // CHECK-NEXT: insertelement <4 x float> {{.*}}, float %2, i32 {{.*}} + // CHECK-NEXT: shufflevector <4 x float> {{.*}}, <4 x float> {{.*}}, <5 x i32> + %0 = llvm.extractelement %arg0, %arg1 : !llvm<"<4 x float>"> + %1 = llvm.insertelement %arg0, %arg2, %arg1 : !llvm<"<4 x float>"> + %2 = llvm.shufflevector %arg0, %arg0 [0 : i32, 0 : i32, 0 : i32, 0 : i32, 7 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> + llvm.return +} -- 2.7.4