}
assert(dstIndices.size() == dstRank);
}
+
+FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
+ TypeRange resultType,
+ ValueRange operands,
+ EmitCInterface emitCInterface) {
+ MLIRContext *context = module.getContext();
+ auto result = SymbolRefAttr::get(context, name);
+ auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+ if (!func) {
+ OpBuilder moduleBuilder(module.getBodyRegion());
+ func = moduleBuilder.create<func::FuncOp>(
+ module.getLoc(), name,
+ FunctionType::get(context, operands.getTypes(), resultType));
+ func.setPrivate();
+ if (static_cast<bool>(emitCInterface))
+ func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
+ UnitAttr::get(context));
+ }
+ return result;
+}
+
+func::CallOp mlir::sparse_tensor::createFuncCall(
+ OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
+ ValueRange operands, EmitCInterface emitCInterface) {
+ auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
+ FlatSymbolRefAttr fn =
+ getFunc(module, name, resultType, operands, emitCInterface);
+ return builder.create<func::CallOp>(loc, resultType, fn, operands);
+}
+
+Type mlir::sparse_tensor::getOpaquePointerType(OpBuilder &builder) {
+ return LLVM::LLVMPointerType::get(builder.getI8Type());
+}
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/ExecutionEngine/SparseTensor/Enums.h"
namespace sparse_tensor {
+/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
+/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
+enum class EmitCInterface : bool { Off = false, On = true };
+
//===----------------------------------------------------------------------===//
// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate
// loop structure to (co-iterate) sparse tensors.
ArrayRef<Value> dstShape,
SmallVectorImpl<Value> &dstIndices);
+/// Returns a function reference (first hit also inserts into module). Sets
+/// the "_emit_c_interface" on the function declaration when requested,
+/// so that LLVM lowering generates a wrapper function that takes care
+/// of ABI complications with passing in and returning MemRefs to C functions.
+FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType,
+ ValueRange operands, EmitCInterface emitCInterface);
+
+/// Creates a `CallOp` to the function reference returned by `getFunc()` in
+/// the builder's module.
+func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name,
+ TypeRange resultType, ValueRange operands,
+ EmitCInterface emitCInterface);
+
+/// Returns the equivalent of `void*` for opaque arguments to the
+/// execution engine.
+Type getOpaquePointerType(OpBuilder &builder);
+
//===----------------------------------------------------------------------===//
// Inlined constant generators.
//
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
namespace {
-/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
-/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
-enum class EmitCInterface : bool { Off = false, On = true };
-
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
-/// Returns the equivalent of `void*` for opaque arguments to the
-/// execution engine.
-static Type getOpaquePointerType(OpBuilder &builder) {
- return LLVM::LLVMPointerType::get(builder.getI8Type());
-}
-
/// Maps each sparse tensor type to an opaque pointer.
static Optional<Type> convertSparseTensorTypes(Type type) {
if (getSparseTensorEncoding(type) != nullptr)
return llvm::None;
}
-/// Returns a function reference (first hit also inserts into module). Sets
-/// the "_emit_c_interface" on the function declaration when requested,
-/// so that LLVM lowering generates a wrapper function that takes care
-/// of ABI complications with passing in and returning MemRefs to C functions.
-static FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name,
- TypeRange resultType, ValueRange operands,
- EmitCInterface emitCInterface) {
- MLIRContext *context = module.getContext();
- auto result = SymbolRefAttr::get(context, name);
- auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
- if (!func) {
- OpBuilder moduleBuilder(module.getBodyRegion());
- func = moduleBuilder.create<func::FuncOp>(
- module.getLoc(), name,
- FunctionType::get(context, operands.getTypes(), resultType));
- func.setPrivate();
- if (static_cast<bool>(emitCInterface))
- func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
- UnitAttr::get(context));
- }
- return result;
-}
-
-/// Creates a `CallOp` to the function reference returned by `getFunc()` in
-/// the builder's module.
-static func::CallOp createFuncCall(OpBuilder &builder, Location loc,
- StringRef name, TypeRange resultType,
- ValueRange operands,
- EmitCInterface emitCInterface) {
- auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
- auto fn = getFunc(module, name, resultType, operands, emitCInterface);
- return builder.create<func::CallOp>(loc, resultType, fn, operands);
-}
-
/// Replaces the `op` with a `CallOp` to the function reference returned
/// by `getFunc()`.
static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,