From: Nicolas Vasilache Date: Mon, 18 Nov 2019 18:38:35 +0000 (-0800) Subject: Standardize all VectorOps class names to be prefixed by Vector - NFC X-Git-Tag: llvmorg-11-init~1466^2~311 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9732bb533cceeebed80c19be799a79853018fa09;p=platform%2Fupstream%2Fllvm.git Standardize all VectorOps class names to be prefixed by Vector - NFC This improves consistency and will concretely avoid collisions between VectorExtractElementOp and ExtractElementOp when they are included in the same transforms / rewrites. PiperOrigin-RevId: 281101588 --- diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 125ecac5..9399cc8 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -49,7 +49,7 @@ class Vector_Op traits = []> : let parser = [{ return ::parse$cppClass(parser, result); }]; } -def ExtractElementOp : +def VectorExtractElementOp : Vector_Op<"extractelement", [NoSideEffect, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, @@ -66,14 +66,17 @@ def ExtractElementOp : %2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32> ``` }]; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value *source, ArrayRef">]; let extraClassDeclaration = [{ + static StringRef getPositionAttrName() { return "position"; } VectorType getVectorType() { return vector()->getType().cast(); } }]; } -def OuterProductOp : +def VectorOuterProductOp : Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic:$acc)>, Results<(outs AnyVector)> { diff --git a/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp index 4c9c5fe..cd01666 100644 --- a/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp @@ -15,9 +15,9 @@ // limitations under the License. // ============================================================================= -#include "mlir/Conversion/VectorConversions/VectorConversions.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorConversions/VectorConversions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Attributes.h" @@ -49,19 +49,19 @@ static LLVM::LLVMType getPtrToElementType(T containerType, .getPointerTo(); } -class ExtractElementOpConversion : public LLVMOpLowering { +class VectorExtractElementOpConversion : public LLVMOpLowering { public: - explicit ExtractElementOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, - typeConverter) {} + explicit VectorExtractElementOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::VectorExtractElementOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); - auto extractOp = cast(op); + auto adaptor = vector::VectorExtractElementOpOperandAdaptor(operands); + auto extractOp = cast(op); auto vectorType = extractOp.vector()->getType().cast(); auto resultType = extractOp.getResult()->getType(); auto llvmResultType = lowering.convertType(resultType); @@ -103,25 +103,25 @@ public: } }; -class OuterProductOpConversion : public LLVMOpLowering { +class VectorOuterProductOpConversion : public LLVMOpLowering { public: - explicit OuterProductOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, - typeConverter) {} + explicit VectorOuterProductOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::VectorOuterProductOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto adaptor = vector::OuterProductOpOperandAdaptor(operands); + auto adaptor = vector::VectorOuterProductOpOperandAdaptor(operands); auto *ctx = op->getContext(); auto vLHS = adaptor.lhs()->getType().cast(); auto vRHS = adaptor.rhs()->getType().cast(); auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); auto llvmArrayOfVectType = lowering.convertType( - cast(op).getResult()->getType()); + cast(op).getResult()->getType()); Value *desc = rewriter.create(loc, llvmArrayOfVectType); Value *a = adaptor.lhs(), *b = adaptor.rhs(); Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); @@ -246,8 +246,8 @@ public: /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert( + patterns.insert( converter.getDialect()->getContext(), converter); } diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 215e92d..c1244e2 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -44,17 +44,34 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context) } //===----------------------------------------------------------------------===// -// ExtractElementOp +// VectorExtractElementOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, ExtractElementOp op) { +static Type inferExtractOpResultType(VectorType vectorType, + ArrayAttr position) { + if (static_cast(position.size()) == vectorType.getRank()) + return vectorType.getElementType(); + return VectorType::get(vectorType.getShape().drop_front(position.size()), + vectorType.getElementType()); +} + +void VectorExtractElementOp::build(Builder *builder, OperationState &result, + Value *source, ArrayRef position) { + result.addOperands(source); + auto positionAttr = builder->getI32ArrayAttr(position); + result.addTypes(inferExtractOpResultType(source->getType().cast(), + positionAttr)); + result.addAttribute(getPositionAttrName(), positionAttr); +} + +static void print(OpAsmPrinter &p, VectorExtractElementOp op) { p << op.getOperationName() << " " << *op.vector() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); p << " : " << op.vector()->getType(); } -static ParseResult parseExtractElementOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseVectorExtractElementOp(OpAsmParser &parser, + OperationState &result) { llvm::SMLoc attributeLoc, typeLoc; SmallVector attrs; OpAsmParser::OperandType vector; @@ -77,19 +94,13 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, attributeLoc, "expected position attribute of rank smaller than vector"); - Type resType = - (static_cast(positionAttr.size()) == vectorType.getRank()) - ? vectorType.getElementType() - : VectorType::get( - vectorType.getShape().drop_front(positionAttr.size()), - vectorType.getElementType()); - + Type resType = inferExtractOpResultType(vectorType, positionAttr); result.attributes = attrs; return failure(parser.resolveOperand(vector, type, result.operands) || parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(ExtractElementOp op) { +static LogicalResult verify(VectorExtractElementOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) return op.emitOpError("expected non-empty position attribute"); @@ -107,19 +118,20 @@ static LogicalResult verify(ExtractElementOp op) { } return success(); } + //===----------------------------------------------------------------------===// -// OuterProductOp +// VectorOuterProductOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, OuterProductOp op) { +static void print(OpAsmPrinter &p, VectorOuterProductOp op) { p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); if (llvm::size(op.acc()) > 0) p << ", " << **op.acc().begin(); p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType(); } -static ParseResult parseOuterProductOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseVectorOuterProductOp(OpAsmParser &parser, + OperationState &result) { SmallVector operandsInfo; Type tLHS, tRHS; if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) || @@ -142,7 +154,7 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser, parser.addTypeToList(resType, result.types)); } -static LogicalResult verify(OuterProductOp op) { +static LogicalResult verify(VectorOuterProductOp op) { VectorType vLHS = op.getOperandVectorTypeLHS(), vRHS = op.getOperandVectorTypeRHS(), vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();