using namespace mlir::edsc;
using namespace mlir::edsc::detail;
-// Factors out the boilerplate that is needed to build and answer the
-// following simple question:
-// Given a set of Value* `values`, how do I get the resulting op(`values`)
-//
-// This is a very loaded question and generally cannot be answered properly.
-// For instance, an LLVM operation has many attributes that may not fit within
-// this simplistic framing (e.g. overflow behavior etc).
-//
-// Still, MLIR is a higher-level IR and the Halide experience shows it is
-// possible to build useful EDSCs with the right amount of sugar.
-//
-// To build EDSCs we need to be able to conveniently support simple operations
-// such as `add` on the type system. This captures the possible behaviors. In
-// the future, this should be automatically constructed from an abstraction
-// that is common to the IR verifier, but for now we need to get off the ground
-// manually.
-//
-// This is expected to be a "dialect-specific" functionality: certain dialects
-// will not have a simple definition. Two such cases that come to mind are:
-// 1. what does it mean to have an operator* on an opaque tensor dialect
-// (dot, vector, hadamard, kronecker ?)-product;
-// 2. LLVM add with attributes like overflow.
-// This is all left for future consideration; in the meantime let's separate
-// concerns and implement useful infrastructure without solving all problems at
-// once.
-
-/// Returns the element type if the type is VectorType or MemRefType; returns
-/// getType if the type is scalar.
-static Type getElementType(const Value &v) {
- if (auto vec = v.getType().dyn_cast<mlir::VectorType>()) {
- return vec.getElementType();
- }
- if (auto mem = v.getType().dyn_cast<mlir::MemRefType>()) {
- return mem.getElementType();
- }
- return v.getType();
-}
-
-static bool isIndexElement(const Value &v) {
- return getElementType(v).isIndex();
-}
-static bool isIntElement(const Value &v) {
- return getElementType(v).isa<IntegerType>();
-}
-static bool isFloatElement(const Value &v) {
- return getElementType(v).isa<FloatType>();
-}
-
-static Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
- if (isIndexElement(*a)) {
- auto *context = builder->getContext();
- auto d0 = getAffineDimExpr(0, context);
- auto d1 = getAffineDimExpr(1, context);
- auto map = AffineMap::get(2, 0, {d0 + d1}, {});
- return makeComposedAffineApply(builder, location, map, {a, b});
- } else if (isIntElement(*a)) {
- return builder->create<AddIOp>(location, a, b)->getResult();
- }
- assert(isFloatElement(*a) && "Expected float element");
- return builder->create<AddFOp>(location, a, b)->getResult();
-}
-
-static Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
- if (isIndexElement(*a)) {
- auto *context = builder->getContext();
- auto d0 = getAffineDimExpr(0, context);
- auto d1 = getAffineDimExpr(1, context);
- auto map = AffineMap::get(2, 0, {d0 - d1}, {});
- return makeComposedAffineApply(builder, location, map, {a, b});
- } else if (isIntElement(*a)) {
- return builder->create<SubIOp>(location, a, b)->getResult();
- }
- assert(isFloatElement(*a) && "Expected float element");
- return builder->create<SubFOp>(location, a, b)->getResult();
-}
-
-static Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) {
- if (!isFloatElement(*a)) {
- return builder->create<MulIOp>(location, a, b)->getResult();
- }
- assert(isFloatElement(*a) && "Expected float element");
- return builder->create<MulFOp>(location, a, b)->getResult();
-}
-
static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
const auto *inst = v.getDefiningInst();
if (inst) {