Promote MemRefDescriptor to a pointer to struct when passing function boundaries...
authorNicolas Vasilache <ntv@google.com>
Fri, 27 Sep 2019 16:55:38 +0000 (09:55 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 27 Sep 2019 16:57:36 +0000 (09:57 -0700)
The strided MemRef RFC discusses a normalized descriptor and interaction with library calls (https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio).
Lowering of nested LLVM structs as value types does not play nicely with externally compiled C/C++ functions due to ABI issues.
Solving the ABI problem generally is a very complex problem and most likely involves taking
a dependence on clang that we do not want atm.

A simple workaround is to pass pointers to memref descriptors at function boundaries, which this CL implement.

PiperOrigin-RevId: 271591708

14 files changed:
mlir/examples/Linalg/Linalg3/Execution.cpp
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Conversion/StandardToLLVM/convert-argattrs.mlir
mlir/test/Conversion/StandardToLLVM/convert-funcs.mlir
mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Examples/Linalg/Linalg1.mlir
mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp

index 5122878..36f0e1c 100644 (file)
@@ -125,10 +125,8 @@ TEST_FUNC(execution) {
   auto A = allocateInit2DMemref(5, 3);
   auto B = allocateInit2DMemref(3, 2);
   auto C = allocateInit2DMemref(5, 2);
-  llvm::SmallVector<void *, 4> args;
-  args.push_back(&A);
-  args.push_back(&B);
-  args.push_back(&C);
+  auto *pA = &A, *pB = &B, *pC = &C;
+  llvm::SmallVector<void *, 3> args({&pA, &pB, &pC});
 
   // Invoke the JIT-compiled function with the arguments.  Note that, for API
   // uniformity reasons, it takes a list of type-erased pointers to arguments.
index 30aaa0d..b4cd87d 100644 (file)
@@ -62,6 +62,20 @@ public:
   /// Returns the LLVM dialect.
   LLVM::LLVMDialect *getDialect() { return llvmDialect; }
 
+  /// Promote the LLVM struct representation of all MemRef descriptors to stack
+  /// and use pointers to struct to avoid the complexity of the
+  /// platform-specific C/C++ ABI lowering related to struct argument passing.
+  SmallVector<Value *, 4> promoteMemRefDescriptors(Location loc,
+                                                   ArrayRef<Value *> opOperands,
+                                                   ArrayRef<Value *> operands,
+                                                   OpBuilder &builder);
+
+  /// Promote the LLVM struct representation of one MemRef descriptor to stack
+  /// and use pointer to struct to avoid the complexity of the platform-specific
+  /// C/C++ ABI lowering related to struct argument passing.
+  Value *promoteOneMemRefDescriptor(Location loc, Value *operand,
+                                    OpBuilder &builder);
+
 protected:
   /// LLVM IR module used to parse/create types.
   llvm::Module *module;
index 3cdf14c..8f5adcd 100644 (file)
@@ -244,6 +244,9 @@ public:
   void applySignatureConversion(Region *region,
                                 TypeConverter::SignatureConversion &conversion);
 
+  /// Replace all the uses of the block argument `from` with value `to`.
+  void replaceUsesOfBlockArgument(BlockArgument *from, Value *to);
+
   /// Clone the given operation without cloning its regions.
   Operation *cloneWithoutRegions(Operation *op);
   template <typename OpT> OpT cloneWithoutRegions(OpT op) {
index c53a52d..961727d 100644 (file)
@@ -49,6 +49,7 @@ static constexpr const char *cuModuleGetFunctionName = "mcuModuleGetFunction";
 static constexpr const char *cuLaunchKernelName = "mcuLaunchKernel";
 static constexpr const char *cuGetStreamHelperName = "mcuGetStreamHelper";
 static constexpr const char *cuStreamSynchronizeName = "mcuStreamSynchronize";
+static constexpr const char *kMcuMemHostRegisterPtr = "mcuMemHostRegisterPtr";
 
 static constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter";
 
@@ -216,6 +217,15 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
                            },
                            getCUResultType())));
   }
+  if (!module.lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr)) {
+    module.push_back(FuncOp::create(loc, kMcuMemHostRegisterPtr,
+                                    builder.getFunctionType(
+                                        {
+                                            getPointerType(), /* void *ptr */
+                                            getInt32Type()    /* int32 flags*/
+                                        },
+                                        {})));
+  }
 }
 
 // Generates a parameters array to be used with a CUDA kernel launch call. The
@@ -229,22 +239,45 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
 Value *
 GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
                                                OpBuilder &builder) {
+  auto numKernelOperands = launchOp.getNumKernelOperands();
   Location loc = launchOp.getLoc();
   auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
                                               builder.getI32IntegerAttr(1));
+  // Provision twice as much for the `array` to allow up to one level of
+  // indirection for each argument.
   auto arraySize = builder.create<LLVM::ConstantOp>(
-      loc, getInt32Type(),
-      builder.getI32IntegerAttr(launchOp.getNumKernelOperands()));
+      loc, getInt32Type(), builder.getI32IntegerAttr(numKernelOperands));
   auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(),
                                               arraySize, /*alignment=*/0);
-  for (int idx = 0, e = launchOp.getNumKernelOperands(); idx < e; ++idx) {
+  for (unsigned idx = 0; idx < numKernelOperands; ++idx) {
     auto operand = launchOp.getKernelOperand(idx);
     auto llvmType = operand->getType().cast<LLVM::LLVMType>();
-    auto memLocation = builder.create<LLVM::AllocaOp>(
+    Value *memLocation = builder.create<LLVM::AllocaOp>(
         loc, llvmType.getPointerTo(), one, /*alignment=*/1);
     builder.create<LLVM::StoreOp>(loc, operand, memLocation);
     auto casted =
         builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
+
+    // Assume all struct arguments come from MemRef. If this assumption does not
+    // hold anymore then we `launchOp` to lower from MemRefType and not after
+    // LLVMConversion has taken place and the MemRef information is lost.
+    // Extra level of indirection in the `array`:
+    //   the descriptor pointer is registered via @mcuMemHostRegisterPtr
+    if (llvmType.isStructTy()) {
+      auto registerFunc =
+          getModule().lookupSymbol<FuncOp>(kMcuMemHostRegisterPtr);
+      auto zero = builder.create<LLVM::ConstantOp>(
+          loc, getInt32Type(), builder.getI32IntegerAttr(0));
+      builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
+                                   builder.getSymbolRefAttr(registerFunc),
+                                   ArrayRef<Value *>{casted, zero});
+      Value *memLocation = builder.create<LLVM::AllocaOp>(
+          loc, getPointerPointerType(), one, /*alignment=*/1);
+      builder.create<LLVM::StoreOp>(loc, casted, memLocation);
+      casted =
+          builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
+    }
+
     auto index = builder.create<LLVM::ConstantOp>(
         loc, getInt32Type(), builder.getI32IntegerAttr(idx));
     auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array,
index 12142eb..ef90a2a 100644 (file)
@@ -276,12 +276,28 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto funcOp = cast<FuncOp>(op);
     FunctionType type = funcOp.getType();
-
-    // Convert the original function arguments.
+    SmallVector<Type, 4> argTypes;
+    argTypes.reserve(type.getNumInputs());
+    SmallVector<unsigned, 4> promotedArgIndices;
+    promotedArgIndices.reserve(type.getNumInputs());
+
+    // Convert the original function arguments. Struct arguments are promoted to
+    // pointer to struct arguments to allow calling external functions with
+    // various ABIs (e.g. compiled from C/C++ on platform X).
     TypeConverter::SignatureConversion result(type.getNumInputs());
-    for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-      if (failed(lowering.convertSignatureArg(i, type.getInput(i), result)))
+    for (auto en : llvm::enumerate(type.getInputs())) {
+      auto t = en.value();
+      auto converted = lowering.convertType(t);
+      if (!converted)
         return matchFailure();
+      if (t.isa<MemRefType>()) {
+        converted = converted.cast<LLVM::LLVMType>().getPointerTo();
+        promotedArgIndices.push_back(en.index());
+      }
+      argTypes.push_back(converted);
+    }
+    for (unsigned idx = 0, e = argTypes.size(); idx < e; ++idx)
+      result.addInputs(idx, argTypes[idx]);
 
     // Pack the result types into a struct.
     Type packedResult;
@@ -301,6 +317,18 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
 
     // Tell the rewriter to convert the region signature.
     rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+
+    // Insert loads from memref descriptor pointers in function bodies.
+    if (!newFuncOp.getBody().empty()) {
+      Block *firstBlock = &newFuncOp.getBody().front();
+      rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
+      for (unsigned idx : promotedArgIndices) {
+        BlockArgument *arg = firstBlock->getArgument(idx);
+        Value *loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
+        rewriter.replaceUsesOfBlockArgument(arg, loaded);
+      }
+    }
+
     rewriter.replaceOp(op, llvm::None);
     return matchSuccess();
   }
@@ -502,13 +530,6 @@ struct SelectOpLowering
     : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
   using Super::Super;
 };
-struct CallOpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::CallOp> {
-  using Super::Super;
-};
-struct CallIndirectOpLowering
-    : public OneToOneLLVMOpLowering<CallIndirectOp, LLVM::CallOp> {
-  using Super::Super;
-};
 struct ConstLLVMOpLowering
     : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
   using Super::Super;
@@ -623,6 +644,100 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
   }
 };
 
+// Helper structure which extracts the necessary information from CallOp-like
+// ops for the purpose of generating an LLVM::CallOp.
+struct FunctionInfo {
+  FunctionType type;
+  CallInterfaceCallable callable;
+};
+static FunctionInfo getFuncOp(ModuleOp module, CallOp op) {
+  return FunctionInfo{module.lookupSymbol<FuncOp>(op.getCallee()).getType(),
+                      SymbolRefAttr::get(op.getCallee(), op.getContext())};
+}
+static FunctionInfo getFuncOp(ModuleOp module, CallIndirectOp op) {
+  if (auto fAttr = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
+    return FunctionInfo{module.lookupSymbol<FuncOp>(fAttr.getValue()).getType(),
+                        fAttr};
+  // Else, this must be an SSA value of FunctionType type.
+  Value *fValue = op.getCallableForCallee().get<Value *>();
+  FunctionType fType = fValue->getType().cast<FunctionType>();
+  return FunctionInfo{fType, fValue};
+}
+template <typename CallOpType>
+static LLVM::CallOp
+createLLVMCall(FunctionInfo fInfo, ConversionPatternRewriter &rewriter,
+               Location loc, Type returnType, ArrayRef<Value *> operands) {
+  if (fInfo.callable.dyn_cast<Value *>())
+    return rewriter.create<LLVM::CallOp>(loc, returnType, operands);
+  auto fAttr = fInfo.callable.get<SymbolRefAttr>();
+  auto namedFAttr = rewriter.getNamedAttr("callee", fAttr);
+  return rewriter.create<LLVM::CallOp>(loc, returnType, operands,
+                                       ArrayRef<NamedAttribute>{namedFAttr});
+}
+
+// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
+// passes the pointer to the MemRef across function boundaries.
+template <typename CallOpType>
+struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
+  using LLVMLegalizationPattern<CallOpType>::LLVMLegalizationPattern;
+  using Super = CallOpInterfaceLowering<CallOpType>;
+  using Base = LLVMLegalizationPattern<CallOpType>;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    OperandAdaptor<CallOpType> transformed(operands);
+    auto callOp = cast<CallOpType>(op);
+    auto module = op->getParentOfType<ModuleOp>();
+    FunctionInfo fInfo = getFuncOp(module, callOp);
+    auto functionType = fInfo.type;
+
+    // Pack the result types into a struct.
+    Type packedResult;
+    unsigned numResults = callOp.getNumResults();
+    if (numResults != 0) {
+      if (!(packedResult =
+                this->lowering.packFunctionResults(functionType.getResults())))
+        return this->matchFailure();
+    }
+
+    SmallVector<Value *, 4> opOperands(op->getOperands());
+    auto promoted = this->lowering.promoteMemRefDescriptors(
+        op->getLoc(), opOperands, operands, rewriter);
+    auto newOp = createLLVMCall<CallOpType>(fInfo, rewriter, op->getLoc(),
+                                            packedResult, promoted);
+
+    // If < 2 results, packingdid not do anything and we can just return.
+    if (numResults < 2) {
+      SmallVector<Value *, 4> results(newOp.getResults());
+      rewriter.replaceOp(op, results);
+      return this->matchSuccess();
+    }
+
+    // Otherwise, it had been converted to an operation producing a structure.
+    // Extract individual results from the structure and return them as list.
+    SmallVector<Value *, 4> results;
+    results.reserve(numResults);
+    for (unsigned i = 0; i < numResults; ++i) {
+      auto type = this->lowering.convertType(op->getResult(i)->getType());
+      results.push_back(rewriter.create<LLVM::ExtractValueOp>(
+          op->getLoc(), type, newOp.getOperation()->getResult(0),
+          rewriter.getIndexArrayAttr(i)));
+    }
+    rewriter.replaceOp(op, results);
+
+    return this->matchSuccess();
+  }
+};
+
+struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
+  using Super::Super;
+};
+
+struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
+  using Super::Super;
+};
+
 // A `dealloc` is converted into a call to `free` on the underlying data buffer.
 // The memref descriptor being an SSA value, there is no need to clean it up
 // in any way.
@@ -1138,6 +1253,42 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
   return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
 }
 
+Value *LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc,
+                                                     Value *operand,
+                                                     OpBuilder &builder) {
+  auto *context = builder.getContext();
+  auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
+  auto indexType = IndexType::get(context);
+  // Alloca with proper alignment. We do not expect optimizations of this
+  // alloca op and so we omit allocating at the entry block.
+  auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo();
+  Value *one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
+                                                IntegerAttr::get(indexType, 1));
+  Value *allocated =
+      builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
+  // Store into the alloca'ed descriptor.
+  builder.create<LLVM::StoreOp>(loc, operand, allocated);
+  return allocated;
+}
+
+SmallVector<Value *, 4> LLVMTypeConverter::promoteMemRefDescriptors(
+    Location loc, ArrayRef<Value *> opOperands, ArrayRef<Value *> operands,
+    OpBuilder &builder) {
+  SmallVector<Value *, 4> promotedOperands;
+  promotedOperands.reserve(operands.size());
+  for (auto it : llvm::zip(opOperands, operands)) {
+    auto *operand = std::get<0>(it);
+    auto *llvmOperand = std::get<1>(it);
+    if (!operand->getType().isa<MemRefType>()) {
+      promotedOperands.push_back(operand);
+      continue;
+    }
+    promotedOperands.push_back(
+        promoteOneMemRefDescriptor(loc, llvmOperand, builder));
+  }
+  return promotedOperands;
+}
+
 /// Create an instance of LLVMTypeConverter in the given context.
 static std::unique_ptr<LLVMTypeConverter>
 makeStandardToLLVMTypeConverter(MLIRContext *context) {
index 1fdff7c..ff05802 100644 (file)
@@ -449,12 +449,16 @@ LogicalResult LaunchFuncOp::verify() {
            << getNumKernelOperands() << " kernel operands but expected "
            << numKernelFuncArgs;
   }
-  auto functionType = kernelFunc.getType();
-  for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
-    if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
-      return emitOpError("type of function argument ")
-             << i << " does not match";
-    }
-  }
+  // Due to the ordering of the current impl of lowering and LLVMLowering, type
+  // checks need to be temporarily disabled.
+  // TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
+  // to encode target module" has landed.
+  // auto functionType = kernelFunc.getType();
+  // for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
+  //   if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
+  //     return emitOpError("type of function argument ")
+  //            << i << " does not match";
+  //   }
+  // }
   return success();
 }
index 3fababb..cdd5be8 100644 (file)
@@ -618,6 +618,16 @@ void ConversionPatternRewriter::applySignatureConversion(
   impl->applySignatureConversion(region, conversion);
 }
 
+void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from,
+                                                           Value *to) {
+  for (auto &u : from->getUses()) {
+    if (u.getOwner() == to->getDefiningOp())
+      continue;
+    u.getOwner()->replaceUsesOfWith(from, to);
+  }
+  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+}
+
 /// Clone the given operation without cloning its regions.
 Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
   Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
index 4c4b620..60bb3e1 100644 (file)
@@ -1,8 +1,24 @@
 // RUN: mlir-opt -lower-to-llvm %s | FileCheck %s
 
-
-// CHECK-LABEL: func @check_attributes(%arg0: !llvm<"{ float*, [2 x i64] }"> {dialect.a = true, dialect.b = 4 : i64}) {
+// CHECK-LABEL: func @check_attributes(%arg0: !llvm<"{ float*, [2 x i64] }*"> {dialect.a = true, dialect.b = 4 : i64}) {
+//  CHECK-NEXT:   llvm.load %arg0 : !llvm<"{ float*, [2 x i64] }*">
 func @check_attributes(%static: memref<10x20xf32> {dialect.a = true, dialect.b = 4 : i64 }) {
+  %c0 = constant 0 : index
+  %0 = load %static[%c0, %c0]: memref<10x20xf32>
+  return
+}
+
+// CHECK-LABEL: func @external_func(!llvm<"{ float*, [2 x i64] }*">)
+//       CHECK: func @call_external(%[[arg:.*]]: !llvm<"{ float*, [2 x i64] }*">) {
+//       CHECK:   %[[ld:.*]] = llvm.load %[[arg]] : !llvm<"{ float*, [2 x i64] }*">
+//       CHECK:   %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:   %[[alloca:.*]] = llvm.alloca %[[c1]] x !llvm<"{ float*, [2 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, [2 x i64] }*">
+//       CHECK:   llvm.store %[[ld]], %[[alloca]] : !llvm<"{ float*, [2 x i64] }*">
+//       CHECK:   call @external_func(%[[alloca]]) : (!llvm<"{ float*, [2 x i64] }*">) -> ()
+func @external_func(memref<10x20xf32>)
+
+func @call_external(%static: memref<10x20xf32>) {
+  call @external_func(%static) : (memref<10x20xf32>) -> ()
   return
 }
 
index f81443a..eeffb56 100644 (file)
@@ -1,3 +1,4 @@
+// RUN: mlir-opt -lower-to-llvm %s
 // RUN: mlir-opt -lower-to-llvm %s | FileCheck %s
 
 //CHECK: func @second_order_arg(!llvm<"void ()*">)
index 496a1ec..d150dff 100644 (file)
@@ -1,31 +1,34 @@
+// RUN: mlir-opt -lower-to-llvm %s
 // RUN: mlir-opt -lower-to-llvm %s | FileCheck %s
 
 
-// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, [2 x i64] }">, %arg1: !llvm<"{ float*, [2 x i64] }">, %arg2: !llvm<"{ float*, [2 x i64] }">)
+// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, [2 x i64] }*">, %arg1: !llvm<"{ float*, [2 x i64] }*">, %arg2: !llvm<"{ float*, [2 x i64] }*">)
 func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) {
   return
 }
 
-// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, [2 x i64] }">) -> !llvm<"{ float*, [2 x i64] }"> {
+// CHECK-LABEL: func @check_static_return(%arg0: !llvm<"{ float*, [2 x i64] }*">) -> !llvm<"{ float*, [2 x i64] }"> {
 func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
-// CHECK-NEXT:  llvm.return %arg0 : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.return %{{.*}} : !llvm<"{ float*, [2 x i64] }">
   return %static : memref<32x18xf32>
 }
 
 // CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float* }"> {
 func @zero_d_alloc() -> memref<f32> {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT:  %1 = llvm.mlir.constant(4 : index) : !llvm.i64
-// CHECK-NEXT:  %2 = llvm.mul %0, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.call @malloc(%2) : (!llvm.i64) -> !llvm<"i8*">
-// CHECK-NEXT:  %4 = llvm.bitcast %3 : !llvm<"i8*"> to !llvm<"float*">
+// CHECK-NEXT:  llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(4 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*">
+// CHECK-NEXT:  llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
   %0 = alloc() : memref<f32>
   return %0 : memref<f32>
 }
 
-// CHECK-LABEL: func @zero_d_dealloc(%arg0: !llvm<"{ float* }">) {
+// CHECK-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm<"{ float* }*">) {
 func @zero_d_dealloc(%arg0: memref<f32>) {
-// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float* }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float* }*">
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float* }">
 // CHECK-NEXT:  %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
 // CHECK-NEXT:  llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
   dealloc %arg0 : memref<f32>
@@ -50,11 +53,12 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
   return %0 : memref<?x42x?xf32>
 }
 
-// CHECK-LABEL: func @mixed_dealloc(%arg0: !llvm<"{ float*, [3 x i64] }">) {
+// CHECK-LABEL: func @mixed_dealloc(%arg0: !llvm<"{ float*, [3 x i64] }*">) {
 func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
-// CHECK-NEXT:  %0 = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [3 x i64] }">
-// CHECK-NEXT:  %1 = llvm.bitcast %0 : !llvm<"float*"> to !llvm<"i8*">
-// CHECK-NEXT:  llvm.call @free(%1) : (!llvm<"i8*">) -> ()
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [3 x i64] }*">
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [3 x i64] }">
+// CHECK-NEXT:  %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
+// CHECK-NEXT:  llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
   dealloc %arg0 : memref<?x42x?xf32>
 // CHECK-NEXT:  llvm.return
   return
@@ -75,11 +79,12 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
   return %0 : memref<?x?xf32>
 }
 
-// CHECK-LABEL: func @dynamic_dealloc(%arg0: !llvm<"{ float*, [2 x i64] }">) {
+// CHECK-LABEL: func @dynamic_dealloc(%arg0: !llvm<"{ float*, [2 x i64] }*">) {
 func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
-// CHECK-NEXT:  %0 = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %1 = llvm.bitcast %0 : !llvm<"float*"> to !llvm<"i8*">
-// CHECK-NEXT:  llvm.call @free(%1) : (!llvm<"i8*">) -> ()
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
+// CHECK-NEXT:  llvm.call @free(%[[ptri8]]) : (!llvm<"i8*">) -> ()
   dealloc %arg0 : memref<?x?xf32>
   return
 }
@@ -97,32 +102,35 @@ func @static_alloc() -> memref<32x18xf32> {
  return %0 : memref<32x18xf32>
 }
 
-// CHECK-LABEL: func @static_dealloc(%arg0: !llvm<"{ float*, [2 x i64] }">) {
+// CHECK-LABEL: func @static_dealloc(%{{.*}}: !llvm<"{ float*, [2 x i64] }*">) {
 func @static_dealloc(%static: memref<10x8xf32>) {
-// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
 // CHECK-NEXT:  %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"i8*">
 // CHECK-NEXT:  llvm.call @free(%[[bc]]) : (!llvm<"i8*">) -> ()
   dealloc %static : memref<10x8xf32>
   return
 }
 
-// CHECK-LABEL: func @zero_d_load(%arg0: !llvm<"{ float* }">) -> !llvm.float {
+// CHECK-LABEL: func @zero_d_load(%{{.*}}: !llvm<"{ float* }*">) -> !llvm.float {
 func @zero_d_load(%arg0: memref<f32>) -> f32 {
-// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float* }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float* }*">
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float* }">
 // CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][] : (!llvm<"float*">) -> !llvm<"float*">
-// CHECK-NEXT:  %2 = llvm.load %[[addr]] : !llvm<"float*">
+// CHECK-NEXT:  %{{.*}} = llvm.load %[[addr]] : !llvm<"float*">
   %0 = load %arg0[] : memref<f32>
   return %0 : f32
 }
 
 // CHECK-LABEL: func @static_load
 func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(10 : index) : !llvm.i64
-// CHECK-NEXT:  %1 = llvm.mlir.constant(42 : index) : !llvm.i64
-// CHECK-NEXT:  %2 = llvm.mul %arg1, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.add %2, %arg2 : !llvm.i64
-// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%3] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.mlir.constant(10 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
 // CHECK-NEXT:  llvm.load %[[addr]] : !llvm<"float*">
   %0 = load %static[%i, %j] : memref<10x42xf32>
   return
@@ -130,33 +138,36 @@ func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
 
 // CHECK-LABEL: func @mixed_load
 func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(42 : index) : !llvm.i64
-// CHECK-NEXT:  %1 = llvm.extractvalue %arg0[1, 1] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %2 = llvm.mul %arg1, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.add %2, %arg2 : !llvm.i64
-// CHECK-NEXT:  %4 = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %5 = llvm.getelementptr %4[%3] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT:  %6 = llvm.load %5 : !llvm<"float*">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 1] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT:  llvm.load %[[addr]] : !llvm<"float*">
   %0 = load %mixed[%i, %j] : memref<42x?xf32>
   return
 }
 
 // CHECK-LABEL: func @dynamic_load
 func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
-// CHECK-NEXT:  %0 = llvm.extractvalue %arg0[1, 0] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %1 = llvm.extractvalue %arg0[1, 1] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %2 = llvm.mul %arg1, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.add %2, %arg2 : !llvm.i64
-// CHECK-NEXT:  %4 = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %5 = llvm.getelementptr %4[%3] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT:  %6 = llvm.load %5 : !llvm<"float*">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 0] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 1] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT:  llvm.load %[[addr]] : !llvm<"float*">
   %0 = load %dynamic[%i, %j] : memref<?x?xf32>
   return
 }
 
-// CHECK-LABEL: func @zero_d_store(%arg0: !llvm<"{ float* }">, %arg1: !llvm.float) {
+// CHECK-LABEL: func @zero_d_store(%{{.*}}: !llvm<"{ float* }*">, %{{.*}}: !llvm.float) {
 func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
-// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float* }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float* }*">
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float* }">
 // CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][] : (!llvm<"float*">) -> !llvm<"float*">
 // CHECK-NEXT:  llvm.store %arg1, %[[addr]] : !llvm<"float*">
   store %arg1, %arg0[] : memref<f32>
@@ -165,118 +176,130 @@ func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
 
 // CHECK-LABEL: func @static_store
 func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(10 : index) : !llvm.i64
-// CHECK-NEXT:  %1 = llvm.mlir.constant(42 : index) : !llvm.i64
-// CHECK-NEXT:  %2 = llvm.mul %arg1, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.add %2, %arg2 : !llvm.i64
-// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%3] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT:  llvm.store %arg3, %[[addr]] : !llvm<"float*">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.mlir.constant(10 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
   store %val, %static[%i, %j] : memref<10x42xf32>
   return
 }
 
 // CHECK-LABEL: func @dynamic_store
 func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) {
-// CHECK-NEXT:  %0 = llvm.extractvalue %arg0[1, 0] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %1 = llvm.extractvalue %arg0[1, 1] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %2 = llvm.mul %arg1, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.add %2, %arg2 : !llvm.i64
-// CHECK-NEXT:  %4 = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %5 = llvm.getelementptr %4[%3] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT:  llvm.store %arg3, %5 : !llvm<"float*">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 0] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 1] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
   store %val, %dynamic[%i, %j] : memref<?x?xf32>
   return
 }
 
 // CHECK-LABEL: func @mixed_store
 func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(42 : index) : !llvm.i64
-// CHECK-NEXT:  %1 = llvm.extractvalue %arg0[1, 1] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %2 = llvm.mul %arg1, %1 : !llvm.i64
-// CHECK-NEXT:  %3 = llvm.add %2, %arg2 : !llvm.i64
-// CHECK-NEXT:  %4 = llvm.extractvalue %arg0[0 : index] : !llvm<"{ float*, [2 x i64] }">
-// CHECK-NEXT:  %5 = llvm.getelementptr %4[%3] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
-// CHECK-NEXT:  llvm.store %arg3, %5 : !llvm<"float*">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 1] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// CHECK-NEXT:  %[[ptr:.*]] = llvm.extractvalue %[[ld]][0 : index] : !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : !llvm<"float*">
   store %val, %mixed[%i, %j] : memref<42x?xf32>
   return
 }
 
 // CHECK-LABEL: func @memref_cast_static_to_dynamic
 func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) {
-// CHECK-NEXT:  llvm.bitcast %arg0 : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.bitcast %[[ld]] : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
   %0 = memref_cast %static : memref<10x42xf32> to memref<?x?xf32>
   return
 }
 
 // CHECK-LABEL: func @memref_cast_static_to_mixed
 func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) {
-// CHECK-NEXT:  llvm.bitcast %arg0 : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.bitcast %[[ld]] : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
   %0 = memref_cast %static : memref<10x42xf32> to memref<?x42xf32>
   return
 }
 
 // CHECK-LABEL: func @memref_cast_dynamic_to_static
 func @memref_cast_dynamic_to_static(%dynamic : memref<?x?xf32>) {
-// CHECK-NEXT:  llvm.bitcast %arg0 : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.bitcast %[[ld]] : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
   %0 = memref_cast %dynamic : memref<?x?xf32> to memref<10x12xf32>
   return
 }
 
 // CHECK-LABEL: func @memref_cast_dynamic_to_mixed
 func @memref_cast_dynamic_to_mixed(%dynamic : memref<?x?xf32>) {
-// CHECK-NEXT:  llvm.bitcast %arg0 : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.bitcast %[[ld]] : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
   %0 = memref_cast %dynamic : memref<?x?xf32> to memref<?x12xf32>
   return
 }
 
 // CHECK-LABEL: func @memref_cast_mixed_to_dynamic
 func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) {
-// CHECK-NEXT:  llvm.bitcast %arg0 : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.bitcast %[[ld]] : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
   %0 = memref_cast %mixed : memref<42x?xf32> to memref<?x?xf32>
   return
 }
 
 // CHECK-LABEL: func @memref_cast_mixed_to_static
 func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) {
-// CHECK-NEXT:  llvm.bitcast %arg0 : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.bitcast %[[ld]] : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
   %0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32>
   return
 }
 
 // CHECK-LABEL: func @memref_cast_mixed_to_mixed
 func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) {
-// CHECK-NEXT:  llvm.bitcast %arg0 : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// CHECK-NEXT:  llvm.bitcast %[[ld]] : !llvm<"{ float*, [2 x i64] }"> to !llvm<"{ float*, [2 x i64] }">
   %0 = memref_cast %mixed : memref<42x?xf32> to memref<?x1xf32>
   return
 }
 
-// CHECK-LABEL: func @mixed_memref_dim(%arg0: !llvm<"{ float*, [5 x i64] }">)
+// CHECK-LABEL: func @mixed_memref_dim(%{{.*}}: !llvm<"{ float*, [5 x i64] }*">)
 func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [5 x i64] }*">
+// CHECK-NEXT:  llvm.mlir.constant(42 : index) : !llvm.i64
   %0 = dim %mixed, 0 : memref<42x?x?x13x?xf32>
-// CHECK-NEXT:  %1 = llvm.extractvalue %arg0[1, 1] : !llvm<"{ float*, [5 x i64] }">
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 1] : !llvm<"{ float*, [5 x i64] }">
   %1 = dim %mixed, 1 : memref<42x?x?x13x?xf32>
-// CHECK-NEXT:  %2 = llvm.extractvalue %arg0[1, 2] : !llvm<"{ float*, [5 x i64] }">
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 2] : !llvm<"{ float*, [5 x i64] }">
   %2 = dim %mixed, 2 : memref<42x?x?x13x?xf32>
-// CHECK-NEXT:  %3 = llvm.mlir.constant(13 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(13 : index) : !llvm.i64
   %3 = dim %mixed, 3 : memref<42x?x?x13x?xf32>
-// CHECK-NEXT:  %4 = llvm.extractvalue %arg0[1, 4] : !llvm<"{ float*, [5 x i64] }">
+// CHECK-NEXT:  llvm.extractvalue %[[ld]][1, 4] : !llvm<"{ float*, [5 x i64] }">
   %4 = dim %mixed, 4 : memref<42x?x?x13x?xf32>
   return
 }
 
-// CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, [5 x i64] }">)
+// CHECK-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"{ float*, [5 x i64] }*">)
 func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(42 : index) : !llvm.i64
+// CHECK-NEXT:  %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, [5 x i64] }*">
+// CHECK-NEXT:  llvm.mlir.constant(42 : index) : !llvm.i64
   %0 = dim %static, 0 : memref<42x32x15x13x27xf32>
-// CHECK-NEXT:  %1 = llvm.mlir.constant(32 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(32 : index) : !llvm.i64
   %1 = dim %static, 1 : memref<42x32x15x13x27xf32>
-// CHECK-NEXT:  %2 = llvm.mlir.constant(15 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(15 : index) : !llvm.i64
   %2 = dim %static, 2 : memref<42x32x15x13x27xf32>
-// CHECK-NEXT:  %3 = llvm.mlir.constant(13 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(13 : index) : !llvm.i64
   %3 = dim %static, 3 : memref<42x32x15x13x27xf32>
-// CHECK-NEXT:  %4 = llvm.mlir.constant(27 : index) : !llvm.i64
+// CHECK-NEXT:  llvm.mlir.constant(27 : index) : !llvm.i64
   %4 = dim %static, 4 : memref<42x32x15x13x27xf32>
   return
 }
index 3f62fd4..7bacabe 100644 (file)
@@ -1,7 +1,9 @@
+// RUN: mlir-opt %s -lower-to-llvm
 // RUN: mlir-opt %s -lower-to-llvm | FileCheck %s
 
 // CHECK-LABEL: func @address_space(
-//       CHECK:   %{{.*}}: !llvm<"{ float addrspace(7)*, [1 x i64] }">)
+//       CHECK:   %{{.*}}: !llvm<"{ float addrspace(7)*, [1 x i64] }*">)
+//       CHECK:   llvm.load %{{.*}} : !llvm<"{ float addrspace(7)*, [1 x i64] }*">
 func @address_space(%arg0 : memref<32xf32, (d0) -> (d0), 7>) {
   %0 = alloc() : memref<32xf32, (d0) -> (d0), 5>
   %1 = constant 7 : index
index 2a0be68..b7edf86 100644 (file)
@@ -155,13 +155,17 @@ func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
   return
 }
 
-func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
-  // expected-error@+1 {{type of function argument 0 does not match}}
-  "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
-      {kernel = @kernel_1}
-      : (index, index, index, index, index, index, f32) -> ()
-  return
-}
+// Due to the ordering of the current impl of lowering and LLVMLowering, type
+// checks need to be temporarily disabled.
+// TODO(ntv,zinenko,herhut): reactivate checks once "changing gpu.launchFunc
+// to encode target module" has landed.
+// func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
+//   // expected-err@+1 {{type of function argument 0 does not match}}
+//   "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
+//       {kernel = @kernel_1}
+//       : (index, index, index, index, index, index, f32) -> ()
+//   return
+// }
 
 // -----
 
index 03e16db..60ec505 100644 (file)
@@ -1,4 +1,5 @@
 // RUN: linalg1-opt %s | FileCheck %s
+// RUN: linalg1-opt %s -lower-linalg-to-llvm
 // RUN: linalg1-opt %s -lower-linalg-to-llvm | FileCheck %s -check-prefix=LLVM
 
 func @view_op(%arg0: memref<f32>, %arg1: memref<?xf32>, %arg2: memref<?x?xf32>) {
@@ -43,80 +44,82 @@ func @slice_op(%arg0: memref<?x?xf32>) {
 //       CHECK:  %[[r2:.*]] = linalg.range %{{.*}}:%[[N]]:%{{.*}} : !linalg.range
 //       CHECK:  %[[V:.*]] = linalg.view %{{.*}}[%[[r1]], %[[r2]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
 //       CHECK:  affine.for %{{.*}} = 0 to #map1(%{{.*}}) {
-//       CHECK:    affine.for %{{.*}} = 0 to #map1(%{{.*}}) {
-//       CHECK:      {{.*}} = linalg.slice %[[V]][%{{.*}}] {dim = 1} : !linalg.view<?x?xf32>, index
-//       CHECK:      %[[V2:.*]] = linalg.slice %[[V]][%{{.*}}] {dim = 0} : !linalg.view<?x?xf32>, index
-//       CHECK:      {{.*}} = linalg.slice %[[V2]][%{{.*}}] {dim = 0} : !linalg.view<?xf32>, index
+//       CHECK:   affine.for %{{.*}} = 0 to #map1(%{{.*}}) {
+//       CHECK:     {{.*}} = linalg.slice %[[V]][%{{.*}}] {dim = 1} : !linalg.view<?x?xf32>, index
+//       CHECK:     %[[V2:.*]] = linalg.slice %[[V]][%{{.*}}] {dim = 0} : !linalg.view<?x?xf32>, index
+//       CHECK:     {{.*}} = linalg.slice %[[V2]][%{{.*}}] {dim = 0} : !linalg.view<?xf32>, index
 
 func @rangeConversion(%arg0: index, %arg1: index, %arg2: index) {
   %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
   return
 }
 // LLVM-LABEL: @rangeConversion
-// LLVM-NEXT: %0 = llvm.mlir.undef : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mlir.undef : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
 
 func @viewRangeConversion(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: !linalg.range) {
   %0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
   return
 }
 // LLVM-LABEL: @viewRangeConversion
-// LLVM-NEXT: %0 = llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %1 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, [2 x i64] }">
-// LLVM-NEXT: %2 = llvm.insertvalue %1, %0[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %3 = llvm.extractvalue %arg0[1, 1] : !llvm<"{ float*, [2 x i64] }">
-// LLVM-NEXT: %4 = llvm.mlir.constant(1 : index) : !llvm.i64
-// LLVM-NEXT: %5 = llvm.mul %4, %3 : !llvm.i64
-// LLVM-NEXT: %6 = llvm.mlir.constant(0 : index) : !llvm.i64
-// LLVM-NEXT: %7 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %8 = llvm.mul %7, %5 : !llvm.i64
-// LLVM-NEXT: %9 = llvm.add %6, %8 : !llvm.i64
-// LLVM-NEXT: %10 = llvm.extractvalue %arg2[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %11 = llvm.mul %10, %4 : !llvm.i64
-// LLVM-NEXT: %12 = llvm.add %9, %11 : !llvm.i64
-// LLVM-NEXT: %13 = llvm.insertvalue %12, %2[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %14 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %15 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %16 = llvm.sub %15, %14 : !llvm.i64
-// LLVM-NEXT: %17 = llvm.insertvalue %16, %13[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %18 = llvm.extractvalue %arg2[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %19 = llvm.extractvalue %arg2[1] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %20 = llvm.sub %19, %18 : !llvm.i64
-// LLVM-NEXT: %21 = llvm.insertvalue %20, %17[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %22 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %23 = llvm.mul %5, %22 : !llvm.i64
-// LLVM-NEXT: %24 = llvm.insertvalue %23, %21[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %25 = llvm.extractvalue %arg2[2] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %26 = llvm.mul %4, %25 : !llvm.i64
-// LLVM-NEXT: %27 = llvm.insertvalue %26, %24[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// LLVM-NEXT:  llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, [2 x i64] }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[1, 1] : !llvm<"{ float*, [2 x i64] }">
+// LLVM-NEXT:  llvm.mlir.constant(1 : index) : !llvm.i64
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.mlir.constant(0 : index) : !llvm.i64
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
 
 func @viewNonRangeConversion(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: index) {
   %0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
   return
 }
 // LLVM-LABEL: @viewNonRangeConversion
-// LLVM-NEXT: %0 = llvm.mlir.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %1 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, [2 x i64] }">
-// LLVM-NEXT: %2 = llvm.insertvalue %1, %0[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %3 = llvm.extractvalue %arg0[1, 1] : !llvm<"{ float*, [2 x i64] }">
-// LLVM-NEXT: %4 = llvm.mlir.constant(1 : index) : !llvm.i64
-// LLVM-NEXT: %5 = llvm.mul %4, %3 : !llvm.i64
-// LLVM-NEXT: %6 = llvm.mlir.constant(0 : index) : !llvm.i64
-// LLVM-NEXT: %7 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %8 = llvm.mul %7, %5 : !llvm.i64
-// LLVM-NEXT: %9 = llvm.add %6, %8 : !llvm.i64
-// LLVM-NEXT: %10 = llvm.mul %arg2, %4 : !llvm.i64
-// LLVM-NEXT: %11 = llvm.add %9, %10 : !llvm.i64
-// LLVM-NEXT: %12 = llvm.insertvalue %11, %2[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %13 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %14 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %15 = llvm.sub %14, %13 : !llvm.i64
-// LLVM-NEXT: %16 = llvm.insertvalue %15, %12[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %17 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %18 = llvm.mul %5, %17 : !llvm.i64
-// LLVM-NEXT: %19 = llvm.insertvalue %18, %16[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT:  llvm.load %{{.*}} : !llvm<"{ float*, [2 x i64] }*">
+// LLVM-NEXT:  llvm.mlir.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, [2 x i64] }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[1, 1] : !llvm<"{ float*, [2 x i64] }">
+// LLVM-NEXT:  llvm.mlir.constant(1 : index) : !llvm.i64
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.mlir.constant(0 : index) : !llvm.i64
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
 
 func @sliceRangeConversion(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
   %0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
@@ -124,27 +127,28 @@ func @sliceRangeConversion(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2:
   return
 }
 // LLVM-LABEL: @sliceRangeConversion
-// LLVM:      %28 = llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %29 = llvm.extractvalue %27[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %30 = llvm.insertvalue %29, %28[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %31 = llvm.extractvalue %27[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %32 = llvm.extractvalue %arg3[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %33 = llvm.extractvalue %27[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %34 = llvm.mul %32, %33 : !llvm.i64
-// LLVM-NEXT: %35 = llvm.add %31, %34 : !llvm.i64
-// LLVM-NEXT: %36 = llvm.insertvalue %35, %30[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %37 = llvm.extractvalue %arg3[1] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %38 = llvm.extractvalue %arg3[0] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %39 = llvm.sub %37, %38 : !llvm.i64
-// LLVM-NEXT: %40 = llvm.extractvalue %27[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %41 = llvm.extractvalue %arg3[2] : !llvm<"{ i64, i64, i64 }">
-// LLVM-NEXT: %42 = llvm.mul %40, %41 : !llvm.i64
-// LLVM-NEXT: %43 = llvm.insertvalue %39, %36[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %44 = llvm.insertvalue %42, %43[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %45 = llvm.extractvalue %27[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %46 = llvm.extractvalue %27[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %47 = llvm.insertvalue %45, %44[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %48 = llvm.insertvalue %46, %47[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM:       llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM:       llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
+// LLVM-NEXT:  llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.extractvalue %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT:  llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
 
 func @sliceNonRangeConversion2(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: index) {
   %0 = linalg.view %arg0[%arg1, %arg2] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
@@ -152,15 +156,15 @@ func @sliceNonRangeConversion2(%arg0: memref<?x?xf32>, %arg1: !linalg.range, %ar
   return
 }
 // LLVM-LABEL: @sliceNonRangeConversion2
-//      LLVM: %28 = llvm.mlir.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %29 = llvm.extractvalue %27[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %30 = llvm.insertvalue %29, %28[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %31 = llvm.extractvalue %27[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %32 = llvm.extractvalue %27[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %33 = llvm.mul %arg3, %32 : !llvm.i64
-// LLVM-NEXT: %34 = llvm.add %31, %33 : !llvm.i64
-// LLVM-NEXT: %35 = llvm.insertvalue %34, %30[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %36 = llvm.extractvalue %27[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %37 = llvm.extractvalue %27[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
-// LLVM-NEXT: %38 = llvm.insertvalue %36, %35[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-// LLVM-NEXT: %39 = llvm.insertvalue %37, %38[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+//      LLVM: llvm.mlir.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT: llvm.mul %{{.*}}arg3, %{{.*}} : !llvm.i64
+// LLVM-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
+// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT: llvm.extractvalue %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+// LLVM-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
index c394662..8fc3757 100644 (file)
@@ -80,29 +80,40 @@ extern "C" int32_t mcuStreamSynchronize(void *stream) {
 
 /// Helper functions for writing mlir example code
 
-// A struct that corresponds to how MLIR represents unknown-length 1d memrefs.
-struct memref_t {
-  float *values;
-  intptr_t length;
+// A struct that corresponds to how MLIR represents unknown-sizes 1d memrefs.
+template <typename T, int N> struct MemRefType {
+  T *data;
+  int64_t sizes[N];
 };
 
 // Allows to register a pointer with the CUDA runtime. Helpful until
 // we have transfer functions implemented.
-extern "C" void mcuMemHostRegister(const memref_t arg, int32_t flags) {
+extern "C" void mcuMemHostRegister(const MemRefType<float, 1> *arg,
+                                   int32_t flags) {
   reportErrorIfAny(
-      cuMemHostRegister(arg.values, arg.length * sizeof(float), flags),
+      cuMemHostRegister(arg->data, arg->sizes[0] * sizeof(float), flags),
       "MemHostRegister");
+  for (int pos = 0; pos < arg->sizes[0]; pos++) {
+    arg->data[pos] = 1.23f;
+  }
+}
+
+// Allows to register a pointer with the CUDA runtime. Helpful until
+// we have transfer functions implemented.
+extern "C" void mcuMemHostRegisterPtr(void *ptr, int32_t flags) {
+  reportErrorIfAny(cuMemHostRegister(ptr, sizeof(void *), flags),
+                   "MemHostRegister");
 }
 
 /// Prints the given float array to stderr.
-extern "C" void mcuPrintFloat(const memref_t arg) {
-  if (arg.length == 0) {
+extern "C" void mcuPrintFloat(const MemRefType<float, 1> *arg) {
+  if (arg->sizes[0] == 0) {
     llvm::outs() << "[]\n";
     return;
   }
-  llvm::outs() << "[" << arg.values[0];
-  for (int pos = 1; pos < arg.length; pos++) {
-    llvm::outs() << ", " << arg.values[pos];
+  llvm::outs() << "[" << arg->data[0];
+  for (int pos = 1; pos < arg->sizes[0]; pos++) {
+    llvm::outs() << ", " << arg->data[pos];
   }
   llvm::outs() << "]\n";
 }