Use MemRefDescriptor in Linalg-to-LLVM conversion
authorAlex Zinenko <zinenko@google.com>
Thu, 14 Nov 2019 16:03:39 +0000 (08:03 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 14 Nov 2019 16:04:10 +0000 (08:04 -0800)
Following up on the consolidation of MemRef descriptor conversion, update
Linalg-to-LLVM conversion to use the helper class that abstracts away the
implementation details of the MemRef descriptor. This required MemRefDescriptor
to become publicly visible. Since this conversion is heavily EDSC-based,
introduce locally an additional wrapper that uses builder and location pointed
to by the EDSC context while emitting descriptor manipulation operations.

PiperOrigin-RevId: 280429228

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp

index 0fb973b..f40d2cf 100644 (file)
@@ -133,6 +133,61 @@ private:
   LLVM::LLVMType unwrap(Type type);
 };
 
+/// Helper class to produce LLVM dialect operations extracting or inserting
+/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
+/// The Value may be null, in which case none of the operations are valid.
+class MemRefDescriptor {
+public:
+  /// Construct a helper for the given descriptor value.
+  explicit MemRefDescriptor(Value *descriptor);
+  /// Builds IR creating an `undef` value of the descriptor type.
+  static MemRefDescriptor undef(OpBuilder &builder, Location loc,
+                                Type descriptorType);
+  /// Builds IR extracting the allocated pointer from the descriptor.
+  Value *allocatedPtr(OpBuilder &builder, Location loc);
+  /// Builds IR inserting the allocated pointer into the descriptor.
+  void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr);
+
+  /// Builds IR extracting the aligned pointer from the descriptor.
+  Value *alignedPtr(OpBuilder &builder, Location loc);
+
+  /// Builds IR inserting the aligned pointer into the descriptor.
+  void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr);
+
+  /// Builds IR extracting the offset from the descriptor.
+  Value *offset(OpBuilder &builder, Location loc);
+
+  /// Builds IR inserting the offset into the descriptor.
+  void setOffset(OpBuilder &builder, Location loc, Value *offset);
+
+  /// Builds IR extracting the pos-th size from the descriptor.
+  Value *size(OpBuilder &builder, Location loc, unsigned pos);
+
+  /// Builds IR inserting the pos-th size into the descriptor
+  void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size);
+
+  /// Builds IR extracting the pos-th size from the descriptor.
+  Value *stride(OpBuilder &builder, Location loc, unsigned pos);
+
+  /// Builds IR inserting the pos-th stride into the descriptor
+  void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
+
+  /*implicit*/ operator Value *() { return value; }
+
+private:
+  Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
+  void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
+
+  // Cached descriptor type.
+  Type structType;
+
+  // Cached index type.
+  Type indexType;
+
+  // Actual descriptor.
+  Value *value;
+};
+
 /// Base class for operation conversions targeting the LLVM IR dialect. Provides
 /// conversion patterns with an access to the containing LLVMLowering for the
 /// purpose of type conversions.
index 0641a6b..570b6c4 100644 (file)
@@ -234,126 +234,117 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
                                PatternBenefit benefit)
     : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
 
-namespace {
-/// Helper class to produce LLVM dialect operations extracting or inserting
-/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
-/// The Value may be null, in which case none of the operations are valid.
-class MemRefDescriptor {
-public:
-  /// Construct a helper for the given descriptor value.
-  explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
-    if (value) {
-      structType = value->getType().cast<LLVM::LLVMType>();
-      indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
-          LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
-    }
-  }
-
-  /// Builds IR creating an `undef` value of the descriptor type.
-  static MemRefDescriptor undef(OpBuilder &builder, Location loc,
-                                Type descriptorType) {
-    Value *descriptor = builder.create<LLVM::UndefOp>(
-        loc, descriptorType.cast<LLVM::LLVMType>());
-    return MemRefDescriptor(descriptor);
-  }
-
-  /// Builds IR extracting the allocated pointer from the descriptor.
-  Value *allocatedPtr(OpBuilder &builder, Location loc) {
-    return extractPtr(builder, loc,
-                      LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
-  }
-
-  /// Builds IR inserting the allocated pointer into the descriptor.
-  void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
-    setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
-           ptr);
-  }
-
-  /// Builds IR extracting the aligned pointer from the descriptor.
-  Value *alignedPtr(OpBuilder &builder, Location loc) {
-    return extractPtr(builder, loc,
-                      LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+/*============================================================================*/
+/* MemRefDescriptor implementation                                            */
+/*============================================================================*/
+
+/// Construct a helper for the given descriptor value.
+MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
+  if (value) {
+    structType = value->getType().cast<LLVM::LLVMType>();
+    indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
+        LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
   }
+}
 
-  /// Builds IR inserting the aligned pointer into the descriptor.
-  void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
-    setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
-           ptr);
-  }
+/// Builds IR creating an `undef` value of the descriptor type.
+MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
+                                         Type descriptorType) {
+  Value *descriptor =
+      builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+  return MemRefDescriptor(descriptor);
+}
 
-  /// Builds IR extracting the offset from the descriptor.
-  Value *offset(OpBuilder &builder, Location loc) {
-    return builder.create<LLVM::ExtractValueOp>(
-        loc, indexType, value,
-        builder.getI64ArrayAttr(
-            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
-  }
+/// Builds IR extracting the allocated pointer from the descriptor.
+Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
+  return extractPtr(builder, loc,
+                    LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
+}
 
-  /// Builds IR inserting the offset into the descriptor.
-  void setOffset(OpBuilder &builder, Location loc, Value *offset) {
-    value = builder.create<LLVM::InsertValueOp>(
-        loc, structType, value, offset,
-        builder.getI64ArrayAttr(
-            LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
-  }
+/// Builds IR inserting the allocated pointer into the descriptor.
+void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
+                                       Value *ptr) {
+  setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
+         ptr);
+}
 
-  /// Builds IR extracting the pos-th size from the descriptor.
-  Value *size(OpBuilder &builder, Location loc, unsigned pos) {
-    return builder.create<LLVM::ExtractValueOp>(
-        loc, indexType, value,
-        builder.getI64ArrayAttr(
-            {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
-  }
+/// Builds IR extracting the aligned pointer from the descriptor.
+Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
+  return extractPtr(builder, loc,
+                    LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+}
 
-  /// Builds IR inserting the pos-th size into the descriptor
-  void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
-    value = builder.create<LLVM::InsertValueOp>(
-        loc, structType, value, size,
-        builder.getI64ArrayAttr(
-            {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
-  }
+/// Builds IR inserting the aligned pointer into the descriptor.
+void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
+                                     Value *ptr) {
+  setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
+         ptr);
+}
 
-  /// Builds IR extracting the pos-th size from the descriptor.
-  Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
-    return builder.create<LLVM::ExtractValueOp>(
-        loc, indexType, value,
-        builder.getI64ArrayAttr(
-            {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
-  }
+/// Builds IR extracting the offset from the descriptor.
+Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
+  return builder.create<LLVM::ExtractValueOp>(
+      loc, indexType, value,
+      builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+}
 
-  /// Builds IR inserting the pos-th stride into the descriptor
-  void setStride(OpBuilder &builder, Location loc, unsigned pos,
-                 Value *stride) {
-    value = builder.create<LLVM::InsertValueOp>(
-        loc, structType, value, stride,
-        builder.getI64ArrayAttr(
-            {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
-  }
+/// Builds IR inserting the offset into the descriptor.
+void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
+                                 Value *offset) {
+  value = builder.create<LLVM::InsertValueOp>(
+      loc, structType, value, offset,
+      builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+}
 
-  /*implicit*/ operator Value *() { return value; }
+/// Builds IR extracting the pos-th size from the descriptor.
+Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
+  return builder.create<LLVM::ExtractValueOp>(
+      loc, indexType, value,
+      builder.getI64ArrayAttr(
+          {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+}
 
-private:
-  Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
-    Type type = structType.getStructElementType(pos);
-    return builder.create<LLVM::ExtractValueOp>(loc, type, value,
-                                                builder.getI64ArrayAttr(pos));
-  }
+/// Builds IR inserting the pos-th size into the descriptor
+void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
+                               Value *size) {
+  value = builder.create<LLVM::InsertValueOp>(
+      loc, structType, value, size,
+      builder.getI64ArrayAttr(
+          {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+}
 
-  void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
-    value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
-                                                builder.getI64ArrayAttr(pos));
-  }
+/// Builds IR extracting the pos-th size from the descriptor.
+Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
+                                unsigned pos) {
+  return builder.create<LLVM::ExtractValueOp>(
+      loc, indexType, value,
+      builder.getI64ArrayAttr(
+          {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+}
 
-  // Cached descriptor type.
-  LLVM::LLVMType structType;
+/// Builds IR inserting the pos-th stride into the descriptor
+void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
+                                 Value *stride) {
+  value = builder.create<LLVM::InsertValueOp>(
+      loc, structType, value, stride,
+      builder.getI64ArrayAttr(
+          {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+}
 
-  // Cached index type.
-  LLVM::LLVMType indexType;
+Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
+                                    unsigned pos) {
+  Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
+  return builder.create<LLVM::ExtractValueOp>(loc, type, value,
+                                              builder.getI64ArrayAttr(pos));
+}
 
-  // Actual descriptor.
-  Value *value;
-};
+void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
+                              Value *ptr) {
+  value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
+                                              builder.getI64ArrayAttr(pos));
+}
 
+namespace {
 // Base class for Standard to LLVM IR op conversions.  Matches the Op type
 // provided as template argument.  Carries a reference to the LLVM dialect in
 // case it is necessary for rewriters.
index 9d03953..61614aa 100644 (file)
@@ -128,33 +128,33 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
 }
 
 namespace {
-/// Factor out the common information for all view conversions:
-///   1. common types in (standard and LLVM dialects)
-///   2. `pos` method
-///   3. view descriptor construction `desc`.
+/// EDSC-compatible wrapper for MemRefDescriptor.
 class BaseViewConversionHelper {
 public:
-  BaseViewConversionHelper(Location loc, MemRefType memRefType,
-                           ConversionPatternRewriter &rewriter,
-                           LLVMTypeConverter &lowering)
-      : zeroDMemRef(memRefType.getRank() == 0),
-        elementTy(getPtrToElementType(memRefType, lowering)),
-        int64Ty(
-            lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()),
-        desc(nullptr), rewriter(rewriter) {
-    assert(isStrided(memRefType) && "expected strided memref type");
-    viewDescriptorTy = lowering.convertType(memRefType).cast<LLVMType>();
-    desc = rewriter.create<LLVM::UndefOp>(loc, viewDescriptorTy);
-  }
-
-  ArrayAttr pos(ArrayRef<int64_t> values) const {
-    return rewriter.getI64ArrayAttr(values);
-  };
-
-  bool zeroDMemRef;
-  LLVMType elementTy, int64Ty, viewDescriptorTy;
-  Value *desc;
-  ConversionPatternRewriter &rewriter;
+  BaseViewConversionHelper(Type type)
+      : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
+
+  BaseViewConversionHelper(Value *v) : d(v) {}
+
+  /// Wrappers around MemRefDescriptor that use EDSC builder and location.
+  Value *allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
+  void setAllocatedPtr(Value *v) { d.setAllocatedPtr(rewriter(), loc(), v); }
+  Value *alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
+  void setAlignedPtr(Value *v) { d.setAlignedPtr(rewriter(), loc(), v); }
+  Value *offset() { return d.offset(rewriter(), loc()); }
+  void setOffset(Value *v) { d.setOffset(rewriter(), loc(), v); }
+  Value *size(unsigned i) { return d.size(rewriter(), loc(), i); }
+  void setSize(unsigned i, Value *v) { d.setSize(rewriter(), loc(), i, v); }
+  Value *stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
+  void setStride(unsigned i, Value *v) { d.setStride(rewriter(), loc(), i, v); }
+
+  operator Value *() { return d; }
+
+private:
+  OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
+  Location loc() { return ScopedContext::getLocation(); }
+
+  MemRefDescriptor d;
 };
 } // namespace
 
@@ -200,53 +200,46 @@ public:
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    edsc::ScopedContext context(rewriter, op->getLoc());
     SliceOpOperandAdaptor adaptor(operands);
-    Value *baseDesc = adaptor.view();
+    BaseViewConversionHelper baseDesc(adaptor.view());
 
     auto sliceOp = cast<SliceOp>(op);
     auto memRefType = sliceOp.getBaseViewType();
+    auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
+                       .cast<LLVM::LLVMType>();
 
-    BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(),
-                                    rewriter, lowering);
-    LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
-    Value *desc = helper.desc;
-
-    edsc::ScopedContext context(rewriter, op->getLoc());
+    BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
 
     // TODO(ntv): extract sizes and emit asserts.
     SmallVector<Value *, 4> strides(memRefType.getRank());
     for (int i = 0, e = memRefType.getRank(); i < e; ++i)
-      strides[i] = extractvalue(
-          int64Ty, baseDesc,
-          helper.pos({LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+      strides[i] = baseDesc.stride(i);
+
+    auto pos = [&rewriter](ArrayRef<int64_t> values) {
+      return rewriter.getI64ArrayAttr(values);
+    };
 
     // Compute base offset.
-    Value *baseOffset = extractvalue(
-        int64Ty, baseDesc,
-        helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    Value *baseOffset = baseDesc.offset();
     for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
       Value *indexing = adaptor.indexings()[i];
       Value *min = indexing;
       if (sliceOp.indexing(i)->getType().isa<RangeType>())
-        min = extractvalue(int64Ty, indexing, helper.pos(0));
+        min = extractvalue(int64Ty, indexing, pos(0));
       baseOffset = add(baseOffset, mul(min, strides[i]));
     }
 
     // Insert the base and aligned pointers.
-    auto ptrPos =
-        helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
-    desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
-    ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-    desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
+    desc.setAllocatedPtr(baseDesc.allocatedPtr());
+    desc.setAlignedPtr(baseDesc.alignedPtr());
 
     // Insert base offset.
-    desc = insertvalue(
-        desc, baseOffset,
-        helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+    desc.setOffset(baseOffset);
 
     // Corner case, no sizes or strides: early return the descriptor.
-    if (helper.zeroDMemRef)
-      return rewriter.replaceOp(op, desc), matchSuccess();
+    if (sliceOp.getViewType().getRank() == 0)
+      return rewriter.replaceOp(op, {desc}), matchSuccess();
 
     Value *zero =
         constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
@@ -258,12 +251,11 @@ public:
       if (indexing->getType().isa<RangeType>()) {
         int rank = en.index();
         Value *rangeDescriptor = adaptor.indexings()[rank];
-        Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
-        Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
-        Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
-        Value *baseSize = extractvalue(
-            int64Ty, baseDesc,
-            helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, rank}));
+        Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+        Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+        Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+        Value *baseSize = baseDesc.size(rank);
+
         // Bound upper by base view upper bound.
         max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
                           baseSize);
@@ -272,19 +264,13 @@ public:
         size =
             llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
         Value *stride = mul(strides[rank], step);
-        desc = insertvalue(
-            desc, size,
-            helper.pos(
-                {LLVMTypeConverter::kSizePosInMemRefDescriptor, numNewDims}));
-        desc = insertvalue(
-            desc, stride,
-            helper.pos(
-                {LLVMTypeConverter::kStridePosInMemRefDescriptor, numNewDims}));
+        desc.setSize(numNewDims, size);
+        desc.setStride(numNewDims, stride);
         ++numNewDims;
       }
     }
 
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(op, {desc});
     return matchSuccess();
   }
 };
@@ -306,56 +292,35 @@ public:
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     // Initialize the common boilerplate and alloca at the top of the FuncOp.
+    edsc::ScopedContext context(rewriter, op->getLoc());
     TransposeOpOperandAdaptor adaptor(operands);
-    Value *baseDesc = adaptor.view();
+    BaseViewConversionHelper baseDesc(adaptor.view());
 
     auto transposeOp = cast<TransposeOp>(op);
     // No permutation, early exit.
     if (transposeOp.permutation().isIdentity())
-      return rewriter.replaceOp(op, baseDesc), matchSuccess();
+      return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
 
-    BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(),
-                                    rewriter, lowering);
-    LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
-    Value *desc = helper.desc;
+    BaseViewConversionHelper desc(
+        lowering.convertType(transposeOp.getViewType()));
 
-    edsc::ScopedContext context(rewriter, op->getLoc());
     // Copy the base and aligned pointers from the old descriptor to the new
     // one.
-    ArrayAttr ptrPos =
-        helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
-    desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
-    ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
-    desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
+    desc.setAllocatedPtr(baseDesc.allocatedPtr());
+    desc.setAlignedPtr(baseDesc.alignedPtr());
 
     // Copy the offset pointer from the old descriptor to the new one.
-    ArrayAttr offPos =
-        helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
-    desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
+    desc.setOffset(baseDesc.offset());
 
     // Iterate over the dimensions and apply size/stride permutation.
     for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
       int sourcePos = en.index();
       int targetPos = en.value().cast<AffineDimExpr>().getPosition();
-      Value *size = extractvalue(
-          int64Ty, baseDesc,
-          helper.pos(
-              {LLVMTypeConverter::kSizePosInMemRefDescriptor, sourcePos}));
-      desc =
-          insertvalue(desc, size,
-                      helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor,
-                                  targetPos}));
-      Value *stride = extractvalue(
-          int64Ty, baseDesc,
-          helper.pos(
-              {LLVMTypeConverter::kStridePosInMemRefDescriptor, sourcePos}));
-      desc = insertvalue(
-          desc, stride,
-          helper.pos(
-              {LLVMTypeConverter::kStridePosInMemRefDescriptor, targetPos}));
+      desc.setSize(targetPos, baseDesc.size(sourcePos));
+      desc.setStride(targetPos, baseDesc.stride(sourcePos));
     }
 
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(op, {desc});
     return matchSuccess();
   }
 };