Use MemRefDescriptor in Vector-to-LLVM convresion
authorAlex Zinenko <zinenko@google.com>
Thu, 14 Nov 2019 17:05:11 +0000 (09:05 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 14 Nov 2019 17:05:42 +0000 (09:05 -0800)
Following up on the consolidation of MemRef descriptor conversion, update
Vector-to-LLVM conversion to use the helper class that abstracts away the
implementation details of the MemRef descriptor. This also makes the types of
the attributes in emitted llvm.insert/extractelement operations consistently
i64 instead of a mix of index and i64.

PiperOrigin-RevId: 280441451

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

index f40d2cf..f0bf3a4 100644 (file)
@@ -172,6 +172,9 @@ public:
   /// Builds IR inserting the pos-th stride into the descriptor
   void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
 
+  /// Returns the (LLVM) type this descriptor points to.
+  LLVM::LLVMType getElementType();
+
   /*implicit*/ operator Value *() { return value; }
 
 private:
index 570b6c4..e0edb0b 100644 (file)
@@ -344,6 +344,11 @@ void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
                                               builder.getI64ArrayAttr(pos));
 }
 
+LLVM::LLVMType MemRefDescriptor::getElementType() {
+  return value->getType().cast<LLVM::LLVMType>().getStructElementType(
+      LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+}
+
 namespace {
 // Base class for Standard to LLVM IR op conversions.  Matches the Op type
 // provided as template argument.  Carries a reference to the LLVM dialect in
index 21bcdc9..5bda8b3 100644 (file)
@@ -177,22 +177,17 @@ public:
         !targetMemRefType.hasStaticShape())
       return matchFailure();
 
-    Value *sourceMemRef = operands[0];
     auto llvmSourceDescriptorTy =
-        sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
+        operands[0]->getType().dyn_cast<LLVM::LLVMType>();
     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
       return matchFailure();
+    MemRefDescriptor sourceMemRef(operands[0]);
 
     auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
                                       .dyn_cast_or_null<LLVM::LLVMType>();
     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
       return matchFailure();
 
-    Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
-        LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-    Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
-        LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-
     int64_t offset;
     SmallVector<int64_t, 4> strides;
     auto successStrides =
@@ -214,55 +209,36 @@ public:
     auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
 
     // Create descriptor.
-    Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
+    auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
+    Type llvmTargetElementTy = desc.getElementType();
     // Set allocated ptr.
-    Value *allocated = rewriter.create<LLVM::ExtractValueOp>(
-        loc, llvmSourceElementTy, sourceMemRef,
-        rewriter.getIndexArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+    Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
     allocated =
         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        op->getLoc(), llvmTargetDescriptorTy, desc, allocated,
-        rewriter.getIndexArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
-    // Set ptr.
-    Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
-        loc, llvmSourceElementTy, sourceMemRef,
-        rewriter.getIndexArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    desc.setAllocatedPtr(rewriter, loc, allocated);
+    // Set aligned ptr.
+    Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
-        rewriter.getIndexArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    desc.setAlignedPtr(rewriter, loc, ptr);
     // Fill offset 0.
     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        op->getLoc(), llvmTargetDescriptorTy, desc, zero,
-        rewriter.getIndexArrayAttr(
-            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    desc.setOffset(rewriter, loc, zero);
+
     // Fill size and stride descriptors in memref.
     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
       int64_t index = indexedSize.index();
       auto sizeAttr =
           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
-      desc = rewriter.create<LLVM::InsertValueOp>(
-          op->getLoc(), llvmTargetDescriptorTy, desc, size,
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+      desc.setSize(rewriter, loc, index, size);
       auto strideAttr =
           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
-      desc = rewriter.create<LLVM::InsertValueOp>(
-          op->getLoc(), llvmTargetDescriptorTy, desc, stride,
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+      desc.setStride(rewriter, loc, index, stride);
     }
 
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(op, {desc});
     return matchSuccess();
   }
 };
index ff07f52..6c5e807 100644 (file)
@@ -54,12 +54,11 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
 }
 // CHECK-LABEL: vector_type_cast
 //       CHECK:   llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-//       CHECK:   %[[allocated:.*]] = llvm.extractvalue {{.*}}[0 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
 //       CHECK:   %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
-//       CHECK:   llvm.insertvalue %[[allocatedBit]], {{.*}}[0 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-//       CHECK:   %[[aligned:.*]] = llvm.extractvalue {{.*}}[1 : index] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.insertvalue %[[allocatedBit]], {{.*}}[0] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
+//       CHECK:   %[[aligned:.*]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
 //       CHECK:   %[[alignedBit:.*]] = llvm.bitcast %[[aligned]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
-//       CHECK:   llvm.insertvalue %[[alignedBit]], {{.*}}[1 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
+//       CHECK:   llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
 //       CHECK:   llvm.mlir.constant(0 : index
-//       CHECK:   llvm.insertvalue {{.*}}[2 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
-
+//       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">