#ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
#define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
namespace mlir {
class BlockArgument;
namespace edsc {
+enum class IterType { Parallel, Reduction };
+
+inline StringRef toString(IterType t) {
+ switch (t) {
+ case IterType::Parallel:
+ return getParallelIteratorTypeName();
+ case IterType::Reduction:
+ return getParallelIteratorTypeName();
+ default:
+ llvm_unreachable("Unsupport IterType");
+ }
+}
+
+/// A StructuredIndexed represents a captured value that can be indexed and
+/// passed to the `makeLinalgGenericOp`. It allows writing intuitive index
+/// expressions such as:
+///
+/// ```
+/// StructuredIndexed A(vA), B(vB), C(vC);
+/// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
+/// ```
+struct StructuredIndexed {
+ StructuredIndexed(Value *v) : value(v) {}
+ StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
+ return StructuredIndexed(value, indexings);
+ }
+
+ operator Value *() const /* implicit */ { return value; }
+ ArrayRef<AffineExpr> getExprs() { return exprs; }
+
+private:
+ StructuredIndexed(Value *v, ArrayRef<AffineExpr> indexings)
+ : value(v), exprs(indexings.begin(), indexings.end()) {
+ assert(v->getType().isa<MemRefType>() && "MemRefType expected");
+ }
+ StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
+ : StructuredIndexed(v.getValue(), indexings) {}
+
+ Value *value;
+ SmallVector<AffineExpr, 4> exprs;
+};
+
inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
-/// EDSC entry point to build linalg.generic operations programmatically.
Operation *makeLinalgGenericOp(
- ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
- ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
- ArrayRef<StringRef> iteratorTypes,
- decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder);
+ ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
+ ArrayRef<StructuredIndexed> outputs,
+ decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder,
+ ArrayRef<Value *> otherValues = {},
+ ArrayRef<Attribute> otherAttributes = {});
+
+//===----------------------------------------------------------------------===//
+// EDSC builders for linalg generic operations.
+//===----------------------------------------------------------------------===//
+
+/// TODO(ntv): In the future we should tie these implementations to something in
+/// Tablegen that generates the proper interfaces and the proper sugared named
+/// ops.
+
+/// Build a linalg.generic that represents C = A * B in the current
+/// ScopedContext.
+Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
+
+template <typename Container> Operation *linalg_matmul(Container values) {
+ assert(values.size() == 3 && "Expected exactly 3 values");
+ return linalg_matmul(values[0], values[1], values[2]);
+}
} // namespace edsc
} // namespace mlir
// limitations under the License.
// =============================================================================
-#include "mlir/EDSC/Builders.h"
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Support/Functional.h"
using namespace mlir;
using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+
+static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
+ unsigned &pos) {
+ for (auto sidx : structuredIndices) {
+ for (auto expr : sidx.getExprs()) {
+ expr.walk([&pos](AffineExpr e) {
+ if (auto d = e.dyn_cast<AffineDimExpr>())
+ pos = std::max(pos, d.getPosition());
+ });
+ }
+ }
+}
Operation *mlir::edsc::makeLinalgGenericOp(
- ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
- ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
- ArrayRef<StringRef> iteratorTypes,
- decltype(defaultRegionBuilder) regionBuilder) {
+ ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
+ ArrayRef<StructuredIndexed> outputs,
+ decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues,
+ ArrayRef<Attribute> otherAttributes) {
auto &builder = edsc::ScopedContext::getBuilder();
auto *ctx = builder.getContext();
+ unsigned nInputs = inputs.size();
+ unsigned nOutputs = outputs.size();
+ unsigned rank = 0;
+ getMaxDimIndex(inputs, rank);
+ getMaxDimIndex(outputs, rank);
SmallVector<AffineMap, 4> maps;
- maps.reserve(mapExpressions.size());
- for (auto exprs : mapExpressions)
- maps.push_back(AffineMap::get(indices.size(), 0, exprs));
+ maps.reserve(nInputs + nOutputs);
+ for (auto in : inputs)
+ maps.push_back(
+ AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs()));
+ for (auto out : outputs)
+ maps.push_back(
+ AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs()));
- SmallVector<Value *, 4> views;
- views.reserve(inputViews.size() + outputViews.size());
- views.append(inputViews.begin(), inputViews.end());
- views.append(outputViews.begin(), outputViews.end());
+ unsigned nViews = nInputs + nOutputs;
+ SmallVector<Value *, 4> values;
+ values.reserve(nViews);
+ values.append(inputs.begin(), inputs.end());
+ values.append(outputs.begin(), outputs.end());
+ auto iteratorStrTypes = functional::map(toString, iteratorTypes);
+ // clang-format off
auto *op =
edsc::ScopedContext::getBuilder()
.create<linalg::GenericOp>(
- edsc::ScopedContext::getLocation(), views,
- IntegerAttr::get(IntegerType::get(64, ctx), inputViews.size()),
- IntegerAttr::get(IntegerType::get(64, ctx), outputViews.size()),
+ edsc::ScopedContext::getLocation(),
+ values,
+ IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
+ IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
builder.getAffineMapArrayAttr(maps),
- builder.getStrArrayAttr(iteratorTypes), StringAttr() /*doc*/,
- FlatSymbolRefAttr() /*fun*/, StringAttr() /*library_call*/
+ builder.getStrArrayAttr(iteratorStrTypes),
+ StringAttr() /*doc*/,
+ FlatSymbolRefAttr() /*fun*/,
+ StringAttr() /*library_call*/
+ /* TODO: other attributes in op */
)
.getOperation();
+ // clang-format on
using namespace edsc;
SmallVector<Type, 4> blockTypes;
- blockTypes.reserve(views.size());
- for (auto *v : views)
- blockTypes.push_back(getElementTypeOrSelf(v));
+ blockTypes.reserve(values.size());
+ for (auto it : llvm::enumerate(values))
+ blockTypes.push_back((it.index() < nViews)
+ ? getElementTypeOrSelf(it.value())
+ : it.value()->getType());
assert(op->getRegions().front().empty());
op->getRegions().front().push_front(new Block);
[&] { regionBuilder(b.getBlock()->getArguments()); });
return op;
}
+
+using linalg_yield = OperationBuilder<linalg::YieldOp>;
+
+Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB,
+ ValueHandle vC) {
+ // clang-format off
+ AffineExpr m, n, k;
+ bindDims(ScopedContext::getContext(), m, n, k);
+ StructuredIndexed A(vA), B(vB), C(vC);
+ return makeLinalgGenericOp(
+ {IterType::Parallel, IterType::Parallel, IterType::Reduction},
+ {A({m, n}), B({k, n})},
+ {C({m, n})},
+ [](ArrayRef<BlockArgument *> args) {
+ using edsc::op::operator*;
+ using edsc::op::operator+;
+ ValueHandle a(args[0]), b(args[1]), c(args[2]);
+ linalg_yield((c + a * b).getValue());
+ });
+ // clang-format on
+}
// clang-format on
TEST_FUNC(linalg_matmul) {
using namespace edsc;
- using namespace edsc::intrinsics;
- using namespace edsc::op;
- using linalg_yield = OperationBuilder<linalg::YieldOp>;
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
auto f =
makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType});
- // clang-format off
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
- Value *A(f.getArgument(0)), *B(f.getArgument(1)), *C(f.getArgument(2));
- AffineExpr m, n, k;
- bindDims(f.getContext(), m, n, k);
- makeLinalgGenericOp(
- {m, n, k},
- {{m, n}, {k, n}, {m, n}},
- {A, B},
- {C},
- {"parallel", "parallel", "reduction"},
- [](ArrayRef<BlockArgument *> args) {
- ValueHandle a(args[0]), b(args[1]), c(args[2]);
- linalg_yield((c + a * b).getValue());
- });
- // clang-format on
+ linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())));
f.print(llvm::outs());
f.erase();