return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
}
+/// Returns the equivalent of `void*` for opaque arguments to the
+/// execution engine.
+static Type getOpaquePointerType(PatternRewriter &rewriter) {
+ return LLVM::LLVMPointerType::get(rewriter.getI8Type());
+}
+
/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
/// of ABI complications with passing in and returning MemRefs to C functions.
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
TypeRange resultType, ValueRange operands,
- bool emitCInterface = false) {
+ bool emitCInterface) {
MLIRContext *context = op->getContext();
auto module = op->getParentOfType<ModuleOp>();
auto result = SymbolRefAttr::get(context, name);
return result;
}
+/// Creates a `CallOp` to the function reference returned by `getFunc()`.
+static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
+ TypeRange resultType, ValueRange operands,
+ bool emitCInterface = false) {
+ auto fn = getFunc(op, name, resultType, operands, emitCInterface);
+ return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
+}
+
+/// Replaces the `op` with a `CallOp` to the function reference returned
+/// by `getFunc()`.
+static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
+ StringRef name, TypeRange resultType,
+ ValueRange operands,
+ bool emitCInterface = false) {
+ auto fn = getFunc(op, name, resultType, operands, emitCInterface);
+ return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
+}
+
/// Generates dimension size call.
static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
SparseTensorEncodingAttr &enc, Value src,
if (AffineMap p = enc.getDimOrdering())
idx = p.getPermutedPosition(idx);
// Generate the call.
- Location loc = op->getLoc();
StringRef name = "sparseDimSize";
- SmallVector<Value, 2> params;
- params.push_back(src);
- params.push_back(constantIndex(rewriter, loc, idx));
+ SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
Type iTp = rewriter.getIndexType();
- auto fn = getFunc(op, name, iTp, params);
- return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
+ return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
}
/// Generates a call into the "swiss army knife" method of the sparse runtime
/// support library for materializing sparse tensors into the computation.
static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
ArrayRef<Value> params) {
- Location loc = op->getLoc();
StringRef name = "newSparseTensor";
- Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
- auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
- auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
+ Type pTp = getOpaquePointerType(rewriter);
+ auto call = createFuncCall(rewriter, op, name, pTp, params,
+ /*emitCInterface=*/true);
return call.getResult(0);
}
static void sizesFromSrc(ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &sizes, Location loc,
Value src) {
- ShapedType stp = src.getType().cast<ShapedType>();
- for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+ unsigned rank = src.getType().cast<ShapedType>().getRank();
+ for (unsigned i = 0; i < rank; i++)
sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
}
SmallVector<Value, 4> &sizes, Operation *op,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value src) {
+ Location loc = op->getLoc();
auto shape = stp.getShape();
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
if (shape[i] == ShapedType::kDynamicSize)
sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
else
- sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
+ sizes.push_back(constantIndex(rewriter, loc, shape[i]));
}
/// Generates an uninitialized temporary buffer of the given size and
}
params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
- ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
+ Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType();
params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
- params.push_back(
- constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
- // User action and pointer.
- Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
- if (!ptr)
- ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
+ params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp));
+ // User action.
params.push_back(constantAction(rewriter, loc, action));
+ // Payload pointer.
+ if (!ptr)
+ ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter));
params.push_back(ptr);
}
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
Type eltType, Value ptr, Value val, Value ind,
Value perm) {
- Location loc = op->getLoc();
StringRef name;
if (eltType.isF64())
name = "addEltF64";
name = "addEltI8";
else
llvm_unreachable("Unknown element type");
- SmallVector<Value, 8> params;
- params.push_back(ptr);
- params.push_back(val);
- params.push_back(ind);
- params.push_back(perm);
- Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
- auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
- rewriter.create<CallOp>(loc, pTp, fn, params);
+ SmallVector<Value, 4> params{ptr, val, ind, perm};
+ Type pTp = getOpaquePointerType(rewriter);
+ createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
}
/// Generates a call to `iter->getNext()`. If there is a next element,
/// the memory for `iter` is freed and the return value is false.
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
Value iter, Value ind, Value elemPtr) {
- Location loc = op->getLoc();
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
StringRef name;
if (elemTp.isF64())
name = "getNextI8";
else
llvm_unreachable("Unknown element type");
- SmallVector<Value, 3> params;
- params.push_back(iter);
- params.push_back(ind);
- params.push_back(elemPtr);
+ SmallVector<Value, 3> params{iter, ind, elemPtr};
Type i1 = rewriter.getI1Type();
- auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
- auto call = rewriter.create<CallOp>(loc, i1, fn, params);
+ auto call = createFuncCall(rewriter, op, name, i1, params,
+ /*emitCInterface=*/true);
return call.getResult(0);
}
}
Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes);
Value zero = constantZero(rewriter, loc, elemTp);
- rewriter.create<linalg::FillOp>(loc, zero, mem).result();
+ rewriter.create<linalg::FillOp>(loc, zero, mem);
return mem;
}
matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef name = "delSparseTensor";
- TypeRange none;
- auto fn = getFunc(op, name, none, adaptor.getOperands());
- rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
+ TypeRange noTp;
+ createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
rewriter.eraseOp(op);
return success();
}
name = "sparsePointers8";
else
return failure();
- auto fn = getFunc(op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};
name = "sparseIndices8";
else
return failure();
- auto fn = getFunc(op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};
name = "sparseValuesI8";
else
return failure();
- auto fn = getFunc(op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};
// Finalize any pending insertions.
StringRef name = "endInsert";
TypeRange noTp;
- auto fn = getFunc(op, name, noTp, adaptor.getOperands());
- rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands());
+ createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
}
rewriter.replaceOp(op, adaptor.getOperands());
return success();
else
llvm_unreachable("Unknown element type");
TypeRange noTp;
- auto fn =
- getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};