typedef void *mlir_func_t;
/// Opaque C type for mlir::Attribute.
typedef const void *mlir_attr_t;
-/// Opaque C type for mlir::edsc::MLIREmiter.
-typedef void *edsc_mlir_emitter_t;
-/// Opaque C type for mlir::edsc::Expr.
-typedef void *edsc_expr_t;
-/// Opaque C type for mlir::edsc::MaxExpr.
-typedef void *edsc_max_expr_t;
-/// Opaque C type for mlir::edsc::MinExpr.
-typedef void *edsc_min_expr_t;
-/// Opaque C type for mlir::edsc::Stmt.
-typedef void *edsc_stmt_t;
-/// Opaque C type for mlir::edsc::Block.
-typedef void *edsc_block_t;
/// Simple C lists for non-owning mlir Opaque C types.
/// Recommended usage is construction from the `data()` and `size()` of a scoped
} mlir_type_list_t;
typedef struct {
- edsc_expr_t *exprs;
- uint64_t n;
-} edsc_expr_list_t;
-
-typedef struct {
- edsc_stmt_t *stmts;
- uint64_t n;
-} edsc_stmt_list_t;
-
-typedef struct {
- edsc_expr_t base;
- edsc_expr_list_t indices;
-} edsc_indexed_t;
-
-typedef struct {
- edsc_indexed_t *list;
- uint64_t n;
-} edsc_indexed_list_t;
-
-typedef struct {
- edsc_block_t *list;
- uint64_t n;
-} edsc_block_list_t;
-
-typedef struct {
const char *name;
mlir_attr_t value;
} mlir_named_attr_t;
/// Returns the arity of `function`.
unsigned getFunctionArity(mlir_func_t function);
-/// Returns a new opaque mlir::edsc::Expr that is bound into `emitter` with a
-/// constant of the specified type.
-edsc_expr_t bindConstantBF16(edsc_mlir_emitter_t emitter, double value);
-edsc_expr_t bindConstantF16(edsc_mlir_emitter_t emitter, float value);
-edsc_expr_t bindConstantF32(edsc_mlir_emitter_t emitter, float value);
-edsc_expr_t bindConstantF64(edsc_mlir_emitter_t emitter, double value);
-edsc_expr_t bindConstantInt(edsc_mlir_emitter_t emitter, int64_t value,
- unsigned bitwidth);
-edsc_expr_t bindConstantIndex(edsc_mlir_emitter_t emitter, int64_t value);
-edsc_expr_t bindConstantFunction(edsc_mlir_emitter_t emitter,
- mlir_func_t function);
-
/// Returns the rank of the `function` argument at position `pos`.
/// If the argument is of MemRefType, this returns the rank of the MemRef.
/// Otherwise returns `0`.
/// Returns an opaque mlir::Type of the `function` argument at position `pos`.
mlir_type_t getTypeOfFunctionArgument(mlir_func_t function, unsigned pos);
-/// Returns an opaque mlir::edsc::Expr that has been bound to the `pos` argument
-/// of `function`.
-edsc_expr_t bindFunctionArgument(edsc_mlir_emitter_t emitter,
- mlir_func_t function, unsigned pos);
-
-/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have
-/// been bound to each argument of `function`.
-///
-/// Prerequisites:
-/// - `result` must have been preallocated with space for exactly the number
-/// of arguments of `function`.
-void bindFunctionArguments(edsc_mlir_emitter_t emitter, mlir_func_t function,
- edsc_expr_list_t *result);
-
-/// Returns the rank of `boundMemRef`. This API function is provided to more
-/// easily compose with `bindFunctionArgument`. A similar function could be
-/// provided for an mlir_type_t of type MemRefType but it is expected that users
-/// of this API either:
-/// 1. construct the MemRefType explicitly, in which case they already have
-/// access to the rank and shape of the MemRefType;
-/// 2. access MemRefs via mlir_function_t *values* in which case they would
-/// pass edsc_expr_t bound to an edsc_emitter_t.
-///
-/// Prerequisites:
-/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
-/// in `emitter`.
-unsigned getBoundMemRefRank(edsc_mlir_emitter_t emitter,
- edsc_expr_t boundMemRef);
-
-/// Fills the preallocated list `result` with opaque mlir::edsc::Expr that have
-/// been bound to each dimension of `boundMemRef`.
-///
-/// Prerequisites:
-/// - `result` must have been preallocated with space for exactly the rank of
-/// `boundMemRef`;
-/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
-/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue
-/// that can only be recovered from `emitter`.
-void bindMemRefShape(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
- edsc_expr_list_t *result);
-
-/// Fills the preallocated lists `resultLbs`, `resultUbs` and `resultSteps` with
-/// opaque mlir::edsc::Expr that have been bound to proper values to traverse
-/// each dimension of `memRefType`.
-/// At the moment:
-/// - `resultsLbs` are always bound to the constant index `0`;
-/// - `resultsUbs` are always bound to the shape of `memRefType`;
-/// - `resultsSteps` are always bound to the constant index `1`.
-/// In the future, this will allow non-contiguous MemRef views.
-///
-/// Prerequisites:
-/// - `resultLbs`, `resultUbs` and `resultSteps` must have each been
-/// preallocated with space for exactly the rank of `boundMemRef`;
-/// - `boundMemRef` must be an opaque edsc_expr_t that has alreay been bound
-/// in `emitter`. This is because symbolic MemRef shapes require an SSAValue
-/// that can only be recovered from `emitter`.
-void bindMemRefView(edsc_mlir_emitter_t emitter, edsc_expr_t boundMemRef,
- edsc_expr_list_t *resultLbs, edsc_expr_list_t *resultUbs,
- edsc_expr_list_t *resultSteps);
-
-/// Returns an opaque expression for an mlir::edsc::Expr.
-edsc_expr_t makeBindable(mlir_type_t type);
-
-/// Returns an opaque expression for an mlir::edsc::Stmt.
-edsc_stmt_t makeStmt(edsc_expr_t e);
-
-/// Returns an opaque expression for an mlir::edsc::Indexed.
-edsc_indexed_t makeIndexed(edsc_expr_t expr);
-
-/// Returns an opaque expression that will emit an abstract operation identified
-/// by its name.
-edsc_expr_t Op(mlir_context_t context, const char *name, mlir_type_t resultType,
- edsc_expr_list_t arguments, edsc_block_list_t successors,
- mlir_named_attr_list_t attrs);
-
-/// Returns an opaque expression that will emit an mlir::LoadOp.
-edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices);
-
-/// Returns an opaque statement for an mlir::StoreOp.
-edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
- edsc_expr_list_t indices);
-
-/// Returns an opaque statement for an mlir::SelectOp.
-edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs);
-
-/// Returns an opaque constant integer expression of the specified type. The
-/// type may be i* or index.
-edsc_expr_t ConstantInteger(mlir_type_t type, int64_t value);
-
-/// Returns an opaque statement for an mlir::ReturnOp.
-edsc_stmt_t Return(edsc_expr_list_t values);
-
-/// Returns an opaque expression for an mlir::edsc::StmtBlock containing the
-/// given list of statements.
-edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts);
-
-/// Set the body of the block to the given statements and return the block.
-edsc_block_t BlockSetBody(edsc_block_t, edsc_stmt_list_t stmts);
-
-/// Returns an opaque statement branching to `destination` and passing
-/// `arguments` as block arguments.
-edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments);
-
-/// Returns an opaque statement that redirects the control flow to either
-/// `trueDestination` or `falseDestination` depending on whether the
-/// `condition` expression is true or false. The caller may pass expressions
-/// as arguments to the destination blocks using `trueArguments` and
-/// `falseArguments`, respectively.
-edsc_stmt_t CondBranch(edsc_expr_t condition, edsc_block_t trueDestination,
- edsc_expr_list_t trueArguments,
- edsc_block_t falseDestination,
- edsc_expr_list_t falseArguments);
-
-/// Returns an opaque statement for an mlir::AffineForOp with `enclosedStmts`
-/// nested below it.
-edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
- edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
-
-/// Returns an opaque statement for a perfectly nested set of mlir::AffineForOp
-/// with `enclosedStmts` nested below it.
-edsc_stmt_t ForNest(edsc_expr_list_t iv, edsc_expr_list_t lb,
- edsc_expr_list_t ub, edsc_expr_list_t step,
- edsc_stmt_list_t enclosedStmts);
-
-/// Returns an opaque 'max' expression that can be used only inside a for loop.
-edsc_max_expr_t Max(edsc_expr_list_t args);
-
-/// Returns an opaque 'min' expression that can be used only inside a for loop.
-edsc_min_expr_t Min(edsc_expr_list_t args);
-
-/// Returns an opaque statement for an mlir::AffineForOp with the lower bound
-/// `max(lbs)` and the upper bound `min(ubs)`, and with `enclosedStmts` nested
-/// below it.
-edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
- edsc_expr_t step, edsc_stmt_list_t enclosedStmts);
-
-/// Returns an opaque expression for the corresponding Binary operation.
-edsc_expr_t Add(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t Sub(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t Mul(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t Div(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t Rem(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t LT(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t LE(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t GT(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t GE(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t EQ(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t NE(edsc_expr_t e1, edsc_expr_t e2);
-
-edsc_expr_t FloorDiv(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t CeilDiv(edsc_expr_t e1, edsc_expr_t e2);
-
-edsc_expr_t And(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t Or(edsc_expr_t e1, edsc_expr_t e2);
-edsc_expr_t Negate(edsc_expr_t e);
-
-edsc_expr_t Call0(edsc_expr_t callee, edsc_expr_list_t args);
-edsc_expr_t Call1(edsc_expr_t callee, mlir_type_t result,
- edsc_expr_list_t args);
-
#ifdef __cplusplus
} // end extern "C"
#endif
+++ /dev/null
-//===- Types.h - MLIR EDSC Type System --------------------------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// Provides a simple value-based type system to implement an EDSC that
-// simplifies emitting MLIR and future MLIR dialects. Most of this should be
-// auto-generated in the future.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_EDSC_TYPES_H_
-#define MLIR_EDSC_TYPES_H_
-
-#include "mlir-c/Core.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Support/LLVM.h"
-
-#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/Casting.h"
-
-namespace mlir {
-
-class MLIRContext;
-class FuncBuilder;
-
-namespace edsc {
-namespace detail {
-
-struct ExprStorage;
-struct StmtStorage;
-struct StmtBlockStorage;
-
-} // namespace detail
-
-class StmtBlock;
-
-/// EDSC Types closely mirror the core MLIR and uses an abstraction similar to
-/// AffineExpr:
-/// 1. a set of composable structs;
-/// 2. with by-value semantics everywhere and operator overloading
-/// 3. with an underlying pointer to impl as payload.
-/// The vast majority of this code should be TableGen'd in the future which
-/// would allow us to automatically emit an EDSC for any IR dialect we are
-/// interested in. In turn this makes any IR dialect fully programmable in a
-/// declarative fashion.
-///
-/// The main differences with the AffineExpr design are as follows:
-/// 1. this type-system is an empty shell to which we can lazily bind Value*
-/// at the moment of emitting MLIR;
-/// 2. the data structures are BumpPointer allocated in a global
-/// `ScopedEDSCContext` with scoped lifetime. This allows avoiding to
-/// pass and store an extra Context pointer around and keeps users honest:
-/// *this is absolutely not meant to escape a local scope*.
-///
-/// The decision of slicing the underlying IR types into Bindable and
-/// NonBindable types is flexible and influences programmability.
-enum class ExprKind {
- FIRST_BINDABLE_EXPR = 100,
- Unbound = FIRST_BINDABLE_EXPR,
- LAST_BINDABLE_EXPR = Unbound,
- FIRST_NON_BINDABLE_EXPR = 200,
- Unary = FIRST_NON_BINDABLE_EXPR,
- Binary,
- Ternary,
- Variadic,
- FIRST_STMT_BLOCK_LIKE_EXPR = 600,
- For = FIRST_STMT_BLOCK_LIKE_EXPR,
- LAST_STMT_BLOCK_LIKE_EXPR = For,
- LAST_NON_BINDABLE_EXPR = LAST_STMT_BLOCK_LIKE_EXPR,
-};
-
-/// Scoped context holding a BumpPtrAllocator.
-/// Creating such an object injects a new allocator in Expr::globalAllocator.
-/// At the moment we can have only have one such context.
-///
-/// Usage:
-///
-/// ```c++
-/// MLFunctionBuilder *b = ...;
-/// Location someLocation = ...;
-/// Value *zeroValue = ...;
-/// Value *oneValue = ...;
-///
-/// ScopedEDSCContext raiiContext;
-/// Constant zero, one;
-/// Value *val = MLIREmitter(b)
-/// .bind(zero, zeroValue)
-/// .bind(one, oneValue)
-/// .emit(someLocation, zero + one);
-/// ```
-///
-/// will emit MLIR resembling:
-///
-/// ```mlir
-/// %2 = add(%c0, %c1) : index
-/// ```
-///
-/// The point of the EDSC is to synthesize arbitrarily more complex patterns in
-/// a declarative fashion. For example, clipping for guaranteed in-bounds access
-/// can be written:
-///
-/// ```c++
-/// auto expr = select(expr < 0, 0, select(expr < size, expr, size - 1));
-/// Value *val = MLIREmitter(b).bind(...).emit(loc, expr);
-/// ```
-struct ScopedEDSCContext {
- ScopedEDSCContext();
- ~ScopedEDSCContext();
- llvm::BumpPtrAllocator allocator;
-};
-
-struct Expr {
-public:
- using ImplType = detail::ExprStorage;
-
- /// Returns the scoped BumpPtrAllocator. This must be done in the context of a
- /// unique `ScopedEDSCContext` declared in an RAII fashion in some enclosing
- /// scope.
- static llvm::BumpPtrAllocator *&globalAllocator() {
- static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
- return allocator;
- }
-
- explicit Expr(Type type);
- /* implicit */ Expr(ImplType *storage) : storage(storage) {}
- explicit Expr(edsc_expr_t expr)
- : storage(reinterpret_cast<ImplType *>(expr)) {}
- operator edsc_expr_t() { return edsc_expr_t{storage}; }
-
- Expr(const Expr &other) = default;
- Expr &operator=(const Expr &other) = default;
- Expr(StringRef name, Type resultType, ArrayRef<Expr> operands,
- ArrayRef<NamedAttribute> atts = {});
-
- template <typename U> bool isa() const;
- template <typename U> U dyn_cast() const;
- template <typename U> U cast() const;
-
- /// Returns `true` if this expression builds the MLIR operation specified as
- /// the template argument. Unlike `isa`, this does not imply we can cast
- /// this Expr to the given type.
- template <typename U> bool is_op() const;
-
- /// Returns the classification for this type.
- ExprKind getKind() const;
- unsigned getId() const;
- StringRef getName() const;
-
- /// Returns the types of the values this expression produces.
- ArrayRef<Type> getResultTypes() const;
-
- /// Returns the list of expressions used as arguments of this expression.
- ArrayRef<Expr> getProperArguments() const;
-
- /// Returns the list of lists of expressions used as arguments of successors
- /// of this expression (i.e., arguments passed to destination basic blocks in
- /// terminator statements).
- SmallVector<ArrayRef<Expr>, 4> getSuccessorArguments() const;
-
- /// Returns the list of expressions used as arguments of the `index`-th
- /// successor of this expression.
- ArrayRef<Expr> getSuccessorArguments(int index) const;
-
- /// Returns the list of argument groups (includes the proper argument group,
- /// followed by successor/block argument groups).
- SmallVector<ArrayRef<Expr>, 4> getAllArgumentGroups() const;
-
- /// Returns the list of attributes of this expression.
- ArrayRef<NamedAttribute> getAttributes() const;
-
- /// Returns the attribute with the given name, if any.
- Attribute getAttribute(StringRef name) const;
-
- /// Returns the list of successors (StmtBlocks) of this expression.
- ArrayRef<StmtBlock> getSuccessors() const;
-
- /// Build the IR corresponding to this expression.
- SmallVector<Value *, 4>
- build(FuncBuilder &b, const llvm::DenseMap<Expr, Value *> &ssaBindings,
- const llvm::DenseMap<StmtBlock, Block *> &blockBindings) const;
-
- void print(raw_ostream &os) const;
- void dump() const;
- std::string str() const;
-
- /// For debugging purposes.
- const void *getStoragePtr() const { return storage; }
-
- /// Explicit conversion to bool. Useful in conjunction with dyn_cast.
- explicit operator bool() const { return storage != nullptr; }
-
- friend ::llvm::hash_code hash_value(Expr arg);
-
-protected:
- friend struct detail::ExprStorage;
- ImplType *storage;
-
- static void resetIds() { newId() = 0; }
- static unsigned &newId();
-};
-
-struct Bindable : public Expr {
- Bindable() = delete;
- Bindable(Expr expr) : Expr(expr) {
- assert(expr.isa<Bindable>() && "expected Bindable");
- }
- Bindable(const Bindable &) = default;
- Bindable &operator=(const Bindable &) = default;
- explicit Bindable(const edsc_expr_t &expr) : Expr(expr) {}
- operator edsc_expr_t() { return edsc_expr_t{storage}; }
-
-private:
- friend class Expr;
- friend struct ScopedEDSCContext;
-};
-
-struct UnaryExpr : public Expr {
- friend class Expr;
-
- UnaryExpr(StringRef name, Expr expr);
- Expr getExpr() const;
-
- template <typename T> static UnaryExpr make(Expr expr) {
- return UnaryExpr(T::getOperationName(), expr);
- }
-
-protected:
- UnaryExpr(Expr::ImplType *ptr) : Expr(ptr) {
- assert(!ptr || isa<UnaryExpr>() && "expected UnaryExpr");
- }
-};
-
-struct BinaryExpr : public Expr {
- friend class Expr;
- BinaryExpr(StringRef name, Type result, Expr lhs, Expr rhs,
- ArrayRef<NamedAttribute> attrs = {});
- Expr getLHS() const;
- Expr getRHS() const;
-
- template <typename T>
- static BinaryExpr make(Type result, Expr lhs, Expr rhs,
- ArrayRef<NamedAttribute> attrs = {}) {
- return BinaryExpr(T::getOperationName(), result, lhs, rhs, attrs);
- }
-
-protected:
- BinaryExpr(Expr::ImplType *ptr) : Expr(ptr) {
- assert(!ptr || isa<BinaryExpr>() && "expected BinaryExpr");
- }
-};
-
-struct TernaryExpr : public Expr {
- friend class Expr;
- TernaryExpr(StringRef name, Expr cond, Expr lhs, Expr rhs);
- Expr getCond() const;
- Expr getLHS() const;
- Expr getRHS() const;
-
- template <typename T> static TernaryExpr make(Expr cond, Expr lhs, Expr rhs) {
- return TernaryExpr(T::getOperationName(), cond, lhs, rhs);
- }
-
-protected:
- TernaryExpr(Expr::ImplType *ptr) : Expr(ptr) {
- assert(!ptr || isa<TernaryExpr>() && "expected TernaryExpr");
- }
-};
-
-struct VariadicExpr : public Expr {
- friend class Expr;
- VariadicExpr(StringRef name, llvm::ArrayRef<Expr> exprs,
- llvm::ArrayRef<Type> types = {},
- llvm::ArrayRef<NamedAttribute> attrs = {},
- llvm::ArrayRef<StmtBlock> succ = {});
- llvm::ArrayRef<Expr> getExprs() const;
- llvm::ArrayRef<Type> getTypes() const;
- llvm::ArrayRef<StmtBlock> getSuccessors() const;
-
- template <typename T>
- static VariadicExpr make(llvm::ArrayRef<Expr> exprs,
- llvm::ArrayRef<Type> types = {},
- llvm::ArrayRef<NamedAttribute> attrs = {},
- llvm::ArrayRef<StmtBlock> succ = {}) {
- return VariadicExpr(T::getOperationName(), exprs, types, attrs, succ);
- }
-
-protected:
- VariadicExpr(Expr::ImplType *ptr) : Expr(ptr) {
- assert(!ptr || isa<VariadicExpr>() && "expected VariadicExpr");
- }
-};
-
-struct StmtBlockLikeExpr : public Expr {
- friend class Expr;
- StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
- llvm::ArrayRef<Type> types = {});
-
-protected:
- StmtBlockLikeExpr(Expr::ImplType *ptr) : Expr(ptr) {
- assert(!ptr || isa<StmtBlockLikeExpr>() && "expected StmtBlockLikeExpr");
- }
-};
-
-/// A Stmt represent a unit of liaison betweeb a Bindable `lhs`, an Expr `rhs`
-/// and a list of `enclosingStmts`. This essentially allows giving a name and a
-/// scoping to objects of type `Expr` so they can be reused once bound to an
-/// Value*. This enables writing generators such as:
-///
-/// ```mlir
-/// Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
-/// tmpAlloc = alloc(tmpMemRefType);
-/// vectorView = vector.type_cast(tmpAlloc, vectorMemRefType),
-/// vectorValue = load(vectorView, zero),
-/// tmpDealloc = dealloc(tmpAlloc)});
-/// emitter.emitStmts({tmpAlloc, vectorView, vectorValue, tmpDealloc});
-/// ```
-///
-/// A Stmt can be declared with either:
-/// 1. default initialization (e.g. `Stmt foo;`) in which case all of its `lhs`,
-/// `rhs` and `enclosingStmts` are unbound;
-/// 2. initialization from an Expr without a Bindable `lhs`
-/// (e.g. store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs)), in which
-/// case the `lhs` is unbound;
-/// 3. an assignment operator to a `lhs` Stmt that is bound implicitly:
-/// (e.g. vectorValue = load(vectorView, zero)).
-///
-/// Only ExprKind::StmtBlockLikeExpr have `enclosedStmts`, these comprise:
-/// 1. `affine.for`-loops for which the `lhs` binds to the induction variable,
-/// `rhs`
-/// binds to an Expr of kind `ExprKind::For` with lower-bound, upper-bound and
-/// step respectively.
-// TODO(zinenko): this StmtBlockLikeExpr should be retired in favor of Expr
-// that can have a list of Blocks they contain, similarly to the core MLIR
-struct Stmt {
- using ImplType = detail::StmtStorage;
- friend class Expr;
- Stmt() : storage(nullptr) {}
- explicit Stmt(ImplType *storage) : storage(storage) {}
- Stmt(const Stmt &other) = default;
- Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
- Stmt(const Bindable &lhs, const Expr &rhs,
- llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
-
- explicit operator Expr() const { return getLHS(); }
- Stmt &operator=(const Expr &expr);
- Stmt &set(const Stmt &other) {
- this->storage = other.storage;
- return *this;
- }
- Stmt &operator=(const Stmt &other) = delete;
- explicit Stmt(edsc_stmt_t stmt)
- : storage(reinterpret_cast<ImplType *>(stmt)) {}
- operator edsc_stmt_t() { return edsc_stmt_t{storage}; }
-
- /// For debugging purposes.
- const ImplType *getStoragePtr() const { return storage; }
-
- void print(raw_ostream &os, llvm::Twine indent = "") const;
- void dump() const;
- std::string str() const;
-
- Expr getLHS() const;
- Expr getRHS() const;
- llvm::ArrayRef<Stmt> getEnclosedStmts() const;
-
-protected:
- ImplType *storage;
-};
-
-/// StmtBlock is a an addressable list of statements.
-///
-/// This enables writing complex generators such as:
-///
-/// ```mlir
-/// Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
-/// Stmt block = Block({
-/// tmpAlloc = alloc(tmpMemRefType),
-/// vectorView = vector.type_cast(tmpAlloc, vectorMemRefType),
-/// For(ivs, lbs, ubs, steps, {
-/// scalarValue = load(scalarMemRef,
-/// accessInfo.clippedScalarAccessExprs), store(scalarValue, tmpAlloc,
-/// accessInfo.tmpAccessExprs),
-/// }),
-/// vectorValue = load(vectorView, zero),
-/// tmpDealloc = dealloc(tmpAlloc.getLHS())});
-/// emitter.emitBlock(block);
-/// ```
-struct StmtBlock {
-public:
- using ImplType = detail::StmtBlockStorage;
-
- StmtBlock() : storage(nullptr) {}
- explicit StmtBlock(ImplType *st) : storage(st) {}
- explicit StmtBlock(edsc_block_t st)
- : storage(reinterpret_cast<ImplType *>(st)) {}
- StmtBlock(const StmtBlock &other) = default;
- StmtBlock(llvm::ArrayRef<Stmt> stmts);
- StmtBlock(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Stmt> stmts = {});
-
- llvm::ArrayRef<Bindable> getArguments() const;
- llvm::ArrayRef<Type> getArgumentTypes() const;
- llvm::ArrayRef<Stmt> getBody() const;
- uint64_t getId() const;
-
- void print(llvm::raw_ostream &os, Twine indent) const;
- std::string str() const;
-
- operator edsc_block_t() { return edsc_block_t{storage}; }
-
- /// Reset the body of this block with the given list of statements.
- StmtBlock &operator=(llvm::ArrayRef<Stmt> stmts);
- void set(llvm::ArrayRef<Stmt> stmts) { *this = stmts; }
-
- ImplType *getStoragePtr() const { return storage; }
-
-private:
- ImplType *storage;
-};
-
-/// These operator build new expressions from the given expressions. Some of
-/// them are unconventional, which mandated extracting them to a separate
-/// namespace. The indended use is as follows.
-///
-/// using namespace edsc;
-/// Expr e1, e2, condition
-/// {
-/// using namespace edsc::op;
-/// condition = !(e1 && e2); // this is a negation expression
-/// }
-/// if (!condition) // this is a nullity check
-/// reportError();
-///
-namespace op {
-/// Creates the BinaryExpr corresponding to the operator.
-Expr operator+(Expr lhs, Expr rhs);
-Expr operator-(Expr lhs, Expr rhs);
-Expr operator*(Expr lhs, Expr rhs);
-Expr operator/(Expr lhs, Expr rhs);
-Expr operator%(Expr lhs, Expr rhs);
-/// In particular operator==, operator!= return a new Expr and *not* a bool.
-Expr operator==(Expr lhs, Expr rhs);
-Expr operator!=(Expr lhs, Expr rhs);
-Expr operator<(Expr lhs, Expr rhs);
-Expr operator<=(Expr lhs, Expr rhs);
-Expr operator>(Expr lhs, Expr rhs);
-Expr operator>=(Expr lhs, Expr rhs);
-/// NB: Unlike boolean && and || these do not short-circuit.
-Expr operator&&(Expr lhs, Expr rhs);
-Expr operator||(Expr lhs, Expr rhs);
-Expr operator!(Expr expr);
-
-inline Expr operator+(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() + rhs.getLHS();
-}
-inline Expr operator-(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() - rhs.getLHS();
-}
-inline Expr operator*(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() * rhs.getLHS();
-}
-
-inline Expr operator<(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() < rhs.getLHS();
-}
-inline Expr operator<=(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() <= rhs.getLHS();
-}
-inline Expr operator>(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() > rhs.getLHS();
-}
-inline Expr operator>=(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() >= rhs.getLHS();
-}
-inline Expr operator&&(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() && rhs.getLHS();
-}
-inline Expr operator||(Stmt lhs, Stmt rhs) {
- return lhs.getLHS() || rhs.getLHS();
-}
-inline Expr operator!(Stmt stmt) { return !stmt.getLHS(); }
-} // end namespace op
-
-Expr floorDiv(Expr lhs, Expr rhs);
-Expr ceilDiv(Expr lhs, Expr rhs);
-
-template <typename U> bool Expr::isa() const {
- auto kind = getKind();
- if (std::is_same<U, Bindable>::value) {
- return kind >= ExprKind::FIRST_BINDABLE_EXPR &&
- kind <= ExprKind::LAST_BINDABLE_EXPR;
- }
- if (std::is_same<U, UnaryExpr>::value) {
- return kind == ExprKind::Unary;
- }
- if (std::is_same<U, BinaryExpr>::value) {
- return kind == ExprKind::Binary;
- }
- if (std::is_same<U, TernaryExpr>::value) {
- return kind == ExprKind::Ternary;
- }
- if (std::is_same<U, VariadicExpr>::value) {
- return kind == ExprKind::Variadic;
- }
- if (std::is_same<U, StmtBlockLikeExpr>::value) {
- return kind >= ExprKind::FIRST_STMT_BLOCK_LIKE_EXPR &&
- kind <= ExprKind::LAST_STMT_BLOCK_LIKE_EXPR;
- }
- return false;
-}
-
-template <typename U> U Expr::dyn_cast() const {
- if (isa<U>()) {
- return U(storage);
- }
- return U((Expr::ImplType *)(nullptr));
-}
-template <typename U> U Expr::cast() const {
- assert(isa<U>());
- return U(storage);
-}
-
-template <typename U> bool Expr::is_op() const {
- return U::getOperationName() == getName();
-}
-
-/// Make Expr hashable.
-inline ::llvm::hash_code hash_value(Expr arg) {
- return ::llvm::hash_value(arg.storage);
-}
-
-inline ::llvm::hash_code hash_value(StmtBlock arg) {
- return ::llvm::hash_value(arg.getStoragePtr());
-}
-
-raw_ostream &operator<<(raw_ostream &os, const Expr &expr);
-raw_ostream &operator<<(raw_ostream &os, const Stmt &stmt);
-
-} // namespace edsc
-} // namespace mlir
-
-namespace llvm {
-
-// Expr hash just like pointers
-template <> struct DenseMapInfo<mlir::edsc::Expr> {
- static mlir::edsc::Expr getEmptyKey() {
- auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
- return mlir::edsc::Expr(static_cast<mlir::edsc::Expr::ImplType *>(pointer));
- }
- static mlir::edsc::Expr getTombstoneKey() {
- auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
- return mlir::edsc::Expr(static_cast<mlir::edsc::Expr::ImplType *>(pointer));
- }
- static unsigned getHashValue(mlir::edsc::Expr val) {
- return mlir::edsc::hash_value(val);
- }
- static bool isEqual(mlir::edsc::Expr LHS, mlir::edsc::Expr RHS) {
- return LHS.getStoragePtr() == RHS.getStoragePtr();
- }
-};
-
-// StmtBlock hash just like pointers
-template <> struct DenseMapInfo<mlir::edsc::StmtBlock> {
- static mlir::edsc::StmtBlock getEmptyKey() {
- auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
- return mlir::edsc::StmtBlock(
- static_cast<mlir::edsc::StmtBlock::ImplType *>(pointer));
- }
- static mlir::edsc::StmtBlock getTombstoneKey() {
- auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
- return mlir::edsc::StmtBlock(
- static_cast<mlir::edsc::StmtBlock::ImplType *>(pointer));
- }
- static unsigned getHashValue(mlir::edsc::StmtBlock val) {
- return mlir::edsc::hash_value(val);
- }
- static bool isEqual(mlir::edsc::StmtBlock LHS, mlir::edsc::StmtBlock RHS) {
- return LHS.getStoragePtr() == RHS.getStoragePtr();
- }
-};
-
-} // namespace llvm
-
-namespace mlir {
-namespace edsc {
-
-/// Free function sugar.
-///
-/// Since bindings are hashed by the underlying pointer address, we need to be
-/// sure to construct new elements in a vector. We cannot just use
-/// `llvm::SmallVector<Expr, 8> dims(n);` directly because a single
-/// `Expr` will be default constructed and copied everywhere in the vector.
-/// Hilarity ensues when trying to bind `Expr` multiple times.
-llvm::SmallVector<Expr, 8> makeNewExprs(unsigned n, Type type);
-template <typename IterTy>
-llvm::SmallVector<Expr, 8> copyExprs(IterTy begin, IterTy end) {
- return llvm::SmallVector<Expr, 8>(begin, end);
-}
-inline llvm::SmallVector<Expr, 8> copyExprs(llvm::ArrayRef<Expr> exprs) {
- return llvm::SmallVector<Expr, 8>(exprs.begin(), exprs.end());
-}
-
-Expr alloc(llvm::ArrayRef<Expr> sizes, Type memrefType);
-inline Expr alloc(Type memrefType) { return alloc({}, memrefType); }
-Expr dealloc(Expr memref);
-
-Expr load(Expr m, llvm::ArrayRef<Expr> indices = {});
-inline Expr load(Stmt m, llvm::ArrayRef<Expr> indices = {}) {
- return load(m.getLHS(), indices);
-}
-Expr store(Expr val, Expr m, llvm::ArrayRef<Expr> indices = {});
-inline Expr store(Stmt val, Expr m, llvm::ArrayRef<Expr> indices = {}) {
- return store(val.getLHS(), m, indices);
-}
-Expr select(Expr cond, Expr lhs, Expr rhs);
-Expr vector_type_cast(Expr memrefExpr, Type memrefType);
-Expr constantInteger(Type t, int64_t value);
-Expr call(Expr func, Type result, llvm::ArrayRef<Expr> args);
-Expr call(Expr func, llvm::ArrayRef<Expr> args);
-
-Stmt Return(ArrayRef<Expr> values = {});
-Stmt Branch(StmtBlock destination, ArrayRef<Expr> args = {});
-Stmt CondBranch(Expr condition, StmtBlock trueDestination,
- ArrayRef<Expr> trueArgs, StmtBlock falseDestination,
- ArrayRef<Expr> falseArgs);
-Stmt CondBranch(Expr condition, StmtBlock trueDestination,
- StmtBlock falseDestination);
-
-Stmt For(Expr lb, Expr ub, Expr step, llvm::ArrayRef<Stmt> enclosedStmts);
-Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
- llvm::ArrayRef<Stmt> enclosedStmts);
-Stmt For(llvm::ArrayRef<Expr> indices, llvm::ArrayRef<Expr> lbs,
- llvm::ArrayRef<Expr> ubs, llvm::ArrayRef<Expr> steps,
- llvm::ArrayRef<Stmt> enclosedStmts);
-
-/// Define a 'affine.for' loop from with multi-valued bounds.
-///
-/// for max(lbs...) to min(ubs...) {}
-///
-Stmt MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs, ArrayRef<Expr> ubs,
- Expr step, ArrayRef<Stmt> enclosedStmts);
-
-/// Define an MLIR Block and bind its arguments to `args`. The types of block
-/// arguments are those of `args`, each of which must have exactly one result
-/// type. The body of the block may be empty and can be reset later.
-StmtBlock block(llvm::ArrayRef<Bindable> args, llvm::ArrayRef<Stmt> stmts);
-/// Define an MLIR Block without arguments. The body of the block can be empty
-/// and can be reset later.
-inline StmtBlock block(llvm::ArrayRef<Stmt> stmts) { return block({}, stmts); }
-
-/// This helper class exists purely for sugaring purposes and allows writing
-/// expressions such as:
-///
-/// ```mlir
-/// Indexed A(...), B(...), C(...);
-/// For(ivs, zeros, shapeA, ones, {
-/// C[ivs] = A[ivs] + B[ivs]
-/// });
-/// ```
-struct Indexed {
- Indexed(Expr e) : base(e), indices() {}
-
- /// Returns a new `Indexed`. As a consequence, an Indexed with attached
- /// indices can never be reused unless it is captured (e.g. via a Stmt).
- /// This is consistent with SSA behavior in MLIR but also allows for some
- /// minimal state and sugaring.
- Indexed operator()(llvm::ArrayRef<Expr> indices = {});
-
- /// Returns a new `Stmt`.
- /// Emits a `store` and clears the attached indices.
- Stmt operator=(Expr expr); // NOLINT: unconventional-assing-operator
-
- /// Implicit conversion.
- /// Emits a `load`.
- operator Expr() { return load(base, indices); }
-
- /// Operator overloadings.
- Expr operator+(Expr e) {
- using op::operator+;
- return load(base, indices) + e;
- }
- Expr operator-(Expr e) {
- using op::operator-;
- return load(base, indices) - e;
- }
- Expr operator*(Expr e) {
- using op::operator*;
- return load(base, indices) * e;
- }
-
-private:
- Expr base;
- llvm::SmallVector<Expr, 8> indices;
-};
-
-struct MaxExpr {
-public:
- explicit MaxExpr(llvm::ArrayRef<Expr> arguments);
- explicit MaxExpr(edsc_max_expr_t st)
- : storage(reinterpret_cast<detail::ExprStorage *>(st)) {}
- llvm::ArrayRef<Expr> getArguments() const;
-
- operator edsc_max_expr_t() { return storage; }
-
-private:
- detail::ExprStorage *storage;
-};
-
-struct MinExpr {
-public:
- explicit MinExpr(llvm::ArrayRef<Expr> arguments);
- explicit MinExpr(edsc_min_expr_t st)
- : storage(reinterpret_cast<detail::ExprStorage *>(st)) {}
- llvm::ArrayRef<Expr> getArguments() const;
-
- operator edsc_min_expr_t() { return storage; }
-
-private:
- detail::ExprStorage *storage;
-};
-
-Stmt For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step,
- llvm::ArrayRef<Stmt> enclosedStmts);
-Stmt For(llvm::ArrayRef<Expr> idxs, llvm::ArrayRef<MaxExpr> lbs,
- llvm::ArrayRef<MinExpr> ubs, llvm::ArrayRef<Expr> steps,
- llvm::ArrayRef<Stmt> enclosedStmts);
-
-inline MaxExpr Max(llvm::ArrayRef<Expr> args) { return MaxExpr(args); }
-inline MinExpr Min(llvm::ArrayRef<Expr> args) { return MinExpr(args); }
-
-} // namespace edsc
-} // namespace mlir
-
-#endif // MLIR_EDSC_TYPES_H_
--- /dev/null
+//===- Types.cpp - Implementations of MLIR Core C APIs --------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir-c/Core.h"
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/StringSwitch.h"
+
+using namespace mlir;
+
+mlir_type_t makeScalarType(mlir_context_t context, const char *name,
+ unsigned bitwidth) {
+ mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
+ mlir_type_t res =
+ llvm::StringSwitch<mlir_type_t>(name)
+ .Case("bf16",
+ mlir_type_t{mlir::FloatType::getBF16(c).getAsOpaquePointer()})
+ .Case("f16",
+ mlir_type_t{mlir::FloatType::getF16(c).getAsOpaquePointer()})
+ .Case("f32",
+ mlir_type_t{mlir::FloatType::getF32(c).getAsOpaquePointer()})
+ .Case("f64",
+ mlir_type_t{mlir::FloatType::getF64(c).getAsOpaquePointer()})
+ .Case("index",
+ mlir_type_t{mlir::IndexType::get(c).getAsOpaquePointer()})
+ .Case("i",
+ mlir_type_t{
+ mlir::IntegerType::get(bitwidth, c).getAsOpaquePointer()})
+ .Default(mlir_type_t{nullptr});
+ if (!res) {
+ llvm_unreachable("Invalid type specifier");
+ }
+ return res;
+}
+
+mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
+ int64_list_t sizes) {
+ auto t = mlir::MemRefType::get(
+ llvm::ArrayRef<int64_t>(sizes.values, sizes.n),
+ mlir::Type::getFromOpaquePointer(elemType),
+ {mlir::AffineMap::getMultiDimIdentityMap(
+ sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))},
+ 0);
+ return mlir_type_t{t.getAsOpaquePointer()};
+}
+
+mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
+ mlir_type_list_t outputs) {
+ llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n);
+ for (unsigned i = 0; i < inputs.n; ++i) {
+ ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]);
+ }
+ for (unsigned i = 0; i < outputs.n; ++i) {
+ outs[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]);
+ }
+ auto ft = mlir::FunctionType::get(
+ ins, outs, reinterpret_cast<mlir::MLIRContext *>(context));
+ return mlir_type_t{ft.getAsOpaquePointer()};
+}
+
+mlir_type_t makeIndexType(mlir_context_t context) {
+ auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+ auto type = mlir::IndexType::get(ctx);
+ return mlir_type_t{type.getAsOpaquePointer()};
+}
+
+mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value) {
+ auto ty = Type::getFromOpaquePointer(reinterpret_cast<const void *>(type));
+ auto attr = IntegerAttr::get(ty, value);
+ return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
+mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) {
+ auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
+ auto attr = BoolAttr::get(value, ctx);
+ return mlir_attr_t{attr.getAsOpaquePointer()};
+}
+
+unsigned getFunctionArity(mlir_func_t function) {
+ auto *f = reinterpret_cast<mlir::Function *>(function);
+ return f->getNumArguments();
+}
+++ /dev/null
-//===- Types.h - MLIR EDSC Type System Implementation -----------*- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/EDSC/Types.h"
-#include "mlir-c/Core.h"
-#include "mlir/AffineOps/AffineOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineExprVisitor.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/StandardOps/Ops.h"
-#include "mlir/Support/STLExtras.h"
-#include "mlir/VectorOps/VectorOps.h"
-
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/StringSwitch.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/raw_ostream.h"
-#include <memory>
-
-using llvm::errs;
-using llvm::Twine;
-
-using namespace mlir;
-using namespace mlir::edsc;
-using namespace mlir::edsc::detail;
-
-namespace mlir {
-namespace edsc {
-namespace detail {
-
-template <typename T> ArrayRef<T> copyIntoExprAllocator(ArrayRef<T> elements) {
- if (elements.empty()) {
- return {};
- }
- auto storage = Expr::globalAllocator()->Allocate<T>(elements.size());
- std::uninitialized_copy(elements.begin(), elements.end(), storage);
- return llvm::makeArrayRef(storage, elements.size());
-}
-
-struct ExprStorage {
- // Note: this structure is similar to OperationState, but stores lists in a
- // EDSC bump allocator.
- ExprKind kind;
- unsigned id;
-
- StringRef opName;
-
- // Exprs can contain multiple groups of operands separated by null
- // expressions. Two null expressions in a row identify an empty group.
- ArrayRef<Expr> operands;
-
- ArrayRef<Type> resultTypes;
- ArrayRef<NamedAttribute> attributes;
- ArrayRef<StmtBlock> successors;
-
- ExprStorage(ExprKind kind, StringRef name, ArrayRef<Type> results,
- ArrayRef<Expr> children, ArrayRef<NamedAttribute> attrs,
- ArrayRef<StmtBlock> succ = {}, unsigned exprId = Expr::newId())
- : kind(kind), id(exprId) {
- operands = copyIntoExprAllocator(children);
- resultTypes = copyIntoExprAllocator(results);
- attributes = copyIntoExprAllocator(attrs);
- successors = copyIntoExprAllocator(succ);
- if (!name.empty()) {
- auto nameStorage = Expr::globalAllocator()->Allocate<char>(name.size());
- std::uninitialized_copy(name.begin(), name.end(), nameStorage);
- opName = StringRef(nameStorage, name.size());
- }
- }
-};
-
-struct StmtStorage {
- StmtStorage(Bindable lhs, Expr rhs, llvm::ArrayRef<Stmt> enclosedStmts)
- : lhs(lhs), rhs(rhs), enclosedStmts(enclosedStmts) {}
- Bindable lhs;
- Expr rhs;
- ArrayRef<Stmt> enclosedStmts;
-};
-
-struct StmtBlockStorage {
- StmtBlockStorage(ArrayRef<Bindable> args, ArrayRef<Type> argTypes,
- ArrayRef<Stmt> stmts) {
- id = nextId();
- arguments = copyIntoExprAllocator(args);
- argumentTypes = copyIntoExprAllocator(argTypes);
- statements = copyIntoExprAllocator(stmts);
- }
-
- void replaceStmts(ArrayRef<Stmt> stmts) {
- Expr::globalAllocator()->Deallocate(statements.data(), statements.size());
- statements = copyIntoExprAllocator(stmts);
- }
-
- static uint64_t &nextId() {
- static thread_local uint64_t next = 0;
- return ++next;
- }
- static void resetIds() { nextId() = 0; }
-
- uint64_t id;
- ArrayRef<Bindable> arguments;
- ArrayRef<Type> argumentTypes;
- ArrayRef<Stmt> statements;
-};
-
-} // namespace detail
-} // namespace edsc
-} // namespace mlir
-
-mlir::edsc::ScopedEDSCContext::ScopedEDSCContext() {
- Expr::globalAllocator() = &allocator;
- Bindable::resetIds();
- StmtBlockStorage::resetIds();
-}
-
-mlir::edsc::ScopedEDSCContext::~ScopedEDSCContext() {
- Expr::globalAllocator() = nullptr;
-}
-
-mlir::edsc::Expr::Expr(Type type) {
- // Initialize with placement new.
- storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
- new (storage) detail::ExprStorage(ExprKind::Unbound, "", {type}, {}, {});
-}
-
-ExprKind mlir::edsc::Expr::getKind() const { return storage->kind; }
-
-unsigned mlir::edsc::Expr::getId() const {
- return static_cast<ImplType *>(storage)->id;
-}
-
-unsigned &mlir::edsc::Expr::newId() {
- static thread_local unsigned id = 0;
- return ++id;
-}
-
-ArrayRef<Type> mlir::edsc::Expr::getResultTypes() const {
- return storage->resultTypes;
-}
-
-ArrayRef<NamedAttribute> mlir::edsc::Expr::getAttributes() const {
- return storage->attributes;
-}
-
-Attribute mlir::edsc::Expr::getAttribute(StringRef name) const {
- for (const auto &namedAttr : getAttributes())
- if (namedAttr.first.is(name))
- return namedAttr.second;
- return {};
-}
-
-ArrayRef<StmtBlock> mlir::edsc::Expr::getSuccessors() const {
- return storage->successors;
-}
-
-StringRef mlir::edsc::Expr::getName() const {
- return static_cast<ImplType *>(storage)->opName;
-}
-
-SmallVector<Value *, 4>
-buildExprs(ArrayRef<Expr> exprs, FuncBuilder &b,
- const llvm::DenseMap<Expr, Value *> &ssaBindings,
- const llvm::DenseMap<StmtBlock, mlir::Block *> &blockBindings) {
- SmallVector<Value *, 4> values;
- values.reserve(exprs.size());
- for (auto child : exprs) {
- auto subResults = child.build(b, ssaBindings, blockBindings);
- assert(subResults.size() == 1 &&
- "expected single-result expression as operand");
- values.push_back(subResults.front());
- }
- return values;
-}
-
-SmallVector<Value *, 4>
-Expr::build(FuncBuilder &b, const llvm::DenseMap<Expr, Value *> &ssaBindings,
- const llvm::DenseMap<StmtBlock, Block *> &blockBindings) const {
- auto it = ssaBindings.find(*this);
- if (it != ssaBindings.end())
- return {it->second};
-
- SmallVector<Value *, 4> operandValues =
- buildExprs(getProperArguments(), b, ssaBindings, blockBindings);
-
- // Special case for emitting composed affine.applies.
- // FIXME: this should not be a special case, instead, define composed form as
- // canonical for the affine.apply operator and expose a generic createAndFold
- // operation on builder that canonicalizes all operations that we emit here.
- if (is_op<AffineApplyOp>()) {
- auto affInstr = makeComposedAffineApply(
- &b, b.getUnknownLoc(),
- getAttribute("map").cast<AffineMapAttr>().getValue(), operandValues);
- return {affInstr.getResult()};
- }
-
- auto state = OperationState(b.getContext(), b.getUnknownLoc(), getName());
- state.addOperands(operandValues);
- state.addTypes(getResultTypes());
- for (const auto &attr : getAttributes())
- state.addAttribute(attr.first, attr.second);
-
- auto successors = getSuccessors();
- auto successorArgs = getSuccessorArguments();
- assert(successors.size() == successorArgs.size() &&
- "expected all successors to have a corresponding operand group");
- for (int i = 0, e = successors.size(); i < e; ++i) {
- StmtBlock block = successors[i];
- assert(blockBindings.count(block) != 0 && "successor block does not exist");
- state.addSuccessor(
- blockBindings.lookup(block),
- buildExprs(successorArgs[i], b, ssaBindings, blockBindings));
- }
-
- Operation *op = b.createOperation(state);
- return llvm::to_vector<4>(op->getResults());
-}
-
-static AffineExpr createOperandAffineExpr(Expr e, int64_t position,
- MLIRContext *context) {
- if (e.is_op<ConstantOp>()) {
- int64_t cst =
- e.getAttribute("value").cast<IntegerAttr>().getValue().getSExtValue();
- return getAffineConstantExpr(cst, context);
- }
- return getAffineDimExpr(position, context);
-}
-
-static Expr createBinaryIndexExpr(
- Expr lhs, Expr rhs,
- std::function<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
- assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
- "only single-result exprs are supported in operators");
- auto thisType = lhs.getResultTypes().front();
- auto thatType = rhs.getResultTypes().front();
- assert(thisType == thatType && "cannot mix types in operators");
- assert(thisType.isIndex() && "expected exprs of index type");
- MLIRContext *context = thisType.getContext();
- auto lhsAff = createOperandAffineExpr(lhs, 0, context);
- auto rhsAff = createOperandAffineExpr(rhs, 1, context);
- auto map = AffineMap::get(2, 0, {affCombiner(lhsAff, rhsAff)}, {});
- auto attr = AffineMapAttr::get(map);
- auto attrId = Identifier::get("map", context);
- auto namedAttr = NamedAttribute{attrId, attr};
- return VariadicExpr("affine.apply", {lhs, rhs}, {IndexType::get(context)},
- {namedAttr});
-}
-
-// Create a binary expression between the two arguments emitting `IOp` if
-// arguments are integers or vectors/tensors thereof, `FOp` if arguments are
-// floating-point or vectors/tensors thereof, and `AffineApplyOp` with an
-// expression produced by `affCombiner` if arguments are of the index type.
-// Die on unsupported types.
-template <typename IOp, typename FOp>
-static Expr createBinaryExpr(
- Expr lhs, Expr rhs,
- std::function<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
- assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
- "only single-result exprs are supported in operators");
- auto thisType = lhs.getResultTypes().front();
- auto thatType = rhs.getResultTypes().front();
- assert(thisType == thatType && "cannot mix types in operators");
- if (thisType.isIndex()) {
- return createBinaryIndexExpr(lhs, rhs, affCombiner);
- } else if (thisType.isa<IntegerType>()) {
- return BinaryExpr::make<IOp>(thisType, lhs, rhs);
- } else if (thisType.isa<FloatType>()) {
- return BinaryExpr::make<FOp>(thisType, lhs, rhs);
- } else if (auto aggregateType = thisType.dyn_cast<VectorOrTensorType>()) {
- if (aggregateType.getElementType().isa<IntegerType>())
- return BinaryExpr::make<IOp>(thisType, lhs, rhs);
- else if (aggregateType.getElementType().isa<FloatType>())
- return BinaryExpr::make<FOp>(thisType, lhs, rhs);
- }
-
- llvm_unreachable("failed to create an Expr");
-}
-
-Expr mlir::edsc::op::operator+(Expr lhs, Expr rhs) {
- return createBinaryExpr<AddIOp, AddFOp>(
- lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
-}
-Expr mlir::edsc::op::operator-(Expr lhs, Expr rhs) {
- return createBinaryExpr<SubIOp, SubFOp>(
- lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
-}
-Expr mlir::edsc::op::operator*(Expr lhs, Expr rhs) {
- return createBinaryExpr<MulIOp, MulFOp>(
- lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
-}
-Expr mlir::edsc::op::operator/(Expr lhs, Expr rhs) {
- return createBinaryExpr<DivISOp, DivFOp>(
- lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
- llvm_unreachable("only exprs of non-index type support operator/");
- });
-}
-Expr mlir::edsc::op::operator%(Expr lhs, Expr rhs) {
- return createBinaryExpr<RemISOp, RemFOp>(
- lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
-}
-
-Expr mlir::edsc::floorDiv(Expr lhs, Expr rhs) {
- return createBinaryIndexExpr(
- lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
-}
-Expr mlir::edsc::ceilDiv(Expr lhs, Expr rhs) {
- return createBinaryIndexExpr(
- lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
-}
-
-static Expr createComparisonExpr(CmpIPredicate predicate, Expr lhs, Expr rhs) {
- assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
- "only single-result exprs are supported in operators");
- auto lhsType = lhs.getResultTypes().front();
- auto rhsType = rhs.getResultTypes().front();
- assert(lhsType == rhsType && "cannot mix types in operators");
- assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
- "only integer comparisons are supported");
-
- MLIRContext *context = lhsType.getContext();
- auto attr = IntegerAttr::get(IndexType::get(context),
- static_cast<int64_t>(predicate));
- auto attrId = Identifier::get(CmpIOp::getPredicateAttrName(), context);
- auto namedAttr = NamedAttribute{attrId, attr};
-
- return BinaryExpr::make<CmpIOp>(IntegerType::get(1, context), lhs, rhs,
- {namedAttr});
-}
-
-Expr mlir::edsc::op::operator==(Expr lhs, Expr rhs) {
- return createComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
-}
-Expr mlir::edsc::op::operator!=(Expr lhs, Expr rhs) {
- return createComparisonExpr(CmpIPredicate::NE, lhs, rhs);
-}
-Expr mlir::edsc::op::operator<(Expr lhs, Expr rhs) {
- // TODO(ntv,zinenko): signed by default, how about unsigned?
- return createComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
-}
-Expr mlir::edsc::op::operator<=(Expr lhs, Expr rhs) {
- return createComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
-}
-Expr mlir::edsc::op::operator>(Expr lhs, Expr rhs) {
- return createComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
-}
-Expr mlir::edsc::op::operator>=(Expr lhs, Expr rhs) {
- return createComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
-}
-
-Expr mlir::edsc::op::operator&&(Expr lhs, Expr rhs) {
- assert(lhs.getResultTypes().size() == 1 && rhs.getResultTypes().size() == 1 &&
- "expected single-result exprs");
- auto thisType = lhs.getResultTypes().front();
- auto thatType = rhs.getResultTypes().front();
- assert(thisType.isInteger(1) && thatType.isInteger(1) &&
- "logical And expects i1");
- return BinaryExpr::make<MulIOp>(thisType, lhs, rhs);
-}
-Expr mlir::edsc::op::operator||(Expr lhs, Expr rhs) {
- // There is not support for bitwise operations, so we emulate logical 'or'
- // lhs || rhs
- // as
- // !(!lhs && !rhs).
- using namespace edsc::op;
- return !(!lhs && !rhs);
-}
-Expr mlir::edsc::op::operator!(Expr expr) {
- assert(expr.getResultTypes().size() == 1 && "expected single-result exprs");
- auto thisType = expr.getResultTypes().front();
- assert(thisType.isInteger(1) && "logical Not expects i1");
- MLIRContext *context = thisType.getContext();
-
- // Create constant 1 expression.s
- auto attr = IntegerAttr::get(thisType, 1);
- auto attrId = Identifier::get("value", context);
- auto namedAttr = NamedAttribute{attrId, attr};
- auto cstOne = VariadicExpr("std.constant", {}, thisType, {namedAttr});
-
- // Emulate negation as (1 - x) : i1
- return cstOne - expr;
-}
-
-llvm::SmallVector<Expr, 8> mlir::edsc::makeNewExprs(unsigned n, Type type) {
- llvm::SmallVector<Expr, 8> res;
- res.reserve(n);
- for (auto i = 0; i < n; ++i) {
- res.push_back(Expr(type));
- }
- return res;
-}
-
-template <typename Target, size_t N, typename Source>
-SmallVector<Target, N> convertCList(Source list) {
- SmallVector<Target, N> result;
- result.reserve(list.n);
- for (unsigned i = 0; i < list.n; ++i) {
- result.push_back(Target(list.list[i]));
- }
- return result;
-}
-
-SmallVector<StmtBlock, 4> makeBlocks(edsc_block_list_t list) {
- return convertCList<StmtBlock, 4>(list);
-}
-
-static llvm::SmallVector<Expr, 8> makeExprs(edsc_expr_list_t exprList) {
- llvm::SmallVector<Expr, 8> exprs;
- exprs.reserve(exprList.n);
- for (unsigned i = 0; i < exprList.n; ++i) {
- exprs.push_back(Expr(exprList.exprs[i]));
- }
- return exprs;
-}
-
-static void fillStmts(edsc_stmt_list_t enclosedStmts,
- llvm::SmallVector<Stmt, 8> *stmts) {
- stmts->reserve(enclosedStmts.n);
- for (unsigned i = 0; i < enclosedStmts.n; ++i) {
- stmts->push_back(Stmt(enclosedStmts.stmts[i]));
- }
-}
-
-edsc_expr_t Op(mlir_context_t context, const char *name, mlir_type_t resultType,
- edsc_expr_list_t arguments, edsc_block_list_t successors,
- mlir_named_attr_list_t attrs) {
- mlir::MLIRContext *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
-
- auto blocks = makeBlocks(successors);
-
- SmallVector<NamedAttribute, 4> attributes;
- attributes.reserve(attrs.n);
- for (int i = 0; i < attrs.n; ++i) {
- auto attribute = Attribute::getFromOpaquePointer(
- reinterpret_cast<const void *>(attrs.list[i].value));
- auto name = Identifier::get(attrs.list[i].name, ctx);
- attributes.emplace_back(name, attribute);
- }
-
- return VariadicExpr(
- name, makeExprs(arguments),
- Type::getFromOpaquePointer(reinterpret_cast<const void *>(resultType)),
- attributes, blocks);
-}
-
-Expr mlir::edsc::alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
- return VariadicExpr::make<AllocOp>(sizes, memrefType);
-}
-
-Expr mlir::edsc::dealloc(Expr memref) {
- return UnaryExpr::make<DeallocOp>(memref);
-}
-
-Stmt mlir::edsc::For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
- assert(lb.getResultTypes().size() == 1 && "expected single-result bounds");
- auto type = lb.getResultTypes().front();
- Expr idx(type);
- return For(Bindable(idx), lb, ub, step, stmts);
-}
-
-Stmt mlir::edsc::For(const Bindable &idx, Expr lb, Expr ub, Expr step,
- ArrayRef<Stmt> stmts) {
- assert(lb);
- assert(ub);
- assert(step);
- // Use a null expression as a sentinel between lower and upper bound
- // expressions in the list of children.
- return Stmt(
- idx, StmtBlockLikeExpr(ExprKind::For, {lb, nullptr, ub, nullptr, step}),
- stmts);
-}
-
-template <typename LB, typename UB>
-Stmt forNestImpl(ArrayRef<Expr> indices, ArrayRef<LB> lbs, ArrayRef<UB> ubs,
- ArrayRef<Expr> steps, ArrayRef<Stmt> enclosedStmts) {
- assert(!indices.empty());
- assert(indices.size() == lbs.size());
- assert(indices.size() == ubs.size());
- assert(indices.size() == steps.size());
- Expr iv = indices.back();
- Stmt curStmt =
- For(Bindable(iv), lbs.back(), ubs.back(), steps.back(), enclosedStmts);
- for (int64_t i = indices.size() - 2; i >= 0; --i) {
- Expr iiv = indices[i];
- curStmt.set(For(Bindable(iiv), lbs[i], ubs[i], steps[i],
- llvm::ArrayRef<Stmt>{&curStmt, 1}));
- }
- return curStmt;
-}
-
-Stmt mlir::edsc::For(ArrayRef<Expr> indices, ArrayRef<Expr> lbs,
- ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
- ArrayRef<Stmt> enclosedStmts) {
- return forNestImpl(indices, lbs, ubs, steps, enclosedStmts);
-}
-
-Stmt mlir::edsc::For(const Bindable &idx, MaxExpr lb, MinExpr ub, Expr step,
- llvm::ArrayRef<Stmt> enclosedStmts) {
- return MaxMinFor(idx, lb.getArguments(), ub.getArguments(), step,
- enclosedStmts);
-}
-
-Stmt mlir::edsc::For(llvm::ArrayRef<Expr> idxs, llvm::ArrayRef<MaxExpr> lbs,
- llvm::ArrayRef<MinExpr> ubs, llvm::ArrayRef<Expr> steps,
- llvm::ArrayRef<Stmt> enclosedStmts) {
- return forNestImpl(idxs, lbs, ubs, steps, enclosedStmts);
-}
-
-Stmt mlir::edsc::MaxMinFor(const Bindable &idx, ArrayRef<Expr> lbs,
- ArrayRef<Expr> ubs, Expr step,
- ArrayRef<Stmt> enclosedStmts) {
- assert(!lbs.empty() && "'affine.for' loop must have lower bounds");
- assert(!ubs.empty() && "'affine.for' loop must have upper bounds");
-
- // Use a null expression as a sentinel between lower and upper bound
- // expressions in the list of children.
- SmallVector<Expr, 8> exprs;
- exprs.insert(exprs.end(), lbs.begin(), lbs.end());
- exprs.push_back(nullptr);
- exprs.insert(exprs.end(), ubs.begin(), ubs.end());
- exprs.push_back(nullptr);
- exprs.push_back(step);
-
- return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, exprs), enclosedStmts);
-}
-
-edsc_max_expr_t Max(edsc_expr_list_t args) {
- return mlir::edsc::Max(makeExprs(args));
-}
-
-edsc_min_expr_t Min(edsc_expr_list_t args) {
- return mlir::edsc::Min(makeExprs(args));
-}
-
-edsc_stmt_t For(edsc_expr_t iv, edsc_expr_t lb, edsc_expr_t ub,
- edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
- llvm::SmallVector<Stmt, 8> stmts;
- fillStmts(enclosedStmts, &stmts);
- return Stmt(
- For(Expr(iv).cast<Bindable>(), Expr(lb), Expr(ub), Expr(step), stmts));
-}
-
-edsc_stmt_t ForNest(edsc_expr_list_t ivs, edsc_expr_list_t lbs,
- edsc_expr_list_t ubs, edsc_expr_list_t steps,
- edsc_stmt_list_t enclosedStmts) {
- llvm::SmallVector<Stmt, 8> stmts;
- fillStmts(enclosedStmts, &stmts);
- return Stmt(For(makeExprs(ivs), makeExprs(lbs), makeExprs(ubs),
- makeExprs(steps), stmts));
-}
-
-edsc_stmt_t MaxMinFor(edsc_expr_t iv, edsc_max_expr_t lb, edsc_min_expr_t ub,
- edsc_expr_t step, edsc_stmt_list_t enclosedStmts) {
- llvm::SmallVector<Stmt, 8> stmts;
- fillStmts(enclosedStmts, &stmts);
- return Stmt(For(Expr(iv).cast<Bindable>(), MaxExpr(lb), MinExpr(ub),
- Expr(step), stmts));
-}
-
-StmtBlock mlir::edsc::block(ArrayRef<Bindable> args, ArrayRef<Stmt> stmts) {
- return StmtBlock(args, stmts);
-}
-
-edsc_block_t Block(edsc_expr_list_t arguments, edsc_stmt_list_t enclosedStmts) {
- llvm::SmallVector<Stmt, 8> stmts;
- fillStmts(enclosedStmts, &stmts);
-
- llvm::SmallVector<Bindable, 8> args;
- for (uint64_t i = 0; i < arguments.n; ++i)
- args.emplace_back(Expr(arguments.exprs[i]));
-
- return StmtBlock(args, stmts);
-}
-
-edsc_block_t BlockSetBody(edsc_block_t block, edsc_stmt_list_t stmts) {
- llvm::SmallVector<Stmt, 8> body;
- fillStmts(stmts, &body);
- StmtBlock(block).set(body);
- return block;
-}
-
-Expr mlir::edsc::load(Expr m, ArrayRef<Expr> indices) {
- assert(m.getResultTypes().size() == 1 && "expected single-result expr");
- auto type = m.getResultTypes().front().dyn_cast<MemRefType>();
- assert(type && "expected memref type");
-
- SmallVector<Expr, 8> exprs;
- exprs.push_back(m);
- exprs.append(indices.begin(), indices.end());
- return VariadicExpr::make<LoadOp>(exprs, {type.getElementType()});
-}
-
-edsc_expr_t Load(edsc_indexed_t indexed, edsc_expr_list_t indices) {
- Indexed i(Expr(indexed.base).cast<Bindable>());
- auto exprs = makeExprs(indices);
- Expr res = i(exprs);
- return res;
-}
-
-Expr mlir::edsc::store(Expr val, Expr m, ArrayRef<Expr> indices) {
- SmallVector<Expr, 8> exprs;
- exprs.push_back(val);
- exprs.push_back(m);
- exprs.append(indices.begin(), indices.end());
- return VariadicExpr::make<StoreOp>(exprs);
-}
-
-edsc_stmt_t Store(edsc_expr_t value, edsc_indexed_t indexed,
- edsc_expr_list_t indices) {
- Indexed i(Expr(indexed.base).cast<Bindable>());
- auto exprs = makeExprs(indices);
- Indexed loc = i(exprs);
- return Stmt(loc = Expr(value));
-}
-
-Expr mlir::edsc::select(Expr cond, Expr lhs, Expr rhs) {
- return TernaryExpr::make<SelectOp>(cond, lhs, rhs);
-}
-
-edsc_expr_t Select(edsc_expr_t cond, edsc_expr_t lhs, edsc_expr_t rhs) {
- return select(Expr(cond), Expr(lhs), Expr(rhs));
-}
-
-Expr mlir::edsc::vector_type_cast(Expr memrefExpr, Type memrefType) {
- return VariadicExpr::make<VectorTypeCastOp>({memrefExpr}, {memrefType});
-}
-
-Expr mlir::edsc::constantInteger(Type t, int64_t value) {
- assert((t.isa<IndexType>() || t.isa<IntegerType>()) &&
- "expected integer or index type");
- MLIRContext *ctx = t.getContext();
- auto attr = IntegerAttr::get(t, value);
- auto attrName = Identifier::get("value", ctx);
- auto namedAttr = NamedAttribute{attrName, attr};
- return VariadicExpr::make<ConstantOp>({}, t, namedAttr);
-}
-
-edsc_expr_t ConstantInteger(mlir_type_t type, int64_t value) {
- auto t = Type::getFromOpaquePointer(reinterpret_cast<const void *>(type));
- return mlir::edsc::constantInteger(t, value);
-}
-
-Expr mlir::edsc::call(Expr func, Type result, ArrayRef<Expr> args) {
- auto exprs = llvm::to_vector<8>(args);
- exprs.insert(exprs.begin(), func);
- return VariadicExpr::make<CallIndirectOp>(exprs, result);
-}
-
-Expr mlir::edsc::call(Expr func, ArrayRef<Expr> args) {
- auto exprs = llvm::to_vector<8>(args);
- exprs.insert(exprs.begin(), func);
- return VariadicExpr::make<CallIndirectOp>(exprs, {});
-}
-
-Stmt mlir::edsc::Return(ArrayRef<Expr> values) {
- return VariadicExpr::make<ReturnOp>(values);
-}
-
-edsc_stmt_t Return(edsc_expr_list_t values) {
- return Stmt(Return(makeExprs(values)));
-}
-
-Stmt mlir::edsc::Branch(StmtBlock destination, ArrayRef<Expr> args) {
- SmallVector<Expr, 4> arguments;
- arguments.push_back(nullptr);
- arguments.insert(arguments.end(), args.begin(), args.end());
- return VariadicExpr::make<BranchOp>(arguments, {}, {}, {destination});
-}
-
-Stmt mlir::edsc::CondBranch(Expr condition, StmtBlock trueDestination,
- ArrayRef<Expr> trueArgs, StmtBlock falseDestination,
- ArrayRef<Expr> falseArgs) {
- SmallVector<Expr, 8> arguments;
- arguments.push_back(condition);
- arguments.push_back(nullptr);
- arguments.append(trueArgs.begin(), trueArgs.end());
- arguments.push_back(nullptr);
- arguments.append(falseArgs.begin(), falseArgs.end());
- return VariadicExpr::make<CondBranchOp>(arguments, {}, {},
- {trueDestination, falseDestination});
-}
-
-Stmt mlir::edsc::CondBranch(Expr condition, StmtBlock trueDestination,
- StmtBlock falseDestination) {
- return CondBranch(condition, trueDestination, {}, falseDestination, {});
-}
-
-static raw_ostream &printBinaryExpr(raw_ostream &os, BinaryExpr e,
- StringRef infix) {
- os << '(' << e.getLHS() << ' ' << infix << ' ' << e.getRHS() << ')';
- return os;
-}
-
-// Get the operator spelling for pretty-printing the infix form of a
-// comparison operator.
-static StringRef getCmpIPredicateInfix(const mlir::edsc::Expr &e) {
- Attribute predicate = e.getAttribute(CmpIOp::getPredicateAttrName());
- assert(predicate && "expected a predicate in a comparison expr");
-
- switch (static_cast<CmpIPredicate>(
- predicate.cast<IntegerAttr>().getValue().getSExtValue())) {
- case CmpIPredicate::EQ:
- return "==";
- case CmpIPredicate::NE:
- return "!=";
- case CmpIPredicate::SGT:
- case CmpIPredicate::UGT:
- return ">";
- case CmpIPredicate::SLT:
- case CmpIPredicate::ULT:
- return "<";
- case CmpIPredicate::SGE:
- case CmpIPredicate::UGE:
- return ">=";
- case CmpIPredicate::SLE:
- case CmpIPredicate::ULE:
- return "<=";
- default:
- llvm_unreachable("unknown predicate");
- }
- return "";
-}
-
-static void printAffineExpr(raw_ostream &os, AffineExpr expr,
- ArrayRef<Expr> dims, ArrayRef<Expr> symbols) {
- struct Visitor : public AffineExprVisitor<Visitor> {
- Visitor(raw_ostream &ostream, ArrayRef<Expr> dimExprs,
- ArrayRef<Expr> symExprs)
- : os(ostream), dims(dimExprs), symbols(symExprs) {}
- raw_ostream &os;
- ArrayRef<Expr> dims;
- ArrayRef<Expr> symbols;
-
- void visitDimExpr(AffineDimExpr dimExpr) {
- os << dims[dimExpr.getPosition()];
- }
-
- void visitSymbolExpr(AffineSymbolExpr symbolExpr) {
- os << symbols[symbolExpr.getPosition()];
- }
-
- void visitConstantExpr(AffineConstantExpr constExpr) {
- os << constExpr.getValue();
- }
-
- void visitBinaryExpr(AffineBinaryOpExpr expr, StringRef infix) {
- visit(expr.getLHS());
- os << infix;
- visit(expr.getRHS());
- }
-
- void visitAddExpr(AffineBinaryOpExpr binOp) {
- visitBinaryExpr(binOp, " + ");
- }
-
- void visitMulExpr(AffineBinaryOpExpr binOp) {
- visitBinaryExpr(binOp, " * ");
- }
-
- void visitModExpr(AffineBinaryOpExpr binOp) {
- visitBinaryExpr(binOp, " % ");
- }
-
- void visitCeilDivExpr(AffineBinaryOpExpr binOp) {
- visitBinaryExpr(binOp, " ceildiv ");
- }
-
- void visitFloorDivExpr(AffineBinaryOpExpr binOp) {
- visitBinaryExpr(binOp, " floordiv ");
- }
- };
-
- Visitor(os, dims, symbols).visit(expr);
-}
-
-static void printAffineMap(raw_ostream &os, AffineMap map,
- ArrayRef<Expr> operands) {
- auto dims = operands.take_front(map.getNumDims());
- auto symbols = operands.drop_front(map.getNumDims());
- assert(map.getNumResults() == 1 &&
- "only 1-result maps are currently supported");
- printAffineExpr(os, map.getResult(0), dims, symbols);
-}
-
-void printAffineApply(raw_ostream &os, mlir::edsc::Expr e) {
- Attribute mapAttr;
- for (const auto &namedAttr : e.getAttributes()) {
- if (namedAttr.first.is("map")) {
- mapAttr = namedAttr.second;
- break;
- }
- }
- assert(mapAttr && "expected a map in an affine apply expr");
-
- printAffineMap(os, mapAttr.cast<AffineMapAttr>().getValue(),
- e.getProperArguments());
-}
-
-edsc_stmt_t Branch(edsc_block_t destination, edsc_expr_list_t arguments) {
- auto args = makeExprs(arguments);
- return mlir::edsc::Branch(StmtBlock(destination), args);
-}
-
-edsc_stmt_t CondBranch(edsc_expr_t condition, edsc_block_t trueDestination,
- edsc_expr_list_t trueArguments,
- edsc_block_t falseDestination,
- edsc_expr_list_t falseArguments) {
- auto trueArgs = makeExprs(trueArguments);
- auto falseArgs = makeExprs(falseArguments);
- return mlir::edsc::CondBranch(Expr(condition), StmtBlock(trueDestination),
- trueArgs, StmtBlock(falseDestination),
- falseArgs);
-}
-
-// If `blockArgs` is not empty, print it as a comma-separated parenthesized
-// list, otherwise print nothing.
-void printOptionalBlockArgs(ArrayRef<Expr> blockArgs, llvm::raw_ostream &os) {
- if (!blockArgs.empty())
- os << '(';
- interleaveComma(blockArgs, os);
- if (!blockArgs.empty())
- os << ")";
-}
-
-void mlir::edsc::Expr::print(raw_ostream &os) const {
- if (auto unbound = this->dyn_cast<Bindable>()) {
- os << "$" << unbound.getId();
- return;
- }
-
- // Handle known binary ops with pretty infix forms.
- if (auto binExpr = this->dyn_cast<BinaryExpr>()) {
- StringRef infix;
- if (binExpr.is_op<AddIOp>() || binExpr.is_op<AddFOp>())
- infix = "+";
- else if (binExpr.is_op<SubIOp>() || binExpr.is_op<SubFOp>())
- infix = "-";
- else if (binExpr.is_op<MulIOp>() || binExpr.is_op<MulFOp>())
- infix = binExpr.getResultTypes().front().isInteger(1) ? "&&" : "*";
- else if (binExpr.is_op<DivISOp>() || binExpr.is_op<DivIUOp>() ||
- binExpr.is_op<DivFOp>())
- infix = "/";
- else if (binExpr.is_op<RemISOp>() || binExpr.is_op<RemIUOp>() ||
- binExpr.is_op<RemFOp>())
- infix = "%";
- else if (binExpr.is_op<CmpIOp>())
- infix = getCmpIPredicateInfix(*this);
-
- if (!infix.empty()) {
- printBinaryExpr(os, binExpr, infix);
- return;
- }
- }
-
- // Handle known variadic ops with pretty forms.
- if (auto narExpr = this->dyn_cast<VariadicExpr>()) {
- if (narExpr.is_op<LoadOp>()) {
- os << narExpr.getName() << '(' << getProperArguments().front() << '[';
- interleaveComma(getProperArguments().drop_front(), os);
- os << "])";
- return;
- }
- if (narExpr.is_op<StoreOp>()) {
- os << narExpr.getName() << '(' << getProperArguments().front() << ", "
- << getProperArguments()[1] << '[';
- interleaveComma(getProperArguments().drop_front(2), os);
- os << "])";
- return;
- }
- if (narExpr.is_op<AffineApplyOp>()) {
- os << '(';
- printAffineApply(os, *this);
- os << ')';
- return;
- }
- if (narExpr.is_op<CallIndirectOp>()) {
- os << '@' << getProperArguments().front() << '(';
- interleaveComma(getProperArguments().drop_front(), os);
- os << ')';
- return;
- }
- if (narExpr.is_op<BranchOp>()) {
- os << "br ^bb" << narExpr.getSuccessors().front().getId();
- printOptionalBlockArgs(getSuccessorArguments(0), os);
- return;
- }
- if (narExpr.is_op<CondBranchOp>()) {
- os << "cond_br(" << getProperArguments()[0] << ", ^bb"
- << getSuccessors().front().getId();
- printOptionalBlockArgs(getSuccessorArguments(0), os);
- os << ", ^bb" << getSuccessors().back().getId();
- printOptionalBlockArgs(getSuccessorArguments(1), os);
- os << ')';
- return;
- }
- }
-
- // Special case for integer constants that are printed as is. Use
- // sign-extended result for everything but i1 (booleans).
- if (this->is_op<ConstantIndexOp>() || this->is_op<ConstantIntOp>()) {
- assert(getAttribute("value"));
- APInt value = getAttribute("value").cast<IntegerAttr>().getValue();
- if (value.getBitWidth() == 1)
- os << value.getZExtValue();
- else
- os << value;
- return;
- }
-
- // Handle all other types of ops with a more generic printing form.
- if (this->isa<UnaryExpr>() || this->isa<BinaryExpr>() ||
- this->isa<TernaryExpr>() || this->isa<VariadicExpr>()) {
- os << (getName().empty() ? "##unknown##" : getName()) << '(';
- interleaveComma(getProperArguments(), os);
- auto successors = getSuccessors();
- if (!successors.empty()) {
- os << '[';
- interleave(
- llvm::zip(successors, getSuccessorArguments()),
- [&os](const std::tuple<const StmtBlock &, const ArrayRef<Expr> &>
- &pair) {
- const auto &block = std::get<0>(pair);
- ArrayRef<Expr> operands = std::get<1>(pair);
- os << "^bb" << block.getId();
- if (!operands.empty()) {
- os << '(';
- interleaveComma(operands, os);
- os << ')';
- }
- },
- [&os]() { os << ", "; });
- os << ']';
- }
- auto attrs = getAttributes();
- if (!attrs.empty()) {
- os << '{';
- interleave(
- attrs,
- [&os](const NamedAttribute &attr) {
- os << attr.first.strref() << ": " << attr.second;
- },
- [&os]() { os << ", "; });
- os << '}';
- }
- os << ')';
- return;
- } else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) {
- switch (stmtLikeExpr.getKind()) {
- // We only print the lb, ub and step here, which are the StmtBlockLike
- // part of the `affine.for` StmtBlockLikeExpr.
- case ExprKind::For: {
- auto exprGroups = stmtLikeExpr.getAllArgumentGroups();
- assert(exprGroups.size() == 3 &&
- "For StmtBlockLikeExpr expected 3 groups");
- assert(exprGroups[2].size() == 1 && "expected 1 expr for loop step");
- if (exprGroups[0].size() == 1 && exprGroups[1].size() == 1) {
- os << exprGroups[0][0] << " to " << exprGroups[1][0] << " step "
- << exprGroups[2][0];
- } else {
- os << "max(";
- interleaveComma(exprGroups[0], os);
- os << ") to min(";
- interleaveComma(exprGroups[1], os);
- os << ") step " << exprGroups[2][0];
- }
- return;
- }
- default: {
- os << "unknown_stmt";
- }
- }
- }
- os << "unknown_kind(" << static_cast<int>(getKind()) << ")";
-}
-
-void mlir::edsc::Expr::dump() const { this->print(llvm::errs()); }
-
-std::string mlir::edsc::Expr::str() const {
- std::string res;
- llvm::raw_string_ostream os(res);
- this->print(os);
- return res;
-}
-
-llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
- const Expr &expr) {
- expr.print(os);
- return os;
-}
-
-edsc_expr_t makeBindable(mlir_type_t type) {
- return Bindable(Expr(Type(reinterpret_cast<const Type::ImplType *>(type))));
-}
-
-mlir::edsc::UnaryExpr::UnaryExpr(StringRef name, Expr expr)
- : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
- // Initialize with placement new.
- new (storage) detail::ExprStorage(ExprKind::Unary, name, {}, {expr}, {});
-}
-Expr mlir::edsc::UnaryExpr::getExpr() const {
- return static_cast<ImplType *>(storage)->operands.front();
-}
-
-mlir::edsc::BinaryExpr::BinaryExpr(StringRef name, Type result, Expr lhs,
- Expr rhs, ArrayRef<NamedAttribute> attrs)
- : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
- // Initialize with placement new.
- new (storage)
- detail::ExprStorage(ExprKind::Binary, name, {result}, {lhs, rhs}, attrs);
-}
-Expr mlir::edsc::BinaryExpr::getLHS() const {
- return static_cast<ImplType *>(storage)->operands.front();
-}
-Expr mlir::edsc::BinaryExpr::getRHS() const {
- return static_cast<ImplType *>(storage)->operands.back();
-}
-
-mlir::edsc::TernaryExpr::TernaryExpr(StringRef name, Expr cond, Expr lhs,
- Expr rhs)
- : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
- // Initialize with placement new.
- assert(lhs.getResultTypes().size() == 1 && "expected single-result expr");
- assert(rhs.getResultTypes().size() == 1 && "expected single-result expr");
- new (storage)
- detail::ExprStorage(ExprKind::Ternary, name,
- {lhs.getResultTypes().front()}, {cond, lhs, rhs}, {});
-}
-Expr mlir::edsc::TernaryExpr::getCond() const {
- return static_cast<ImplType *>(storage)->operands[0];
-}
-Expr mlir::edsc::TernaryExpr::getLHS() const {
- return static_cast<ImplType *>(storage)->operands[1];
-}
-Expr mlir::edsc::TernaryExpr::getRHS() const {
- return static_cast<ImplType *>(storage)->operands[2];
-}
-
-mlir::edsc::VariadicExpr::VariadicExpr(StringRef name, ArrayRef<Expr> exprs,
- ArrayRef<Type> types,
- ArrayRef<NamedAttribute> attrs,
- ArrayRef<StmtBlock> succ)
- : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
- // Initialize with placement new.
- new (storage)
- detail::ExprStorage(ExprKind::Variadic, name, types, exprs, attrs, succ);
-}
-ArrayRef<Expr> mlir::edsc::VariadicExpr::getExprs() const {
- return storage->operands;
-}
-ArrayRef<Type> mlir::edsc::VariadicExpr::getTypes() const {
- return storage->resultTypes;
-}
-ArrayRef<StmtBlock> mlir::edsc::VariadicExpr::getSuccessors() const {
- return storage->successors;
-}
-
-mlir::edsc::StmtBlockLikeExpr::StmtBlockLikeExpr(ExprKind kind,
- ArrayRef<Expr> exprs,
- ArrayRef<Type> types)
- : Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
- // Initialize with placement new.
- new (storage) detail::ExprStorage(kind, "", types, exprs, {});
-}
-
-static ArrayRef<Expr> getOneArgumentGroupStartingFrom(int start,
- ExprStorage *storage) {
- for (int i = start, e = storage->operands.size(); i < e; ++i) {
- if (!storage->operands[i])
- return storage->operands.slice(start, i - start);
- }
- return storage->operands.drop_front(start);
-}
-
-static SmallVector<ArrayRef<Expr>, 4>
-getAllArgumentGroupsStartingFrom(int start, ExprStorage *storage) {
- SmallVector<ArrayRef<Expr>, 4> groups;
- while (start < storage->operands.size()) {
- auto group = getOneArgumentGroupStartingFrom(start, storage);
- start += group.size() + 1;
- groups.push_back(group);
- }
- return groups;
-}
-
-ArrayRef<Expr> mlir::edsc::Expr::getProperArguments() const {
- return getOneArgumentGroupStartingFrom(0, storage);
-}
-
-SmallVector<ArrayRef<Expr>, 4> mlir::edsc::Expr::getSuccessorArguments() const {
- // Skip the first group containing proper arguments.
- // Note that +1 to size is necessary to step over the nullptrs in the list.
- int start = getOneArgumentGroupStartingFrom(0, storage).size() + 1;
- return getAllArgumentGroupsStartingFrom(start, storage);
-}
-
-ArrayRef<Expr> mlir::edsc::Expr::getSuccessorArguments(int index) const {
- assert(index >= 0 && "argument group index is out of bounds");
- assert(!storage->operands.empty() && "argument list is empty");
-
- // Skip over the first index + 1 groups (also includes proper arguments).
- int start = 0;
- for (int i = 0, e = index + 1; i < e; ++i) {
- assert(start < storage->operands.size() &&
- "argument group index is out of bounds");
- start += getOneArgumentGroupStartingFrom(start, storage).size() + 1;
- }
- return getOneArgumentGroupStartingFrom(start, storage);
-}
-
-SmallVector<ArrayRef<Expr>, 4> mlir::edsc::Expr::getAllArgumentGroups() const {
- return getAllArgumentGroupsStartingFrom(0, storage);
-}
-
-mlir::edsc::Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
- llvm::ArrayRef<Stmt> enclosedStmts) {
- storage = Expr::globalAllocator()->Allocate<detail::StmtStorage>();
- // Initialize with placement new.
- auto enclosedStmtStorage =
- Expr::globalAllocator()->Allocate<Stmt>(enclosedStmts.size());
- std::uninitialized_copy(enclosedStmts.begin(), enclosedStmts.end(),
- enclosedStmtStorage);
- new (storage) detail::StmtStorage{
- lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())};
-}
-
-// Statement with enclosed statements does not have a LHS.
-mlir::edsc::Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
- : Stmt(Bindable(Expr(Type())), rhs, enclosedStmts) {}
-
-edsc_stmt_t makeStmt(edsc_expr_t e) {
- assert(e && "unexpected empty expression");
- return Stmt(Expr(e));
-}
-
-Stmt &mlir::edsc::Stmt::operator=(const Expr &expr) {
- auto types = expr.getResultTypes();
- assert(types.size() == 1 && "single result Expr expected in Stmt::operator=");
- Stmt res(Bindable(Expr(types.front())), expr, {});
- std::swap(res.storage, this->storage);
- return *this;
-}
-
-Expr mlir::edsc::Stmt::getLHS() const {
- return static_cast<ImplType *>(storage)->lhs;
-}
-
-Expr mlir::edsc::Stmt::getRHS() const {
- return static_cast<ImplType *>(storage)->rhs;
-}
-
-llvm::ArrayRef<Stmt> mlir::edsc::Stmt::getEnclosedStmts() const {
- return storage->enclosedStmts;
-}
-
-void mlir::edsc::Stmt::print(raw_ostream &os, Twine indent) const {
- if (!storage) {
- os << "null_storage";
- return;
- }
- auto lhs = getLHS();
- auto rhs = getRHS();
-
- if (auto stmtExpr = rhs.dyn_cast<StmtBlockLikeExpr>()) {
- switch (stmtExpr.getKind()) {
- case ExprKind::For:
- os << indent << "for(" << lhs << " = " << stmtExpr << ") {";
- os << "\n";
- for (const auto &s : getEnclosedStmts()) {
- if (!s.getRHS().isa<StmtBlockLikeExpr>()) {
- os << indent << " ";
- }
- s.print(os, indent + " ");
- os << ";\n";
- }
- os << indent << "}";
- return;
- default: {
- // TODO(ntv): print more statement cases.
- os << "TODO";
- }
- }
- } else {
- os << lhs << " = " << rhs;
- }
-}
-
-void mlir::edsc::Stmt::dump() const { this->print(llvm::errs()); }
-
-std::string mlir::edsc::Stmt::str() const {
- std::string res;
- llvm::raw_string_ostream os(res);
- this->print(os);
- return res;
-}
-
-llvm::raw_ostream &mlir::edsc::operator<<(llvm::raw_ostream &os,
- const Stmt &stmt) {
- stmt.print(os);
- return os;
-}
-
-mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef<Stmt> stmts)
- : StmtBlock({}, stmts) {}
-
-mlir::edsc::StmtBlock::StmtBlock(llvm::ArrayRef<Bindable> args,
- llvm::ArrayRef<Stmt> stmts) {
- // Extract block argument types from bindable types.
- // Bindables must have a single type.
- llvm::SmallVector<Type, 8> argTypes;
- argTypes.reserve(args.size());
- for (Bindable arg : args) {
- auto argResults = arg.getResultTypes();
- assert(argResults.size() == 1 &&
- "only single-result expressions are supported");
- argTypes.push_back(argResults.front());
- }
- storage = Expr::globalAllocator()->Allocate<detail::StmtBlockStorage>();
- new (storage) detail::StmtBlockStorage(args, argTypes, stmts);
-}
-
-mlir::edsc::StmtBlock &mlir::edsc::StmtBlock::operator=(ArrayRef<Stmt> stmts) {
- storage->replaceStmts(stmts);
- return *this;
-}
-
-ArrayRef<mlir::edsc::Bindable> mlir::edsc::StmtBlock::getArguments() const {
- return storage->arguments;
-}
-
-ArrayRef<Type> mlir::edsc::StmtBlock::getArgumentTypes() const {
- return storage->argumentTypes;
-}
-
-ArrayRef<mlir::edsc::Stmt> mlir::edsc::StmtBlock::getBody() const {
- return storage->statements;
-}
-
-uint64_t mlir::edsc::StmtBlock::getId() const { return storage->id; }
-
-void mlir::edsc::StmtBlock::print(llvm::raw_ostream &os, Twine indent) const {
- os << indent << "^bb" << storage->id;
- if (!getArgumentTypes().empty())
- os << '(';
- interleaveComma(getArguments(), os);
- if (!getArgumentTypes().empty())
- os << ')';
- os << ":\n";
- for (auto stmt : getBody()) {
- stmt.print(os, indent + " ");
- os << '\n';
- }
-}
-
-std::string mlir::edsc::StmtBlock::str() const {
- std::string result;
- llvm::raw_string_ostream os(result);
- print(os, "");
- return result;
-}
-
-Indexed mlir::edsc::Indexed::operator()(llvm::ArrayRef<Expr> indices) {
- Indexed res(base);
- res.indices = llvm::SmallVector<Expr, 4>(indices.begin(), indices.end());
- return res;
-}
-
-// NOLINTNEXTLINE: unconventional-assign-operator
-Stmt mlir::edsc::Indexed::operator=(Expr expr) {
- return Stmt(store(expr, base, indices));
-}
-
-edsc_indexed_t makeIndexed(edsc_expr_t expr) {
- return edsc_indexed_t{expr, edsc_expr_list_t{nullptr, 0}};
-}
-
-edsc_indexed_t index(edsc_indexed_t indexed, edsc_expr_list_t indices) {
- return edsc_indexed_t{indexed.base, indices};
-}
-
-MaxExpr::MaxExpr(ArrayRef<Expr> arguments) {
- storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
- new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {});
-}
-
-ArrayRef<Expr> MaxExpr::getArguments() const { return storage->operands; }
-
-MinExpr::MinExpr(ArrayRef<Expr> arguments) {
- storage = Expr::globalAllocator()->Allocate<detail::ExprStorage>();
- new (storage) detail::ExprStorage(ExprKind::Variadic, "", {}, arguments, {});
-}
-
-ArrayRef<Expr> MinExpr::getArguments() const { return storage->operands; }
-
-mlir_type_t makeScalarType(mlir_context_t context, const char *name,
- unsigned bitwidth) {
- mlir::MLIRContext *c = reinterpret_cast<mlir::MLIRContext *>(context);
- mlir_type_t res =
- llvm::StringSwitch<mlir_type_t>(name)
- .Case("bf16",
- mlir_type_t{mlir::FloatType::getBF16(c).getAsOpaquePointer()})
- .Case("f16",
- mlir_type_t{mlir::FloatType::getF16(c).getAsOpaquePointer()})
- .Case("f32",
- mlir_type_t{mlir::FloatType::getF32(c).getAsOpaquePointer()})
- .Case("f64",
- mlir_type_t{mlir::FloatType::getF64(c).getAsOpaquePointer()})
- .Case("index",
- mlir_type_t{mlir::IndexType::get(c).getAsOpaquePointer()})
- .Case("i",
- mlir_type_t{
- mlir::IntegerType::get(bitwidth, c).getAsOpaquePointer()})
- .Default(mlir_type_t{nullptr});
- if (!res) {
- llvm_unreachable("Invalid type specifier");
- }
- return res;
-}
-
-mlir_type_t makeMemRefType(mlir_context_t context, mlir_type_t elemType,
- int64_list_t sizes) {
- auto t = mlir::MemRefType::get(
- llvm::ArrayRef<int64_t>(sizes.values, sizes.n),
- mlir::Type::getFromOpaquePointer(elemType),
- {mlir::AffineMap::getMultiDimIdentityMap(
- sizes.n, reinterpret_cast<mlir::MLIRContext *>(context))},
- 0);
- return mlir_type_t{t.getAsOpaquePointer()};
-}
-
-mlir_type_t makeFunctionType(mlir_context_t context, mlir_type_list_t inputs,
- mlir_type_list_t outputs) {
- llvm::SmallVector<mlir::Type, 8> ins(inputs.n), outs(outputs.n);
- for (unsigned i = 0; i < inputs.n; ++i) {
- ins[i] = mlir::Type::getFromOpaquePointer(inputs.types[i]);
- }
- for (unsigned i = 0; i < outputs.n; ++i) {
- outs[i] = mlir::Type::getFromOpaquePointer(outputs.types[i]);
- }
- auto ft = mlir::FunctionType::get(
- ins, outs, reinterpret_cast<mlir::MLIRContext *>(context));
- return mlir_type_t{ft.getAsOpaquePointer()};
-}
-
-mlir_type_t makeIndexType(mlir_context_t context) {
- auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
- auto type = mlir::IndexType::get(ctx);
- return mlir_type_t{type.getAsOpaquePointer()};
-}
-
-mlir_attr_t makeIntegerAttr(mlir_type_t type, int64_t value) {
- auto ty = Type::getFromOpaquePointer(reinterpret_cast<const void *>(type));
- auto attr = IntegerAttr::get(ty, value);
- return mlir_attr_t{attr.getAsOpaquePointer()};
-}
-
-mlir_attr_t makeBoolAttr(mlir_context_t context, bool value) {
- auto *ctx = reinterpret_cast<mlir::MLIRContext *>(context);
- auto attr = BoolAttr::get(value, ctx);
- return mlir_attr_t{attr.getAsOpaquePointer()};
-}
-
-unsigned getFunctionArity(mlir_func_t function) {
- auto *f = reinterpret_cast<mlir::Function *>(function);
- return f->getNumArguments();
-}
+++ /dev/null
-//===- APITest.cpp - Test for EDSC APIs -----------------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-// RUN: %p/api-test | FileCheck %s
-
-#include "mlir/AffineOps/AffineOps.h"
-#include "mlir/EDSC/MLIREmitter.h"
-#include "mlir/EDSC/Types.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/StandardOps/Ops.h"
-#include "mlir/Transforms/LoopUtils.h"
-
-#include "Test.h"
-
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-
-static MLIRContext &globalContext() {
- static thread_local MLIRContext context;
- return context;
-}
-
-static std::unique_ptr<Function> makeFunction(StringRef name,
- ArrayRef<Type> results = {},
- ArrayRef<Type> args = {}) {
- auto &ctx = globalContext();
- auto function = llvm::make_unique<Function>(
- UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx));
- function->addEntryBlock();
- return function;
-}
-
-// Inject a EDSC-constructed infinite loop implemented by mutual branching
-// between two blocks, following the pattern:
-//
-// br ^bb1
-// ^bb1:
-// br ^bb2
-// ^bb2:
-// br ^bb1
-//
-// Use blocks with arguments.
-TEST_FUNC(blocks) {
- using namespace edsc::op;
-
- auto f = makeFunction("blocks");
- FuncBuilder builder(f.get());
- edsc::ScopedEDSCContext context;
- // Declare two blocks. Note that we must declare the blocks before creating
- // branches to them.
- auto type = builder.getIntegerType(32);
- edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type), r(type);
- edsc::StmtBlock b1 = edsc::block({arg1, arg2}, {}),
- b2 = edsc::block({arg3, arg4}, {});
- auto c1 = edsc::constantInteger(type, 42);
- auto c2 = edsc::constantInteger(type, 1234);
-
- // Make an infinite loops by branching between the blocks. Note that copy-
- // assigning a block won't work well with branches, update the body instead.
- b1.set({r = arg1 + arg2, edsc::Branch(b2, {arg1, r})});
- b2.set({edsc::Branch(b1, {arg3, arg4})});
- auto op = edsc::Branch(b2, {c1, c2});
-
- // Emit a branch to b2. This should also emit blocks b2 and b1 that appear as
- // successors to the current block after the branch operation is insterted.
- edsc::MLIREmitter(&builder, f->getLoc()).emitStmt(op);
-
- // clang-format off
- // CHECK-LABEL: @blocks
- // CHECK: %c42_i32 = constant 42 : i32
- // CHECK-NEXT: %c1234_i32 = constant 1234 : i32
- // CHECK-NEXT: br ^bb1(%c42_i32, %c1234_i32 : i32, i32)
- // CHECK-NEXT: ^bb1(%0: i32, %1: i32): // 2 preds: ^bb0, ^bb2
- // CHECK-NEXT: br ^bb2(%0, %1 : i32, i32)
- // CHECK-NEXT: ^bb2(%2: i32, %3: i32): // pred: ^bb1
- // CHECK-NEXT: %4 = addi %2, %3 : i32
- // CHECK-NEXT: br ^bb1(%2, %4 : i32, i32)
- // CHECK-NEXT: }
- // clang-format on
- f->print(llvm::outs());
-}
-
-// Inject two EDSC-constructed blocks with arguments and a conditional branch
-// operation that transfers control to these blocks.
-TEST_FUNC(cond_branch) {
- auto f =
- makeFunction("cond_branch", {}, {IntegerType::get(1, &globalContext())});
-
- FuncBuilder builder(f.get());
- edsc::ScopedEDSCContext context;
- auto i1 = builder.getIntegerType(1);
- auto i32 = builder.getIntegerType(32);
- auto i64 = builder.getIntegerType(64);
- edsc::Expr arg1(i32), arg2(i64), arg3(i32);
- // Declare two blocks with different numbers of arguments.
- edsc::StmtBlock b1 = edsc::block({arg1}, {edsc::Return()}),
- b2 = edsc::block({arg2, arg3}, {edsc::Return()});
- edsc::Expr funcArg(i1);
-
- // Inject the conditional branch.
- auto condBranch = edsc::CondBranch(
- funcArg, b1, {edsc::constantInteger(i32, 32)}, b2,
- {edsc::constantInteger(i64, 64), edsc::constantInteger(i32, 42)});
-
- builder.setInsertionPoint(&*f->begin(), f->begin()->begin());
- edsc::MLIREmitter(&builder, f->getLoc())
- .bind(edsc::Bindable(funcArg), f->getArgument(0))
- .emitStmt(condBranch);
-
- // clang-format off
- // CHECK-LABEL: @cond_branch
- // CHECK: %c0 = constant 0 : index
- // CHECK-NEXT: %c1 = constant 1 : index
- // CHECK-NEXT: %c32_i32 = constant 32 : i32
- // CHECK-NEXT: %c64_i64 = constant 64 : i64
- // CHECK-NEXT: %c42_i32 = constant 42 : i32
- // CHECK-NEXT: cond_br %arg0, ^bb1(%c32_i32 : i32), ^bb2(%c64_i64, %c42_i32 : i64, i32)
- // CHECK-NEXT: ^bb1(%0: i32): // pred: ^bb0
- // CHECK-NEXT: return
- // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
- // CHECK-NEXT: return
- // clang-format on
- f->print(llvm::outs());
-}
-
-// Inject a EDSC-constructed `affine.for` loop with bounds coming from function
-// arguments.
-TEST_FUNC(dynamic_for_func_args) {
- auto indexType = IndexType::get(&globalContext());
- auto f = makeFunction("dynamic_for_func_args", {}, {indexType, indexType});
- FuncBuilder builder(f.get());
-
- using namespace edsc::op;
- Type index = IndexType::get(f->getContext());
- edsc::ScopedEDSCContext context;
- edsc::Expr lb(index), ub(index), step(index);
- step = edsc::constantInteger(index, 3);
- auto loop = edsc::For(lb, ub, step, {lb * step + ub, step + lb});
- edsc::MLIREmitter(&builder, f->getLoc())
- .bind(edsc::Bindable(lb), f->getArgument(0))
- .bind(edsc::Bindable(ub), f->getArgument(1))
- .emitStmt(loop)
- .emitStmt(edsc::Return());
-
- // clang-format off
- // CHECK-LABEL: func @dynamic_for_func_args(%arg0: index, %arg1: index) {
- // CHECK: affine.for %i0 = (d0) -> (d0)(%arg0) to (d0) -> (d0)(%arg1) step 3 {
- // CHECK: {{.*}} = affine.apply ()[s0] -> (s0 * 3)()[%arg0]
- // CHECK: {{.*}} = affine.apply ()[s0, s1] -> (s1 + s0 * 3)()[%arg0, %arg1]
- // CHECK: {{.*}} = affine.apply ()[s0] -> (s0 + 3)()[%arg0]
- // clang-format on
- f->print(llvm::outs());
-}
-
-// Inject a EDSC-constructed `affine.for` loop with non-constant bounds that are
-// obtained from AffineApplyOp (also constructed using EDSC operator
-// overloads).
-TEST_FUNC(dynamic_for) {
- auto indexType = IndexType::get(&globalContext());
- auto f = makeFunction("dynamic_for", {},
- {indexType, indexType, indexType, indexType});
- FuncBuilder builder(f.get());
-
- edsc::ScopedEDSCContext context;
- edsc::Expr lb1(indexType), lb2(indexType), ub1(indexType), ub2(indexType),
- step(indexType);
- using namespace edsc::op;
- auto lb = lb1 - lb2;
- auto ub = ub1 + ub2;
- auto loop = edsc::For(lb, ub, step, {});
- edsc::MLIREmitter(&builder, f->getLoc())
- .bind(edsc::Bindable(lb1), f->getArgument(0))
- .bind(edsc::Bindable(lb2), f->getArgument(1))
- .bind(edsc::Bindable(ub1), f->getArgument(2))
- .bind(edsc::Bindable(ub2), f->getArgument(3))
- .bindConstant<ConstantIndexOp>(edsc::Bindable(step), 2)
- .emitStmt(loop);
-
- // clang-format off
- // CHECK-LABEL: func @dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
- // CHECK: %0 = affine.apply ()[s0, s1] -> (s0 - s1)()[%arg0, %arg1]
- // CHECK-NEXT: %1 = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3]
- // CHECK-NEXT: affine.for %i0 = (d0) -> (d0)(%0) to (d0) -> (d0)(%1) step 2 {
- // clang-format on
- f->print(llvm::outs());
-}
-
-// Inject a EDSC-constructed empty `affine.for` loop with max/min bounds that
-// corresponds to
-//
-// for max(%arg0, %arg1) to (%arg2, %arg3) step 1
-//
-TEST_FUNC(max_min_for) {
- auto indexType = IndexType::get(&globalContext());
- auto f = makeFunction("max_min_for", {},
- {indexType, indexType, indexType, indexType});
- FuncBuilder builder(f.get());
-
- edsc::ScopedEDSCContext context;
- edsc::Expr lb1(f->getArgument(0)->getType());
- edsc::Expr lb2(f->getArgument(1)->getType());
- edsc::Expr ub1(f->getArgument(2)->getType());
- edsc::Expr ub2(f->getArgument(3)->getType());
- edsc::Expr iv(builder.getIndexType());
- edsc::Expr step = edsc::constantInteger(builder.getIndexType(), 1);
- auto loop =
- edsc::MaxMinFor(edsc::Bindable(iv), {lb1, lb2}, {ub1, ub2}, step, {});
- edsc::MLIREmitter(&builder, f->getLoc())
- .bind(edsc::Bindable(lb1), f->getArgument(0))
- .bind(edsc::Bindable(lb2), f->getArgument(1))
- .bind(edsc::Bindable(ub1), f->getArgument(2))
- .bind(edsc::Bindable(ub2), f->getArgument(3))
- .emitStmt(loop);
-
- // clang-format off
- // CHECK-LABEL: func @max_min_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
- // CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) {
- // clang-format on
- f->print(llvm::outs());
-}
-
-// Inject EDSC-constructed chain of indirect calls that corresponds to
-//
-// @callee()
-// var x = @second_order_callee(@callee)
-// @callee_args(x, x)
-//
-TEST_FUNC(call_indirect) {
- auto indexType = IndexType::get(&globalContext());
- auto callee = makeFunction("callee");
- auto calleeArgs = makeFunction("callee_args", {}, {indexType, indexType});
- auto secondOrderCallee =
- makeFunction("second_order_callee",
- {FunctionType::get({}, {indexType}, &globalContext())},
- {FunctionType::get({}, {}, &globalContext())});
- auto f = makeFunction("call_indirect");
- FuncBuilder builder(f.get());
-
- auto funcRetIndexType = builder.getFunctionType({}, builder.getIndexType());
-
- edsc::ScopedEDSCContext context;
- edsc::Expr func(callee->getType()), funcArgs(calleeArgs->getType()),
- secondOrderFunc(secondOrderCallee->getType());
- auto stmt = edsc::call(func, {});
- auto chainedCallResult =
- edsc::call(edsc::call(secondOrderFunc, funcRetIndexType, {func}),
- builder.getIndexType(), {});
- auto argsCall = edsc::call(funcArgs, {chainedCallResult, chainedCallResult});
- edsc::MLIREmitter(&builder, f->getLoc())
- .bindConstant<ConstantOp>(edsc::Bindable(func),
- builder.getFunctionAttr(callee.get()))
- .bindConstant<ConstantOp>(edsc::Bindable(funcArgs),
- builder.getFunctionAttr(calleeArgs.get()))
- .bindConstant<ConstantOp>(
- edsc::Bindable(secondOrderFunc),
- builder.getFunctionAttr(secondOrderCallee.get()))
- .emitStmt(stmt)
- .emitStmt(chainedCallResult)
- .emitStmt(argsCall);
-
- // clang-format off
- // CHECK-LABEL: @call_indirect
- // CHECK: %f = constant @callee : () -> ()
- // CHECK: %f_0 = constant @callee_args : (index, index) -> ()
- // CHECK: %f_1 = constant @second_order_callee : (() -> ()) -> (() -> index)
- // CHECK: call_indirect %f() : () -> ()
- // CHECK: %0 = call_indirect %f_1(%f) : (() -> ()) -> (() -> index)
- // CHECK: %1 = call_indirect %0() : () -> index
- // CHECK: call_indirect %f_0(%1, %1) : (index, index) -> ()
- // clang-format on
- f->print(llvm::outs());
-}
-
-// Inject EDSC-constructed 1-D pointwise-add loop with assignments to scalars,
-// `dim` indicates the shape of the memref storing the values.
-static std::unique_ptr<Function> makeAssignmentsFunction(int dim) {
- auto memrefType =
- MemRefType::get({dim}, FloatType::getF32(&globalContext()), {}, 0);
- auto f =
- makeFunction("assignments", {}, {memrefType, memrefType, memrefType});
- FuncBuilder builder(f.get());
-
- edsc::ScopedEDSCContext context;
- edsc::MLIREmitter emitter(&builder, f->getLoc());
-
- edsc::Expr zero = emitter.zero();
- edsc::Expr one = emitter.one();
- auto args = emitter.makeBoundFunctionArguments(f.get());
- auto views = emitter.makeBoundMemRefViews(args.begin(), args.end());
-
- Type indexType = builder.getIndexType();
- edsc::Expr i(indexType);
- edsc::Expr A = args[0], B = args[1], C = args[2];
- edsc::Expr M = views[0].dim(0);
- // clang-format off
- using namespace edsc::op;
- edsc::Stmt scalarA, scalarB, tmp;
- auto block = edsc::block({
- For(i, zero, M, one, {
- scalarA = load(A, {i}),
- scalarB = load(B, {i}),
- tmp = scalarA * scalarB,
- store(tmp, C, {i})
- }),
- });
- // clang-format on
- emitter.emitStmts(block.getBody());
-
- return f;
-}
-
-TEST_FUNC(assignments_1) {
- auto f = makeAssignmentsFunction(4);
-
- // clang-format off
- // CHECK-LABEL: func @assignments(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) {
- // CHECK: affine.for %[[iv:.*]] = 0 to 4 {
- // CHECK: %[[a:.*]] = load %arg0[%[[iv]]] : memref<4xf32>
- // CHECK: %[[b:.*]] = load %arg1[%[[iv]]] : memref<4xf32>
- // CHECK: %[[tmp:.*]] = mulf %[[a]], %[[b]] : f32
- // CHECK: store %[[tmp]], %arg2[%[[iv]]] : memref<4xf32>
- // clang-format on
- f->print(llvm::outs());
-}
-
-TEST_FUNC(assignments_2) {
- auto f = makeAssignmentsFunction(-1);
-
- // clang-format off
- // CHECK-LABEL: func @assignments(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
- // CHECK: affine.for %[[iv:.*]] = {{.*}} to {{.*}} {
- // CHECK: %[[a:.*]] = load %arg0[%[[iv]]] : memref<?xf32>
- // CHECK: %[[b:.*]] = load %arg1[%[[iv]]] : memref<?xf32>
- // CHECK: %[[tmp:.*]] = mulf %[[a]], %[[b]] : f32
- // CHECK: store %[[tmp]], %arg2[%[[iv]]] : memref<?xf32>
- // clang-format on
- f->print(llvm::outs());
-}
-
-int main() {
- RUN_TESTS();
- return 0;
-}