+++ /dev/null
-//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#ifndef LINALG2_TENSOROPS_INL_H_
-#define LINALG2_TENSOROPS_INL_H_
-
-#include "linalg2/Ops.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace linalg {
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getInputs() {
- auto *op = static_cast<ConcreteOp *>(this)->getOperation();
- return {op->operand_begin(), op->operand_begin() + getNumInputs()};
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
- auto *op = static_cast<ConcreteOp *>(this)->getOperation();
- return {op->operand_begin() + getNumInputs(),
- op->operand_begin() + getNumInputs() + getNumOutputs()};
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() {
- return {getInputs().begin(), getOutputs().end()};
-}
-
-template <class ConcreteOp>
-mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
- auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
- if (getNumInputs() <= 0)
- concreteOp->emitOpError("expected at least one input");
- if (getNumOutputs() <= 0)
- concreteOp->emitOpError("expected at least one output");
- if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
- concreteOp->emitOpError("expected " +
- llvm::Twine(getNumInputs() + getNumOutputs()) +
- " operands");
- }
- for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
- if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
- return concreteOp->emitOpError("operand " + llvm::Twine(i) +
- " not a ViewType");
- }
- for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
- ++i) {
- auto viewType =
- concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
- if (!viewType)
- return concreteOp->emitOpError("operand " + llvm::Twine(i) +
- " not a ViewType");
- if (viewType.getRank() != getNumParallelDims())
- return concreteOp->emitOpError("operand " + llvm::Twine(i) +
- " must be of rank " +
- llvm::Twine(getNumParallelDims()));
- }
- return mlir::success();
-}
-
-template <class ConcreteOp>
-bool linalg::TensorContractionBase<ConcreteOp>::parse(
- mlir::OpAsmParser *parser, mlir::OperationState *result) {
- llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
-}
-
-// A TensorContraction prints as:
-//
-// ```{.mlir}
-// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
-// ```
-//
-// for example:
-//
-// ```
-// linalg.matmul(%0, %1, %2) : view<?x?xf32>
-// ```
-//
-// Where %0, %1 and %2 are ssa-values of type ViewType.
-template <class ConcreteOp>
-void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
- *p << static_cast<ConcreteOp *>(this)->getOperationName() << "(";
- auto *last = *std::prev(getInputsAndOutputs().end());
- for (auto *i : getInputsAndOutputs()) {
- *p << *i << ((i == last) ? "" : ", ");
- }
- *p << ") : ";
- auto *lastOutput = *std::prev(getOutputs().end());
- for (auto *o : getOutputs()) {
- *p << o->getType() << ((o == lastOutput) ? "" : ",");
- }
-}
-
-} // namespace linalg
-
-#endif // LINALG2_TENSOROPS_INL_H_
/// A generic TensorContraction base class which captures the generic behavior
/// of tensor contraction operations (with broadcast).
-template <class ConcreteOp> class TensorContractionBase {
-protected:
- using TensorContractionBaseType = TensorContractionBase<ConcreteOp>;
-
- //////////////////////////////////////////////////////////////////////////////
- // Hooks to customize the behavior of this op.
- //////////////////////////////////////////////////////////////////////////////
- /// Generic implementation of hooks that should be called from `ConcreteType`s
- mlir::LogicalResult verify();
- static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
- void print(mlir::OpAsmPrinter *p);
-
+class TensorContractionBase {
public:
+ virtual ~TensorContractionBase() {}
+
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
- TensorContractionBase() = default;
+ virtual llvm::StringRef getTensorContractionName() = 0;
mlir::Operation::operand_range getInputs();
mlir::Operation::operand_range getOutputs();
- mlir::Operation::operand_range getInputsAndOutputs();
+ mlir::Operation::operand_range getInputsAndOutputs() {
+ return {getInputs().begin(), getOutputs().end()};
+ }
/// These are better as methods calling into the ConcreteOp instead of
/// template parameters because methods allow more generic behavior and avoid
/// specializing for number of arguments. All derived classes have
/// `VariadicOperands` and a build method from both an ArrayRef<mlirValue*>
/// and the proper number of mlir::Value*.
- unsigned getNumInputs() {
- return static_cast<ConcreteOp *>(this)->numInputs;
- };
- unsigned getNumOutputs() {
- return static_cast<ConcreteOp *>(this)->numOutputs;
- };
- unsigned getNumParallelDims() {
- return static_cast<ConcreteOp *>(this)->numParallelDims;
- };
- unsigned getNumReductionDims() {
- return static_cast<ConcreteOp *>(this)->numReductionDims;
- };
+ virtual unsigned getNumInputs() = 0;
+ virtual unsigned getNumOutputs() = 0;
+ virtual unsigned getNumParallelDims() = 0;
+ virtual unsigned getNumReductionDims() = 0;
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
: getOutputView(viewIndex - getNumInputs());
}
+ /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+ /// loop over matvec). Does nothing by default.
+ virtual void writeAsFinerGrainTensorContraction() {}
+
/// Each op is responsible for declaring how it lowers itself to scalar form,
/// given the enclosing parallel and reduction induction variables.
/// `emitScalarImplementation` emits the scalar IR for the op in the nesting
/// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
/// `parallel.back()`).
- void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
- llvm::ArrayRef<mlir::Value *> reductionIvs);
+ virtual void
+ emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs) {}
/// Represents a mapping from the loops to all the ranges of the operands.
/// The operands and their ranges are in the order defined by the particular
/// it explicitly is not expensive and generalizes to cases where an analysis
/// is not available. For details, see the description of
/// loopsToOperandRangeMaps in each ConcreteOp.
- llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+ virtual llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() {
+ return llvm::SmallVector<mlir::AffineMap, 8>();
+ }
};
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
-class DotOp : public TensorContractionBase<DotOp>,
+class DotOp : public TensorContractionBase,
public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> {
public:
using Op::Op;
- using TensorContractionBaseType =
- TensorContractionBase::TensorContractionBaseType;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
- static constexpr unsigned numInputs = 2;
- static constexpr unsigned numOutputs = 1;
- static constexpr unsigned numParallelDims = 0;
- static constexpr unsigned numReductionDims = 1;
+ llvm::StringRef getTensorContractionName() override {
+ return getOperationName();
+ }
+ unsigned getNumInputs() override { return 2; }
+ unsigned getNumOutputs() override { return 1; }
+ unsigned getNumParallelDims() override { return 0; }
+ unsigned getNumReductionDims() override { return 1; }
+#if LINALG_STEP > 2
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
//////////////////////////////////////////////////////////////////////////////
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
- void writeAsFinerGrainTensorContraction();
+ void writeAsFinerGrainTensorContraction() override;
/// Inputs to this map will be (%k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(K), B(K), C() is:
/// (d0) -> (d0, d0)(%k)
/// And the operands ranges are:
/// (%k, %k)
- llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+ llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
/// to:
/// cond = (r_i == zero)
/// scalarC = select(cond, zerof, C[]);
/// C[] = scalarC + A[r_i] * B[r_i];
- void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
- llvm::ArrayRef<mlir::Value *> reductionIvs);
+ void
+ emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs) override;
+#endif // LINALG_STEP
};
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
-class MatvecOp : public TensorContractionBase<MatvecOp>,
+class MatvecOp : public TensorContractionBase,
public mlir::Op<MatvecOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> {
public:
using Op::Op;
- using TensorContractionBaseType =
- TensorContractionBase::TensorContractionBaseType;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
- static constexpr unsigned numInputs = 2;
- static constexpr unsigned numOutputs = 1;
- static constexpr unsigned numParallelDims = 1;
- static constexpr unsigned numReductionDims = 1;
+ llvm::StringRef getTensorContractionName() override {
+ return getOperationName();
+ }
+ unsigned getNumInputs() override { return 2; }
+ unsigned getNumOutputs() override { return 1; }
+ unsigned getNumParallelDims() override { return 1; }
+ unsigned getNumReductionDims() override { return 1; }
+#if LINALG_STEP > 2
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
//////////////////////////////////////////////////////////////////////////////
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
- void writeAsFinerGrainTensorContraction();
+ void writeAsFinerGrainTensorContraction() override;
/// Inputs to this map will be (%m, %k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
/// And the operands ranges are:
/// (%m, %k, %k, %m)
- llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+ llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
/// loop with iv `r_j`, emits MLIR corresponding to:
/// cond = (r_j == zero)
/// scalarC = select(cond, zerof, C(i));
/// C(i) = scalarC + A(i, r_j) * B(r_j);
- void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
- llvm::ArrayRef<mlir::Value *> reductionIvs);
+ void
+ emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs) override;
+#endif // LINALG_STEP
};
/// Implements C = A * B on 2-D matrices.
-class MatmulOp : public TensorContractionBase<MatmulOp>,
+class MatmulOp : public TensorContractionBase,
public mlir::Op<MatmulOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> {
public:
using Op::Op;
- using TensorContractionBaseType =
- TensorContractionBase::TensorContractionBaseType;
//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
- static constexpr unsigned numInputs = 2;
- static constexpr unsigned numOutputs = 1;
- static constexpr unsigned numParallelDims = 2;
- static constexpr unsigned numReductionDims = 1;
+ llvm::StringRef getTensorContractionName() override {
+ return getOperationName();
+ }
+ unsigned getNumInputs() override { return 2; }
+ unsigned getNumOutputs() override { return 1; }
+ unsigned getNumParallelDims() override { return 2; }
+ unsigned getNumReductionDims() override { return 1; }
+#if LINALG_STEP > 2
//////////////////////////////////////////////////////////////////////////////
// Used in Linalg3 and later.
//////////////////////////////////////////////////////////////////////////////
/// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
/// loop over matvec). Does nothing by default.
- void writeAsFinerGrainTensorContraction();
+ void writeAsFinerGrainTensorContraction() override;
/// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
/// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
/// And the operands ranges are:
/// (%m, %k, %k, %n, %m, %n)
- llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+ llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
/// cond = (r_k == zero)
/// scalarC = select(cond, zerof, C[i, j]);
/// C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
- void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
- llvm::ArrayRef<mlir::Value *> reductionIvs);
+ void
+ emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+ llvm::ArrayRef<mlir::Value *> reductionIvs) override;
+#endif // LINALG_STEP
};
} // namespace linalg
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#include "linalg2/TensorOps-inl.h"
-
#endif // LINALG2_TENSOROPS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
-using llvm::ArrayRef;
-using llvm::Twine;
-
using namespace mlir;
using namespace linalg;
+#define TENSOR_CONTRACTION_DISPATCH(FUNCTION_NAME) \
+ if (getTensorContractionName() == MatmulOp::getOperationName()) { \
+ return FUNCTION_NAME(static_cast<MatmulOp &>(*this)); \
+ } \
+ if (getTensorContractionName() == MatvecOp::getOperationName()) { \
+ return FUNCTION_NAME(static_cast<MatvecOp &>(*this)); \
+ } \
+ if (getTensorContractionName() == DotOp::getOperationName()) { \
+ return FUNCTION_NAME(static_cast<DotOp &>(*this)); \
+ } \
+ llvm_unreachable("Missing linalg op");
+
+template <typename ConcreteOp>
+static mlir::Operation::operand_range getInputs(ConcreteOp &concreteOp) {
+ return {concreteOp.operand_begin(),
+ concreteOp.operand_begin() + concreteOp.getNumInputs()};
+}
+
+mlir::Operation::operand_range linalg::TensorContractionBase::getInputs() {
+ TENSOR_CONTRACTION_DISPATCH(::getInputs);
+}
+
+template <typename ConcreteOp>
+static mlir::Operation::operand_range getOutputs(ConcreteOp &concreteOp) {
+ return {concreteOp.operand_begin() + concreteOp.getNumInputs(),
+ concreteOp.operand_begin() + concreteOp.getNumInputs() +
+ concreteOp.getNumOutputs()};
+}
+
+mlir::Operation::operand_range linalg::TensorContractionBase::getOutputs() {
+ TENSOR_CONTRACTION_DISPATCH(::getOutputs);
+}
+
+template <typename LinalgOp>
+static mlir::LogicalResult verifyLinalgOp(LinalgOp op) {
+ if (op.getNumInputs() <= 0)
+ op.emitOpError("expected at least one input");
+ if (op.getNumOutputs() <= 0)
+ op.emitOpError("expected at least one output");
+ if (op.getNumOperands() != op.getNumInputs() + op.getNumOutputs()) {
+ op.emitOpError("expected " +
+ llvm::Twine(op.getNumInputs() + op.getNumOutputs()) +
+ " operands");
+ }
+ for (unsigned i = 0, e = op.getNumInputs(); i < e; ++i) {
+ if (!op.getOperand(i)->getType().template isa<ViewType>())
+ return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
+ }
+ for (unsigned i = op.getNumInputs(),
+ e = op.getNumInputs() + op.getNumOutputs();
+ i < e; ++i) {
+ auto viewType = op.getOperand(i)->getType().template dyn_cast<ViewType>();
+ if (!viewType)
+ return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
+ if (viewType.getRank() != op.getNumParallelDims())
+ return op.emitOpError("operand " + llvm::Twine(i) + " must be of rank " +
+ llvm::Twine(op.getNumParallelDims()));
+ }
+ return mlir::success();
+}
+
+// A TensorContraction prints as:
+//
+// ```{.mlir}
+// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
+// ```
+//
+// for example:
+//
+// ```
+// linalg.matmul(%0, %1, %2) : view<?x?xf32>
+// ```
+//
+// Where %0, %1 and %2 are ssa-values of type ViewType.
+template <typename LinalgOp>
+static void printLinalgOp(mlir::OpAsmPrinter *p, LinalgOp op) {
+ *p << op.getOperationName() << "(";
+ auto *last = *std::prev(op.getInputsAndOutputs().end());
+ for (auto *i : op.getInputsAndOutputs()) {
+ *p << *i << ((i == last) ? "" : ", ");
+ }
+ *p << ") : ";
+ auto *lastOutput = *std::prev(op.getOutputs().end());
+ for (auto *o : op.getOutputs()) {
+ *p << o->getType() << ((o == lastOutput) ? "" : ",");
+ }
+}
+
//////////////////////////////////////////////////////////////////////////////
// Op-specific Dot.
//////////////////////////////////////////////////////////////////////////////
}
LogicalResult linalg::DotOp::verify() {
- if (failed(TensorContractionBaseType::verify()))
+ if (failed(verifyLinalgOp(*this)))
return failure();
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
unsigned index = 0;
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::DotOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
- return TensorContractionBaseType::parse(parser, result);
+ llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
}
-void linalg::DotOp::print(mlir::OpAsmPrinter *p) {
- TensorContractionBaseType::print(p);
-}
+void linalg::DotOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
//////////////////////////////////////////////////////////////////////////////
// Op-specific Matvec.
}
LogicalResult linalg::MatvecOp::verify() {
- if (failed(TensorContractionBaseType::verify()))
+ if (failed(verifyLinalgOp(*this)))
return failure();
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
if (getViewRank(A) != 2)
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
- return TensorContractionBaseType::parse(parser, result);
+ llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
}
-void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) {
- TensorContractionBaseType::print(p);
-}
+void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
//////////////////////////////////////////////////////////////////////////////
// Op-specific Matmul.
}
LogicalResult linalg::MatmulOp::verify() {
- if (failed(TensorContractionBaseType::verify()))
+ if (failed(verifyLinalgOp(*this)))
return failure();
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
unsigned index = 0;
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
- return TensorContractionBaseType::parse(parser, result);
+ llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
}
-void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) {
- TensorContractionBaseType::print(p);
-}
+void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
+++ /dev/null
-//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#ifndef LINALG3_TENSOROPS_INL_H_
-#define LINALG3_TENSOROPS_INL_H_
-
-#include "linalg1/Common.h"
-#include "linalg1/Utils.h"
-#include "linalg2/TensorOps.h"
-#include "linalg3/Analysis.h"
-#include "linalg3/Ops.h"
-
-template <class ConcreteOp>
-mlir::Value *
-linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
- return *(getInputs().begin() + viewIndex);
-}
-
-template <class ConcreteOp>
-mlir::Value *
-linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
- return *(getOutputs().begin() + viewIndex);
-}
-
-template <class ConcreteOp>
-llvm::SmallVector<mlir::AffineMap, 8>
-linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
- return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
-}
-
-template <class ConcreteOp>
-void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
- llvm::ArrayRef<mlir::Value *> parallelIvs,
- llvm::ArrayRef<mlir::Value *> reductionIvs) {
- static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
- reductionIvs);
-}
-
-template <class ConcreteOp>
-mlir::AffineMap linalg::operandRangesToLoopsMap(
- linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
- mlir::AffineMap current;
- // Individual submaps may not be invertible but their union must be invertible
- // by construction.
- for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
- if (!m)
- continue;
- if (!current) {
- current = m;
- continue;
- }
- llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
- current.getResults().end());
- results.append(m.getResults().begin(), m.getResults().end());
- current = mlir::AffineMap::get(
- std::max(current.getNumDims(), m.getNumDims()),
- current.getNumSymbols() + m.getNumSymbols(), results, {});
- }
- return inverseSubMap(current);
-}
-
-// Extract the ranges from a given ViewOp or SliceOp.
-//
-// In the case of a ViewOp, things are simple: just traverse the indexings and
-// get all the ranges (i.e. drop the indices).
-//
-// In the case of a SliceOp, things are trickier because we need to handle a
-// potential rank-reduction:
-// 1. Examine the indexing to determine if it is rank-reducing.
-// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
-// that `d >= slicingDim`. This is to account for the rank reduction.
-// `getRootIndex` is then called on the **parent** view
-static llvm::SmallVector<mlir::Value *, 8>
-extractRangesFromViewOrSliceOp(mlir::Value *view) {
- // This expects a viewType which must come from either ViewOp or SliceOp.
- assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
- if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
- return viewOp.getRanges();
-
- auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
- unsigned slicingDim = sliceOp.getSlicingDim();
- auto *indexing = *(sliceOp.getIndexings().begin());
- bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
- unsigned offset = 0;
- llvm::SmallVector<mlir::Value *, 8> res;
- res.reserve(sliceOp.getRank());
- for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
- if (d == slicingDim && isRankReducing)
- offset = 1;
- auto *parentView = sliceOp.getParentView();
- auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
- res.push_back(indexingPosPair.first);
- }
- return res;
-}
-
-template <class ConcreteOp>
-static llvm::SmallVector<mlir::Value *, 8>
-getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
- llvm::SmallVector<mlir::Value *, 8> res;
- for (auto *in : tensorContraction.getInputs()) {
- auto subres = extractRangesFromViewOrSliceOp(in);
- res.append(subres.begin(), subres.end());
- }
- return res;
-}
-
-template <class ConcreteOp>
-static llvm::SmallVector<mlir::Value *, 8>
-getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
- llvm::SmallVector<mlir::Value *, 8> res;
- for (auto *out : tensorContraction.getOutputs()) {
- auto subres = extractRangesFromViewOrSliceOp(out);
- res.append(subres.begin(), subres.end());
- }
- return res;
-}
-
-template <class ConcreteOp>
-llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
- linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
- llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
- llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
- res.append(tmp.begin(), tmp.end());
- return res;
-}
-
-#endif // LINALG3_TENSOROPS_INL_H_
/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
/// map ranges to enclosing loops for all the operands' ranges.
-template <class ConcreteOp>
-mlir::AffineMap operandRangesToLoopsMap(
- linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+mlir::AffineMap
+operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction);
/// Takes a `tensorContraction` and returns the ranges of all its operands.
/// When an operand comes from a ViewOp, things are simple:
/// In the case of a SliceOp, things are more involved because we need to handle
/// potential rank-reductions.
/// This function abstracts this complexity away and returns all the ranges.
-template <class ConcreteOp>
llvm::SmallVector<mlir::Value *, 8>
-getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+getRanges(linalg::TensorContractionBase &tensorContraction);
} // namespace linalg
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#include "linalg3/TensorOps-inl.h"
-
#endif // LINALG3_TENSOROPS_H_
#include "mlir/IR/Types.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/LLVMIR/Transforms.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
#include "linalg1/ConvertToLLVMDialect.h"
#include "linalg1/LLVMIntrinsics.h"
//
//===----------------------------------------------------------------------===//
-#include "linalg1/Analysis.h"
#include "linalg1/Common.h"
+#include "linalg3/Analysis.h"
#include "linalg3/Intrinsics.h"
#include "linalg3/Ops.h"
#include "mlir/IR/Builders.h"
using namespace linalg;
using namespace linalg::intrinsics;
+mlir::Value *linalg::TensorContractionBase::getInputView(unsigned viewIndex) {
+ return *(getInputs().begin() + viewIndex);
+}
+
+mlir::Value *linalg::TensorContractionBase::getOutputView(unsigned viewIndex) {
+ return *(getOutputs().begin() + viewIndex);
+}
+
+mlir::AffineMap linalg::operandRangesToLoopsMap(
+ linalg::TensorContractionBase &tensorContraction) {
+ mlir::AffineMap current;
+ // Individual submaps may not be invertible but their union must be invertible
+ // by construction.
+ for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
+ if (!m)
+ continue;
+ if (!current) {
+ current = m;
+ continue;
+ }
+ llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
+ current.getResults().end());
+ results.append(m.getResults().begin(), m.getResults().end());
+ current = mlir::AffineMap::get(
+ std::max(current.getNumDims(), m.getNumDims()),
+ current.getNumSymbols() + m.getNumSymbols(), results, {});
+ }
+ return inverseSubMap(current);
+}
+
+// Extract the ranges from a given ViewOp or SliceOp.
+//
+// In the case of a ViewOp, things are simple: just traverse the indexings and
+// get all the ranges (i.e. drop the indices).
+//
+// In the case of a SliceOp, things are trickier because we need to handle a
+// potential rank-reduction:
+// 1. Examine the indexing to determine if it is rank-reducing.
+// 2. If it is rank-reducing, an offset of 1 is added to the dimensions such
+// that `d >= slicingDim`. This is to account for the rank reduction.
+// `getRootIndex` is then called on the **parent** view
+static llvm::SmallVector<mlir::Value *, 8>
+extractRangesFromViewOrSliceOp(mlir::Value *view) {
+ // This expects a viewType which must come from either ViewOp or SliceOp.
+ assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
+ if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
+ return viewOp.getRanges();
+
+ auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+ unsigned slicingDim = sliceOp.getSlicingDim();
+ auto *indexing = *(sliceOp.getIndexings().begin());
+ bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
+ unsigned offset = 0;
+ llvm::SmallVector<mlir::Value *, 8> res;
+ res.reserve(sliceOp.getRank());
+ for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
+ if (d == slicingDim && isRankReducing)
+ offset = 1;
+ auto *parentView = sliceOp.getParentView();
+ auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
+ res.push_back(indexingPosPair.first);
+ }
+ return res;
+}
+
+static llvm::SmallVector<mlir::Value *, 8>
+getInputRanges(linalg::TensorContractionBase &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *in : tensorContraction.getInputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(in);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+static llvm::SmallVector<mlir::Value *, 8>
+getOutputRanges(linalg::TensorContractionBase &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res;
+ for (auto *out : tensorContraction.getOutputs()) {
+ auto subres = extractRangesFromViewOrSliceOp(out);
+ res.append(subres.begin(), subres.end());
+ }
+ return res;
+}
+
+llvm::SmallVector<mlir::Value *, 8>
+linalg::getRanges(linalg::TensorContractionBase &tensorContraction) {
+ llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
+ llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
+ res.append(tmp.begin(), tmp.end());
+ return res;
+}
+
//////////////////////////////////////////////////////////////////////////////
// Implementation of DotOp.
//////////////////////////////////////////////////////////////////////////////
//===----------------------------------------------------------------------===//
#include "linalg3/Transforms.h"
+#include "linalg1/Common.h"
#include "linalg2/Intrinsics.h"
#include "linalg3/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
//===----------------------------------------------------------------------===//
#include "linalg4/Transforms.h"
+#include "linalg1/Common.h"
+#include "linalg1/Utils.h"
#include "linalg3/Intrinsics.h"
#include "linalg3/TensorOps.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/LoopUtils.h"
v->getDefiningOp()->dyn_cast<ConstantIndexOp>().getValue() == 0;
}
-template <typename ConcreteOp>
static llvm::SmallVector<Value *, 4>
-makeTiledRanges(TensorContractionBase<ConcreteOp> &contraction,
- ArrayRef<Value *> allRanges, llvm::ArrayRef<Value *> ivs,
+makeTiledRanges(TensorContractionBase &contraction, ArrayRef<Value *> allRanges,
+ llvm::ArrayRef<Value *> ivs,
llvm::ArrayRef<Value *> tileSizes) {
assert(ivs.size() == tileSizes.size());
if (ivs.empty())
return RangeParts(allRanges).makeRanges();
- auto *op = static_cast<ConcreteOp *>(&contraction);
RangeParts result(allRanges.size());
RangeParts rangeParts(allRanges);
- for (auto map : op->loopsToOperandRangeMaps()) {
+ for (auto map : contraction.loopsToOperandRangeMaps()) {
// 1. Take the first ivs results of the map, the other ones are not composed
// but merely copied over.
assert(map.getNumSymbols() == 0);
assert(map.getRangeSizes().empty());
MLIRContext *context = ScopedContext::getContext();
- unsigned numParallel = op->getNumParallelDims();
- unsigned numReduction = op->getNumReductionDims();
+ unsigned numParallel = contraction.getNumParallelDims();
+ unsigned numReduction = contraction.getNumReductionDims();
if (ivs.size() < numParallel + numReduction) {
// Inject zeros in positions that are not tiled.
SmallVector<AffineExpr, 4> dimReplacements(numParallel + numReduction);
return result.makeRanges();
}
-template <class ConcreteOp>
static SmallVector<Value *, 4>
-makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
+makeTiledViews(linalg::TensorContractionBase &contraction,
ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes) {
auto tiledRanges =
makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes);
return res;
}
-template <class ConcreteOp>
+template <typename ConcreteOp>
static SmallVector<mlir::AffineForOp, 8>
-writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
+writeContractionAsTiledViews(ConcreteOp &contraction,
ArrayRef<Value *> tileSizes) {
assert(tileSizes.size() <=
contraction.getNumParallelDims() + contraction.getNumReductionDims());
- auto *op = static_cast<ConcreteOp *>(&contraction);
- ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc());
+ ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
+ contraction.getLoc());
SmallVector<IndexHandle, 4> ivs(tileSizes.size());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);