using constant_index = ValueBuilder<ConstantIndexOp>;
using constant_int = ValueBuilder<ConstantIntOp>;
using dealloc = OperationBuilder<DeallocOp>;
+using dim = ValueBuilder<DimOp>;
using load = ValueBuilder<LoadOp>;
using ret = OperationBuilder<ReturnOp>;
using select = ValueBuilder<SelectOp>;
std::pair<mlir::Value *, unsigned> getViewRootIndexing(mlir::Value *view,
unsigned dim);
-////////////////////////////////////////////////////////////////////////////////
-/// Helper functions to avoid dispatching at all client sites.
-////////////////////////////////////////////////////////////////////////////////
-/// Asserts `view` is of ViewType and returns its rank.
-unsigned getViewRank(mlir::Value *view);
-
} // namespace linalg
#endif // LINALG1_ANALYSIS_H_
--- /dev/null
+//===- Utils.h - Linalg dialect utility functions definitions -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_UTILS_H_
+#define LINALG1_UTILS_H_
+
+namespace mlir {
+class Value;
+} // namespace mlir
+
+namespace linalg {
+
+/// Asserts `view` is of ViewType and returns its rank.
+unsigned getViewRank(mlir::Value *view);
+
+} // namespace linalg
+
+#endif // LINALG1_UTILS_H_
unsigned parentDim = dim > sliceDim ? dim + 1 : dim;
return getViewRootIndexing(parentView, parentDim);
}
-
-////////////////////////////////////////////////////////////////////////////////
-/// Helper functions to avoid dispatch at all client sites.
-////////////////////////////////////////////////////////////////////////////////
-unsigned linalg::getViewRank(Value *view) {
- assert(view->getType().isa<ViewType>() && "expected a ViewType");
- if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>())
- return viewOp.getRank();
- return view->getDefiningOp()->dyn_cast<SliceOp>().getRank();
-}
#include "linalg1/Analysis.h"
#include "linalg1/Ops.h"
#include "linalg1/Types.h"
+#include "linalg1/Utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
--- /dev/null
+//===- Utils.cpp - Implementation of utiliy functions for Linalg ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements utility functions for the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg1/Utils.h"
+#include "linalg1/Ops.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace linalg;
+
+unsigned linalg::getViewRank(Value *view) {
+ assert(view->getType().isa<ViewType>() && "expected a ViewType");
+ if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>())
+ return viewOp.getRank();
+ return view->getDefiningOp()->dyn_cast<SliceOp>().getRank();
+}
--- /dev/null
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG2_TENSOROPS_INL_H_
+#define LINALG2_TENSOROPS_INL_H_
+
+#include "linalg2/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace linalg {
+
+template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getInputs() {
+ auto *op = static_cast<ConcreteOp *>(this)->getOperation();
+ return {op->operand_begin(), op->operand_begin() + getNumInputs()};
+}
+
+template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
+ auto *op = static_cast<ConcreteOp *>(this)->getOperation();
+ return {op->operand_begin() + getNumInputs(),
+ op->operand_begin() + getNumInputs() + getNumOutputs()};
+}
+
+template <class ConcreteOp>
+mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
+ auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
+ if (getNumInputs() <= 0)
+ concreteOp->emitOpError("expected at least one input");
+ if (getNumOutputs() <= 0)
+ concreteOp->emitOpError("expected at least one output");
+ if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
+ concreteOp->emitOpError("expected " +
+ llvm::Twine(getNumInputs() + getNumOutputs()) +
+ " operands");
+ }
+ for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
+ if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
+ return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+ " not a ViewType");
+ }
+ for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
+ ++i) {
+ auto viewType =
+ concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
+ if (!viewType)
+ return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+ " not a ViewType");
+ if (viewType.getRank() != getNumParallelDims())
+ return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+ " must be of rank " +
+ llvm::Twine(getNumParallelDims()));
+ }
+ return mlir::success();
+}
+
+template <class ConcreteOp>
+bool linalg::TensorContractionBase<ConcreteOp>::parse(
+ mlir::OpAsmParser *parser, mlir::OperationState *result) {
+ llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
+}
+
+// A TensorContraction prints as:
+//
+// ```{.mlir}
+// concrete_op_name {%0, %1} -> {%2}
+// ```
+//
+// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a
+// view.
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
+ *p << static_cast<ConcreteOp *>(this)->getOperationName() << " {";
+ auto *lastInput = *std::prev(getInputs().end());
+ for (auto *i : getInputs()) {
+ *p << *i << ((i == lastInput) ? "} -> {" : ", ");
+ }
+ auto *lastOutput = *std::prev(getOutputs().end());
+ for (auto *o : getOutputs()) {
+ *p << *o << ((o == lastOutput) ? "}" : ",");
+ }
+}
+
+} // namespace linalg
+
+#endif // LINALG2_TENSOROPS_INL_H_
// limitations under the License.
// =============================================================================
-#ifndef LINALG2_MATMULOP_H_
-#define LINALG2_MATMULOP_H_
+#ifndef LINALG2_TENSOROPS_H_
+#define LINALG2_TENSOROPS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
+namespace mlir {
+class AffineForOp;
+} // namespace mlir
+
namespace linalg {
/// A generic TensorContraction base class which captures the generic behavior
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
TensorContractionBase() = default;
-
- mlir::Type getInputElementType(unsigned i);
- mlir::Type getOutputElementType(unsigned i);
- mlir::Value *getInputView(unsigned i);
- mlir::Value *getOutputView(unsigned i);
- mlir::Value *getInputMemRef(unsigned i);
- mlir::Value *getOutputMemRef(unsigned i);
mlir::Operation::operand_range getInputs();
mlir::Operation::operand_range getOutputs();
unsigned getNumReductionDims() {
return static_cast<ConcreteOp *>(this)->numReductionDims;
};
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ mlir::Value *getInputView(unsigned i);
+ mlir::Value *getOutputView(unsigned i);
+ /// Computes a mapping from all the ranges of the operands to the enclosing
+ /// loops. In order to support "broadcast"-style semantics, we need to
+ /// consider all the operands (i.e. input operands are not sufficient).
+ /// The operands and their ranges are in the order defined by the particular
+ /// ConcreteOp implementation, the resulting map must match those.
+ /// This is currently computed but can also be specified explicitly in each
+ /// operator to generalize to cases where an analysis is not available.
+ mlir::AffineMap operandRangesToLoopsMap();
};
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
static constexpr unsigned numOutputs = 1;
static constexpr unsigned numParallelDims = 0;
static constexpr unsigned numReductionDims = 1;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+ /// loop over matvec). Does nothing by default.
+ void writeAsFinerGrainTensorContraction();
};
/// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
static constexpr unsigned numOutputs = 1;
static constexpr unsigned numParallelDims = 1;
static constexpr unsigned numReductionDims = 1;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+ /// loop over matvec). Does nothing by default.
+ void writeAsFinerGrainTensorContraction();
};
/// Implements C = A * B on 2-D matrices.
static constexpr unsigned numOutputs = 1;
static constexpr unsigned numParallelDims = 2;
static constexpr unsigned numReductionDims = 1;
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Used in Linalg3 and later.
+ //////////////////////////////////////////////////////////////////////////////
+ /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+ /// loop over matvec). Does nothing by default.
+ void writeAsFinerGrainTensorContraction();
};
+
} // namespace linalg
-#endif // LINALG2_MATMULOP_H_
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#include "linalg2/TensorOps-inl.h"
+
+#endif // LINALG2_TENSOROPS_H_
//
//===----------------------------------------------------------------------===//
-#include "linalg2/Analysis.h"
+#include "linalg1/Utils.h"
#include "linalg2/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
using namespace mlir;
using namespace linalg;
-template <class ConcreteOp>
-Type linalg::TensorContractionBase<ConcreteOp>::getInputElementType(
- unsigned idx) {
- return getInputView(idx)
- ->getType()
- .template cast<ViewType>()
- .getElementType();
-}
-
-template <class ConcreteOp>
-Type linalg::TensorContractionBase<ConcreteOp>::getOutputElementType(
- unsigned idx) {
- return getOutputView(idx)
- ->getType()
- .template cast<ViewType>()
- .getElementType();
-}
-
-template <class ConcreteOp>
-Value *linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned idx) {
- return *(getInputs().begin() + idx);
-}
-
-template <class ConcreteOp>
-Value *linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned idx) {
- return *(getOutputs().begin() + idx);
-}
-
-template <class ConcreteOp>
-Value *linalg::TensorContractionBase<ConcreteOp>::getInputMemRef(unsigned idx) {
- return getViewSupportingMemRef(*(getInputs().begin() + idx));
-}
-
-template <class ConcreteOp>
-Value *
-linalg::TensorContractionBase<ConcreteOp>::getOutputMemRef(unsigned idx) {
- return getViewSupportingMemRef(*(getOutputs().begin() + idx));
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getInputs() {
- auto *op = static_cast<ConcreteOp *>(this)->getOperation();
- return {op->operand_begin(), op->operand_begin() + getNumInputs()};
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
- auto *op = static_cast<ConcreteOp *>(this)->getOperation();
- return {op->operand_begin() + getNumInputs(),
- op->operand_begin() + getNumInputs() + getNumOutputs()};
-}
-
-template <class ConcreteOp>
-LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
- auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
- if (getNumInputs() <= 0)
- concreteOp->emitOpError("expected at least one input");
- if (getNumOutputs() <= 0)
- concreteOp->emitOpError("expected at least one output");
- if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
- concreteOp->emitOpError(
- "expected " + Twine(getNumInputs() + getNumOutputs()) + " operands");
- }
- for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
- if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
- return concreteOp->emitOpError("operand " + Twine(i) + " not a ViewType");
- }
- for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
- ++i) {
- auto viewType =
- concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
- if (!viewType)
- return concreteOp->emitOpError("operand " + Twine(i) + " not a ViewType");
- if (viewType.getRank() != getNumParallelDims())
- return concreteOp->emitOpError("operand " + Twine(i) +
- " must be of rank " +
- Twine(getNumParallelDims()));
- }
- return success();
-}
-
-template <class ConcreteOp>
-bool linalg::TensorContractionBase<ConcreteOp>::parse(OpAsmParser *parser,
- OperationState *result) {
- llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
-}
-
-// A TensorContraction prints as:
-//
-// ```{.mlir}
-// concrete_op_name {%0, %1} -> {%2}
-// ```
-//
-// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a
-// view.
-template <class ConcreteOp>
-void linalg::TensorContractionBase<ConcreteOp>::print(OpAsmPrinter *p) {
- *p << static_cast<ConcreteOp *>(this)->getOperationName() << " {";
- auto *lastInput = *std::prev(getInputs().end());
- for (auto *i : getInputs()) {
- *p << *i << ((i == lastInput) ? "} -> {" : ", ");
- }
- auto *lastOutput = *std::prev(getOutputs().end());
- for (auto *o : getOutputs()) {
- *p << *o << ((o == lastOutput) ? "}" : ",");
- }
-}
-
//////////////////////////////////////////////////////////////////////////////
// Op-specific Dot.
//////////////////////////////////////////////////////////////////////////////
--- /dev/null
+//===- Example.cpp - Our running example ----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+// RUN: %p/test | FileCheck %s
+
+#include "TestHarness.h"
+#include "linalg1/Common.h"
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+#include "linalg3/Transforms.h"
+#include "mlir/IR/OpImplementation.h"
+
+using llvm::StringRef;
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace linalg;
+using namespace linalg::common;
+using namespace linalg::intrinsics;
+
+Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+ MLIRContext *context = module.getContext();
+ auto dynamic2DMemRefType = floatMemRefType<2>(context);
+ mlir::Function *f = linalg::common::makeFunction(
+ module, name,
+ {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
+
+ ScopedContext scope(f);
+ // clang-format off
+ ValueHandle
+ M = dim(f->getArgument(0), 0),
+ N = dim(f->getArgument(2), 1),
+ K = dim(f->getArgument(0), 1),
+ rM = range(constant_index(0), M, constant_index(1)),
+ rN = range(constant_index(0), N, constant_index(1)),
+ rK = range(constant_index(0), K, constant_index(1)),
+ vA = view(f->getArgument(0), {rM, rK}),
+ vB = view(f->getArgument(1), {rK, rN}),
+ vC = view(f->getArgument(2), {rM, rN});
+ matmul(vA, vB, vC);
+ ret();
+ // clang-format on
+
+ return f;
+}
+
+TEST_FUNC(matmul_as_matvec) {
+ MLIRContext context;
+ Module module(&context);
+ mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
+ lowerToFinerGrainedTensorContraction(f);
+ // clang-format off
+ // CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ // CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+ // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
+ // CHECK-NEXT: %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+ // CHECK-NEXT: %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+ // CHECK-NEXT: linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]}
+ // clang-format on
+ cleanupAndPrintFunction(f);
+}
+
+TEST_FUNC(matmul_as_dot) {
+ MLIRContext context;
+ Module module(&context);
+ mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
+ lowerToFinerGrainedTensorContraction(f);
+ lowerToFinerGrainedTensorContraction(f);
+ // clang-format off
+ // CHECK-LABEL: func @matmul_as_dot(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+ // CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
+ // CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+ // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
+ // CHECK-NEXT: %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+ // CHECK-NEXT: %[[sC:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+ // CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
+ // CHECK-NEXT: %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view<f32>">
+ // CHECK-NEXT: %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>">
+ // CHECK-NEXT: linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
+ // clang-format on
+ cleanupAndPrintFunction(f);
+}
+
+int main() {
+ RUN_TESTS();
+ return 0;
+}
--- /dev/null
+//===- Ops.h - Linalg Ops single entry point ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_OPS_H_
+#define LINALG3_OPS_H_
+
+#include "linalg2/Ops.h"
+#include "linalg3/TensorOps.h"
+
+#endif // LINALG3_OPS_H_
--- /dev/null
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG3_TENSOROPS_INL_H_
+#define LINALG3_TENSOROPS_INL_H_
+
+#include "linalg1/Common.h"
+#include "linalg2/TensorOps.h"
+
+namespace linalg {
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned i) {
+ return *(getInputs().begin() + i);
+}
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned i) {
+ return *(getOutputs().begin() + i);
+}
+
+} // namespace linalg
+
+#endif // LINALG3_TENSOROPS-INL_H_
--- /dev/null
+//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_TENSOROPS_H_
+#define LINALG3_TENSOROPS_H_
+
+#include "linalg2/TensorOps.h"
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#include "linalg3/TensorOps-inl.h"
+
+#endif // LINALG3_TENSOROPS_H_
--- /dev/null
+//===- Transforms.h - Linalg dialect Transformations definition -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_TRANSFORMS_H_
+#define LINALG3_TRANSFORMS_H_
+
+#include "linalg2/Transforms.h"
+
+namespace mlir {
+class Function;
+} // namespace mlir
+
+namespace linalg {
+
+/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
+/// to only use linalg.view operations.
+void composeSliceOps(mlir::Function *f);
+
+/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot)
+/// as linalg.matvec (resp. linalg.dot, loop form).
+void lowerToFinerGrainedTensorContraction(mlir::Function *f);
+
+} // namespace linalg
+
+#endif // LINALG3_TRANSFORMS_H_
--- /dev/null
+//===- TensorOps.cpp - Implementation of the linalg TensorOps operation ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR operation to create new tensor computation
+// operations in the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg1/Analysis.h"
+#include "linalg1/Common.h"
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace linalg;
+using namespace linalg::intrinsics;
+
+// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
+// The body expression for dot is: C() = A(r_i) * B(r_i);
+// So we must drop the `i` loop from the matvec.
+void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
+ auto *op = getOperation();
+ ScopedContext scope(FuncBuilder(op), op->getLoc());
+ IndexHandle i;
+ auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
+ auto indexingPosPair = getViewRootIndexing(vA, 0);
+ assert(indexingPosPair.first->getDefiningOp() &&
+ indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
+ linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
+ dot(slice(vA, i, 0), vB, slice(vC, i, 0)),
+ });
+}
+
+// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
+// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
+// So we must drop the `j` loop from the matmul.
+// This is fine because parallel dimensions permute: we can just do it
+// declaratively.
+void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
+ auto *op = getOperation();
+ ScopedContext scope(FuncBuilder(op), op->getLoc());
+ IndexHandle j;
+ auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
+ auto indexingPosPair = getViewRootIndexing(vB, 1);
+ assert(indexingPosPair.first->getDefiningOp() &&
+ indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
+ linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
+ matvec(vA, slice(vB, j, 1), slice(vC, j, 1)),
+ });
+}
--- /dev/null
+//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements analyses and transformations for the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg3/Transforms.h"
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace linalg;
+using namespace linalg::intrinsics;
+
+void linalg::composeSliceOps(mlir::Function *f) {
+ f->walkPostOrder<SliceOp>([](SliceOp sliceOp) {
+ auto *sliceResult = sliceOp.getResult();
+ auto viewOp = createFullyComposedView(sliceResult);
+ sliceResult->replaceAllUsesWith(viewOp.getResult());
+ sliceOp.erase();
+ });
+}
+
+void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
+ f->walkPostOrder([](Operation *op) {
+ if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
+ matmulOp.writeAsFinerGrainTensorContraction();
+ } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
+ matvecOp.writeAsFinerGrainTensorContraction();
+ } else {
+ return;
+ }
+ op->erase();
+ });
+}