[mlir][sparse] Move a few routines to CodegenUtils.
authorbixia1 <bixia@google.com>
Tue, 11 Oct 2022 20:22:14 +0000 (13:22 -0700)
committerbixia1 <bixia@google.com>
Tue, 11 Oct 2022 21:42:18 +0000 (14:42 -0700)
Move a few supporting routines for generating function calls to CodegenUtils so
that they can be used by the codegen path for sparse tensor file input and
output.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135691

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

index f56479d..aaeb625 100644 (file)
@@ -550,3 +550,36 @@ void mlir::sparse_tensor::translateIndicesArray(
   }
   assert(dstIndices.size() == dstRank);
 }
+
+FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
+                                               TypeRange resultType,
+                                               ValueRange operands,
+                                               EmitCInterface emitCInterface) {
+  MLIRContext *context = module.getContext();
+  auto result = SymbolRefAttr::get(context, name);
+  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+  if (!func) {
+    OpBuilder moduleBuilder(module.getBodyRegion());
+    func = moduleBuilder.create<func::FuncOp>(
+        module.getLoc(), name,
+        FunctionType::get(context, operands.getTypes(), resultType));
+    func.setPrivate();
+    if (static_cast<bool>(emitCInterface))
+      func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
+                    UnitAttr::get(context));
+  }
+  return result;
+}
+
+func::CallOp mlir::sparse_tensor::createFuncCall(
+    OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
+    ValueRange operands, EmitCInterface emitCInterface) {
+  auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
+  FlatSymbolRefAttr fn =
+      getFunc(module, name, resultType, operands, emitCInterface);
+  return builder.create<func::CallOp>(loc, resultType, fn, operands);
+}
+
+Type mlir::sparse_tensor::getOpaquePointerType(OpBuilder &builder) {
+  return LLVM::LLVMPointerType::get(builder.getI8Type());
+}
index 7e8c5eb..9908060 100644 (file)
@@ -15,6 +15,8 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/ExecutionEngine/SparseTensor/Enums.h"
@@ -28,6 +30,10 @@ class Value;
 
 namespace sparse_tensor {
 
+/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
+/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
+enum class EmitCInterface : bool { Off = false, On = true };
+
 //===----------------------------------------------------------------------===//
 // SparseTensorLoopEmiter class, manages sparse tensors and helps to generate
 // loop structure to (co-iterate) sparse tensors.
@@ -225,6 +231,23 @@ void translateIndicesArray(OpBuilder &builder, Location loc,
                            ArrayRef<Value> dstShape,
                            SmallVectorImpl<Value> &dstIndices);
 
+/// Returns a function reference (first hit also inserts into module). Sets
+/// the "_emit_c_interface" on the function declaration when requested,
+/// so that LLVM lowering generates a wrapper function that takes care
+/// of ABI complications with passing in and returning MemRefs to C functions.
+FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType,
+                          ValueRange operands, EmitCInterface emitCInterface);
+
+/// Creates a `CallOp` to the function reference returned by `getFunc()` in
+/// the builder's module.
+func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name,
+                            TypeRange resultType, ValueRange operands,
+                            EmitCInterface emitCInterface);
+
+/// Returns the equivalent of `void*` for opaque arguments to the
+/// execution engine.
+Type getOpaquePointerType(OpBuilder &builder);
+
 //===----------------------------------------------------------------------===//
 // Inlined constant generators.
 //
index 47f42d5..00d4525 100644 (file)
@@ -20,8 +20,6 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -36,20 +34,10 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
-/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
-enum class EmitCInterface : bool { Off = false, On = true };
-
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-/// Returns the equivalent of `void*` for opaque arguments to the
-/// execution engine.
-static Type getOpaquePointerType(OpBuilder &builder) {
-  return LLVM::LLVMPointerType::get(builder.getI8Type());
-}
-
 /// Maps each sparse tensor type to an opaque pointer.
 static Optional<Type> convertSparseTensorTypes(Type type) {
   if (getSparseTensorEncoding(type) != nullptr)
@@ -57,40 +45,6 @@ static Optional<Type> convertSparseTensorTypes(Type type) {
   return llvm::None;
 }
 
-/// Returns a function reference (first hit also inserts into module). Sets
-/// the "_emit_c_interface" on the function declaration when requested,
-/// so that LLVM lowering generates a wrapper function that takes care
-/// of ABI complications with passing in and returning MemRefs to C functions.
-static FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name,
-                                 TypeRange resultType, ValueRange operands,
-                                 EmitCInterface emitCInterface) {
-  MLIRContext *context = module.getContext();
-  auto result = SymbolRefAttr::get(context, name);
-  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
-  if (!func) {
-    OpBuilder moduleBuilder(module.getBodyRegion());
-    func = moduleBuilder.create<func::FuncOp>(
-        module.getLoc(), name,
-        FunctionType::get(context, operands.getTypes(), resultType));
-    func.setPrivate();
-    if (static_cast<bool>(emitCInterface))
-      func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
-                    UnitAttr::get(context));
-  }
-  return result;
-}
-
-/// Creates a `CallOp` to the function reference returned by `getFunc()` in
-/// the builder's module.
-static func::CallOp createFuncCall(OpBuilder &builder, Location loc,
-                                   StringRef name, TypeRange resultType,
-                                   ValueRange operands,
-                                   EmitCInterface emitCInterface) {
-  auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
-  auto fn = getFunc(module, name, resultType, operands, emitCInterface);
-  return builder.create<func::CallOp>(loc, resultType, fn, operands);
-}
-
 /// Replaces the `op` with  a `CallOp` to the function reference returned
 /// by `getFunc()`.
 static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,