Add edsc::ops for pointwise, conv and dilated_conv
authorNicolas Vasilache <ntv@google.com>
Mon, 16 Dec 2019 21:32:02 +0000 (13:32 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Dec 2019 21:42:38 +0000 (13:42 -0800)
This CL adds more Linalg EDSC ops and tests to support building pointwise operations along with conv and dilated_conv.
This also fixes a bug in the existing linalg_matmul EDSC and beefs up the test.

The current set of ops is already enough to build an interesting, albeit simple, model used internally.

PiperOrigin-RevId: 285838012

mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/EDSC/Intrinsics.h
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
mlir/test/EDSC/builder-api-test.cpp

index 00da1d6..4213420 100644 (file)
 #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
 #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
 
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 
 namespace mlir {
 class BlockArgument;
-namespace edsc {
 
+namespace edsc {
 enum class IterType { Parallel, Reduction };
 
 inline StringRef toString(IterType t) {
@@ -38,7 +40,7 @@ inline StringRef toString(IterType t) {
   case IterType::Parallel:
     return getParallelIteratorTypeName();
   case IterType::Reduction:
-    return getParallelIteratorTypeName();
+    return getReductionIteratorTypeName();
   default:
     llvm_unreachable("Unsupport IterType");
   }
@@ -78,20 +80,83 @@ inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
 Operation *makeLinalgGenericOp(
     ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
     ArrayRef<StructuredIndexed> outputs,
-    decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder,
+    llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder =
+        defaultRegionBuilder,
     ArrayRef<Value *> otherValues = {},
     ArrayRef<Attribute> otherAttributes = {});
 
+namespace ops {
+using edsc::StructuredIndexed;
+using edsc::ValueHandle;
+using edsc::intrinsics::linalg_yield;
+
 //===----------------------------------------------------------------------===//
 // EDSC builders for linalg generic operations.
 //===----------------------------------------------------------------------===//
 
+/// Build the body of a region to compute a multiply-accumulate, under the
+/// current ScopedContext, at the current insert point.
+void macRegionBuilder(ArrayRef<BlockArgument *> args);
+
 /// TODO(ntv): In the future we should tie these implementations to something in
 /// Tablegen that generates the proper interfaces and the proper sugared named
 /// ops.
 
-/// Build a linalg.generic that represents C = A * B in the current
-/// ScopedContext.
+/// Build a linalg.pointwise, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (i0, ..., in) = (par, ..., par)
+///    |
+///    |  O...(some_subset...(i0, ..., in)) =
+///    |    some_pointwise_func...(I...(some_other_subset...(i0, ..., in)))
+/// ```
+///
+/// This is a very generic entry point that can be configured in many ways to
+/// build a perfect loop nest of parallel loops with arbitrarily complex
+/// innermost loop code and whatever (explicit) broadcast semantics.
+///
+/// This can be used with both out-of-place and in-place semantics.
+/// The client is responsible for ensuring the region operations are compatible
+/// with in-place semantics and parallelism.
+
+/// Unary pointwise operation (with broadcast) entry point.
+using UnaryPointwiseOpBuilder = llvm::function_ref<Value *(ValueHandle)>;
+Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
+                            StructuredIndexed I, StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = tanh(I)`. The client is responsible for specifying the proper
+/// indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O);
+
+/// Binary pointwise operation (with broadcast) entry point.
+using BinaryPointwiseOpBuilder =
+    llvm::function_ref<Value *(ValueHandle, ValueHandle)>;
+Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
+                            StructuredIndexed I1, StructuredIndexed I2,
+                            StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = I1 + I2`. The client is responsible for specifying the proper
+/// indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2,
+                                StructuredIndexed O);
+
+/// Build a linalg.pointwise with all `parallel` iterators and a region that
+/// computes `O = max(I!, I2)`. The client is responsible for specifying the
+/// proper indexings when creating the StructuredIndexed.
+Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
+                                StructuredIndexed O);
+
+// TODO(ntv): Implement more useful pointwise operations on a per-need basis.
+
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (m, n, k) = (par, par, seq)
+///    |
+///    |  C(m, n) += A(m, k) * B(k, n)
+/// ```
 Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
 
 template <typename Container> Operation *linalg_matmul(Container values) {
@@ -99,6 +164,76 @@ template <typename Container> Operation *linalg_matmul(Container values) {
   return linalg_matmul(values[0], values[1], values[2]);
 }
 
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (batch, f, [h, w, ...], [kh, kw, ...], c) =
+///    |  (par, par, [par, par, ...], [red, red, ...], red)
+///    |
+///    | O(batch, [h, w, ...], f) +=
+///    |   I(batch,
+///    |     [
+///    |       stride[0] * h + dilations[0] * kh,
+///    |       stride[1] * w + dilations[1] * kw, ...
+///          ],
+///    |     c)
+///    |   *
+///    |   W([kh, kw, ...], c, f)
+/// ```
+/// If `dilations` or `strides` are left empty, the default value of `1` is used
+/// along each relevant dimension.
+///
+/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
+///
+// TODO(ntv) Extend convolution rank with some template magic.
+Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO,
+                            ArrayRef<int> strides = {},
+                            ArrayRef<int> dilations = {});
+
+template <typename Container>
+Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
+                            ArrayRef<int> dilations = {}) {
+  assert(values.size() == 3 && "Expected exactly 3 values");
+  return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations);
+}
+
+/// Build a linalg.generic, under the current ScopedContext, at the current
+/// insert point, that computes:
+/// ```
+///    (batch, dm, c, [h, w, ...], [kh, kw, ...]) =
+///    |  (par, par, par, [par, par, ...], [red, red, ...])
+///    |
+///    | O(batch, [h, w, ...], c * depth_multiplier) +=
+///    |   I(batch,
+///    |     [
+///    |       stride[0] * h + dilations[0] * kh,
+///    |       stride[1] * w + dilations[1] * kw, ...
+///          ],
+///    |     c)
+///    |   *
+///    |   W([kh, kw, ...], c, depth_multiplier)
+/// ```
+/// If `dilations` or `strides` are left empty, the default value of `1` is used
+/// along each relevant dimension.
+///
+/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
+///
+// TODO(ntv) Extend convolution rank with some template magic.
+Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
+                                    ValueHandle vO, int depth_multiplier = 1,
+                                    ArrayRef<int> strides = {},
+                                    ArrayRef<int> dilations = {});
+
+template <typename Container>
+Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier,
+                                    ArrayRef<int> strides = {},
+                                    ArrayRef<int> dilations = {}) {
+  assert(values.size() == 3 && "Expected exactly 3 values");
+  return linalg_dilated_conv_nhwc(values[0], values[1], values[2],
+                                  depth_multiplier, strides, dilations);
+}
+
+} // namespace ops
 } // namespace edsc
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
new file mode 100644 (file)
index 0000000..f1acab6
--- /dev/null
@@ -0,0 +1,35 @@
+//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
+#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+
+namespace mlir {
+namespace edsc {
+namespace intrinsics {
+
+using linalg_fill = OperationBuilder<linalg::FillOp>;
+using linalg_yield = OperationBuilder<linalg::YieldOp>;
+
+} // namespace intrinsics
+} // namespace edsc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
index 4f9621c..1f24a90 100644 (file)
@@ -247,13 +247,13 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> {
 }
 
 def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> {
-  let arguments = (ins AnyStridedMemRef:$input,
+  let arguments = (ins AnyStridedMemRef:$output,
                    AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
   let extraClassDeclaration = libraryCallName # [{
     ArrayAttr indexing_maps();
 
     ArrayAttr iterator_types() {
-      unsigned nPar = input()->getType().cast<ShapedType>().getRank();
+      unsigned nPar = output()->getType().cast<ShapedType>().getRank();
       MLIRContext *ctx = getContext();
       SmallVector<Attribute, 8> iters(
         nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
index 68bd210..6dbb343 100644 (file)
@@ -154,22 +154,22 @@ template <typename Op> struct ValueBuilder : public ValueHandle {
 
   /// Folder-based
   template <typename... Args>
-  ValueBuilder(OperationFolder &folder, Args... args)
+  ValueBuilder(OperationFolder *folder, Args... args)
       : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(args)...)) {}
-  ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs)
+  ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs)
       : ValueBuilder(ValueBuilder::create<Op>(folder, detail::unpack(vs))) {}
   template <typename... Args>
-  ValueBuilder(OperationFolder &folder, ArrayRef<ValueHandle> vs, Args... args)
+  ValueBuilder(OperationFolder *folder, ArrayRef<ValueHandle> vs, Args... args)
       : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(vs),
                                             detail::unpack(args)...)) {}
   template <typename T, typename... Args>
-  ValueBuilder(OperationFolder &folder, T t, ArrayRef<ValueHandle> vs,
+  ValueBuilder(OperationFolder *folder, T t, ArrayRef<ValueHandle> vs,
                Args... args)
       : ValueHandle(ValueHandle::create<Op>(folder, detail::unpack(t),
                                             detail::unpack(vs),
                                             detail::unpack(args)...)) {}
   template <typename T1, typename T2, typename... Args>
-  ValueBuilder(OperationFolder &folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
+  ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef<ValueHandle> vs,
                Args... args)
       : ValueHandle(ValueHandle::create<Op>(
             folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs),
@@ -200,6 +200,7 @@ template <typename Op> struct OperationBuilder : public OperationHandle {
   OperationBuilder() : OperationHandle(OperationHandle::create<Op>()) {}
 };
 
+using addf = ValueBuilder<AddFOp>;
 using affine_apply = ValueBuilder<AffineApplyOp>;
 using affine_if = OperationBuilder<AffineIfOp>;
 using affine_load = ValueBuilder<AffineLoadOp>;
@@ -212,11 +213,14 @@ using constant_int = ValueBuilder<ConstantIntOp>;
 using dealloc = OperationBuilder<DeallocOp>;
 using dim = ValueBuilder<DimOp>;
 using muli = ValueBuilder<MulIOp>;
+using mulf = ValueBuilder<MulFOp>;
+using memref_cast = ValueBuilder<MemRefCastOp>;
 using ret = OperationBuilder<ReturnOp>;
 using select = ValueBuilder<SelectOp>;
 using std_load = ValueBuilder<LoadOp>;
 using std_store = OperationBuilder<StoreOp>;
 using subi = ValueBuilder<SubIOp>;
+using tanh = ValueBuilder<TanhOp>;
 using view = ValueBuilder<ViewOp>;
 
 /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
index 3daeafe..77e3a1e 100644 (file)
@@ -16,6 +16,7 @@
 // =============================================================================
 
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/EDSC/Builders.h"
 #include "mlir/EDSC/Intrinsics.h"
@@ -26,6 +27,7 @@
 using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
+using namespace mlir::edsc::ops;
 
 static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
                            unsigned &pos) {
@@ -42,24 +44,26 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
 Operation *mlir::edsc::makeLinalgGenericOp(
     ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
     ArrayRef<StructuredIndexed> outputs,
-    decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues,
-    ArrayRef<Attribute> otherAttributes) {
+    llvm::function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder,
+    ArrayRef<Value *> otherValues, ArrayRef<Attribute> otherAttributes) {
   auto &builder = edsc::ScopedContext::getBuilder();
   auto *ctx = builder.getContext();
   unsigned nInputs = inputs.size();
   unsigned nOutputs = outputs.size();
-  unsigned rank = 0;
-  getMaxDimIndex(inputs, rank);
-  getMaxDimIndex(outputs, rank);
+  unsigned maxPos = 0;
+  getMaxDimIndex(inputs, maxPos);
+  getMaxDimIndex(outputs, maxPos);
+  // maxPos is 0 indexed, need to turn this into a count (i.e. +1)
+  unsigned nDims = maxPos + 1;
 
   SmallVector<AffineMap, 4> maps;
   maps.reserve(nInputs + nOutputs);
   for (auto in : inputs)
     maps.push_back(
-        AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs()));
+        AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
   for (auto out : outputs)
     maps.push_back(
-        AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs()));
+        AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
 
   unsigned nViews = nInputs + nOutputs;
   SmallVector<Value *, 4> values;
@@ -105,23 +109,148 @@ Operation *mlir::edsc::makeLinalgGenericOp(
   return op;
 }
 
-using linalg_yield = OperationBuilder<linalg::YieldOp>;
+void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument *> args) {
+  using edsc::op::operator+;
+  using edsc::op::operator*;
+  assert(args.size() == 3 && "expected 3 block arguments");
+  ValueHandle a(args[0]), b(args[1]), c(args[2]);
+  linalg_yield((c + a * b).getValue());
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
+                                             StructuredIndexed I,
+                                             StructuredIndexed O) {
+  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
+                                           edsc::IterType::Parallel);
+  auto fun = [&unaryOp](ArrayRef<BlockArgument *> args) {
+    assert(args.size() == 2 && "expected 2 block arguments");
+    ValueHandle a(args[0]);
+    linalg_yield(unaryOp(a));
+  };
+  return makeLinalgGenericOp(iterTypes, {I}, {O}, fun);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
+                                                  StructuredIndexed O) {
+  ;
+  using edsc::intrinsics::tanh;
+  UnaryPointwiseOpBuilder unOp(
+      [](ValueHandle a) -> Value * { return tanh(a); });
+  return linalg_pointwise(unOp, I, O);
+}
+
+/// Binary pointwise operation (with broadcast) entry point.
+Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
+                                             StructuredIndexed I1,
+                                             StructuredIndexed I2,
+                                             StructuredIndexed O) {
+  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
+                                           edsc::IterType::Parallel);
+  auto fun = [&binaryOp](ArrayRef<BlockArgument *> args) {
+    assert(args.size() == 3 && "expected 3 block arguments");
+    ValueHandle a(args[0]), b(args[1]);
+    linalg_yield(binaryOp(a, b));
+  };
+  return makeLinalgGenericOp(iterTypes, {I1, I2}, {O}, fun);
+}
 
-Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB,
-                                     ValueHandle vC) {
+Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
+                                                 StructuredIndexed I2,
+                                                 StructuredIndexed O) {
+  using edsc::op::operator+;
+  BinaryPointwiseOpBuilder binOp(
+      [](ValueHandle a, ValueHandle b) -> Value * { return a + b; });
+  return linalg_pointwise(binOp, I1, I2, O);
+}
+
+Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
+                                                 StructuredIndexed I2,
+                                                 StructuredIndexed O) {
+  BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value * {
+    using edsc::intrinsics::select;
+    using edsc::op::operator>;
+    return select(a > b, a, b).getValue();
+  });
+  return linalg_pointwise(binOp, I1, I2, O);
+}
+
+Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
+                                          ValueHandle vC) {
   // clang-format off
   AffineExpr m, n, k;
   bindDims(ScopedContext::getContext(), m, n, k);
   StructuredIndexed A(vA), B(vB), C(vC);
   return makeLinalgGenericOp(
     {IterType::Parallel, IterType::Parallel, IterType::Reduction},
-    {A({m, n}), B({k, n})},
+    {A({m, k}), B({k, n})},
     {C({m, n})},
-    [](ArrayRef<BlockArgument *> args) {
-      using edsc::op::operator*;
-      using edsc::op::operator+;
-      ValueHandle a(args[0]), b(args[1]), c(args[2]);
-      linalg_yield((c + a * b).getValue());
-  });
+    macRegionBuilder);
+  // clang-format on
+}
+
+Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
+                                             ValueHandle vO,
+                                             ArrayRef<int> strides,
+                                             ArrayRef<int> dilations) {
+  MLIRContext *ctx = ScopedContext::getContext();
+  // TODO(ntv) some template magic to make everything rank-polymorphic.
+  assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
+  assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
+
+  // Some short names.
+  auto par = IterType::Parallel;
+  auto red = IterType::Reduction;
+  auto s = strides;
+  auto d = dilations;
+
+  AffineExpr b, f, h, w, kh, kw, c;
+  bindDims(ctx, b, f, h, w, kh, kw, c);
+  unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
+  StructuredIndexed I(vI), W(vW), O(vO);
+  // clang-format off
+  return makeLinalgGenericOp(
+    {par, par, par, par, red, red, red}, {
+      I({b,
+         // Roundtrip to flattened form to serve as canonicalization and ensure
+         // consistent ordering of subexpressions.
+         simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
+         simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
+         c}),
+      W({kh, kw, c, f})}, {
+      O({b, h, w, f})},
+    macRegionBuilder);
+  // clang-format on
+}
+
+Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
+    ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
+    ArrayRef<int> strides, ArrayRef<int> dilations) {
+  MLIRContext *ctx = ScopedContext::getContext();
+  // TODO(ntv) some template magic to make everything rank-polymorphic.
+  assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm");
+  assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm");
+
+  // Some short names.
+  auto par = IterType::Parallel;
+  auto red = IterType::Reduction;
+  auto s = strides;
+  auto d = dilations;
+
+  // clang-format off
+  AffineExpr b, dm, c, h, w, kh, kw;
+  bindDims(ctx, b, dm, c, h, w, kh, kw);
+  unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
+  StructuredIndexed I(vI), W(vW), O(vO);
+  return makeLinalgGenericOp(
+    {par, par, par, par, par, red, red}, {
+      I({b,
+         // Roundtrip to flattened form to serve as canonicalization and ensure
+         // consistent ordering of subexpressions.
+         simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
+         simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
+         c}),
+      W({kh, kw, c, dm})}, {
+      O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
+    macRegionBuilder);
   // clang-format on
 }
index abd1eb0..81bb0b9 100644 (file)
@@ -811,16 +811,61 @@ TEST_FUNC(affine_if_op) {
 }
 
 // clang-format off
+// CHECK-LABEL: func @linalg_pointwise
+//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK-SAME: indexing_maps = [(d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1)],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+//       CHECK:       addf
+//       CHECK:     }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK-SAME: indexing_maps = [(d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1)],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+//       CHECK:       cmpf "ogt"
+//       CHECK:       select
+//       CHECK:   }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+//       CHECK:   linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
+// CHECK-SAME:      indexing_maps = [(d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1)],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel"]}
+//       CHECK:     tanh
+//       CHECK:   }: memref<?x?xf32>, memref<?x?xf32>
+// clang-format on
+TEST_FUNC(linalg_pointwise_test) {
+  using namespace edsc;
+  using namespace edsc::ops;
+
+  auto f32Type = FloatType::getF32(&globalContext());
+  auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
+  auto f = makeFunction("linalg_pointwise", {},
+                        {memrefType, memrefType, memrefType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+  ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
+  AffineExpr i, j;
+  bindDims(&globalContext(), i, j);
+  StructuredIndexed SA(A), SB(B), SC(C);
+  linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
+  linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
+  linalg_pointwise_tanh(SA({i, j}), SC({i, j}));
+
+  f.print(llvm::outs());
+  f.erase();
+}
+
+// clang-format off
 // CHECK-LABEL: func @linalg_matmul
-//       CHECK:   linalg.generic
+//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK-SAME: indexing_maps = [(d0, d1, d2) -> (d0, d2), (d0, d1, d2) -> (d2, d1), (d0, d1, d2) -> (d0, d1)],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
 ///      CHECK:   ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
 //       CHECK:     %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
 //       CHECK:     %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
 //       CHECK:     linalg.yield %[[a4]] : f32
 //       CHECK:   }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
 // clang-format on
-TEST_FUNC(linalg_matmul) {
+TEST_FUNC(linalg_matmul_test) {
   using namespace edsc;
+  using namespace edsc::ops;
 
   auto f32Type = FloatType::getF32(&globalContext());
   auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
@@ -835,6 +880,70 @@ TEST_FUNC(linalg_matmul) {
   f.erase();
 }
 
+// clang-format off
+// CHECK-LABEL: func @linalg_conv_nhwc
+//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK-SAME: indexing_maps = [(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6),
+// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1),
+// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
+///      CHECK:   ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
+//       CHECK:     %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
+//       CHECK:     %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
+//       CHECK:     linalg.yield %[[a4]] : f32
+//       CHECK:   }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+// clang-format on
+TEST_FUNC(linalg_conv_nhwc) {
+  using namespace edsc;
+  using namespace edsc::ops;
+
+  auto f32Type = FloatType::getF32(&globalContext());
+  auto memrefType = MemRefType::get({-1, -1, -1, -1}, f32Type, {}, 0);
+  auto f = makeFunction("linalg_conv_nhwc", {},
+                        {memrefType, memrefType, memrefType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+  linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())),
+                   /*strides=*/{3, 4}, /*dilations=*/{5, 6});
+
+  f.print(llvm::outs());
+  f.erase();
+}
+
+// clang-format off
+// CHECK-LABEL: func @linalg_dilated_conv_nhwc
+//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK-SAME: indexing_maps = [(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3 * 3 + d5 * 5, d4 * 4 + d6 * 6, d2),
+// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1),
+// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+///      CHECK:   ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
+//       CHECK:     %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
+//       CHECK:     %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
+//       CHECK:     linalg.yield %[[a4]] : f32
+//       CHECK:   }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+// clang-format on
+TEST_FUNC(linalg_dilated_conv_nhwc) {
+  using namespace edsc;
+  using namespace edsc::ops;
+
+  auto f32Type = FloatType::getF32(&globalContext());
+  auto memrefType = MemRefType::get({-1, -1, -1, -1}, f32Type, {}, 0);
+  auto f = makeFunction("linalg_dilated_conv_nhwc", {},
+                        {memrefType, memrefType, memrefType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+  linalg_dilated_conv_nhwc(
+      makeValueHandles(llvm::to_vector<3>(f.getArguments())),
+      /*depth_multiplier=*/7,
+      /*strides=*/{3, 4}, /*dilations=*/{5, 6});
+
+  f.print(llvm::outs());
+  f.erase();
+}
+
 int main() {
   RUN_TESTS();
   return 0;