[mlir][linalg] Add symbolic type conversion to linalg named ops.
authorStella Laurenzo <stellaraccident@gmail.com>
Sat, 27 Feb 2021 02:01:15 +0000 (18:01 -0800)
committerStella Laurenzo <stellaraccident@gmail.com>
Sat, 27 Feb 2021 23:52:35 +0000 (15:52 -0800)
commit2ceedc3a201386c6cbbcea5cec3f5e01d04f6445
tree7b2367e39ed31acaac99ffa3462d68f7f1c0d841
parent5867c18e2c0d403b51a594897d56d935286748e4
[mlir][linalg] Add symbolic type conversion to linalg named ops.

This enables this kind of construct in the DSL to generate a named op that is polymorphic over numeric type variables `T` and `U`, generating the correct arithmetic casts at construction time:

```
@tc_def_op
def polymorphic_matmul(A=TensorDef(T1, S.M, S.K),
                       B=TensorDef(T2, S.K, S.N),
                       C=TensorDef(U, S.M, S.N, output=True)):
  implements(ContractionOpInterface)
  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
```

Presently, this only supports type variables that are bound to the element type of one of the arguments, although a further extension that allows binding a type variable to an attribute would allow some more expressiveness and may be useful for some formulations. This is left to a future patch. In addition, this patch does not yet materialize the verifier support which ensures that types are bound correctly (for such simple examples, failing to do so will yield IR that fails verification, it just won't yet fail with a precise error).

Note that the full grid of extensions/truncation/int<->float conversions are supported, but many of them are lossy and higher level code needs to be mindful of numerics (it is not the job of this level).

As-is, this should be sufficient for most integer matmul scenarios we work with in typical quantization schemes.

Differential Revision: https://reviews.llvm.org/D97603
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp