static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
+public:
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
mlir::Operation::operand_range getInputs();
mlir::Operation::operand_range getOutputs();
-public:
/// These are better as methods calling into the ConcreteOp instead of
/// template parameters because methods allow more generic behavior and avoid
/// specializing for number of arguments. All derived classes have
//////////////////////////////////////////////////////////////////////////////
mlir::Value *getInputView(unsigned i);
mlir::Value *getOutputView(unsigned i);
- /// Computes a mapping from all the ranges of the operands to the enclosing
- /// loops. In order to support "broadcast"-style semantics, we need to
- /// consider all the operands (i.e. input operands are not sufficient).
+
+ /// Each op is responsible for declaring how it lowers itself to scalar form,
+ /// given the enclosing parallel and reduction induction variables.
+ /// `emitScalarImplementation` emits the scalar IR for the op in the nesting
+ /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
+ /// `parallel.back()`).
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
+
+ /// Represents a mapping from the loops to all the ranges of the operands.
/// The operands and their ranges are in the order defined by the particular
/// ConcreteOp implementation, the resulting map must match those.
- /// This is currently computed but can also be specified explicitly in each
- /// operator to generalize to cases where an analysis is not available.
- mlir::AffineMap operandRangesToLoopsMap();
+ /// In favorable cases, this can be calculated by an analysis but specifying
+ /// it explicitly is not expensive and generalizes to cases where an analysis
+ /// is not available.
+ /// For details, see the description of loopsToOperandRangesMap in each
+ /// ConcreteOp
+ mlir::AffineMap loopsToOperandRangesMap();
};
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction();
+
+ /// Inputs to this map will be (%k) coming from enclosing loops.
+ /// Therefore, the mapping to get back to A(K), B(K), C() is:
+ /// (d0) -> (d0, d0)(%k)
+ /// And the operands ranges are:
+ /// (%k, %k)
+ mlir::AffineMap loopsToOperandRangesMap();
+
+ /// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
+ /// to:
+ /// 1. conditionally assign scalarC to 0.0f on the first iteration or load
+ /// C[] from memory (0-D tensor)
+ /// 2. multiply A[r_i] by B[r_i] and add to scalarC
+ /// 3. store back scalarC at C[]
+ ///
+ /// In some compact index notation this could be written:
+ /// cond = (r_i == zero)
+ /// scalarC = select(cond, zerof, C[]);
+ /// C[] = scalarC + A[r_i] * B[r_i];
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
};
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction();
+
+ /// Inputs to this map will be (%m, %k) coming from enclosing loops.
+ /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
+ /// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
+ /// And the operands ranges are:
+ /// (%m, %k, %k, %m)
+ mlir::AffineMap loopsToOperandRangesMap();
+
+ /// Given an enclosing parallel loop with iv `i` and an enclosing parallel
+ /// loop with iv `r_j`, emits MLIR corresponding to:
+ /// 1. conditionally assign scalarC to 0.0f on the first iteration or load
+ /// C[i]
+ /// 2. multiply A[i, r_j] by B[r_j] and add to scalarC
+ /// 3. store back scalarC at C[i]
+ ///
+ /// In some compact index notation this could be written:
+ /// cond = (r_j == zero)
+ /// scalarC = select(cond, zerof, C(i));
+ /// C(i) = scalarC + A(i, r_j) * B(r_j);
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
};
/// Implements C = A * B on 2-D matrices.
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
void writeAsFinerGrainTensorContraction();
+
+ /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
+ /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
+ /// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
+ /// And the operands ranges are:
+ /// (%m, %k, %k, %n, %m, %n)
+ mlir::AffineMap loopsToOperandRangesMap();
+
+ /// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
+ /// reduction loop with iv `r_k`, emits MLIR corresponding to:
+ /// 1. conditionally assign scalarC to 0.0f on the first iteration or load
+ /// C[i, j]
+ /// 2. multiply A[i, r_k] by B[r_k, j] and add to scalarC
+ /// 3. store back scalarC at C[i, j]
+ ///
+ /// In some compact index notation this could be written:
+ /// cond = (r_k == zero)
+ /// scalarC = select(cond, zerof, C[i, j]);
+ /// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
+ void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs);
};
} // namespace linalg
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
lowerToFinerGrainedTensorContraction(f);
+ composeSliceOps(f);
// clang-format off
// CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+ // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32xf32>">
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
- // CHECK-NEXT: %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
- // CHECK-NEXT: %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
- // CHECK-NEXT: linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]}
+ // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+ // CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+ // CHECK: linalg.matvec {%[[vA]], %[[vB]]} -> {%[[vC]]}
// clang-format on
cleanupAndPrintFunction(f);
}
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
lowerToFinerGrainedTensorContraction(f);
lowerToFinerGrainedTensorContraction(f);
+ composeSliceOps(f);
// clang-format off
// CHECK-LABEL: func @matmul_as_dot(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
- // CHECK-NEXT: %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
- // CHECK-NEXT: %[[sC:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+ // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
- // CHECK-NEXT: %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view<f32>">
- // CHECK-NEXT: %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>">
+ // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+ // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<0xf32>">
// CHECK-NEXT: linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
// clang-format on
cleanupAndPrintFunction(f);
}
+TEST_FUNC(matmul_as_loops) {
+ MLIRContext context;
+ Module module(&context);
+ mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+ lowerToLoops(f);
+ composeSliceOps(f);
+ // clang-format off
+ // CHECK-LABEL: func @matmul_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ // CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
+ // CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+ // CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
+ // CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg<"range">
+ // CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg<"range">
+ // CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg<"range">
+ // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view<f32xf32>">
+ // CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view<f32xf32>">
+ // CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view<f32xf32>">
+ // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) {
+ // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) {
+ // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
+ // CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index
+ // CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view<f32xf32>">
+ // CHECK: %{{.*}} = select {{.*}} : f32
+ // CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view<f32xf32>">
+ // CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view<f32xf32>">
+ // CHECK: %{{.*}} = mulf {{.*}} : f32
+ // CHECK: %{{.*}} = addf {{.*}} : f32
+ // CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg<"view<f32xf32>">
+ // clang-format on
+ cleanupAndPrintFunction(f);
+}
+
+TEST_FUNC(matmul_as_matvec_as_loops) {
+ MLIRContext context;
+ Module module(&context);
+ mlir::Function *f =
+ makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
+ lowerToFinerGrainedTensorContraction(f);
+ lowerToLoops(f);
+ composeSliceOps(f);
+ // clang-format off
+ // CHECK-LABEL: func @matmul_as_matvec_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ // CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
+ // CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+ // CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
+ // CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view<f32xf32>">
+ // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
+ // CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view<f32>">
+ // CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view<f32>">
+ // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
+ // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
+ // CHECK: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
+ // CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view<f32>">
+ // CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32
+ // CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view<f32>">
+ // CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view<f32xf32>">
+ // CHECK: %{{.*}} = mulf %[[A]], %[[B]] : f32
+ // CHECK: %{{.*}} = addf %[[C2]], %{{.*}} : f32
+ // CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view<f32>">
+ // clang-format on
+ cleanupAndPrintFunction(f);
+}
+
int main() {
RUN_TESTS();
return 0;
--- /dev/null
+//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_ANALYSIS_H_
+#define LINALG3_ANALYSIS_H_
+
+#include "linalg2/Analysis.h"
+
+namespace mlir {
+class AffineMap;
+} // namespace mlir
+
+namespace linalg {
+
+/// Given a `map` specification and a subset of its results
+/// `[beginResult, endResult)`, returns the inverse map that maps result
+/// positions to dim positions.
+mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0,
+ unsigned endResult = 0);
+
+} // namespace linalg
+
+#endif // LINALG3_ANALYSIS_H_
--- /dev/null
+//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_INTRINSICS_H_
+#define LINALG3_INTRINSICS_H_
+
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+
+namespace linalg {
+namespace intrinsics {
+using load = mlir::edsc::intrinsics::ValueBuilder<LoadOp>;
+using store = mlir::edsc::intrinsics::OperationBuilder<StoreOp>;
+} // namespace intrinsics
+} // namespace linalg
+
+#endif // LINALG3_INTRINSICS_H_
--- /dev/null
+//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_LOADSTOREOP_H_
+#define LINALG3_LOADSTOREOP_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace linalg {
+
+class ViewType;
+
+/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType
+/// instead of MemRefType.
+class LoadOp : public mlir::Op<LoadOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::OneResult> {
+public:
+ using Op::Op;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.load"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *view,
+ mlir::ArrayRef<mlir::Value *> indices = {});
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ unsigned getRank();
+ ViewType getViewType();
+ mlir::Value *getView() { return getOperand(0); }
+ mlir::Operation::operand_range getIndices() {
+ return {operand_begin() + 1, operand_end()};
+ }
+};
+
+/// A linalg.StoreOp is the counterpart of affine.store but operating on
+/// ViewType instead of MemRefType.
+class StoreOp : public mlir::Op<StoreOp, mlir::OpTrait::VariadicOperands,
+ mlir::OpTrait::ZeroResult> {
+public:
+ using Op::Op;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Hooks to customize the behavior of this op.
+ //////////////////////////////////////////////////////////////////////////////
+ static llvm::StringRef getOperationName() { return "linalg.store"; }
+ static void build(mlir::Builder *b, mlir::OperationState *result,
+ mlir::Value *valueToStore, mlir::Value *view,
+ mlir::ArrayRef<mlir::Value *> indices = {});
+ mlir::LogicalResult verify();
+ static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+ void print(mlir::OpAsmPrinter *p);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Op-specific functionality.
+ //////////////////////////////////////////////////////////////////////////////
+ unsigned getRank();
+ ViewType getViewType();
+ mlir::Value *getValueToStore() { return getOperand(0); }
+ mlir::Value *getView() { return getOperand(1); }
+ mlir::Operation::operand_range getIndices() {
+ return {operand_begin() + 2, operand_end()};
+ }
+};
+
+} // namespace linalg
+
+#endif // LINALG3_LOADSTOREOP_H_
#define LINALG3_OPS_H_
#include "linalg2/Ops.h"
+#include "linalg3/LoadStoreOps.h"
#include "linalg3/TensorOps.h"
#endif // LINALG3_OPS_H_
#define LINALG3_TENSOROPS_INL_H_
#include "linalg1/Common.h"
+#include "linalg1/Utils.h"
#include "linalg2/TensorOps.h"
-
-namespace linalg {
+#include "linalg3/Analysis.h"
+#include "linalg3/Ops.h"
template <class ConcreteOp>
mlir::Value *
return *(getOutputs().begin() + i);
}
-} // namespace linalg
+template <class ConcreteOp>
+mlir::AffineMap
+linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangesMap() {
+ return static_cast<ConcreteOp *>(this)->loopsToOperandRangesMap();
+}
+
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
+ llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs) {
+ static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
+ reductionIvs);
+}
+
+template <class ConcreteOp>
+mlir::AffineMap linalg::operandRangesToLoopsMap(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ return inverseSubMap(tensorContraction.loopsToOperandRangesMap());
+}
+
+// Extract the ranges from a given ViewOp or SliceOp.
+//
+// In the case of a ViewOp, things are simple: just traverse the indexings and
+// get all the ranges (i.e. drop the indices).
+//
+// In the case of a SliceOp, things are trickier because we need to handle a
+// potential rank-reduction:
+// 1. Examine the indexing to determine if it is rank-reducing.
+// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
+// that `d >= slicingDim`. This is to account for the rank reduction.
+// `getRootIndex` is then called on the **parent** view
+static llvm::SmallVector<mlir::Value *, 8>
+extractRangesFromViewOrSliceOp(mlir::Value *view) {
+ // This expects a viewType which must come from either ViewOp or SliceOp.
+ assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
+ if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
+ return viewOp.getRanges();
+
+ auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+ unsigned slicingDim = sliceOp.getSlicingDim();
+ auto *indexing = *(sliceOp.getIndexings().begin());
+ bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
+ unsigned offset = 0;
+ llvm::SmallVector<mlir::Value *, 8> res;
+ res.reserve(sliceOp.getRank());
+ for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
+ if (d == slicingDim && isRankReducing)
+ offset = 1;
+ auto *parentView = sliceOp.getParentView();
+ auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
+ res.push_back(indexingPosPair.first);
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *in : tensorContraction.getInputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(in);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *out : tensorContraction.getOutputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(out);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
+ llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
+ res.append(tmp.begin(), tmp.end());
+ return res;
+}
-#endif // LINALG3_TENSOROPS-INL_H_
+#endif // LINALG3_TENSOROPS_INL_H_
#include "linalg2/TensorOps.h"
+namespace linalg {
+
+///
+/// Ideally all these functions would go in an Analysis but until
+/// TensorContractionBase is templated, they need to remain close enough.
+///
+
+/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
+/// map ranges to enclosing loops for all the operands' ranges.
+template <class ConcreteOp>
+mlir::AffineMap operandRangesToLoopsMap(
+ linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+
+/// Takes a `tensorContraction` and returns the ranges of all its operands.
+/// When an operand comes from a ViewOp, things are simple:
+/// just traverse the indexings and get all the ranges
+/// (i.e. drop the rank-reducing indices).
+/// In the case of a SliceOp, things are more involved because we need to handle
+/// potential rank-reductions.
+/// This function abstracts this complexity away and returns all the ranges.
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8>
+getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+
+} // namespace linalg
+
/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
/// TensorOps by adding implementations as they are needed in the appropriate
/// step in the tutorial.
/// to only use linalg.view operations.
void composeSliceOps(mlir::Function *f);
-/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot)
-/// as linalg.matvec (resp. linalg.dot, loop form).
+/// Traverses `f` and rewrites linalg.load and linalg.store to affine.load and
+/// affine.store operations.
+void lowerLinalgLoadStores(mlir::Function *f);
+
+/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec)
+/// as linalg.matvec (resp. linalg.dot).
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
+/// Traverses `f` and rewrites linalg operations in loop form.
+void lowerToLoops(mlir::Function *f);
+
} // namespace linalg
#endif // LINALG3_TRANSFORMS_H_
--- /dev/null
+//===- Analysis.cpp - Implementation of analysis functions for Linalg -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR operation to create a new RangeType in the
+// linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg3/Analysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/StandardTypes.h"
+
+using llvm::SmallVector;
+using namespace mlir;
+
+// Compute an inverse map (only works with permutations for now).
+// Note that the mapping is generally non-full rank, so this returns the first
+// seen entry for each dim.
+static AffineMap inversePermutationMap(AffineMap map) {
+ SmallVector<AffineExpr, 4> exprs(map.getNumDims());
+ for (auto en : llvm::enumerate(map.getResults())) {
+ auto expr = en.value();
+ auto d = expr.dyn_cast<AffineDimExpr>();
+ assert(d && "permutation map expected");
+ if (exprs[d.getPosition()])
+ continue;
+ exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
+ }
+ SmallVector<AffineExpr, 4> seenExprs;
+ seenExprs.reserve(map.getNumDims());
+ for (auto expr : exprs)
+ if (expr)
+ seenExprs.push_back(expr);
+ assert(map.getNumSymbols() == 0 && "expected map without symbols");
+ assert(seenExprs.size() == map.getNumInputs() && "map is not invertible");
+ return AffineMap::get(map.getNumResults(), 0, seenExprs, {});
+}
+
+mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult,
+ unsigned endResult) {
+ if (beginResult == 0 && endResult == 0)
+ endResult = map.getNumResults();
+ auto subMap = AffineMap::get(
+ map.getNumDims(), map.getNumSymbols(),
+ map.getResults().slice(beginResult, endResult - beginResult), {});
+ return inversePermutationMap(subMap);
+}
--- /dev/null
+//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file registers the Linalg dialect and should live in a standalone
+// library. Linking with this library will create a static global object that
+// performs dialect registration.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg1/Dialect.h"
+#include "linalg1/Types.h"
+#include "linalg3/Ops.h"
+
+using namespace linalg;
+
+LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
+ : Dialect("linalg", context) {
+ addTypes<RangeType, ViewType>();
+ addOperations<DotOp, LoadOp, MatvecOp, MatmulOp, RangeOp, SliceOp, StoreOp,
+ ViewOp>();
+}
+
+// Dialect registration triggers the creation of a `LinalgDialect` object which
+// adds the proper types and operations to the dialect.
+static mlir::DialectRegistration<LinalgDialect> LinalgOps;
--- /dev/null
+//===- LoadStoreOps.cpp - Implementation of linalg Load/Store operations --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements linalg.load and linalg.store operations which allow
+// accessing memory through ViewType values.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg3/LoadStoreOps.h"
+#include "linalg3/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+using llvm::ArrayRef;
+using namespace mlir;
+using namespace linalg;
+
+////////////////////////////////////////////////////////////////////////////////
+// LoadOp.
+////////////////////////////////////////////////////////////////////////////////
+void linalg::LoadOp::build(Builder *b, OperationState *result, Value *view,
+ ArrayRef<Value *> indices) {
+ auto viewType = view->getType().cast<ViewType>();
+ result->addOperands(view);
+ result->addOperands(indices);
+ result->addTypes(viewType.getElementType());
+}
+
+void linalg::LoadOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getView() << '[';
+ p->printOperands(getIndices());
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << getViewType();
+}
+
+bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) {
+ llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
+ return false;
+}
+
+LogicalResult linalg::LoadOp::verify() {
+ if (getNumOperands() == 0)
+ return emitOpError("expected a view to load from");
+
+ auto viewType = getView()->getType().dyn_cast<ViewType>();
+ if (!viewType)
+ return emitOpError("first operand must be a view");
+
+ if (getType() != viewType.getElementType())
+ return emitOpError("result type must match element type of the view");
+
+ if (getRank() != getNumOperands() - 1)
+ return emitOpError("incorrect number of indices for load");
+
+ for (auto *idx : getIndices())
+ if (!idx->getType().isIndex())
+ return emitOpError("index to load must have 'index' type");
+
+ return success();
+}
+
+ViewType linalg::LoadOp::getViewType() {
+ return getView()->getType().cast<ViewType>();
+}
+
+unsigned linalg::LoadOp::getRank() { return getViewType().getRank(); }
+
+////////////////////////////////////////////////////////////////////////////////
+// StoreOp.
+////////////////////////////////////////////////////////////////////////////////
+void linalg::StoreOp::build(Builder *b, OperationState *result,
+ Value *valueToStore, Value *view,
+ ArrayRef<Value *> indices) {
+ result->addOperands(valueToStore);
+ result->addOperands(view);
+ result->addOperands(indices);
+}
+
+void linalg::StoreOp::print(OpAsmPrinter *p) {
+ *p << getOperationName() << " " << *getValueToStore();
+ *p << ", " << *getView() << '[';
+ p->printOperands(getIndices());
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << getViewType();
+}
+
+bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) {
+ assert(false && "NYI");
+ return false;
+}
+
+LogicalResult linalg::StoreOp::verify() {
+ if (getNumOperands() < 2)
+ return emitOpError("expected a value to store and a view");
+
+ // Second operand is a memref type.
+ auto viewType = getView()->getType().dyn_cast<ViewType>();
+ if (!viewType)
+ return emitOpError("second operand must be a view");
+
+ // First operand must have same type as memref element type.
+ if (getValueToStore()->getType() != viewType.getElementType())
+ return emitOpError("first operand must have same element type as the view");
+
+ if (getNumOperands() != 2 + viewType.getRank())
+ return emitOpError("store index operand count not equal to view rank");
+
+ for (auto *idx : getIndices())
+ if (!idx->getType().isIndex())
+ return emitOpError("index to store must have 'index' type");
+
+ return success();
+}
+
+unsigned linalg::StoreOp::getRank() { return getViewType().getRank(); }
+
+ViewType linalg::StoreOp::getViewType() {
+ return getView()->getType().cast<ViewType>();
+}
#include "linalg1/Analysis.h"
#include "linalg1/Common.h"
-#include "linalg2/Intrinsics.h"
+#include "linalg3/Intrinsics.h"
#include "linalg3/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
using namespace mlir;
using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
using namespace linalg;
using namespace linalg::intrinsics;
+//////////////////////////////////////////////////////////////////////////////
+// Implementation of DotOp.
+//////////////////////////////////////////////////////////////////////////////
+AffineMap linalg::DotOp::loopsToOperandRangesMap() {
+ // A(K), B(K), C()
+ assert(getRanges(*this).size() == 2);
+ auto *context = ScopedContext::getContext();
+ auto d0 = getAffineDimExpr(0, context); // K
+ // A(K), B(K), C()
+ // (d0) -> (d0, d0)(%k)
+ return AffineMap::get(1, 0, {d0, d0}, {});
+}
+
+void linalg::DotOp::emitScalarImplementation(
+ llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
+ using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
+ linalg::intrinsics::store>;
+ assert(reductionIvs.size() == 1);
+ auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
+ auto *body = innermostLoop.getBody();
+ using edsc::op::operator+;
+ using edsc::op::operator*;
+ using edsc::op::operator==;
+ using edsc::intrinsics::select;
+ ScopedContext scope( // account for affine.terminator in loop.
+ FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+ auto f32 = ScopedContext::getBuilder()->getF32Type();
+ IndexHandle zero(constant_index(0));
+ ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+ IndexHandle r_i(reductionIvs[0]);
+ IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
+ ValueHandle cond = (r_i == zero);
+ ValueHandle scalarC = select(cond, zerof, *C());
+ C() = scalarC + A(r_i) * B(r_i);
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// Implementation of MatvecOp.
+//////////////////////////////////////////////////////////////////////////////
+AffineMap linalg::MatvecOp::loopsToOperandRangesMap() {
+ // A(M, K), B(K), C(M)
+ assert(getRanges(*this).size() == 4);
+ auto *context = ScopedContext::getContext();
+ auto d0 = getAffineDimExpr(0, context); // M
+ auto d1 = getAffineDimExpr(1, context); // K
+ // A(M, K), B(K), C(M)
+ // (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
+ return AffineMap::get(2, 0, {d0, d1, d1, d0}, {});
+}
+
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
// The body expression for dot is: C() = A(r_i) * B(r_i);
// So we must drop the `i` loop from the matvec.
void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
auto *op = getOperation();
- ScopedContext scope(FuncBuilder(op), op->getLoc());
- IndexHandle i;
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
auto indexingPosPair = getViewRootIndexing(vA, 0);
assert(indexingPosPair.first->getDefiningOp() &&
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
- linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
- dot(slice(vA, i, 0), vB, slice(vC, i, 0)),
+ // clang-format off
+ ScopedContext scope(FuncBuilder(op), op->getLoc());
+ IndexHandle i;
+ using linalg::common::LoopNestRangeBuilder;
+ LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
+ [&i, &vA, &vB, &vC]() {
+ ValueHandle sliceA = slice(vA, i, 0);
+ ValueHandle sliceC = slice(vC, i, 0);
+ dot(sliceA, vB, sliceC);
+ /// NestedBuilders expect handles, we thus return an IndexHandle.
+ return IndexHandle();
+ }()
});
+ // clang-format on
+}
+
+void linalg::MatvecOp::emitScalarImplementation(
+ llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
+ using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
+ linalg::intrinsics::store>;
+ assert(reductionIvs.size() == 1);
+ auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
+ auto *body = innermostLoop.getBody();
+ using edsc::op::operator+;
+ using edsc::op::operator*;
+ using edsc::op::operator==;
+ using edsc::intrinsics::select;
+ ScopedContext scope( // account for affine.terminator in loop.
+ FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+ auto f32 = ScopedContext::getBuilder()->getF32Type();
+ IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
+ IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
+ IndexHandle zero(constant_index(0));
+ ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+ ValueHandle cond = (r_j == zero);
+ ValueHandle scalarC = select(cond, zerof, *C(i));
+ C(i) = scalarC + A(i, r_j) * B(r_j);
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// Op-specific Matmul.
+//////////////////////////////////////////////////////////////////////////////
+AffineMap linalg::MatmulOp::loopsToOperandRangesMap() {
+ // A(M, K), B(K, N), C(M, N)
+ assert(getRanges(*this).size() == 6);
+ auto *context = ScopedContext::getContext();
+ auto d0 = getAffineDimExpr(0, context); // M
+ auto d1 = getAffineDimExpr(1, context); // N
+ auto d2 = getAffineDimExpr(2, context); // K
+ // A(M, K), B(K, N), C(M, N):
+ // (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
+ return AffineMap::get(3, 0, {d0, d2, d2, d1, d0, d1}, {});
}
// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
// declaratively.
void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
auto *op = getOperation();
- ScopedContext scope(FuncBuilder(op), op->getLoc());
- IndexHandle j;
auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
auto indexingPosPair = getViewRootIndexing(vB, 1);
assert(indexingPosPair.first->getDefiningOp() &&
indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
- linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
- matvec(vA, slice(vB, j, 1), slice(vC, j, 1)),
+ using linalg::common::LoopNestRangeBuilder;
+ // clang-format off
+ ScopedContext scope(FuncBuilder(op), op->getLoc());
+ IndexHandle j;
+ LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
+ [&j, &vA, &vB, &vC]() {
+ ValueHandle sliceB = slice(vB, j, 1);
+ ValueHandle sliceC = slice(vC, j, 1);
+ matvec(vA, sliceB, sliceC);
+ /// NestedBuilders expect handles, we thus return an IndexHandle.
+ return IndexHandle();
+ }()
});
+ // clang-format on
+}
+
+void linalg::MatmulOp::emitScalarImplementation(
+ llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
+ using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
+ linalg::intrinsics::store>;
+ assert(reductionIvs.size() == 1);
+ auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
+ auto *body = innermostLoop.getBody();
+ using edsc::op::operator+;
+ using edsc::op::operator*;
+ using edsc::op::operator==;
+ using edsc::intrinsics::select;
+ ScopedContext scope( // account for affine.terminator in loop.
+ FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+ auto f32 = ScopedContext::getBuilder()->getF32Type();
+ IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
+ IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
+ IndexHandle zero(constant_index(0));
+ ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+ ValueHandle cond = r_k == zero;
+ ValueHandle scalarC = select(cond, zerof, *C(i, j));
+ C(i, j) = scalarC + A(i, r_k) * B(r_k, j);
}
op->erase();
});
}
+
+// Folding eagerly is necessary to abide by affine.for static step requirement.
+// Returns nullptr if folding is not trivially feasible.
+static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
+ assert(map.getNumResults() == 1 && "single result map expected");
+ auto expr = map.getResult(0);
+ if (auto dim = expr.dyn_cast<AffineDimExpr>())
+ return operands[dim.getPosition()];
+ if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+ return operands[map.getNumDims() + sym.getPosition()];
+ if (auto cst = expr.dyn_cast<AffineConstantExpr>())
+ return constant_index(cst.getValue());
+ return nullptr;
+}
+
+static Value *makeFoldedComposedAffineApply(AffineMap map,
+ ArrayRef<Value *> operandsRef) {
+ SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
+ fullyComposeAffineMapAndOperands(&map, &operands);
+ if (auto *v = tryFold(map, operands)) {
+ return v;
+ }
+ auto *b = ScopedContext::getBuilder();
+ auto loc = ScopedContext::getLocation();
+ return b->create<AffineApplyOp>(loc, map, operands).getResult();
+}
+
+struct RangeParts {
+ explicit RangeParts(unsigned reserved);
+ RangeParts(ArrayRef<Value *> ranges);
+
+ SmallVector<Value *, 4> makeRanges();
+
+ SmallVector<Value *, 4> mins;
+ SmallVector<Value *, 4> maxes;
+ SmallVector<Value *, 4> steps;
+};
+
+RangeParts::RangeParts(unsigned reserved) {
+ mins.reserve(reserved);
+ maxes.reserve(reserved);
+ steps.reserve(reserved);
+}
+
+static SmallVector<Value *, 4>
+extractFromRanges(ArrayRef<Value *> ranges,
+ std::function<Value *(RangeOp)> extract) {
+ SmallVector<Value *, 4> res;
+ res.reserve(ranges.size());
+ for (auto *v : ranges) {
+ auto r = v->getDefiningOp()->cast<RangeOp>();
+ res.push_back(extract(r));
+ }
+ return res;
+}
+
+RangeParts::RangeParts(ArrayRef<Value *> ranges)
+ : mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
+ maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
+ steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
+
+SmallVector<Value *, 4> RangeParts::makeRanges() {
+ SmallVector<Value *, 4> res;
+ res.reserve(mins.size());
+ for (auto z : llvm::zip(mins, maxes, steps)) {
+ res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z)));
+ }
+ return res;
+}
+
+static RangeParts makeGenericRangeParts(AffineMap map,
+ ArrayRef<Value *> ranges) {
+ assert(map.getNumInputs() == ranges.size());
+ unsigned numDims = map.getNumDims();
+ assert(map.getNumSymbols() == 0);
+ assert(map.getRangeSizes().empty());
+
+ RangeParts res(map.getNumResults());
+ RangeParts rangeParts(ranges);
+ for (auto expr : map.getResults()) {
+ AffineMap map = AffineMap::get(numDims, 0, expr, {});
+ res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins));
+ res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes));
+ res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps));
+ }
+ return res;
+}
+
+SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
+ ArrayRef<Value *> ranges) {
+ return makeGenericRangeParts(map, ranges).makeRanges();
+}
+
+static SmallVector<Value *, 4> makeGenericLoopRanges(
+ AffineMap operandRangesToLoopsMap, ArrayRef<Value *> ranges,
+ llvm::Optional<ArrayRef<Value *>> tileSizes = llvm::None) {
+ RangeParts res = makeGenericRangeParts(operandRangesToLoopsMap, ranges);
+ if (!tileSizes.hasValue())
+ return res.makeRanges();
+ SmallVector<Value *, 4> tiledSteps;
+ for (auto z : llvm::zip(res.steps, *tileSizes)) {
+ auto *step = std::get<0>(z);
+ auto tileSize = std::get<1>(z);
+ auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+ auto tileSizeValue =
+ tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+ assert(stepValue > 0);
+ tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
+ }
+ res.steps = tiledSteps;
+ return res.makeRanges();
+}
+
+template <class ContractionOp>
+static SmallVector<mlir::AffineForOp, 4>
+writeAsLoops(ContractionOp contraction) {
+ ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
+ contraction.getLoc());
+ auto loopRanges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
+ getRanges(contraction));
+
+ SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
+ SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
+ auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs);
+ auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs);
+ assert(loopRanges.size() == pivs.size() + rivs.size());
+
+ // clang-format off
+ using linalg::common::LoopNestRangeBuilder;
+ ArrayRef<Value *> ranges(loopRanges);
+ LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
+ LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
+ [&contraction, ¶llelIvs, &reductionIvs]() {
+ SmallVector<mlir::Value *, 4> parallel(
+ parallelIvs.begin(), parallelIvs.end());
+ SmallVector<mlir::Value *, 4> reduction(
+ reductionIvs.begin(), reductionIvs.end());
+ contraction.emitScalarImplementation(parallel, reduction);
+ /// NestedBuilders expect handles, we thus return an IndexHandle.
+ return IndexHandle();
+ }()
+ })
+ });
+ // clang-format on
+
+ SmallVector<mlir::AffineForOp, 4> res;
+ res.reserve(pivs.size() + rivs.size());
+ for (auto iv : parallelIvs)
+ res.push_back(getForInductionVarOwner(iv.getValue()));
+ for (auto iv : reductionIvs)
+ res.push_back(getForInductionVarOwner(iv.getValue()));
+ return res;
+}
+
+void linalg::lowerToLoops(mlir::Function *f) {
+ f->walkPostOrder([](Operation *op) {
+ if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
+ writeAsLoops(matmulOp);
+ } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
+ writeAsLoops(matvecOp);
+ } else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
+ writeAsLoops(dotOp);
+ } else {
+ return;
+ }
+ op->erase();
+ });
+}