Concentrate memref descriptor manipulation logic in one place
authorAlex Zinenko <zinenko@google.com>
Thu, 14 Nov 2019 08:48:41 +0000 (00:48 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 14 Nov 2019 08:49:12 +0000 (00:49 -0800)
Memref descriptor is becoming increasingly complex. Memrefs are manipulated by
multiple standard instructions, each of which has a non-trivial lowering to the
LLVM dialect. This leads to verbose code that manipulates the descriptors
exposing the internals of insert/extractelement opreations. Implement a wrapper
class that contains a memref descriptor and provides semantically named methods
that build the primitive IR operations instead.

PiperOrigin-RevId: 280371225

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

index 791a237..0641a6b 100644 (file)
@@ -235,6 +235,125 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
     : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
 
 namespace {
+/// Helper class to produce LLVM dialect operations extracting or inserting
+/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
+/// The Value may be null, in which case none of the operations are valid.
+class MemRefDescriptor {
+public:
+  /// Construct a helper for the given descriptor value.
+  explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
+    if (value) {
+      structType = value->getType().cast<LLVM::LLVMType>();
+      indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
+          LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
+    }
+  }
+
+  /// Builds IR creating an `undef` value of the descriptor type.
+  static MemRefDescriptor undef(OpBuilder &builder, Location loc,
+                                Type descriptorType) {
+    Value *descriptor = builder.create<LLVM::UndefOp>(
+        loc, descriptorType.cast<LLVM::LLVMType>());
+    return MemRefDescriptor(descriptor);
+  }
+
+  /// Builds IR extracting the allocated pointer from the descriptor.
+  Value *allocatedPtr(OpBuilder &builder, Location loc) {
+    return extractPtr(builder, loc,
+                      LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
+  }
+
+  /// Builds IR inserting the allocated pointer into the descriptor.
+  void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
+    setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
+           ptr);
+  }
+
+  /// Builds IR extracting the aligned pointer from the descriptor.
+  Value *alignedPtr(OpBuilder &builder, Location loc) {
+    return extractPtr(builder, loc,
+                      LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+  }
+
+  /// Builds IR inserting the aligned pointer into the descriptor.
+  void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
+    setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
+           ptr);
+  }
+
+  /// Builds IR extracting the offset from the descriptor.
+  Value *offset(OpBuilder &builder, Location loc) {
+    return builder.create<LLVM::ExtractValueOp>(
+        loc, indexType, value,
+        builder.getI64ArrayAttr(
+            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+  }
+
+  /// Builds IR inserting the offset into the descriptor.
+  void setOffset(OpBuilder &builder, Location loc, Value *offset) {
+    value = builder.create<LLVM::InsertValueOp>(
+        loc, structType, value, offset,
+        builder.getI64ArrayAttr(
+            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+  }
+
+  /// Builds IR extracting the pos-th size from the descriptor.
+  Value *size(OpBuilder &builder, Location loc, unsigned pos) {
+    return builder.create<LLVM::ExtractValueOp>(
+        loc, indexType, value,
+        builder.getI64ArrayAttr(
+            {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+  }
+
+  /// Builds IR inserting the pos-th size into the descriptor
+  void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
+    value = builder.create<LLVM::InsertValueOp>(
+        loc, structType, value, size,
+        builder.getI64ArrayAttr(
+            {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+  }
+
+  /// Builds IR extracting the pos-th size from the descriptor.
+  Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
+    return builder.create<LLVM::ExtractValueOp>(
+        loc, indexType, value,
+        builder.getI64ArrayAttr(
+            {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+  }
+
+  /// Builds IR inserting the pos-th stride into the descriptor
+  void setStride(OpBuilder &builder, Location loc, unsigned pos,
+                 Value *stride) {
+    value = builder.create<LLVM::InsertValueOp>(
+        loc, structType, value, stride,
+        builder.getI64ArrayAttr(
+            {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+  }
+
+  /*implicit*/ operator Value *() { return value; }
+
+private:
+  Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
+    Type type = structType.getStructElementType(pos);
+    return builder.create<LLVM::ExtractValueOp>(loc, type, value,
+                                                builder.getI64ArrayAttr(pos));
+  }
+
+  void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
+    value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
+                                                builder.getI64ArrayAttr(pos));
+  }
+
+  // Cached descriptor type.
+  LLVM::LLVMType structType;
+
+  // Cached index type.
+  LLVM::LLVMType indexType;
+
+  // Actual descriptor.
+  Value *value;
+};
+
 // Base class for Standard to LLVM IR op conversions.  Matches the Op type
 // provided as template argument.  Carries a reference to the LLVM dialect in
 // case it is necessary for rewriters.
@@ -278,29 +397,6 @@ public:
     return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
   }
 
-  // Extract allocated data pointer value from a value representing a memref.
-  static Value *
-  extractAllocatedMemRefElementPtr(ConversionPatternRewriter &builder,
-                                   Location loc, Value *memref,
-                                   Type elementTypePtr) {
-    return builder.create<LLVM::ExtractValueOp>(
-        loc, elementTypePtr, memref,
-        builder.getI64ArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
-  }
-
-  // Extract properly aligned data pointer value from a value representing a
-  // memref.
-  static Value *
-  extractAlignedMemRefElementPtr(ConversionPatternRewriter &builder,
-                                 Location loc, Value *memref,
-                                 Type elementTypePtr) {
-    return builder.create<LLVM::ExtractValueOp>(
-        loc, elementTypePtr, memref,
-        builder.getI64ArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
-  }
-
 protected:
   LLVM::LLVMDialect &dialect;
 };
@@ -786,14 +882,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
 
     // Create the MemRef descriptor.
     auto structType = lowering.convertType(type);
-    Value *memRefDescriptor =
-        rewriter.create<LLVM::UndefOp>(loc, structType, ArrayRef<Value *>{});
-
+    auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
     // Field 1: Allocated pointer, used for malloc/free.
-    memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
-        loc, structType, memRefDescriptor, bitcastAllocated,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+    memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
+
     // Field 2: Actual aligned pointer to payload.
     Value *bitcastAligned = bitcastAllocated;
     if (align) {
@@ -808,20 +900,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
       bitcastAligned = rewriter.create<LLVM::BitcastOp>(
           loc, elementPtrType, ArrayRef<Value *>(aligned));
     }
-    memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
-        loc, structType, memRefDescriptor, bitcastAligned,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
+
     // Field 3: Offset in aligned pointer.
-    memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
-        loc, structType, memRefDescriptor,
-        createIndexConstant(rewriter, loc, offset),
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    memRefDescriptor.setOffset(rewriter, loc,
+                               createIndexConstant(rewriter, loc, offset));
 
     if (type.getRank() == 0)
       // No size/stride descriptor in memref, return the descriptor value.
-      return rewriter.replaceOp(op, memRefDescriptor);
+      return rewriter.replaceOp(op, {memRefDescriptor});
 
     // Fields 4 and 5: Sizes and strides of the strided MemRef.
     // Store all sizes in the descriptor. Only dynamic sizes are passed in as
@@ -846,18 +933,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
     // Fill size and stride descriptors in memref.
     for (auto indexedSize : llvm::enumerate(sizes)) {
       int64_t index = indexedSize.index();
-      memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
-          loc, structType, memRefDescriptor, indexedSize.value(),
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
-      memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
-          loc, structType, memRefDescriptor, strideValues[index],
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
+      memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
+      memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
     }
 
     // Return the final value of the descriptor.
-    rewriter.replaceOp(op, memRefDescriptor);
+    rewriter.replaceOp(op, {memRefDescriptor});
   }
 };
 
@@ -947,13 +1028,10 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
                                         /*isVarArg=*/false));
     }
 
-    auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
-    Type elementPtrType = type.getStructElementType(
-        LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
-    Value *bufferPtr = extractAllocatedMemRefElementPtr(
-        rewriter, op->getLoc(), transformed.memref(), elementPtrType);
+    MemRefDescriptor memref(transformed.memref());
     Value *casted = rewriter.create<LLVM::BitcastOp>(
-        op->getLoc(), getVoidPtrType(), bufferPtr);
+        op->getLoc(), getVoidPtrType(),
+        memref.allocatedPtr(rewriter, op->getLoc()));
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
     return matchSuccess();
@@ -1003,10 +1081,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
     int64_t index = dimOp.getIndex();
     // Extract dynamic size from the memref descriptor.
     if (ShapedType::isDynamic(shape[index]))
-      rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
-          op, getIndexType(), transformed.memrefOrTensor(),
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
+      rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
+                                  .size(rewriter, op->getLoc(), index)});
     else
       // Use constant for static size.
       rewriter.replaceOp(
@@ -1058,34 +1134,21 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   // This is a strided getElementPtr variant that linearizes subscripts as:
   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
   Value *getStridedElementPtr(Location loc, Type elementTypePtr,
-                              Value *memRefDescriptor,
-                              ArrayRef<Value *> indices,
+                              Value *descriptor, ArrayRef<Value *> indices,
                               ArrayRef<int64_t> strides, int64_t offset,
                               ConversionPatternRewriter &rewriter) const {
-    auto indexTy = this->getIndexType();
-    Value *base = this->extractAlignedMemRefElementPtr(
-        rewriter, loc, memRefDescriptor, elementTypePtr);
-    Value *offsetValue =
-        offset == MemRefType::getDynamicStrideOrOffset()
-            ? rewriter.create<LLVM::ExtractValueOp>(
-                  loc, indexTy, memRefDescriptor,
-                  rewriter.getI64ArrayAttr(
-                      LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
-            : this->createIndexConstant(rewriter, loc, offset);
+    MemRefDescriptor memRefDescriptor(descriptor);
+
+    Value *base = memRefDescriptor.alignedPtr(rewriter, loc);
+    Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
+                             ? memRefDescriptor.offset(rewriter, loc)
+                             : this->createIndexConstant(rewriter, loc, offset);
+
     for (int i = 0, e = indices.size(); i < e; ++i) {
-      Value *stride;
-      if (strides[i] != MemRefType::getDynamicStrideOrOffset()) {
-        // Use static stride.
-        auto attr =
-            rewriter.getIntegerAttr(rewriter.getIndexType(), strides[i]);
-        stride = rewriter.create<LLVM::ConstantOp>(loc, indexTy, attr);
-      } else {
-        // Use dynamic stride.
-        stride = rewriter.create<LLVM::ExtractValueOp>(
-            loc, indexTy, memRefDescriptor,
-            rewriter.getI64ArrayAttr(
-                {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
-      }
+      Value *stride =
+          strides[i] == MemRefType::getDynamicStrideOrOffset()
+              ? memRefDescriptor.stride(rewriter, loc, i)
+              : this->createIndexConstant(rewriter, loc, strides[i]);
       Value *additionalOffset =
           rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
       offsetValue =
@@ -1452,74 +1515,45 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
       return matchFailure();
 
     // Create the descriptor.
-    Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
+    MemRefDescriptor sourceMemRef(adaptor.source());
+    auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
 
     // Copy the buffer pointer from the old descriptor to the new one.
-    Value *sourceDescriptor = adaptor.source();
-    Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
-        loc, sourceElementTy.getPointerTo(), sourceDescriptor,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+    Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
     Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc, targetElementTy.getPointerTo(), extracted);
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        loc, desc, bitcastPtr,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
-    extracted = rewriter.create<LLVM::ExtractValueOp>(
-        loc, sourceElementTy.getPointerTo(), sourceDescriptor,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
+
+    extracted = sourceMemRef.alignedPtr(rewriter, loc);
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc, targetElementTy.getPointerTo(), extracted);
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        loc, desc, bitcastPtr,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
     // Extract strides needed to compute offset.
     SmallVector<Value *, 4> strideValues;
     strideValues.reserve(viewMemRefType.getRank());
-    for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
-      strideValues.push_back(rewriter.create<LLVM::ExtractValueOp>(
-          loc, getIndexType(), sourceDescriptor,
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})));
-    }
+    for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
+      strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
 
     // Offset.
-    Value *baseOffset = rewriter.create<LLVM::ExtractValueOp>(
-        loc, getIndexType(), sourceDescriptor,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    Value *baseOffset = sourceMemRef.offset(rewriter, loc);
     for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
       Value *min = adaptor.offsets()[i];
       baseOffset = rewriter.create<LLVM::AddOp>(
           loc, baseOffset,
           rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
     }
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        loc, desc, baseOffset,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    targetMemRef.setOffset(rewriter, loc, baseOffset);
 
     // Update sizes and strides.
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
-      // Update size.
-      desc = rewriter.create<LLVM::InsertValueOp>(
-          loc, desc, adaptor.sizes()[i],
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
-      // Update stride.
-      desc = rewriter.create<LLVM::InsertValueOp>(
-          loc, desc,
-          rewriter.create<LLVM::MulOp>(loc, adaptor.strides()[i],
-                                       strideValues[i]),
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+      targetMemRef.setSize(rewriter, loc, i, adaptor.sizes()[i]);
+      targetMemRef.setStride(rewriter, loc, i,
+                             rewriter.create<LLVM::MulOp>(
+                                 loc, adaptor.strides()[i], strideValues[i]));
     }
 
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(op, {targetMemRef});
     return matchSuccess();
   }
 };
@@ -1571,10 +1605,6 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
     auto loc = op->getLoc();
     auto viewOp = cast<ViewOp>(op);
     ViewOpOperandAdaptor adaptor(operands);
-    auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
-    auto sourceElementTy =
-        lowering.convertType(sourceMemRefType.getElementType())
-            .dyn_cast<LLVM::LLVMType>();
 
     auto viewMemRefType = viewOp.getType();
     auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
@@ -1593,32 +1623,20 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
              matchFailure();
 
     // Create the descriptor.
-    Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
+    MemRefDescriptor sourceMemRef(adaptor.source());
+    auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
 
     // Field 1: Copy the allocated pointer, used for malloc/free.
-    Value *sourceDescriptor = adaptor.source();
-    Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
-        loc, sourceElementTy.getPointerTo(), sourceDescriptor,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+    Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
     Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc, targetElementTy.getPointerTo(), extracted);
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        loc, desc, bitcastPtr,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor));
+    targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
 
     // Field 2: Copy the actual aligned pointer to payload.
-    extracted = rewriter.create<LLVM::ExtractValueOp>(
-        loc, sourceElementTy.getPointerTo(), sourceDescriptor,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    extracted = sourceMemRef.alignedPtr(rewriter, loc);
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc, targetElementTy.getPointerTo(), extracted);
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        loc, desc, bitcastPtr,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor));
+    targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
     // Field 3: Copy the offset in aligned pointer.
     unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
@@ -1630,14 +1648,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
                             ? createIndexConstant(rewriter, loc, offset)
                             // TODO(ntv): better adaptor.
                             : sizeAndOffsetOperands.back();
-    desc = rewriter.create<LLVM::InsertValueOp>(
-        loc, desc, baseOffset,
-        rewriter.getI64ArrayAttr(
-            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    targetMemRef.setOffset(rewriter, loc, baseOffset);
 
     // Early exit for 0-D corner case.
     if (viewMemRefType.getRank() == 0)
-      return rewriter.replaceOp(op, desc), matchSuccess();
+      return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
 
     // Fields 4 and 5: Update sizes and strides.
     if (strides.back() != 1)
@@ -1648,20 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
       // Update size.
       Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
                             sizeAndOffsetOperands, i);
-      desc = rewriter.create<LLVM::InsertValueOp>(
-          loc, desc, size,
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
+      targetMemRef.setSize(rewriter, loc, i, size);
       // Update stride.
       stride = getStride(rewriter, loc, strides, nextSize, stride, i);
-      desc = rewriter.create<LLVM::InsertValueOp>(
-          loc, desc, stride,
-          rewriter.getI64ArrayAttr(
-              {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+      targetMemRef.setStride(rewriter, loc, i, stride);
       nextSize = size;
     }
 
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(op, {targetMemRef});
     return matchSuccess();
   }
 };