From 3c179b657583c4098d189a475d85f39ff230d924 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 16 Dec 2019 13:32:02 -0800 Subject: [PATCH] Add edsc::ops for pointwise, conv and dilated_conv This CL adds more Linalg EDSC ops and tests to support building pointwise operations along with conv and dilated_conv. This also fixes a bug in the existing linalg_matmul EDSC and beefs up the test. The current set of ops is already enough to build an interesting, albeit simple, model used internally. PiperOrigin-RevId: 285838012 --- mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h | 145 +++++++++++++++++- mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h | 35 +++++ .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 4 +- mlir/include/mlir/EDSC/Intrinsics.h | 14 +- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 163 ++++++++++++++++++--- mlir/test/EDSC/builder-api-test.cpp | 113 +++++++++++++- 6 files changed, 443 insertions(+), 31 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index 00da1d6..4213420 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -22,15 +22,17 @@ #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" namespace mlir { class BlockArgument; -namespace edsc { +namespace edsc { enum class IterType { Parallel, Reduction }; inline StringRef toString(IterType t) { @@ -38,7 +40,7 @@ inline StringRef toString(IterType t) { case IterType::Parallel: return getParallelIteratorTypeName(); case IterType::Reduction: - return getParallelIteratorTypeName(); + return getReductionIteratorTypeName(); default: llvm_unreachable("Unsupport IterType"); } @@ -78,20 +80,83 @@ inline void defaultRegionBuilder(ArrayRef args) {} Operation *makeLinalgGenericOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, - decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder, + llvm::function_ref)> regionBuilder = + defaultRegionBuilder, ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); +namespace ops { +using edsc::StructuredIndexed; +using edsc::ValueHandle; +using edsc::intrinsics::linalg_yield; + //===----------------------------------------------------------------------===// // EDSC builders for linalg generic operations. //===----------------------------------------------------------------------===// +/// Build the body of a region to compute a multiply-accumulate, under the +/// current ScopedContext, at the current insert point. +void macRegionBuilder(ArrayRef args); + /// TODO(ntv): In the future we should tie these implementations to something in /// Tablegen that generates the proper interfaces and the proper sugared named /// ops. -/// Build a linalg.generic that represents C = A * B in the current -/// ScopedContext. +/// Build a linalg.pointwise, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (i0, ..., in) = (par, ..., par) +/// | +/// | O...(some_subset...(i0, ..., in)) = +/// | some_pointwise_func...(I...(some_other_subset...(i0, ..., in))) +/// ``` +/// +/// This is a very generic entry point that can be configured in many ways to +/// build a perfect loop nest of parallel loops with arbitrarily complex +/// innermost loop code and whatever (explicit) broadcast semantics. +/// +/// This can be used with both out-of-place and in-place semantics. +/// The client is responsible for ensuring the region operations are compatible +/// with in-place semantics and parallelism. + +/// Unary pointwise operation (with broadcast) entry point. +using UnaryPointwiseOpBuilder = llvm::function_ref; +Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, + StructuredIndexed I, StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = tanh(I)`. The client is responsible for specifying the proper +/// indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); + +/// Binary pointwise operation (with broadcast) entry point. +using BinaryPointwiseOpBuilder = + llvm::function_ref; +Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, + StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = I1 + I2`. The client is responsible for specifying the proper +/// indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +/// Build a linalg.pointwise with all `parallel` iterators and a region that +/// computes `O = max(I!, I2)`. The client is responsible for specifying the +/// proper indexings when creating the StructuredIndexed. +Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, + StructuredIndexed O); + +// TODO(ntv): Implement more useful pointwise operations on a per-need basis. + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (m, n, k) = (par, par, seq) +/// | +/// | C(m, n) += A(m, k) * B(k, n) +/// ``` Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); template Operation *linalg_matmul(Container values) { @@ -99,6 +164,76 @@ template Operation *linalg_matmul(Container values) { return linalg_matmul(values[0], values[1], values[2]); } +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (batch, f, [h, w, ...], [kh, kw, ...], c) = +/// | (par, par, [par, par, ...], [red, red, ...], red) +/// | +/// | O(batch, [h, w, ...], f) += +/// | I(batch, +/// | [ +/// | stride[0] * h + dilations[0] * kh, +/// | stride[1] * w + dilations[1] * kw, ... +/// ], +/// | c) +/// | * +/// | W([kh, kw, ...], c, f) +/// ``` +/// If `dilations` or `strides` are left empty, the default value of `1` is used +/// along each relevant dimension. +/// +/// For now `...` must be empty (i.e. only 2-D convolutions are supported). +/// +// TODO(ntv) Extend convolution rank with some template magic. +Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef strides = {}, + ArrayRef dilations = {}); + +template +Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, + ArrayRef dilations = {}) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations); +} + +/// Build a linalg.generic, under the current ScopedContext, at the current +/// insert point, that computes: +/// ``` +/// (batch, dm, c, [h, w, ...], [kh, kw, ...]) = +/// | (par, par, par, [par, par, ...], [red, red, ...]) +/// | +/// | O(batch, [h, w, ...], c * depth_multiplier) += +/// | I(batch, +/// | [ +/// | stride[0] * h + dilations[0] * kh, +/// | stride[1] * w + dilations[1] * kw, ... +/// ], +/// | c) +/// | * +/// | W([kh, kw, ...], c, depth_multiplier) +/// ``` +/// If `dilations` or `strides` are left empty, the default value of `1` is used +/// along each relevant dimension. +/// +/// For now `...` must be empty (i.e. only 2-D convolutions are supported). +/// +// TODO(ntv) Extend convolution rank with some template magic. +Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, + ValueHandle vO, int depth_multiplier = 1, + ArrayRef strides = {}, + ArrayRef dilations = {}); + +template +Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier, + ArrayRef strides = {}, + ArrayRef dilations = {}) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_dilated_conv_nhwc(values[0], values[1], values[2], + depth_multiplier, strides, dilations); +} + +} // namespace ops } // namespace edsc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h new file mode 100644 index 0000000..f1acab6 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -0,0 +1,35 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using linalg_fill = OperationBuilder; +using linalg_yield = OperationBuilder; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 4f9621c..1f24a90 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -247,13 +247,13 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputs<1>, NOutputs<1>]> { } def FillOp : LinalgLibrary_Op<"fill", [NInputs<0>, NOutputs<1>]> { - let arguments = (ins AnyStridedMemRef:$input, + let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value); let extraClassDeclaration = libraryCallName # [{ ArrayAttr indexing_maps(); ArrayAttr iterator_types() { - unsigned nPar = input()->getType().cast().getRank(); + unsigned nPar = output()->getType().cast().getRank(); MLIRContext *ctx = getContext(); SmallVector iters( nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index 68bd210..6dbb343 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -154,22 +154,22 @@ template struct ValueBuilder : public ValueHandle { /// Folder-based template - ValueBuilder(OperationFolder &folder, Args... args) + ValueBuilder(OperationFolder *folder, Args... args) : ValueHandle(ValueHandle::create(folder, detail::unpack(args)...)) {} - ValueBuilder(OperationFolder &folder, ArrayRef vs) + ValueBuilder(OperationFolder *folder, ArrayRef vs) : ValueBuilder(ValueBuilder::create(folder, detail::unpack(vs))) {} template - ValueBuilder(OperationFolder &folder, ArrayRef vs, Args... args) + ValueBuilder(OperationFolder *folder, ArrayRef vs, Args... args) : ValueHandle(ValueHandle::create(folder, detail::unpack(vs), detail::unpack(args)...)) {} template - ValueBuilder(OperationFolder &folder, T t, ArrayRef vs, + ValueBuilder(OperationFolder *folder, T t, ArrayRef vs, Args... args) : ValueHandle(ValueHandle::create(folder, detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {} template - ValueBuilder(OperationFolder &folder, T1 t1, T2 t2, ArrayRef vs, + ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef vs, Args... args) : ValueHandle(ValueHandle::create( folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), @@ -200,6 +200,7 @@ template struct OperationBuilder : public OperationHandle { OperationBuilder() : OperationHandle(OperationHandle::create()) {} }; +using addf = ValueBuilder; using affine_apply = ValueBuilder; using affine_if = OperationBuilder; using affine_load = ValueBuilder; @@ -212,11 +213,14 @@ using constant_int = ValueBuilder; using dealloc = OperationBuilder; using dim = ValueBuilder; using muli = ValueBuilder; +using mulf = ValueBuilder; +using memref_cast = ValueBuilder; using ret = OperationBuilder; using select = ValueBuilder; using std_load = ValueBuilder; using std_store = OperationBuilder; using subi = ValueBuilder; +using tanh = ValueBuilder; using view = ValueBuilder; /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`. diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 3daeafe..77e3a1e 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" @@ -26,6 +27,7 @@ using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; +using namespace mlir::edsc::ops; static void getMaxDimIndex(ArrayRef structuredIndices, unsigned &pos) { @@ -42,24 +44,26 @@ static void getMaxDimIndex(ArrayRef structuredIndices, Operation *mlir::edsc::makeLinalgGenericOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, - decltype(defaultRegionBuilder) regionBuilder, ArrayRef otherValues, - ArrayRef otherAttributes) { + llvm::function_ref)> regionBuilder, + ArrayRef otherValues, ArrayRef otherAttributes) { auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); unsigned nOutputs = outputs.size(); - unsigned rank = 0; - getMaxDimIndex(inputs, rank); - getMaxDimIndex(outputs, rank); + unsigned maxPos = 0; + getMaxDimIndex(inputs, maxPos); + getMaxDimIndex(outputs, maxPos); + // maxPos is 0 indexed, need to turn this into a count (i.e. +1) + unsigned nDims = maxPos + 1; SmallVector maps; maps.reserve(nInputs + nOutputs); for (auto in : inputs) maps.push_back( - AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs())); + AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs())); for (auto out : outputs) maps.push_back( - AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs())); + AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); unsigned nViews = nInputs + nOutputs; SmallVector values; @@ -105,23 +109,148 @@ Operation *mlir::edsc::makeLinalgGenericOp( return op; } -using linalg_yield = OperationBuilder; +void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { + using edsc::op::operator+; + using edsc::op::operator*; + assert(args.size() == 3 && "expected 3 block arguments"); + ValueHandle a(args[0]), b(args[1]), c(args[2]); + linalg_yield((c + a * b).getValue()); +} + +Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, + StructuredIndexed I, + StructuredIndexed O) { + SmallVector iterTypes(O.getExprs().size(), + edsc::IterType::Parallel); + auto fun = [&unaryOp](ArrayRef args) { + assert(args.size() == 2 && "expected 2 block arguments"); + ValueHandle a(args[0]); + linalg_yield(unaryOp(a)); + }; + return makeLinalgGenericOp(iterTypes, {I}, {O}, fun); +} + +Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, + StructuredIndexed O) { + ; + using edsc::intrinsics::tanh; + UnaryPointwiseOpBuilder unOp( + [](ValueHandle a) -> Value * { return tanh(a); }); + return linalg_pointwise(unOp, I, O); +} + +/// Binary pointwise operation (with broadcast) entry point. +Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, + StructuredIndexed I1, + StructuredIndexed I2, + StructuredIndexed O) { + SmallVector iterTypes(O.getExprs().size(), + edsc::IterType::Parallel); + auto fun = [&binaryOp](ArrayRef args) { + assert(args.size() == 3 && "expected 3 block arguments"); + ValueHandle a(args[0]), b(args[1]); + linalg_yield(binaryOp(a, b)); + }; + return makeLinalgGenericOp(iterTypes, {I1, I2}, {O}, fun); +} -Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC) { +Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, + StructuredIndexed I2, + StructuredIndexed O) { + using edsc::op::operator+; + BinaryPointwiseOpBuilder binOp( + [](ValueHandle a, ValueHandle b) -> Value * { return a + b; }); + return linalg_pointwise(binOp, I1, I2, O); +} + +Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, + StructuredIndexed I2, + StructuredIndexed O) { + BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value * { + using edsc::intrinsics::select; + using edsc::op::operator>; + return select(a > b, a, b).getValue(); + }); + return linalg_pointwise(binOp, I1, I2, O); +} + +Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, + ValueHandle vC) { // clang-format off AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC); return makeLinalgGenericOp( {IterType::Parallel, IterType::Parallel, IterType::Reduction}, - {A({m, n}), B({k, n})}, + {A({m, k}), B({k, n})}, {C({m, n})}, - [](ArrayRef args) { - using edsc::op::operator*; - using edsc::op::operator+; - ValueHandle a(args[0]), b(args[1]), c(args[2]); - linalg_yield((c + a * b).getValue()); - }); + macRegionBuilder); + // clang-format on +} + +Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, + ValueHandle vO, + ArrayRef strides, + ArrayRef dilations) { + MLIRContext *ctx = ScopedContext::getContext(); + // TODO(ntv) some template magic to make everything rank-polymorphic. + assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm"); + assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); + + // Some short names. + auto par = IterType::Parallel; + auto red = IterType::Reduction; + auto s = strides; + auto d = dilations; + + AffineExpr b, f, h, w, kh, kw, c; + bindDims(ctx, b, f, h, w, kh, kw, c); + unsigned numDims = c.cast().getPosition() + 1; + StructuredIndexed I(vI), W(vW), O(vO); + // clang-format off + return makeLinalgGenericOp( + {par, par, par, par, red, red, red}, { + I({b, + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. + simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), + simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), + c}), + W({kh, kw, c, f})}, { + O({b, h, w, f})}, + macRegionBuilder); + // clang-format on +} + +Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc( + ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier, + ArrayRef strides, ArrayRef dilations) { + MLIRContext *ctx = ScopedContext::getContext(); + // TODO(ntv) some template magic to make everything rank-polymorphic. + assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm"); + assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); + + // Some short names. + auto par = IterType::Parallel; + auto red = IterType::Reduction; + auto s = strides; + auto d = dilations; + + // clang-format off + AffineExpr b, dm, c, h, w, kh, kw; + bindDims(ctx, b, dm, c, h, w, kh, kw); + unsigned numDims = kw.cast().getPosition() + 1; + StructuredIndexed I(vI), W(vW), O(vO); + return makeLinalgGenericOp( + {par, par, par, par, par, red, red}, { + I({b, + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. + simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), + simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), + c}), + W({kh, kw, c, dm})}, { + O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})}, + macRegionBuilder); // clang-format on } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index abd1eb0..81bb0b94 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -811,16 +811,61 @@ TEST_FUNC(affine_if_op) { } // clang-format off +// CHECK-LABEL: func @linalg_pointwise +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [(d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1)], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: addf +// CHECK: }: memref, memref, memref +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [(d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1)], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: cmpf "ogt" +// CHECK: select +// CHECK: }: memref, memref, memref +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [(d0, d1) -> (d0, d1), (d0, d1) -> (d0, d1)], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK: tanh +// CHECK: }: memref, memref +// clang-format on +TEST_FUNC(linalg_pointwise_test) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); + auto f = makeFunction("linalg_pointwise", {}, + {memrefType, memrefType, memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + AffineExpr i, j; + bindDims(&globalContext(), i, j); + StructuredIndexed SA(A), SB(B), SC(C); + linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); + linalg_pointwise_tanh(SA({i, j}), SC({i, j})); + + f.print(llvm::outs()); + f.erase(); +} + +// clang-format off // CHECK-LABEL: func @linalg_matmul -// CHECK: linalg.generic +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [(d0, d1, d2) -> (d0, d2), (d0, d1, d2) -> (d2, d1), (d0, d1, d2) -> (d0, d1)], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} /// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): // CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 // CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 // CHECK: linalg.yield %[[a4]] : f32 // CHECK: }: memref, memref, memref // clang-format on -TEST_FUNC(linalg_matmul) { +TEST_FUNC(linalg_matmul_test) { using namespace edsc; + using namespace edsc::ops; auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); @@ -835,6 +880,70 @@ TEST_FUNC(linalg_matmul) { f.erase(); } +// clang-format off +// CHECK-LABEL: func @linalg_conv_nhwc +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6), +// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1), +// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} +/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 +// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 +// CHECK: linalg.yield %[[a4]] : f32 +// CHECK: }: memref, memref, memref +// clang-format on +TEST_FUNC(linalg_conv_nhwc) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1, -1, -1}, f32Type, {}, 0); + auto f = makeFunction("linalg_conv_nhwc", {}, + {memrefType, memrefType, memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + /*strides=*/{3, 4}, /*dilations=*/{5, 6}); + + f.print(llvm::outs()); + f.erase(); +} + +// clang-format off +// CHECK-LABEL: func @linalg_dilated_conv_nhwc +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3 * 3 + d5 * 5, d4 * 4 + d6 * 6, d2), +// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1), +// CHECK-SAME: (d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} +/// CHECK: ^bb1(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 +// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 +// CHECK: linalg.yield %[[a4]] : f32 +// CHECK: }: memref, memref, memref +// clang-format on +TEST_FUNC(linalg_dilated_conv_nhwc) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1, -1, -1}, f32Type, {}, 0); + auto f = makeFunction("linalg_dilated_conv_nhwc", {}, + {memrefType, memrefType, memrefType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + linalg_dilated_conv_nhwc( + makeValueHandles(llvm::to_vector<3>(f.getArguments())), + /*depth_multiplier=*/7, + /*strides=*/{3, 4}, /*dilations=*/{5, 6}); + + f.print(llvm::outs()); + f.erase(); +} + int main() { RUN_TESTS(); return 0; -- 2.7.4