add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
+add_subdirectory(MLProgram)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
--- /dev/null
+add_subdirectory(IR)
--- /dev/null
+set(LLVM_TARGET_DEFINITIONS MLProgramOps.td)
+add_mlir_dialect(MLProgramOps ml_program)
+add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc)
--- /dev/null
+//===- MLProgram.h - MLProgram dialect ----------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
+#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// MLProgramDialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc"
+
+//===----------------------------------------------------------------------===//
+// MLProgram Dialect Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc"
+
+#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
--- /dev/null
+//===- MLProgramBase.td - Base defs for ml_program dialect --*- tablegen -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLPROGRAM_BASE
+#define MLPROGRAM_BASE
+
+include "mlir/IR/OpBase.td"
+
+def MLProgram_Dialect : Dialect {
+ let name = "ml_program";
+ let cppNamespace = "::mlir::ml_program";
+ let description = [{
+ The MLProgram dialect contains structural operations and types for
+ defining a compiled Machine-Learning program, as created from common
+ ML frameworks, such as TensorFlow, PyTorch, JAX, etc. It does not itself
+ define computation ops common to such frameworks but establishes a common
+ programming model for establishing modules, functions, globals and
+ memory model components appropriate for such an abstract level of detail.
+
+ This dialect is under active development, and while stability is an
+ eventual goal, it is not guaranteed at this juncture. Given the early state,
+ it is recommended to inquire further prior to using this dialect.
+ }];
+
+ let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+}
+
+#endif // MLPROGRAM_BASE
--- /dev/null
+//===- MLProgramOps.td - Structural ML Program Ops ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLPROGRAM_OPS
+#define MLPROGRAM_OPS
+
+include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/FunctionInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+class MLProgram_Op<string mnemonic, list<Trait> traits = []> :
+ Op<MLProgram_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_FuncOp : MLProgram_Op<"func", [
+ CallableOpInterface, FunctionOpInterface, IsolatedFromAbove,
+ RegionKindInterface, Symbol
+ ]> {
+ let summary = "Function containing a single `SSACFG` region";
+ let description = [{
+ This simple function container represents callables in an ML program where
+ the body is an `SSACFG` region. It must be terminated by a `return` op which
+ yields values with the same arity and types as the `FunctionType` results
+ of the containing `func`.
+
+ This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
+ it cannot represent nested symbols.
+
+ Example:
+
+ ```mlir
+ ml_program.func private @some_extern(i32) -> i32
+ ml_program.func @compute(%arg0 : i32) -> i32 {
+ ml_program.return %arg0 : i32
+ }
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // CallableOpInterface
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() {
+ return isExternal() ? nullptr : &getBody();
+ }
+
+ /// Returns the results types that the callable region produces when
+ /// executed.
+ ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the argument types of this function.
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // RegionKindInterface Methods
+ //===------------------------------------------------------------------===//
+ static ::mlir::RegionKind getRegionKind(unsigned index) {
+ return ::mlir::RegionKind::SSACFG;
+ }
+
+ //===------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ bool isDeclaration() { return isExternal(); }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// SubgraphOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
+ CallableOpInterface, FunctionOpInterface, HasOnlyGraphRegion,
+ IsolatedFromAbove, RegionKindInterface, SingleBlock, Symbol
+ ]> {
+ let summary = "An function containing a single `Graph` region";
+ let description = [{
+ This simple function container represents callables in an ML program where
+ the body is a `Graph` region containing a single block. It must be
+ terminated by an `output` op which yields values with the same arity and
+ types as the `FunctionType` results of the containing `subgraph`.
+
+ This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
+ it cannot represented nested symbols.
+
+ Example:
+
+ ```mlir
+ ml_program.subgraph private @some_extern(i32) -> i32
+ ml_program.subgraph @compute(%arg0 : i32) -> i32 {
+ ml_program.output %arg0 : i32
+ }
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // CallableOpInterface
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+ /// Returns the results types that the callable region produces when
+ /// executed.
+ ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the argument types of this function.
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ bool isDeclaration() { return isExternal(); }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// OutputOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_OutputOp : MLProgram_Op<"output", [
+ NoSideEffect, HasParent<"SubgraphOp">, ReturnLike, Terminator
+ ]> {
+ let summary = "Outputs values from a subgraph function";
+ let description = [{
+ The `output` operation terminates a subgraph by yielding values
+ to the caller.
+ The operation takes variable number of operands and produces no results.
+ The operand number and types must match the signature of the function
+ that contains the operation.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [OpBuilder<(ins), [{
+ build($_builder, $_state, llvm::None);
+ }]>];
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_ReturnOp : MLProgram_Op<"return", [
+ NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator
+ ]> {
+ let summary = "Returns values from a `func` function";
+ let description = [{
+ The `return` operation terminates a `func` function by yielding values
+ to the caller.
+ The operation takes variable number of operands and produces no results.
+ The operand number and types must match the signature of the function
+ that contains the operation.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [OpBuilder<(ins), [{
+ build($_builder, $_state, llvm::None);
+ }]>];
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+ let hasVerifier = 1;
+}
+
+#endif // MLPROGRAM_OPS
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
linalg::LinalgDialect,
math::MathDialect,
memref::MemRefDialect,
+ ml_program::MLProgramDialect,
scf::SCFDialect,
omp::OpenMPDialect,
pdl::PDLDialect,
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
+add_subdirectory(MLProgram)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
--- /dev/null
+add_subdirectory(IR)
--- /dev/null
+add_mlir_dialect_library(MLIRMLProgram
+ MLProgramOps.cpp
+ MLProgramDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram
+
+ DEPENDS
+ MLIRMLProgramOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRDialect
+ MLIRInferTypeOpInterface
+ MLIRIR
+ )
--- /dev/null
+//===- MLProgramDialect.cpp - MLProgram dialect implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+
+using namespace mlir;
+using namespace mlir::ml_program;
+
+#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc"
+
+void ml_program::MLProgramDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
+ >();
+}
--- /dev/null
+//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace mlir::ml_program;
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// SubgraphOp
+//===----------------------------------------------------------------------===//
+
+ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void SubgraphOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// OutputOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OutputOp::verify() {
+ auto function = cast<SubgraphOp>((*this)->getParentOp());
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getFunctionType().getResults();
+ if (getNumOperands() != results.size())
+ return emitOpError("has ")
+ << getNumOperands() << " operands, but enclosing function (@"
+ << function.getName() << ") outputs " << results.size();
+
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (getOperand(i).getType() != results[i])
+ return emitError() << "type of output operand " << i << " ("
+ << getOperand(i).getType()
+ << ") doesn't match function result type ("
+ << results[i] << ")"
+ << " in function @" << function.getName();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReturnOp::verify() {
+ auto function = cast<FuncOp>((*this)->getParentOp());
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getFunctionType().getResults();
+ if (getNumOperands() != results.size())
+ return emitOpError("has ")
+ << getNumOperands() << " operands, but enclosing function (@"
+ << function.getName() << ") returns " << results.size();
+
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (getOperand(i).getType() != results[i])
+ return emitError() << "type of return operand " << i << " ("
+ << getOperand(i).getType()
+ << ") doesn't match function result type ("
+ << results[i] << ")"
+ << " in function @" << function.getName();
+
+ return success();
+}
--- /dev/null
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -verify-diagnostics %s
+
+ml_program.func @ssa_enforced(%arg0 : i32) -> i32 {
+ // expected-error @+1 {{does not dominate this use}}
+ %1 = "unregistered.dummy"(%0) : (i32) -> i32
+ // expected-note @+1 {{operand defined here}}
+ %0 = "unregistered.dummy"(%arg0) : (i32) -> i32
+ ml_program.return %0 : i32
+}
+
+// -----
+ml_program.func @return_arity_match(%arg0 : i32) -> i32 {
+ // expected-error @+1 {{enclosing function (@return_arity_match) returns 1}}
+ ml_program.return %arg0, %arg0 : i32, i32
+}
+
+// -----
+ml_program.func @return_type_match(%arg0 : i64) -> i32 {
+ // expected-error @+1 {{doesn't match function result}}
+ ml_program.return %arg0 : i64
+}
+
+// -----
+ml_program.subgraph @output_arity_match(%arg0 : i32) -> i32 {
+ // expected-error @+1 {{enclosing function (@output_arity_match) outputs 1}}
+ ml_program.output %arg0, %arg0 : i32, i32
+}
+
+// -----
+ml_program.subgraph @output_type_match(%arg0 : i64) -> i32 {
+ // expected-error @+1 {{doesn't match function result}}
+ ml_program.output %arg0 : i64
+}
--- /dev/null
+// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --allow-unregistered-dialect --mlir-print-op-generic | mlir-opt --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: ml_program.func private @extern_func
+ml_program.func private @extern_func(i32) -> i32
+
+// CHECK-LABEL: ml_program.func @defined_func
+ml_program.func @defined_func(%arg0 : i32) -> i32 {
+ ml_program.return %arg0 : i32
+}
+
+// CHECK-LABEL: ml_program.subgraph private @extern_subgraph
+ml_program.subgraph private @extern_subgraph(i32) -> i32
+
+// CHECK-LABEL: ml_program.subgraph @compute_subgraph
+ml_program.subgraph @compute_subgraph(%arg0 : i32) -> i32 {
+ %1 = "unregistered.dummy"(%0) : (i32) -> i32
+ %0 = "unregistered.dummy"(%arg0) : (i32) -> i32
+ ml_program.output %0 : i32
+}
// CHECK-NEXT: llvm
// CHECK-NEXT: math
// CHECK-NEXT: memref
+// CHECK-NEXT: ml_program
// CHECK-NEXT: nvvm
// CHECK-NEXT: omp
// CHECK-NEXT: pdl
":LinalgToSPIRV",
":LinalgToStandard",
":LinalgTransforms",
+ ":MLProgramDialect",
":MathDialect",
":MathToLLVM",
":MathToLibm",
],
)
+##---------------------------------------------------------------------------##
+# MLProgram dialect
+##---------------------------------------------------------------------------##
+
+td_library(
+ name = "MLProgramOpsTdFiles",
+ srcs = [
+ "include/mlir/Dialect/MLProgram/IR/MLProgramBase.td",
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
+ ],
+ includes = ["include"],
+ deps = [
+ ":CallInterfacesTdFiles",
+ ":ControlFlowInterfacesTdFiles",
+ ":FunctionInterfacesTdFiles",
+ ":OpBaseTdFiles",
+ ":RegionKindInterfaceIncGen",
+ ":SideEffectInterfacesTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "MLProgramOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc",
+ ),
+ (
+ ["-gen-dialect-decls"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc",
+ ),
+ (
+ ["-gen-dialect-defs"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
+ deps = [":MLProgramOpsTdFiles"],
+)
+
+cc_library(
+ name = "MLProgramDialect",
+ srcs = glob([
+ "lib/Dialect/MLProgram/IR/*.cpp",
+ "lib/Dialect/MLProgram/IR/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Dialect/MLProgram/IR/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":ControlFlowInterfaces",
+ ":IR",
+ ":MLProgramOpsIncGen",
+ ":Pass",
+ ":Support",
+ "//llvm:Support",
+ ],
+)
+
+##---------------------------------------------------------------------------##
+# Allocation interfaces
+##---------------------------------------------------------------------------##
+
td_library(
name = "AllocationOpInterfaceTdFiles",
srcs = ["include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"],