Standardize all VectorOps class names to be prefixed by Vector - NFC
authorNicolas Vasilache <ntv@google.com>
Mon, 18 Nov 2019 18:38:35 +0000 (10:38 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 18 Nov 2019 18:39:07 +0000 (10:39 -0800)
This improves consistency and will concretely avoid collisions between VectorExtractElementOp and ExtractElementOp when they are included in the same transforms / rewrites.

PiperOrigin-RevId: 281101588

mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Conversion/VectorConversions/VectorToLLVM.cpp
mlir/lib/Dialect/VectorOps/VectorOps.cpp

index 125ecac..9399cc8 100644 (file)
@@ -49,7 +49,7 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-def ExtractElementOp :
+def VectorExtractElementOp :
   Vector_Op<"extractelement", [NoSideEffect,
      PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
@@ -66,14 +66,17 @@ def ExtractElementOp :
       %2 = vector.extractelement %0[3, 3, 3]: vector<4x8x16xf32>
     ```
   }];
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState &result, Value *source, ArrayRef<int32_t>">];
   let extraClassDeclaration = [{
+    static StringRef getPositionAttrName() { return "position"; }
     VectorType getVectorType() {
       return vector()->getType().cast<VectorType>();
     }
   }];
 }
 
-def OuterProductOp :
+def VectorOuterProductOp :
   Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
     Results<(outs AnyVector)> {
index 4c9c5fe..cd01666 100644 (file)
@@ -15,9 +15,9 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/Conversion/VectorConversions/VectorConversions.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/VectorConversions/VectorConversions.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/IR/Attributes.h"
@@ -49,19 +49,19 @@ static LLVM::LLVMType getPtrToElementType(T containerType,
       .getPointerTo();
 }
 
-class ExtractElementOpConversion : public LLVMOpLowering {
+class VectorExtractElementOpConversion : public LLVMOpLowering {
 public:
-  explicit ExtractElementOpConversion(MLIRContext *context,
-                                      LLVMTypeConverter &typeConverter)
-      : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
-                       typeConverter) {}
+  explicit VectorExtractElementOpConversion(MLIRContext *context,
+                                            LLVMTypeConverter &typeConverter)
+      : LLVMOpLowering(vector::VectorExtractElementOp::getOperationName(),
+                       context, typeConverter) {}
 
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op->getLoc();
-    auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
-    auto extractOp = cast<vector::ExtractElementOp>(op);
+    auto adaptor = vector::VectorExtractElementOpOperandAdaptor(operands);
+    auto extractOp = cast<vector::VectorExtractElementOp>(op);
     auto vectorType = extractOp.vector()->getType().cast<VectorType>();
     auto resultType = extractOp.getResult()->getType();
     auto llvmResultType = lowering.convertType(resultType);
@@ -103,25 +103,25 @@ public:
   }
 };
 
-class OuterProductOpConversion : public LLVMOpLowering {
+class VectorOuterProductOpConversion : public LLVMOpLowering {
 public:
-  explicit OuterProductOpConversion(MLIRContext *context,
-                                    LLVMTypeConverter &typeConverter)
-      : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
-                       typeConverter) {}
+  explicit VectorOuterProductOpConversion(MLIRContext *context,
+                                          LLVMTypeConverter &typeConverter)
+      : LLVMOpLowering(vector::VectorOuterProductOp::getOperationName(),
+                       context, typeConverter) {}
 
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op->getLoc();
-    auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
+    auto adaptor = vector::VectorOuterProductOpOperandAdaptor(operands);
     auto *ctx = op->getContext();
     auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
     auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
     auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
     auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
     auto llvmArrayOfVectType = lowering.convertType(
-        cast<vector::OuterProductOp>(op).getResult()->getType());
+        cast<vector::VectorOuterProductOp>(op).getResult()->getType());
     Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
     Value *a = adaptor.lhs(), *b = adaptor.rhs();
     Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
@@ -246,8 +246,8 @@ public:
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  patterns.insert<ExtractElementOpConversion, OuterProductOpConversion,
-                  VectorTypeCastOpConversion>(
+  patterns.insert<VectorExtractElementOpConversion,
+                  VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
       converter.getDialect()->getContext(), converter);
 }
 
index 215e92d..c1244e2 100644 (file)
@@ -44,17 +44,34 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
 }
 
 //===----------------------------------------------------------------------===//
-// ExtractElementOp
+// VectorExtractElementOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, ExtractElementOp op) {
+static Type inferExtractOpResultType(VectorType vectorType,
+                                     ArrayAttr position) {
+  if (static_cast<int64_t>(position.size()) == vectorType.getRank())
+    return vectorType.getElementType();
+  return VectorType::get(vectorType.getShape().drop_front(position.size()),
+                         vectorType.getElementType());
+}
+
+void VectorExtractElementOp::build(Builder *builder, OperationState &result,
+                                   Value *source, ArrayRef<int32_t> position) {
+  result.addOperands(source);
+  auto positionAttr = builder->getI32ArrayAttr(position);
+  result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
+                                           positionAttr));
+  result.addAttribute(getPositionAttrName(), positionAttr);
+}
+
+static void print(OpAsmPrinter &p, VectorExtractElementOp op) {
   p << op.getOperationName() << " " << *op.vector() << op.position();
   p.printOptionalAttrDict(op.getAttrs(), {"position"});
   p << " : " << op.vector()->getType();
 }
 
-static ParseResult parseExtractElementOp(OpAsmParser &parser,
-                                         OperationState &result) {
+static ParseResult parseVectorExtractElementOp(OpAsmParser &parser,
+                                               OperationState &result) {
   llvm::SMLoc attributeLoc, typeLoc;
   SmallVector<NamedAttribute, 4> attrs;
   OpAsmParser::OperandType vector;
@@ -77,19 +94,13 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
         attributeLoc,
         "expected position attribute of rank smaller than vector");
 
-  Type resType =
-      (static_cast<int64_t>(positionAttr.size()) == vectorType.getRank())
-          ? vectorType.getElementType()
-          : VectorType::get(
-                vectorType.getShape().drop_front(positionAttr.size()),
-                vectorType.getElementType());
-
+  Type resType = inferExtractOpResultType(vectorType, positionAttr);
   result.attributes = attrs;
   return failure(parser.resolveOperand(vector, type, result.operands) ||
                  parser.addTypeToList(resType, result.types));
 }
 
-static LogicalResult verify(ExtractElementOp op) {
+static LogicalResult verify(VectorExtractElementOp op) {
   auto positionAttr = op.position().getValue();
   if (positionAttr.empty())
     return op.emitOpError("expected non-empty position attribute");
@@ -107,19 +118,20 @@ static LogicalResult verify(ExtractElementOp op) {
   }
   return success();
 }
+
 //===----------------------------------------------------------------------===//
-// OuterProductOp
+// VectorOuterProductOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, OuterProductOp op) {
+static void print(OpAsmPrinter &p, VectorOuterProductOp op) {
   p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
   if (llvm::size(op.acc()) > 0)
     p << ", " << **op.acc().begin();
   p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
 }
 
-static ParseResult parseOuterProductOp(OpAsmParser &parser,
-                                       OperationState &result) {
+static ParseResult parseVectorOuterProductOp(OpAsmParser &parser,
+                                             OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
   Type tLHS, tRHS;
   if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
@@ -142,7 +154,7 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
       parser.addTypeToList(resType, result.types));
 }
 
-static LogicalResult verify(OuterProductOp op) {
+static LogicalResult verify(VectorOuterProductOp op) {
   VectorType vLHS = op.getOperandVectorTypeLHS(),
              vRHS = op.getOperandVectorTypeRHS(),
              vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType();