From 4d4cb17da8509156ca690e3d7eaf2e00ab606780 Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 25 Feb 2022 15:04:38 +0000 Subject: [PATCH] [mlir][OpDSL] Refactor function handling. Prepare the OpDSL function handling to introduce more function classes. A follow up commit will split ArithFn into UnaryFn and BinaryFn. This revision prepares the split by adding a function kind enum to handle different function types using a single class on the various levels of the stack (for example, there is now one TensorFn and one ScalarFn). Depends On D119718 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D120108 --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 627 +++++++++++++-------- .../dialects/linalg/opdsl/lang/comprehension.py | 77 ++- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 18 +- .../mlir/dialects/linalg/opdsl/lang/scalar_expr.py | 86 ++- .../test-linalg-ods-yaml-gen.yaml | 16 +- .../python/dialects/linalg/opdsl/assignments.py | 32 +- .../mlir-linalg-ods-yaml-gen.cpp | 80 ++- 7 files changed, 535 insertions(+), 401 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 5ebd121..fed9d39 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -45,29 +45,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: C value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + attr_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - attr_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + attr_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - attr_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul_unsigned @@ -109,29 +113,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: C value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast_unsigned - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: B - fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_matmul @@ -183,51 +191,59 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: C value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: mmt4d @@ -280,29 +296,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: accum value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: accum - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: lhs - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: rhs - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul @@ -345,29 +365,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: C value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_batch_matmul @@ -420,51 +444,59 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: C value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec @@ -505,29 +537,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: x value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: x - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: vecmat @@ -568,29 +604,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: x value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: x - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matvec @@ -632,29 +672,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: C value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: dot @@ -694,29 +738,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: C value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d @@ -757,29 +805,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d @@ -822,29 +874,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d @@ -890,29 +946,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d_nwc_wcf @@ -970,29 +1030,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf @@ -1064,29 +1128,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf_q @@ -1171,51 +1239,59 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw_fchw @@ -1287,29 +1363,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d_ndhwc_dhwcf @@ -1383,29 +1463,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_1d_nwc_wc @@ -1462,29 +1546,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc @@ -1551,29 +1639,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc_q @@ -1651,51 +1743,59 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm @@ -1763,29 +1863,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm_q @@ -1865,51 +1969,59 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_sum @@ -1975,18 +2087,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max @@ -2052,18 +2166,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: max operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max_unsigned @@ -2129,18 +2245,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: max_unsigned operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nchw_max @@ -2206,18 +2324,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: max operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min @@ -2283,18 +2403,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: min operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min_unsigned @@ -2360,18 +2482,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: min_unsigned operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast_unsigned --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_sum @@ -2443,18 +2567,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_max @@ -2526,18 +2652,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: max operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_min @@ -2609,18 +2737,20 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: min operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_tensor @@ -2651,12 +2781,13 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: value - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d @@ -2703,107 +2834,128 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: T operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2147483647 : i64' - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: F64 operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 1 - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 0 - fn_name: cast - !ScalarExpression scalar_arg: seed - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' - fn_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' - fn_name: cast - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: mul operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: sub operands: - !ScalarExpression @@ -2811,15 +2963,15 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: min - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2.3283063999999999E-10 : f64' - fn_name: cast - !ScalarExpression scalar_arg: min - fn_name: cast --- !LinalgOpConfig metadata: !LinalgOpMetadata name: soft_plus_2d @@ -2852,28 +3004,33 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: log operands: - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_const: '1.000000e+00 : f64' - fn_name: cast - !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: exp operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - fn_name: cast diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 68c0880..d26aa07 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -133,55 +133,36 @@ class TensorUse(TensorExpression): f"[{', '.join([repr(i) for i in self.indices])}]") -class TensorArithFn(TensorExpression): - """Application of an arithmetic function.""" +class TensorFn(TensorExpression): + """Application of a tensor function.""" - def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]): - self.arith_fn = arith_fn - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArithFn(self.arith_fn.fn_name, - *[arg.to_scalar_expression() for arg in self.args - ]).expr() - - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - super().visit_tensor_exprs(callback) - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" - - -class TensorTypeFn(TensorExpression): - """Application of a type conversion function.""" - - def __init__(self, type_fn: Optional["TypeFn"], - operand_def: Optional["OperandDef"], type_var: TypeVar, - arg: TensorExpression): - if bool(type_fn) + bool(operand_def) != 1: - raise ValueError("Either 'type_fn' or 'operand_def' must be specified") - self.type_fn = type_fn + def __init__(self, kind: "FunctionKind", name: Optional[str], + operand_def: Optional["OperandDef"], type_var: Optional[TypeVar], + args: Sequence[TensorExpression]): + if bool(name) + bool(operand_def) != 1: + raise ValueError("One of 'name', 'operand_def' must be specified") + self.name = name + self.kind = kind self.operand_def = operand_def self.type_var = type_var - self.arg = arg + self.args = args def to_scalar_expression(self) -> ScalarExpression: if self.operand_def: - assert self.operand_def.name, "TypeFnAttr not registered with an op" - fn_name = self.type_fn.fn_name if self.type_fn else None + assert self.operand_def.name, "TensorFn not registered with an op" attr_name = self.operand_def.name if self.operand_def else None - return ScalarTypeFn(fn_name, attr_name, self.type_var, - self.arg.to_scalar_expression()).expr() + args = [arg.to_scalar_expression() for arg in self.args] + return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): super().visit_tensor_exprs(callback) - self.arg.visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) def __repr__(self): - return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]" - f"({self.type_var}, {self.arg})") + name = self.operand_def.name if self.operand_def else self.name + return (f"{self.kind.name}.{name}(type_var={self.type_var}, " + f"args={', '.join(repr(a) for a in self.args)})") class TensorReduceFn(TensorExpression): @@ -194,7 +175,7 @@ class TensorReduceFn(TensorExpression): args: Sequence[TensorExpression]): self.reduce_use = reduce_use self.lhs = None # type: Optional[TensorUse] - self.args = tuple(args) + self.args = args def to_scalar_expression(self) -> ScalarExpression: if self.lhs is None: @@ -202,7 +183,8 @@ class TensorReduceFn(TensorExpression): f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() + return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None, + None, full_args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): for arg in self.args: @@ -259,6 +241,11 @@ class index(TensorExpression): ############################################################################### +class FunctionKind(Enum): + ARITH = 0 + TYPE = 1 + + class TypeFnType: """Type conversion function. @@ -269,8 +256,8 @@ class TypeFnType: def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType": - return TensorTypeFn(self, None, type_var, arg) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) def __repr__(self): return f"{self.fn_name}" @@ -301,8 +288,8 @@ class ArithFnType: def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, *args) -> "TensorArithFn": - return TensorArithFn(self, args) + def __call__(self, *args) -> "TensorFn": + return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args) def __repr__(self): return f"{self.fn_name}" @@ -562,8 +549,8 @@ class TypeFnAttrDef: self.operand_def = OperandDef( OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn: - return TensorTypeFn(None, self.operand_def, type_var, arg) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) ############################################################################### diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index fc8c13b..07050f5 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -270,19 +270,19 @@ class _BodyBuilder: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.arith_fn: - fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}") + elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH: + fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}") operand_values = [ - self.expression(operand) for operand in expr.arith_fn.operands + self.expression(operand) for operand in expr.scalar_fn.operands ] return fn(*operand_values) - elif expr.type_fn: - fn_name = expr.type_fn.fn_name - if expr.type_fn.attr_name: - fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name] + elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE: + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name] fn = self._get_function(f"_typefn_{fn_name}") - operand = self.expression(expr.type_fn.operand) - return fn(expr.type_fn.type_var.name, operand) + operand_value = self.expression(expr.scalar_fn.operands[0]) + return fn(expr.scalar_fn.type_var.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") def yield_outputs(self, *output_names: str): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index af21b40..aa894dc 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -15,13 +15,13 @@ can be easily consumed from the C++ side, not necessarily for ergonomics. from typing import Optional, Sequence -from .yaml_helper import * +from .comprehension import * from .types import * +from .yaml_helper import * __all__ = [ "ScalarAssign", - "ScalarArithFn", - "ScalarTypeFn", + "ScalarFn", "ScalarArg", "ScalarConst", "ScalarIndex", @@ -29,36 +29,27 @@ __all__ = [ ] -class ScalarArithFn: - """A type of ScalarExpression that applies an arithmetic function.""" - - def __init__(self, fn_name: str, *operands: "ScalarExpression"): - self.fn_name = fn_name - self.operands = operands - - def expr(self) -> "ScalarExpression": - return ScalarExpression(arith_fn=self) - - def __repr__(self): - return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})" - - -class ScalarTypeFn: - """A type of ScalarExpression that applies a type conversion function.""" +class ScalarFn: + """A type of ScalarExpression that applies a function.""" - def __init__(self, fn_name: Optional[str], attr_name: Optional[str], - type_var: TypeVar, operand: "ScalarExpression"): + def __init__(self, kind: "FunctionKind", fn_name: Optional[str], + attr_name: Optional[str], type_var: Optional["TypeVar"], + operands: Sequence["ScalarExpression"]): + if bool(fn_name) + bool(attr_name) != 1: + raise ValueError("One of 'fn_name', 'attr_name' must be specified") + self.kind = kind self.fn_name = fn_name self.attr_name = attr_name self.type_var = type_var - self.operand = operand + self.operands = operands def expr(self) -> "ScalarExpression": - return ScalarExpression(type_fn=self) + return ScalarExpression(scalar_fn=self) def __repr__(self): - return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>" - f"({self.type_var}, {self.operand})") + name = self.fn_name if self.fn_name else self.attr_name + return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " + f"operands=[{', '.join(self.operands)}])") class ScalarArg: @@ -104,51 +95,38 @@ class ScalarExpression(YAMLObject): """An expression on scalar values. Can be one of: - - ScalarArithFn - - ScalarTypeFn + - ScalarFn - ScalarArg - ScalarConst - ScalarIndex - - ScalarSymbolicCast """ yaml_tag = "!ScalarExpression" def __init__(self, - arith_fn: Optional[ScalarArithFn] = None, - type_fn: Optional[ScalarTypeFn] = None, + scalar_fn: Optional[ScalarFn] = None, scalar_arg: Optional[ScalarArg] = None, scalar_const: Optional[ScalarConst] = None, scalar_index: Optional[ScalarIndex] = None): - if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) + + if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index)) != 1: - raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', " - "'scalar_const', 'scalar_index', must be specified") - self.arith_fn = arith_fn - self.type_fn = type_fn + raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " + "'scalar_index' must be specified") + self.scalar_fn = scalar_fn self.scalar_arg = scalar_arg self.scalar_const = scalar_const self.scalar_index = scalar_index def to_yaml_custom_dict(self): - if self.arith_fn: - return dict( - arith_fn=dict( - fn_name=self.arith_fn.fn_name, - operands=list(self.arith_fn.operands), - )) - if self.type_fn: - # Note that even though operands must be arity 1, we write it the - # same way as for apply because it allows handling code to be more - # generic vs having a special form. - type_fn_dict = dict( - type_var=self.type_fn.type_var.name, - operands=[self.type_fn.operand], - ) - if self.type_fn.fn_name: - type_fn_dict["fn_name"] = self.type_fn.fn_name - if self.type_fn.attr_name: - type_fn_dict["attr_name"] = self.type_fn.attr_name - return dict(type_fn=type_fn_dict) + if self.scalar_fn: + scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) + if self.scalar_fn.fn_name: + scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name + if self.scalar_fn.attr_name: + scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name + if self.scalar_fn.type_var: + scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name + scalar_fn_dict["operands"] = list(self.scalar_fn.operands) + return dict(scalar_fn=scalar_fn_dict) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_const: diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index f4019e8..660637e 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -39,23 +39,26 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - arith_fn: + scalar_fn: + kind: arith fn_name: add operands: - !ScalarExpression - type_fn: + scalar_fn: + kind: type + attr_name: cast type_var: T operands: - !ScalarExpression scalar_const: '42 : i64' - attr_name: cast - !ScalarExpression - type_fn: + scalar_fn: + kind: type + attr_name: cast type_var: T operands: - !ScalarExpression scalar_index: 1 - attr_name: cast # ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1" @@ -236,7 +239,8 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - type_fn: + scalar_fn: + kind: type fn_name: cast type_var: U operands: diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py index 926a05e..5b87216 100644 --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -9,22 +9,24 @@ from mlir.dialects.linalg.opdsl.lang import * # CHECK: - # CHECK: arg: C # CHECK: value: -# CHECK: arith_fn: +# CHECK: scalar_fn: # CHECK: fn_name: add # CHECK: operands: -# CHECK: arith_fn: +# CHECK: scalar_fn: # CHECK: fn_name: mul # CHECK: operands: -# CHECK: type_fn: +# CHECK: scalar_fn: +# CHECK: kind: type +# CHECK: attr_name: cast # CHECK: type_var: U # CHECK: operands: # CHECK: scalar_arg: A +# CHECK: scalar_fn: +# CHECK: kind: type # CHECK: attr_name: cast -# CHECK: type_fn: # CHECK: type_var: U # CHECK: operands: # CHECK: scalar_arg: B -# CHECK: attr_name: cast @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), @@ -39,21 +41,28 @@ def matmul( # CHECK: assignments: # CHECK: - # CHECK: arg: O -# CHECK: arith_fn: +# CHECK: scalar_fn: +# CHECK: kind: arith # CHECK: fn_name: sub # CHECK: operands: -# CHECK: arith_fn: +# CHECK: scalar_fn: +# CHECK: kind: arith # CHECK: fn_name: add # CHECK: operands: -# CHECK: type_fn: +# CHECK: scalar_fn: +# CHECK: kind: type # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_const: '3.1415926535897931 : f64' -# CHECK: type_fn: +# CHECK: scalar_fn: +# CHECK: kind: type +# CHECK: fn_name: cast # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_const: '42 : i64' -# CHECK: type_fn: +# CHECK: scalar_fn: +# CHECK: kind: type +# CHECK: fn_name: cast # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' @@ -70,7 +79,8 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)): # CHECK: assignments: # CHECK: - # CHECK: arg: O -# CHECK: arith_fn: +# CHECK: scalar_fn: +# CHECK: kind: arith # CHECK: fn_name: add # CHECK: operands: # CHECK: scalar_index: 1 diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 7c850e6..d1fc9ac 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -90,28 +90,23 @@ struct LinalgIndexingMapsConfig { struct ScalarExpression; -struct ScalarArithFn { - std::string fnName; - // NOTE: Must be pure heap allocated container (not SmallVector) - // due to recursive data type. - std::vector operands; -}; +enum class ScalarFnKind { Arith, Type }; -struct ScalarTypeFn { - std::string typeVar; +struct ScalarFn { + ScalarFnKind kind; + Optional fnName; + Optional attrName; + Optional typeVar; // NOTE: This must be of arity 1, but to break the self-referential cycle, // we use a heap allocated vector. std::vector operands; - Optional fnName; - Optional attrName; }; struct ScalarExpression { Optional arg; Optional constant; Optional index; - Optional arithFn; - Optional typeFn; + Optional scalarFn; }; struct ScalarAssign { @@ -265,16 +260,23 @@ struct MappingTraits { /// - `scalar_arg`: An operation argument. /// - `scalar_const`: A constant definition. /// - `scalar_index`: An iteration index. -/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`). -/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`). +/// - `scalar_fn`: A named function (see `ScalarFn`). template <> struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { io.mapOptional("scalar_arg", info.arg); io.mapOptional("scalar_const", info.constant); io.mapOptional("scalar_index", info.index); - io.mapOptional("arith_fn", info.arithFn); - io.mapOptional("type_fn", info.typeFn); + io.mapOptional("scalar_fn", info.scalarFn); + } +}; + +/// Scalar function kind enum. +template <> +struct ScalarEnumerationTraits { + static void enumeration(IO &io, ScalarFnKind &value) { + io.enumCase(value, "arith", ScalarFnKind::Arith); + io.enumCase(value, "type", ScalarFnKind::Type); } }; @@ -284,20 +286,13 @@ struct MappingTraits { /// - `add(lhs, rhs)` /// - `mul(lhs, rhs)` template <> -struct MappingTraits { - static void mapping(IO &io, ScalarArithFn &info) { - io.mapRequired("fn_name", info.fnName); - io.mapRequired("operands", info.operands); - } -}; - -template <> -struct MappingTraits { - static void mapping(IO &io, ScalarTypeFn &info) { - io.mapRequired("type_var", info.typeVar); - io.mapRequired("operands", info.operands); +struct MappingTraits { + static void mapping(IO &io, ScalarFn &info) { + io.mapRequired("kind", info.kind); io.mapOptional("fn_name", info.fnName); io.mapOptional("attr_name", info.attrName); + io.mapOptional("type_var", info.typeVar); + io.mapRequired("operands", info.operands); } }; @@ -1060,11 +1055,12 @@ if ({0}Iter != attrs.end()) {{ cppIdent, *expression.index)); return cppIdent; } - if (expression.arithFn) { + if (expression.scalarFn && + expression.scalarFn->kind == ScalarFnKind::Arith) { // Apply function. // Recursively generate operands. SmallVector operandCppValues; - for (ScalarExpression &operand : expression.arithFn->operands) { + for (ScalarExpression &operand : expression.scalarFn->operands) { auto operandCppValue = generateExpression(operand); if (!operandCppValue) return None; @@ -1073,28 +1069,30 @@ if ({0}Iter != attrs.end()) {{ std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back( llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent, - expression.arithFn->fnName, + expression.scalarFn->fnName, interleaveToString(operandCppValues, ", "))); return cppIdent; } - if (expression.typeFn) { + if (expression.scalarFn && + expression.scalarFn->kind == ScalarFnKind::Type) { // Symbolic cast. // Operands must be arity 1. - if (expression.typeFn->operands.size() != 1) { + if (expression.scalarFn->operands.size() != 1) { emitError(genContext.getLoc()) << "type conversion operand arity must be 1"; return None; } Optional operandCppValue = - generateExpression(expression.typeFn->operands[0]); + generateExpression(expression.scalarFn->operands[0]); if (!operandCppValue) return None; + assert(expression.scalarFn->typeVar.hasValue()); Optional typeCppValue = - findTypeValue(expression.typeFn->typeVar, args); + findTypeValue(expression.scalarFn->typeVar.getValue(), args); if (!typeCppValue) { emitError(genContext.getLoc()) - << "type variable " << expression.typeFn->typeVar + << "type variable " << expression.scalarFn->typeVar.getValue() << ", used in a type conversion, must map to a predefined or " << "an argument type but it does not"; return None; @@ -1102,17 +1100,17 @@ if ({0}Iter != attrs.end()) {{ // Use the function name or the attribute to build the type function. std::string typeFunc = llvm::formatv( - "TypeFn::{0}", expression.typeFn->fnName.getValueOr("")); - if (expression.typeFn->attrName) { + "TypeFn::{0}", expression.scalarFn->fnName.getValueOr("")); + if (expression.scalarFn->attrName) { if (llvm::none_of(args, [&](LinalgOperandDef &arg) { return arg.kind == LinalgOperandDefKind::TypeFnAttr && - arg.name == expression.typeFn->attrName.getValue(); + arg.name == expression.scalarFn->attrName.getValue(); })) { emitError(genContext.getLoc()) << "missing type function attribute " - << expression.typeFn->attrName.getValue(); + << expression.scalarFn->attrName.getValue(); } - typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName); + typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName); } std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back(llvm::formatv( -- 2.7.4