[mlir][OpDSL] Add type function attributes.
authorgysit <gysit@google.com>
Fri, 25 Feb 2022 08:12:34 +0000 (08:12 +0000)
committergysit <gysit@google.com>
Fri, 25 Feb 2022 08:25:23 +0000 (08:25 +0000)
commit51fdd802c794faf6e5b57cccd6ea181c7f8a83e9
tree6eec6a12cddd18d803af616870d678559698252e
parent3fe6f9388fd3ef61d47b1f8b573596fc44b2f144
[mlir][OpDSL] Add type function attributes.

Previously, OpDSL operation used hardcoded type conversion operations (cast or cast_unsigned). Supporting signed and unsigned casts thus meant implementing two different operations. Type function attributes allow us to define a single operation that has a cast type function attribute which at operation instantiation time may be set to cast or cast_unsigned. We may for example, defina a matmul operation with a cast argument:

```
@linalg_structured_op
def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True),
    cast=TypeFnAttrDef(default=TypeFn.cast)):
  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
```

When instantiating the operation the attribute may be set to the desired cast function:

```
linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
```

The revsion introduces a enum in the Linalg dialect that maps one-by-one to the type functions defined by OpDSL.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D119718
24 files changed:
mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/arguments.py
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
mlir/test/python/dialects/linalg/ops.py
mlir/test/python/integration/dialects/linalg/opsrun.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel