tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
```
+## Unary Operations
+
+### 'exp' operation
+
+Syntax:
+
+``` {.ebnf}
+operation ::= ssa-id `=` `exp` ssa-use `:` type
+```
+
+Examples:
+
+```mlir {.mlir}
+// Scalar natural exponential.
+%a = exp %b : f64
+
+// SIMD vector element-wise natural exponential.
+%f = exp %g : vector<4xf32>
+
+// Tensor element-wise natural exponential.
+%x = exp %y : tensor<4x?xf8>
+```
+
+The `exp` operation takes one operand and returns one result of the same type.
+This type may be a float scalar type, a vector whose element type is float, or a
+tensor of floats. It has no standard attributes.
+
## Arithmetic Operations
Basic arithmetic in MLIR is specified by standard operations described in this
section.
-TODO: "sub" etc. Let's not get excited about filling this out yet, we can define
-these on demand. We should be highly informed by and learn from the operations
-supported by HLO and LLVM.
-
### 'addi' operation
Syntax:
// SIMD vector element-wise addition, e.g. for Intel SSE.
%f = addi %g, %h : vector<4xi32>
-// Tensor element-wise addition, analogous to HLO's add operation.
+// Tensor element-wise addition.
%x = addi %y, %z : tensor<4x?xi8>
```
// SIMD vector addition, e.g. for Intel SSE.
%f = addf %g, %h : vector<4xf32>
-// Tensor addition, analogous to HLO's add operation.
+// Tensor addition.
%x = addf %y, %z : tensor<4x?xbf16>
```
// SIMD pointwise vector multiplication, e.g. for Intel SSE.
%f = mulf %g, %h : vector<4xf32>
-// Tensor pointwise multiplication, analogous to HLO's pointwise multiply operation.
+// Tensor pointwise multiplication.
%x = mulf %y, %z : tensor<4x?xbf16>
```
let hasFolder = 1;
}
+// Base class for unary ops. Requires single operand and result. Individual
+// classes will have `operand` accessor.
+class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
+ Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
+ let results = (outs AnyType);
+ let printer = [{
+ return printStandardUnaryOp(this->getOperation(), p);
+ }];
+}
+
+class UnaryOpSameOperandAndResultType<string mnemonic, list<OpTrait> traits = []> :
+ UnaryOp<mnemonic, !listconcat(traits, [SameOperandsAndResultType])> {
+ let parser = [{
+ return impl::parseOneResultSameOperandTypeOp(parser, result);
+ }];
+}
+
+class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
+ UnaryOpSameOperandAndResultType<mnemonic, traits>,
+ Arguments<(ins FloatLike:$operand)>;
+
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
let hasFolder = 1;
}
+def ExpOp : FloatUnaryOp<"exp"> {
+ let summary = "base-e exponential of the specified value";
+}
+
def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
let summary = "element extract operation";
let description = [{
return res;
}
+template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
+ static_assert(
+ std::is_base_of<
+ typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
+ SourceOp>::value,
+ "wrong operand count");
+};
+
+template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
+ static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
+ "expected a single operand");
+};
+
+template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
+ OpCountValidator<SourceOp, OpCount>();
+}
+
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
-// Ops for binary ops with one result. This supports higher-dimensional vector
+// Ops for N-ary ops with one result. This supports higher-dimensional vector
// types.
-template <typename SourceOp, typename TargetOp>
-struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+template <typename SourceOp, typename TargetOp, unsigned OpCount>
+struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
- using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
+ using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
- static_assert(
- std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
- "expected binary op");
+ ValidateOpCount<SourceOp, OpCount>();
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
- "expected single result op");
+ "expected same operands and result type");
// Cannot convert ops if their operands are not of LLVM type.
for (Value *operand : operands) {
arraySizes.push_back(llvmTy.getArrayNumElements());
llvmTy = llvmTy.getArrayElementType();
}
- assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
+ assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
auto llvmVectorTy = llvmTy;
// Iteratively extract a position coordinates with basis `arraySize` from a
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
- Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmVectorTy, operands[0], position);
- Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
- loc, llvmVectorTy, operands[1], position);
+ SmallVector<Value *, OpCount> extractedOperands;
+ for (unsigned i = 0; i < OpCount; ++i) {
+ extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmVectorTy, operands[i], position));
+ }
Value *newVal = rewriter.create<TargetOp>(
- loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
- op->getAttrs());
+ loc, llvmVectorTy, extractedOperands, op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position);
}
}
};
+template <typename SourceOp, typename TargetOp>
+using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
+template <typename SourceOp, typename TargetOp>
+using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
+
// Specific lowerings.
// FIXME: this should be tablegen'ed.
+struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::exp> {
+ using Super::Super;
+};
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
using Super::Super;
};
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// FIXME: this should be tablegen'ed
+ // clang-format off
patterns.insert<
- AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
- BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
- CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
- DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
- DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
- MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
- RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
- SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering,
- SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering,
- SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
+ AddFOpLowering,
+ AddIOpLowering,
+ AllocOpLowering,
+ AndOpLowering,
+ BranchOpLowering,
+ CallIndirectOpLowering,
+ CallOpLowering,
+ CmpFOpLowering,
+ CmpIOpLowering,
+ CondBranchOpLowering,
+ ConstLLVMOpLowering,
+ DeallocOpLowering,
+ DimOpLowering,
+ DivFOpLowering,
+ DivISOpLowering,
+ DivIUOpLowering,
+ ExpOpLowering,
+ FPExtLowering,
+ FPTruncLowering,
+ FuncOpConversion,
+ IndexCastOpLowering,
+ LoadOpLowering,
+ MemRefCastOpLowering,
+ MulFOpLowering,
+ MulIOpLowering,
+ OrOpLowering,
+ RemFOpLowering,
+ RemISOpLowering,
+ RemIUOpLowering,
+ ReturnOpLowering,
+ SIToFPLowering,
+ SelectOpLowering,
+ SignExtendIOpLowering,
+ SplatOpLowering,
+ StoreOpLowering,
+ SubFOpLowering,
+ SubIOpLowering,
+ TruncateIOpLowering,
+ XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
+ // clang-format on
}
// Convert types using the stored LLVM IR module.
// StandardOpsDialect
//===----------------------------------------------------------------------===//
+/// A custom unary operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
+ assert(op->getNumOperands() == 1 && "unary op should have one operand");
+ assert(op->getNumResults() == 1 && "unary op should have one result");
+
+ const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
+ << *op->getOperand(0);
+ p.printOptionalAttrDict(op->getAttrs());
+ p << " : " << op->getOperand(0)->getType();
+}
+
/// A custom binary operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
return;
}
- p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+ const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0) << ", " << *op->getOperand(1);
p.printOptionalAttrDict(op->getAttrs());
/// A custom cast operation printer that omits the "std." prefix from the
/// operation names.
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
- p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+ const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
<< op->getResult(0)->getType();
}