Apply a level of sugaring to the linalg.generic EDSC - NFC
authorNicolas Vasilache <ntv@google.com>
Sat, 14 Dec 2019 00:35:49 +0000 (16:35 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 14 Dec 2019 01:39:46 +0000 (17:39 -0800)
Make the declarative C++ builder API simpler to use so we can start chaining these ops together.

PiperOrigin-RevId: 285496266

mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
mlir/test/EDSC/builder-api-test.cpp

index 3618ec1..00da1d6 100644 (file)
 #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
 #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
 
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 
 namespace mlir {
 class BlockArgument;
 namespace edsc {
 
+enum class IterType { Parallel, Reduction };
+
+inline StringRef toString(IterType t) {
+  switch (t) {
+  case IterType::Parallel:
+    return getParallelIteratorTypeName();
+  case IterType::Reduction:
+    return getParallelIteratorTypeName();
+  default:
+    llvm_unreachable("Unsupport IterType");
+  }
+}
+
+/// A StructuredIndexed represents a captured value that can be indexed and
+/// passed to the `makeLinalgGenericOp`. It allows writing intuitive index
+/// expressions such as:
+///
+/// ```
+///      StructuredIndexed A(vA), B(vB), C(vC);
+///      makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
+/// ```
+struct StructuredIndexed {
+  StructuredIndexed(Value *v) : value(v) {}
+  StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
+    return StructuredIndexed(value, indexings);
+  }
+
+  operator Value *() const /* implicit */ { return value; }
+  ArrayRef<AffineExpr> getExprs() { return exprs; }
+
+private:
+  StructuredIndexed(Value *v, ArrayRef<AffineExpr> indexings)
+      : value(v), exprs(indexings.begin(), indexings.end()) {
+    assert(v->getType().isa<MemRefType>() && "MemRefType expected");
+  }
+  StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
+      : StructuredIndexed(v.getValue(), indexings) {}
+
+  Value *value;
+  SmallVector<AffineExpr, 4> exprs;
+};
+
 inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
 
-/// EDSC entry point to build linalg.generic operations programmatically.
 Operation *makeLinalgGenericOp(
-    ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
-    ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
-    ArrayRef<StringRef> iteratorTypes,
-    decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder);
+    ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
+    ArrayRef<StructuredIndexed> outputs,
+    decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder,
+    ArrayRef<Value *> otherValues = {},
+    ArrayRef<Attribute> otherAttributes = {});
+
+//===----------------------------------------------------------------------===//
+// EDSC builders for linalg generic operations.
+//===----------------------------------------------------------------------===//
+
+/// TODO(ntv): In the future we should tie these implementations to something in
+/// Tablegen that generates the proper interfaces and the proper sugared named
+/// ops.
+
+/// Build a linalg.generic that represents C = A * B in the current
+/// ScopedContext.
+Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
+
+template <typename Container> Operation *linalg_matmul(Container values) {
+  assert(values.size() == 3 && "Expected exactly 3 values");
+  return linalg_matmul(values[0], values[1], values[2]);
+}
 
 } // namespace edsc
 } // namespace mlir
index 606160b..3daeafe 100644 (file)
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/EDSC/Builders.h"
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/EDSC/Builders.h"
 #include "mlir/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Support/Functional.h"
 
 using namespace mlir;
 using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+
+static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
+                           unsigned &pos) {
+  for (auto sidx : structuredIndices) {
+    for (auto expr : sidx.getExprs()) {
+      expr.walk([&pos](AffineExpr e) {
+        if (auto d = e.dyn_cast<AffineDimExpr>())
+          pos = std::max(pos, d.getPosition());
+      });
+    }
+  }
+}
 
 Operation *mlir::edsc::makeLinalgGenericOp(
-    ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
-    ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
-    ArrayRef<StringRef> iteratorTypes,
-    decltype(defaultRegionBuilder) regionBuilder) {
+    ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
+    ArrayRef<StructuredIndexed> outputs,
+    decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues,
+    ArrayRef<Attribute> otherAttributes) {
   auto &builder = edsc::ScopedContext::getBuilder();
   auto *ctx = builder.getContext();
+  unsigned nInputs = inputs.size();
+  unsigned nOutputs = outputs.size();
+  unsigned rank = 0;
+  getMaxDimIndex(inputs, rank);
+  getMaxDimIndex(outputs, rank);
 
   SmallVector<AffineMap, 4> maps;
-  maps.reserve(mapExpressions.size());
-  for (auto exprs : mapExpressions)
-    maps.push_back(AffineMap::get(indices.size(), 0, exprs));
+  maps.reserve(nInputs + nOutputs);
+  for (auto in : inputs)
+    maps.push_back(
+        AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs()));
+  for (auto out : outputs)
+    maps.push_back(
+        AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs()));
 
-  SmallVector<Value *, 4> views;
-  views.reserve(inputViews.size() + outputViews.size());
-  views.append(inputViews.begin(), inputViews.end());
-  views.append(outputViews.begin(), outputViews.end());
+  unsigned nViews = nInputs + nOutputs;
+  SmallVector<Value *, 4> values;
+  values.reserve(nViews);
+  values.append(inputs.begin(), inputs.end());
+  values.append(outputs.begin(), outputs.end());
 
+  auto iteratorStrTypes = functional::map(toString, iteratorTypes);
+  // clang-format off
   auto *op =
       edsc::ScopedContext::getBuilder()
           .create<linalg::GenericOp>(
-              edsc::ScopedContext::getLocation(), views,
-              IntegerAttr::get(IntegerType::get(64, ctx), inputViews.size()),
-              IntegerAttr::get(IntegerType::get(64, ctx), outputViews.size()),
+              edsc::ScopedContext::getLocation(),
+              values,
+              IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
+              IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
               builder.getAffineMapArrayAttr(maps),
-              builder.getStrArrayAttr(iteratorTypes), StringAttr() /*doc*/,
-              FlatSymbolRefAttr() /*fun*/, StringAttr() /*library_call*/
+              builder.getStrArrayAttr(iteratorStrTypes),
+              StringAttr() /*doc*/,
+              FlatSymbolRefAttr() /*fun*/,
+              StringAttr() /*library_call*/
+              /* TODO: other attributes in op */
               )
           .getOperation();
+  // clang-format on
 
   using namespace edsc;
   SmallVector<Type, 4> blockTypes;
-  blockTypes.reserve(views.size());
-  for (auto *v : views)
-    blockTypes.push_back(getElementTypeOrSelf(v));
+  blockTypes.reserve(values.size());
+  for (auto it : llvm::enumerate(values))
+    blockTypes.push_back((it.index() < nViews)
+                             ? getElementTypeOrSelf(it.value())
+                             : it.value()->getType());
 
   assert(op->getRegions().front().empty());
   op->getRegions().front().push_front(new Block);
@@ -70,3 +104,24 @@ Operation *mlir::edsc::makeLinalgGenericOp(
       [&] { regionBuilder(b.getBlock()->getArguments()); });
   return op;
 }
+
+using linalg_yield = OperationBuilder<linalg::YieldOp>;
+
+Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB,
+                                     ValueHandle vC) {
+  // clang-format off
+  AffineExpr m, n, k;
+  bindDims(ScopedContext::getContext(), m, n, k);
+  StructuredIndexed A(vA), B(vB), C(vC);
+  return makeLinalgGenericOp(
+    {IterType::Parallel, IterType::Parallel, IterType::Reduction},
+    {A({m, n}), B({k, n})},
+    {C({m, n})},
+    [](ArrayRef<BlockArgument *> args) {
+      using edsc::op::operator*;
+      using edsc::op::operator+;
+      ValueHandle a(args[0]), b(args[1]), c(args[2]);
+      linalg_yield((c + a * b).getValue());
+  });
+  // clang-format on
+}
index dc17305..abd1eb0 100644 (file)
@@ -821,32 +821,15 @@ TEST_FUNC(affine_if_op) {
 // clang-format on
 TEST_FUNC(linalg_matmul) {
   using namespace edsc;
-  using namespace edsc::intrinsics;
-  using namespace edsc::op;
-  using linalg_yield = OperationBuilder<linalg::YieldOp>;
 
   auto f32Type = FloatType::getF32(&globalContext());
   auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
   auto f =
       makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType});
 
-  // clang-format off
   OpBuilder builder(f.getBody());
   ScopedContext scope(builder, f.getLoc());
-  Value *A(f.getArgument(0)), *B(f.getArgument(1)), *C(f.getArgument(2));
-  AffineExpr m, n, k;
-  bindDims(f.getContext(), m, n, k);
-  makeLinalgGenericOp(
-    {m, n, k},
-    {{m, n}, {k, n}, {m, n}},
-    {A, B},
-    {C},
-    {"parallel", "parallel", "reduction"},
-    [](ArrayRef<BlockArgument *> args) {
-      ValueHandle a(args[0]), b(args[1]), c(args[2]);
-      linalg_yield((c + a * b).getValue());
-  });
-  // clang-format on
+  linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())));
 
   f.print(llvm::outs());
   f.erase();