From 2d252a0f5cc2b320fd57f2eb9313a11bef74e793 Mon Sep 17 00:00:00 2001 From: bixia1 Date: Tue, 11 Oct 2022 13:22:14 -0700 Subject: [PATCH] [mlir][sparse] Move a few routines to CodegenUtils. Move a few supporting routines for generating function calls to CodegenUtils so that they can be used by the codegen path for sparse tensor file input and output. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135691 --- .../SparseTensor/Transforms/CodegenUtils.cpp | 33 ++++++++++++++++ .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 23 +++++++++++ .../Transforms/SparseTensorConversion.cpp | 46 ---------------------- 3 files changed, 56 insertions(+), 46 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index f56479d..aaeb625 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -550,3 +550,36 @@ void mlir::sparse_tensor::translateIndicesArray( } assert(dstIndices.size() == dstRank); } + +FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name, + TypeRange resultType, + ValueRange operands, + EmitCInterface emitCInterface) { + MLIRContext *context = module.getContext(); + auto result = SymbolRefAttr::get(context, name); + auto func = module.lookupSymbol(result.getAttr()); + if (!func) { + OpBuilder moduleBuilder(module.getBodyRegion()); + func = moduleBuilder.create( + module.getLoc(), name, + FunctionType::get(context, operands.getTypes(), resultType)); + func.setPrivate(); + if (static_cast(emitCInterface)) + func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), + UnitAttr::get(context)); + } + return result; +} + +func::CallOp mlir::sparse_tensor::createFuncCall( + OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, + ValueRange operands, EmitCInterface emitCInterface) { + auto module = builder.getBlock()->getParentOp()->getParentOfType(); + FlatSymbolRefAttr fn = + getFunc(module, name, resultType, operands, emitCInterface); + return builder.create(loc, resultType, fn, operands); +} + +Type mlir::sparse_tensor::getOpaquePointerType(OpBuilder &builder) { + return LLVM::LLVMPointerType::get(builder.getI8Type()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 7e8c5eb..9908060 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -15,6 +15,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/ExecutionEngine/SparseTensor/Enums.h" @@ -28,6 +30,10 @@ class Value; namespace sparse_tensor { +/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, +/// `createFuncCall()`, and `replaceOpWithFuncCall()`. +enum class EmitCInterface : bool { Off = false, On = true }; + //===----------------------------------------------------------------------===// // SparseTensorLoopEmiter class, manages sparse tensors and helps to generate // loop structure to (co-iterate) sparse tensors. @@ -225,6 +231,23 @@ void translateIndicesArray(OpBuilder &builder, Location loc, ArrayRef dstShape, SmallVectorImpl &dstIndices); +/// Returns a function reference (first hit also inserts into module). Sets +/// the "_emit_c_interface" on the function declaration when requested, +/// so that LLVM lowering generates a wrapper function that takes care +/// of ABI complications with passing in and returning MemRefs to C functions. +FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, + ValueRange operands, EmitCInterface emitCInterface); + +/// Creates a `CallOp` to the function reference returned by `getFunc()` in +/// the builder's module. +func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, + TypeRange resultType, ValueRange operands, + EmitCInterface emitCInterface); + +/// Returns the equivalent of `void*` for opaque arguments to the +/// execution engine. +Type getOpaquePointerType(OpBuilder &builder); + //===----------------------------------------------------------------------===// // Inlined constant generators. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 47f42d5..00d4525 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -20,8 +20,6 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -36,20 +34,10 @@ using namespace mlir::sparse_tensor; namespace { -/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, -/// `createFuncCall()`, and `replaceOpWithFuncCall()`. -enum class EmitCInterface : bool { Off = false, On = true }; - //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// -/// Returns the equivalent of `void*` for opaque arguments to the -/// execution engine. -static Type getOpaquePointerType(OpBuilder &builder) { - return LLVM::LLVMPointerType::get(builder.getI8Type()); -} - /// Maps each sparse tensor type to an opaque pointer. static Optional convertSparseTensorTypes(Type type) { if (getSparseTensorEncoding(type) != nullptr) @@ -57,40 +45,6 @@ static Optional convertSparseTensorTypes(Type type) { return llvm::None; } -/// Returns a function reference (first hit also inserts into module). Sets -/// the "_emit_c_interface" on the function declaration when requested, -/// so that LLVM lowering generates a wrapper function that takes care -/// of ABI complications with passing in and returning MemRefs to C functions. -static FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, - TypeRange resultType, ValueRange operands, - EmitCInterface emitCInterface) { - MLIRContext *context = module.getContext(); - auto result = SymbolRefAttr::get(context, name); - auto func = module.lookupSymbol(result.getAttr()); - if (!func) { - OpBuilder moduleBuilder(module.getBodyRegion()); - func = moduleBuilder.create( - module.getLoc(), name, - FunctionType::get(context, operands.getTypes(), resultType)); - func.setPrivate(); - if (static_cast(emitCInterface)) - func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), - UnitAttr::get(context)); - } - return result; -} - -/// Creates a `CallOp` to the function reference returned by `getFunc()` in -/// the builder's module. -static func::CallOp createFuncCall(OpBuilder &builder, Location loc, - StringRef name, TypeRange resultType, - ValueRange operands, - EmitCInterface emitCInterface) { - auto module = builder.getBlock()->getParentOp()->getParentOfType(); - auto fn = getFunc(module, name, resultType, operands, emitCInterface); - return builder.create(loc, resultType, fn, operands); -} - /// Replaces the `op` with a `CallOp` to the function reference returned /// by `getFunc()`. static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, -- 2.7.4