Start a Linalg dialect
authorNicolas Vasilache <ntv@google.com>
Thu, 18 Apr 2019 15:25:54 +0000 (08:25 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 18 Apr 2019 18:50:27 +0000 (11:50 -0700)
    This CL starts implementing a Linalg dialect with the objective of supporting
    optimizing compilation of loops and library calls for a subset of common linear
    algebra operations.

    This CL starts by simply adding a linalg.range type and an operation with the
    proper roundtripping test.

--

PiperOrigin-RevId: 244189468

mlir/examples/Linalg/Linalg1/include/linalg1/Types.h
mlir/include/mlir/Linalg/LinalgOps.h [new file with mode: 0644]
mlir/include/mlir/Linalg/LinalgTypes.h [new file with mode: 0644]
mlir/lib/CMakeLists.txt
mlir/lib/Linalg/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Linalg/LinalgOps.cpp [new file with mode: 0644]
mlir/lib/Linalg/LinalgRegistration.cpp [new file with mode: 0644]
mlir/lib/Linalg/LinalgTypes.cpp [new file with mode: 0644]
mlir/test/Linalg/roundtrip.mlir [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt

index b2fa7fd..5032e96 100644 (file)
@@ -23,9 +23,9 @@
 namespace linalg {
 
 enum LinalgTypes {
-  Range = mlir::Type::FIRST_LINALG_TYPE,
+  Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
   View,
-  LAST_USED_LINALG_TYPE = View,
+  FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View,
 };
 
 } // namespace linalg
diff --git a/mlir/include/mlir/Linalg/LinalgOps.h b/mlir/include/mlir/Linalg/LinalgOps.h
new file mode 100644 (file)
index 0000000..7921822
--- /dev/null
@@ -0,0 +1,53 @@
+//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_LINALGOPS_H_
+#define MLIR_LINALG_LINALGOPS_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+/// A RangeOp is used to create a value of RangeType from 3 values of type index
+/// that represent the min, max and step values of the range.
+class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
+                          OpTrait::OneResult, OpTrait::HasNoSideEffect> {
+public:
+  using Op::Op;
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Hooks to customize the behavior of this op.
+  //////////////////////////////////////////////////////////////////////////////
+  static llvm::StringRef getOperationName() { return "linalg.range"; }
+  static void build(Builder *b, OperationState *result, Value *min, Value *max,
+                    Value *step);
+  LogicalResult verify();
+  static bool parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Op-specific functionality.
+  //////////////////////////////////////////////////////////////////////////////
+  Value *min() { return getOperand(0); }
+  Value *max() { return getOperand(1); }
+  Value *step() { return getOperand(2); }
+};
+
+} // namespace mlir
+
+#endif // MLIR_LINALG_LINALGOPS_H_
diff --git a/mlir/include/mlir/Linalg/LinalgTypes.h b/mlir/include/mlir/Linalg/LinalgTypes.h
new file mode 100644 (file)
index 0000000..2d2c74e
--- /dev/null
@@ -0,0 +1,59 @@
+//===- LinalgTypes.h - Linalg Types ---------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_LINALGTYPES_H_
+#define MLIR_LINALG_LINALGTYPES_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+class MLIRContext;
+
+enum LinalgTypes {
+  Range = Type::FIRST_LINALG_TYPE,
+  LAST_USED_LINALG_TYPE = Range,
+};
+
+class LinalgDialect : public Dialect {
+public:
+  explicit LinalgDialect(MLIRContext *context);
+
+  /// Parse a type registered to this dialect.
+  Type parseType(llvm::StringRef spec, Location loc) const override;
+
+  /// Print a type registered to this dialect.
+  void printType(Type type, llvm::raw_ostream &os) const override;
+};
+
+/// A RangeType represents a minimal range abstraction (min, max, step).
+class RangeType : public Type::TypeBase<RangeType, Type> {
+public:
+  // Used for generic hooks in TypeBase.
+  using Base::Base;
+  /// Construction hook.
+  static RangeType get(MLIRContext *context) {
+    /// Custom, uniq'ed construction in the MLIRContext.
+    return Base::get(context, LinalgTypes::Range);
+  }
+  /// Used to implement llvm-style cast.
+  static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
+};
+
+} // namespace mlir
+
+#endif // MLIR_LINALG_LINALGTYPES_H_
index 3d05cd1..920cf79 100644 (file)
@@ -6,6 +6,7 @@ add_subdirectory(ExecutionEngine)
 add_subdirectory(FxpMathOps)
 add_subdirectory(IR)
 add_subdirectory(LLVMIR)
+add_subdirectory(Linalg)
 add_subdirectory(Parser)
 add_subdirectory(Pass)
 add_subdirectory(Quantization)
diff --git a/mlir/lib/Linalg/CMakeLists.txt b/mlir/lib/Linalg/CMakeLists.txt
new file mode 100644 (file)
index 0000000..b1df307
--- /dev/null
@@ -0,0 +1,8 @@
+add_llvm_library(MLIRLinalg
+  LinalgOps.cpp
+  LinalgRegistration.cpp
+  LinalgTypes.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
+  )
diff --git a/mlir/lib/Linalg/LinalgOps.cpp b/mlir/lib/Linalg/LinalgOps.cpp
new file mode 100644 (file)
index 0000000..bba47fb
--- /dev/null
@@ -0,0 +1,67 @@
+//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a the Linalg operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Linalg/LinalgOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Linalg/LinalgTypes.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+
+void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
+                          Value *max, Value *step) {
+  result->addOperands({min, max, step});
+  result->addTypes({RangeType::get(b->getContext())});
+}
+
+// Verification is simply that a RangeOp takes 3 index ssa-value.
+mlir::LogicalResult mlir::RangeOp::verify() {
+  if (!min() || !min()->getType().isa<IndexType>())
+    return emitOpError("first operand should be of type index");
+  if (!max() || !max()->getType().isa<IndexType>())
+    return emitOpError("second operand should be of type index");
+  if (!step() || !step()->getType().isa<IndexType>())
+    return emitOpError("third operand should be of type index");
+  return mlir::success();
+}
+
+// A RangeOp prints as:
+//
+// ```{.mlir}
+//   linalg.range %0:%1:%2 : !linalg.range
+// ```
+void mlir::RangeOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
+     << " : " << getType();
+}
+
+bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
+  RangeType type;
+  auto affineIntTy = parser->getBuilder().getIndexType();
+  return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
+         parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
+         parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
+         parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
+         parser->addTypeToList(type, result->types);
+}
diff --git a/mlir/lib/Linalg/LinalgRegistration.cpp b/mlir/lib/Linalg/LinalgRegistration.cpp
new file mode 100644 (file)
index 0000000..3637037
--- /dev/null
@@ -0,0 +1,24 @@
+//===- LinalgRegistration.cpp - Register the linalg dialect statically ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Linalg/LinalgOps.h"
+#include "mlir/Linalg/LinalgTypes.h"
+
+using namespace mlir;
+
+// Static initialization for LinalgOps dialect registration.
+static DialectRegistration<LinalgDialect> LinalgOps;
diff --git a/mlir/lib/Linalg/LinalgTypes.cpp b/mlir/lib/Linalg/LinalgTypes.cpp
new file mode 100644 (file)
index 0000000..7aabfd2
--- /dev/null
@@ -0,0 +1,53 @@
+//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the Linalg dialect types and dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Linalg/LinalgTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Linalg/LinalgOps.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+
+mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
+    : Dialect("linalg", context) {
+  addTypes<RangeType>();
+  addOperations<RangeOp>();
+}
+
+Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
+  MLIRContext *context = getContext();
+  if (spec == "range")
+    return RangeType::get(getContext());
+  return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
+}
+
+/// RangeType prints as just "range".
+static void print(RangeType rt, raw_ostream &os) { os << "range"; }
+
+void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
+  switch (type.getKind()) {
+  default:
+    llvm_unreachable("Unhandled Linalg type");
+  case LinalgTypes::Range:
+    print(type.cast<RangeType>(), os);
+    break;
+  }
+}
diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir
new file mode 100644 (file)
index 0000000..f98558a
--- /dev/null
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s
+
+func @range(%arg0: index, %arg1: index, %arg2: index) {
+  %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
+  return
+}
+// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
+//  CHECK-NEXT:  %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
\ No newline at end of file
index a68b30d..f6e9ac7 100644 (file)
@@ -3,6 +3,7 @@ set(LIBS
   MLIRAnalysis
   MLIREDSC
   MLIRFxpMathOps
+  MLIRLinalg
   MLIRLLVMIR
   MLIRParser
   MLIRPass