[mlir][OpDSL] Refactor function handling.
authorgysit <gysit@google.com>
Fri, 25 Feb 2022 15:04:38 +0000 (15:04 +0000)
committergysit <gysit@google.com>
Fri, 25 Feb 2022 15:05:32 +0000 (15:05 +0000)
Prepare the OpDSL function handling to introduce more function classes. A follow up commit will split ArithFn into UnaryFn and BinaryFn. This revision prepares the split by adding a function kind enum to handle different function types using a single class on the various levels of the stack (for example, there is now one TensorFn and one ScalarFn).

Depends On D119718

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D120108

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index 5ebd121..fed9d39 100644 (file)
@@ -45,29 +45,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: C
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                attr_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
-                attr_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                attr_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
-                attr_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul_unsigned
@@ -109,29 +113,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: C
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast_unsigned
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
-                fn_name: cast_unsigned
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast_unsigned
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
-                fn_name: cast_unsigned
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
@@ -183,51 +191,59 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: C
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: A
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: AZp
-                    fn_name: cast
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: B
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: BZp
-                    fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: mmt4d
@@ -280,29 +296,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: accum
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: accum
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: AccumType
                 operands:
                 - !ScalarExpression
                   scalar_arg: lhs
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: AccumType
                 operands:
                 - !ScalarExpression
                   scalar_arg: rhs
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matmul
@@ -345,29 +365,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: C
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_batch_matmul
@@ -420,51 +444,59 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: C
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: A
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: AZp
-                    fn_name: cast
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: B
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: BZp
-                    fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
@@ -505,29 +537,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: x
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: x
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: y
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: vecmat
@@ -568,29 +604,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: x
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: x
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: y
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: batch_matvec
@@ -632,29 +672,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: C
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: dot
@@ -694,29 +738,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: C
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: C
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: A
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: B
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_1d
@@ -757,29 +805,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d
@@ -822,29 +874,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_3d
@@ -890,29 +946,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_1d_nwc_wcf
@@ -970,29 +1030,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nhwc_hwcf
@@ -1064,29 +1128,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nhwc_hwcf_q
@@ -1171,51 +1239,59 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: IZp
-                    fn_name: cast
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: K
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: KZp
-                    fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nchw_fchw
@@ -1287,29 +1363,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_3d_ndhwc_dhwcf
@@ -1383,29 +1463,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_1d_nwc_wc
@@ -1462,29 +1546,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_nhwc_hwc
@@ -1551,29 +1639,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_nhwc_hwc_q
@@ -1651,51 +1743,59 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: IZp
-                    fn_name: cast
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: K
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: KZp
-                    fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_nhwc_hwcm
@@ -1763,29 +1863,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: I
-                fn_name: cast
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_arg: K
-                fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_nhwc_hwcm_q
@@ -1865,51 +1969,59 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: mul
             operands:
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: IZp
-                    fn_name: cast
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: sub
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: K
-                    fn_name: cast
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: KZp
-                    fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_sum
@@ -1975,18 +2087,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_max
@@ -2052,18 +2166,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: max
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_max_unsigned
@@ -2129,18 +2245,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: max_unsigned
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast_unsigned
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast_unsigned
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nchw_max
@@ -2206,18 +2324,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: max
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_min
@@ -2283,18 +2403,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: min
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_min_unsigned
@@ -2360,18 +2482,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: min_unsigned
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast_unsigned
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast_unsigned
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_ndhwc_sum
@@ -2443,18 +2567,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_ndhwc_max
@@ -2526,18 +2652,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: max
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_ndhwc_min
@@ -2609,18 +2737,20 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: min
         operands:
         - !ScalarExpression
           scalar_arg: O
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            fn_name: cast
             type_var: U
             operands:
             - !ScalarExpression
               scalar_arg: I
-            fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_tensor
@@ -2651,12 +2781,13 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      type_fn:
+      scalar_fn:
+        kind: type
+        fn_name: cast
         type_var: U
         operands:
         - !ScalarExpression
           scalar_arg: value
-        fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d
@@ -2703,107 +2834,128 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      type_fn:
+      scalar_fn:
+        kind: type
+        fn_name: cast
         type_var: T
         operands:
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: add
             operands:
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: mul
                 operands:
                 - !ScalarExpression
-                  arith_fn:
+                  scalar_fn:
+                    kind: arith
                     fn_name: add
                     operands:
                     - !ScalarExpression
-                      type_fn:
+                      scalar_fn:
+                        kind: type
+                        fn_name: cast
                         type_var: F64
                         operands:
                         - !ScalarExpression
                           scalar_const: '2147483647 : i64'
-                        fn_name: cast
                     - !ScalarExpression
-                      type_fn:
+                      scalar_fn:
+                        kind: type
+                        fn_name: cast
                         type_var: F64
                         operands:
                         - !ScalarExpression
-                          arith_fn:
+                          scalar_fn:
+                            kind: arith
                             fn_name: add
                             operands:
                             - !ScalarExpression
-                              arith_fn:
+                              scalar_fn:
+                                kind: arith
                                 fn_name: mul
                                 operands:
                                 - !ScalarExpression
-                                  arith_fn:
+                                  scalar_fn:
+                                    kind: arith
                                     fn_name: add
                                     operands:
                                     - !ScalarExpression
-                                      type_fn:
+                                      scalar_fn:
+                                        kind: type
+                                        fn_name: cast
                                         type_var: I32
                                         operands:
                                         - !ScalarExpression
                                           scalar_index: 1
-                                        fn_name: cast
                                     - !ScalarExpression
-                                      arith_fn:
+                                      scalar_fn:
+                                        kind: arith
                                         fn_name: add
                                         operands:
                                         - !ScalarExpression
-                                          arith_fn:
+                                          scalar_fn:
+                                            kind: arith
                                             fn_name: mul
                                             operands:
                                             - !ScalarExpression
-                                              arith_fn:
+                                              scalar_fn:
+                                                kind: arith
                                                 fn_name: add
                                                 operands:
                                                 - !ScalarExpression
-                                                  type_fn:
+                                                  scalar_fn:
+                                                    kind: type
+                                                    fn_name: cast
                                                     type_var: I32
                                                     operands:
                                                     - !ScalarExpression
                                                       scalar_index: 0
-                                                    fn_name: cast
                                                 - !ScalarExpression
                                                   scalar_arg: seed
                                             - !ScalarExpression
-                                              type_fn:
+                                              scalar_fn:
+                                                kind: type
+                                                fn_name: cast
                                                 type_var: I32
                                                 operands:
                                                 - !ScalarExpression
                                                   scalar_const: '1103515245 : i64'
-                                                fn_name: cast
                                         - !ScalarExpression
-                                          type_fn:
+                                          scalar_fn:
+                                            kind: type
+                                            fn_name: cast
                                             type_var: I32
                                             operands:
                                             - !ScalarExpression
                                               scalar_const: '12345 : i64'
-                                            fn_name: cast
                                 - !ScalarExpression
-                                  type_fn:
+                                  scalar_fn:
+                                    kind: type
+                                    fn_name: cast
                                     type_var: I32
                                     operands:
                                     - !ScalarExpression
                                       scalar_const: '1103515245 : i64'
-                                    fn_name: cast
                             - !ScalarExpression
-                              type_fn:
+                              scalar_fn:
+                                kind: type
+                                fn_name: cast
                                 type_var: I32
                                 operands:
                                 - !ScalarExpression
                                   scalar_const: '12345 : i64'
-                                fn_name: cast
-                        fn_name: cast
                 - !ScalarExpression
-                  arith_fn:
+                  scalar_fn:
+                    kind: arith
                     fn_name: mul
                     operands:
                     - !ScalarExpression
-                      arith_fn:
+                      scalar_fn:
+                        kind: arith
                         fn_name: sub
                         operands:
                         - !ScalarExpression
@@ -2811,15 +2963,15 @@ structured_op: !LinalgStructuredOpConfig
                         - !ScalarExpression
                           scalar_arg: min
                     - !ScalarExpression
-                      type_fn:
+                      scalar_fn:
+                        kind: type
+                        fn_name: cast
                         type_var: F64
                         operands:
                         - !ScalarExpression
                           scalar_const: '2.3283063999999999E-10 : f64'
-                        fn_name: cast
             - !ScalarExpression
               scalar_arg: min
-        fn_name: cast
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: soft_plus_2d
@@ -2852,28 +3004,33 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: log
         operands:
         - !ScalarExpression
-          arith_fn:
+          scalar_fn:
+            kind: arith
             fn_name: add
             operands:
             - !ScalarExpression
-              type_fn:
+              scalar_fn:
+                kind: type
+                fn_name: cast
                 type_var: U
                 operands:
                 - !ScalarExpression
                   scalar_const: '1.000000e+00 : f64'
-                fn_name: cast
             - !ScalarExpression
-              arith_fn:
+              scalar_fn:
+                kind: arith
                 fn_name: exp
                 operands:
                 - !ScalarExpression
-                  type_fn:
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast
                     type_var: U
                     operands:
                     - !ScalarExpression
                       scalar_arg: I
-                    fn_name: cast
index 68c0880..d26aa07 100644 (file)
@@ -133,55 +133,36 @@ class TensorUse(TensorExpression):
             f"[{', '.join([repr(i) for i in self.indices])}]")
 
 
-class TensorArithFn(TensorExpression):
-  """Application of an arithmetic function."""
+class TensorFn(TensorExpression):
+  """Application of a tensor function."""
 
-  def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]):
-    self.arith_fn = arith_fn
-    self.args = tuple(args)
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarArithFn(self.arith_fn.fn_name,
-                         *[arg.to_scalar_expression() for arg in self.args
-                          ]).expr()
-
-  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
-    super().visit_tensor_exprs(callback)
-    for arg in self.args:
-      arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
-
-
-class TensorTypeFn(TensorExpression):
-  """Application of a type conversion function."""
-
-  def __init__(self, type_fn: Optional["TypeFn"],
-               operand_def: Optional["OperandDef"], type_var: TypeVar,
-               arg: TensorExpression):
-    if bool(type_fn) + bool(operand_def) != 1:
-      raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
-    self.type_fn = type_fn
+  def __init__(self, kind: "FunctionKind", name: Optional[str],
+               operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
+               args: Sequence[TensorExpression]):
+    if bool(name) + bool(operand_def) != 1:
+      raise ValueError("One of 'name', 'operand_def' must be specified")
+    self.name = name
+    self.kind = kind
     self.operand_def = operand_def
     self.type_var = type_var
-    self.arg = arg
+    self.args = args
 
   def to_scalar_expression(self) -> ScalarExpression:
     if self.operand_def:
-      assert self.operand_def.name, "TypeFnAttr not registered with an op"
-    fn_name = self.type_fn.fn_name if self.type_fn else None
+      assert self.operand_def.name, "TensorFn not registered with an op"
     attr_name = self.operand_def.name if self.operand_def else None
-    return ScalarTypeFn(fn_name, attr_name, self.type_var,
-                        self.arg.to_scalar_expression()).expr()
+    args = [arg.to_scalar_expression() for arg in self.args]
+    return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
 
   def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
     super().visit_tensor_exprs(callback)
-    self.arg.visit_tensor_exprs(callback)
+    for arg in self.args:
+      arg.visit_tensor_exprs(callback)
 
   def __repr__(self):
-    return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
-            f"({self.type_var}, {self.arg})")
+    name = self.operand_def.name if self.operand_def else self.name
+    return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
+            f"args={', '.join(repr(a) for a in self.args)})")
 
 
 class TensorReduceFn(TensorExpression):
@@ -194,7 +175,7 @@ class TensorReduceFn(TensorExpression):
                args: Sequence[TensorExpression]):
     self.reduce_use = reduce_use
     self.lhs = None  # type: Optional[TensorUse]
-    self.args = tuple(args)
+    self.args = args
 
   def to_scalar_expression(self) -> ScalarExpression:
     if self.lhs is None:
@@ -202,7 +183,8 @@ class TensorReduceFn(TensorExpression):
                        f"bound to its lhs: {self}")
     full_args = [self.lhs.to_scalar_expression()
                 ] + [arg.to_scalar_expression() for arg in self.args]
-    return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
+    return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None,
+                    None, full_args).expr()
 
   def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
     for arg in self.args:
@@ -259,6 +241,11 @@ class index(TensorExpression):
 ###############################################################################
 
 
+class FunctionKind(Enum):
+  ARITH = 0
+  TYPE = 1
+
+
 class TypeFnType:
   """Type conversion function.
 
@@ -269,8 +256,8 @@ class TypeFnType:
   def __init__(self, fn_name: str):
     self.fn_name = fn_name
 
-  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
-    return TensorTypeFn(self, None, type_var, arg)
+  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+    return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
 
   def __repr__(self):
     return f"{self.fn_name}"
@@ -301,8 +288,8 @@ class ArithFnType:
   def __init__(self, fn_name: str):
     self.fn_name = fn_name
 
-  def __call__(self, *args) -> "TensorArithFn":
-    return TensorArithFn(self, args)
+  def __call__(self, *args) -> "TensorFn":
+    return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args)
 
   def __repr__(self):
     return f"{self.fn_name}"
@@ -562,8 +549,8 @@ class TypeFnAttrDef:
     self.operand_def = OperandDef(
         OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
 
-  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
-    return TensorTypeFn(None, self.operand_def, type_var, arg)
+  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
+    return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
 
 
 ###############################################################################
index fc8c13b..07050f5 100644 (file)
@@ -270,19 +270,19 @@ class _BodyBuilder:
       dim_attr = IntegerAttr.get(
           IntegerType.get_signless(64), expr.scalar_index.dim)
       return linalg.IndexOp(dim_attr).result
-    elif expr.arith_fn:
-      fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
+    elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH:
+      fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}")
       operand_values = [
-          self.expression(operand) for operand in expr.arith_fn.operands
+          self.expression(operand) for operand in expr.scalar_fn.operands
       ]
       return fn(*operand_values)
-    elif expr.type_fn:
-      fn_name = expr.type_fn.fn_name
-      if expr.type_fn.attr_name:
-        fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
+    elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE:
+      fn_name = expr.scalar_fn.fn_name
+      if expr.scalar_fn.attr_name:
+        fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
       fn = self._get_function(f"_typefn_{fn_name}")
-      operand = self.expression(expr.type_fn.operand)
-      return fn(expr.type_fn.type_var.name, operand)
+      operand_value = self.expression(expr.scalar_fn.operands[0])
+      return fn(expr.scalar_fn.type_var.name, operand_value)
     raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
 
   def yield_outputs(self, *output_names: str):
index af21b40..aa894dc 100644 (file)
@@ -15,13 +15,13 @@ can be easily consumed from the C++ side, not necessarily for ergonomics.
 
 from typing import Optional, Sequence
 
-from .yaml_helper import *
+from .comprehension import *
 from .types import *
+from .yaml_helper import *
 
 __all__ = [
     "ScalarAssign",
-    "ScalarArithFn",
-    "ScalarTypeFn",
+    "ScalarFn",
     "ScalarArg",
     "ScalarConst",
     "ScalarIndex",
@@ -29,36 +29,27 @@ __all__ = [
 ]
 
 
-class ScalarArithFn:
-  """A type of ScalarExpression that applies an arithmetic function."""
-
-  def __init__(self, fn_name: str, *operands: "ScalarExpression"):
-    self.fn_name = fn_name
-    self.operands = operands
-
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(arith_fn=self)
-
-  def __repr__(self):
-    return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
-
-
-class ScalarTypeFn:
-  """A type of ScalarExpression that applies a type conversion function."""
+class ScalarFn:
+  """A type of ScalarExpression that applies a function."""
 
-  def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
-               type_var: TypeVar, operand: "ScalarExpression"):
+  def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
+               attr_name: Optional[str], type_var: Optional["TypeVar"],
+               operands: Sequence["ScalarExpression"]):
+    if bool(fn_name) + bool(attr_name) != 1:
+      raise ValueError("One of 'fn_name', 'attr_name' must be specified")
+    self.kind = kind
     self.fn_name = fn_name
     self.attr_name = attr_name
     self.type_var = type_var
-    self.operand = operand
+    self.operands = operands
 
   def expr(self) -> "ScalarExpression":
-    return ScalarExpression(type_fn=self)
+    return ScalarExpression(scalar_fn=self)
 
   def __repr__(self):
-    return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
-            f"({self.type_var}, {self.operand})")
+    name = self.fn_name if self.fn_name else self.attr_name
+    return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
+            f"operands=[{', '.join(self.operands)}])")
 
 
 class ScalarArg:
@@ -104,51 +95,38 @@ class ScalarExpression(YAMLObject):
   """An expression on scalar values.
 
   Can be one of:
-    - ScalarArithFn
-    - ScalarTypeFn
+    - ScalarFn
     - ScalarArg
     - ScalarConst
     - ScalarIndex
-    - ScalarSymbolicCast
   """
   yaml_tag = "!ScalarExpression"
 
   def __init__(self,
-               arith_fn: Optional[ScalarArithFn] = None,
-               type_fn: Optional[ScalarTypeFn] = None,
+               scalar_fn: Optional[ScalarFn] = None,
                scalar_arg: Optional[ScalarArg] = None,
                scalar_const: Optional[ScalarConst] = None,
                scalar_index: Optional[ScalarIndex] = None):
-    if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
+    if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
         bool(scalar_index)) != 1:
-      raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
-                       "'scalar_const', 'scalar_index', must be specified")
-    self.arith_fn = arith_fn
-    self.type_fn = type_fn
+      raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
+                       "'scalar_index' must be specified")
+    self.scalar_fn = scalar_fn
     self.scalar_arg = scalar_arg
     self.scalar_const = scalar_const
     self.scalar_index = scalar_index
 
   def to_yaml_custom_dict(self):
-    if self.arith_fn:
-      return dict(
-          arith_fn=dict(
-              fn_name=self.arith_fn.fn_name,
-              operands=list(self.arith_fn.operands),
-          ))
-    if self.type_fn:
-      # Note that even though operands must be arity 1, we write it the
-      # same way as for apply because it allows handling code to be more
-      # generic vs having a special form.
-      type_fn_dict = dict(
-          type_var=self.type_fn.type_var.name,
-          operands=[self.type_fn.operand],
-      )
-      if self.type_fn.fn_name:
-        type_fn_dict["fn_name"] = self.type_fn.fn_name
-      if self.type_fn.attr_name:
-        type_fn_dict["attr_name"] = self.type_fn.attr_name
-      return dict(type_fn=type_fn_dict)
+    if self.scalar_fn:
+      scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
+      if self.scalar_fn.fn_name:
+        scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
+      if self.scalar_fn.attr_name:
+        scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
+      if self.scalar_fn.type_var:
+        scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
+      scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
+      return dict(scalar_fn=scalar_fn_dict)
     elif self.scalar_arg:
       return dict(scalar_arg=self.scalar_arg.arg)
     elif self.scalar_const:
index f4019e8..660637e 100644 (file)
@@ -39,23 +39,26 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      arith_fn:
+      scalar_fn:
+        kind: arith
         fn_name: add
         operands:
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            attr_name: cast
             type_var: T
             operands:
             - !ScalarExpression
               scalar_const: '42 : i64'
-            attr_name: cast
         - !ScalarExpression
-          type_fn:
+          scalar_fn:
+            kind: type
+            attr_name: cast
             type_var: T
             operands:
             - !ScalarExpression
               scalar_index: 1
-            attr_name: cast
 
 # ODS-LABEL:  def Test1Op : LinalgStructuredBase_Op<"test1"
 
@@ -236,7 +239,8 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      type_fn:
+      scalar_fn:
+        kind: type
         fn_name: cast
         type_var: U
         operands:
index 926a05e..5b87216 100644 (file)
@@ -9,22 +9,24 @@ from mlir.dialects.linalg.opdsl.lang import *
 # CHECK:  -
 # CHECK:    arg: C
 # CHECK:    value:
-# CHECK:      arith_fn:
+# CHECK:      scalar_fn:
 # CHECK:        fn_name: add
 # CHECK:        operands:
-# CHECK:          arith_fn:
+# CHECK:          scalar_fn:
 # CHECK:            fn_name: mul
 # CHECK:            operands:
-# CHECK:              type_fn:
+# CHECK:              scalar_fn:
+# CHECK:                kind: type
+# CHECK:                attr_name: cast
 # CHECK:                type_var: U
 # CHECK:                operands:
 # CHECK:                  scalar_arg: A
+# CHECK:              scalar_fn:
+# CHECK:                kind: type
 # CHECK:                attr_name: cast
-# CHECK:              type_fn:
 # CHECK:                type_var: U
 # CHECK:                operands:
 # CHECK:                  scalar_arg: B
-# CHECK:                attr_name: cast
 @linalg_structured_op
 def matmul(
     A=TensorDef(T, S.M, S.K),
@@ -39,21 +41,28 @@ def matmul(
 # CHECK: assignments:
 # CHECK:  -
 # CHECK:    arg: O
-# CHECK:      arith_fn:
+# CHECK:      scalar_fn:
+# CHECK:        kind: arith
 # CHECK:        fn_name: sub
 # CHECK:        operands:
-# CHECK:          arith_fn:
+# CHECK:          scalar_fn:
+# CHECK:            kind: arith
 # CHECK:            fn_name: add
 # CHECK:            operands:
-# CHECK:              type_fn:
+# CHECK:              scalar_fn:
+# CHECK:                kind: type
 # CHECK:                type_var: T
 # CHECK:                operands:
 # CHECK:                  scalar_const: '3.1415926535897931 : f64'
-# CHECK:              type_fn:
+# CHECK:              scalar_fn:
+# CHECK:                kind: type
+# CHECK:                fn_name: cast
 # CHECK:                type_var: T
 # CHECK:                operands:
 # CHECK:                  scalar_const: '42 : i64'
-# CHECK:          type_fn:
+# CHECK:          scalar_fn:
+# CHECK:            kind: type
+# CHECK:            fn_name: cast
 # CHECK:            type_var: T
 # CHECK:            operands:
 # CHECK:              scalar_const: '1.{{[0]*}}e+03 : f64'
@@ -70,7 +79,8 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
 # CHECK: assignments:
 # CHECK:  -
 # CHECK:    arg: O
-# CHECK:      arith_fn:
+# CHECK:      scalar_fn:
+# CHECK:        kind: arith
 # CHECK:        fn_name: add
 # CHECK:        operands:
 # CHECK:          scalar_index: 1
index 7c850e6..d1fc9ac 100644 (file)
@@ -90,28 +90,23 @@ struct LinalgIndexingMapsConfig {
 
 struct ScalarExpression;
 
-struct ScalarArithFn {
-  std::string fnName;
-  // NOTE: Must be pure heap allocated container (not SmallVector)
-  // due to recursive data type.
-  std::vector<ScalarExpression> operands;
-};
+enum class ScalarFnKind { Arith, Type };
 
-struct ScalarTypeFn {
-  std::string typeVar;
+struct ScalarFn {
+  ScalarFnKind kind;
+  Optional<std::string> fnName;
+  Optional<std::string> attrName;
+  Optional<std::string> typeVar;
   // NOTE: This must be of arity 1, but to break the self-referential cycle,
   // we use a heap allocated vector.
   std::vector<ScalarExpression> operands;
-  Optional<std::string> fnName;
-  Optional<std::string> attrName;
 };
 
 struct ScalarExpression {
   Optional<std::string> arg;
   Optional<std::string> constant;
   Optional<int64_t> index;
-  Optional<ScalarArithFn> arithFn;
-  Optional<ScalarTypeFn> typeFn;
+  Optional<ScalarFn> scalarFn;
 };
 
 struct ScalarAssign {
@@ -265,16 +260,23 @@ struct MappingTraits<ScalarAssign> {
 ///   - `scalar_arg`: An operation argument.
 ///   - `scalar_const`: A constant definition.
 ///   - `scalar_index`: An iteration index.
-///   - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
-///   - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
+///   - `scalar_fn`: A named function (see `ScalarFn`).
 template <>
 struct MappingTraits<ScalarExpression> {
   static void mapping(IO &io, ScalarExpression &info) {
     io.mapOptional("scalar_arg", info.arg);
     io.mapOptional("scalar_const", info.constant);
     io.mapOptional("scalar_index", info.index);
-    io.mapOptional("arith_fn", info.arithFn);
-    io.mapOptional("type_fn", info.typeFn);
+    io.mapOptional("scalar_fn", info.scalarFn);
+  }
+};
+
+/// Scalar function kind enum.
+template <>
+struct ScalarEnumerationTraits<ScalarFnKind> {
+  static void enumeration(IO &io, ScalarFnKind &value) {
+    io.enumCase(value, "arith", ScalarFnKind::Arith);
+    io.enumCase(value, "type", ScalarFnKind::Type);
   }
 };
 
@@ -284,20 +286,13 @@ struct MappingTraits<ScalarExpression> {
 ///   - `add(lhs, rhs)`
 ///   - `mul(lhs, rhs)`
 template <>
-struct MappingTraits<ScalarArithFn> {
-  static void mapping(IO &io, ScalarArithFn &info) {
-    io.mapRequired("fn_name", info.fnName);
-    io.mapRequired("operands", info.operands);
-  }
-};
-
-template <>
-struct MappingTraits<ScalarTypeFn> {
-  static void mapping(IO &io, ScalarTypeFn &info) {
-    io.mapRequired("type_var", info.typeVar);
-    io.mapRequired("operands", info.operands);
+struct MappingTraits<ScalarFn> {
+  static void mapping(IO &io, ScalarFn &info) {
+    io.mapRequired("kind", info.kind);
     io.mapOptional("fn_name", info.fnName);
     io.mapOptional("attr_name", info.attrName);
+    io.mapOptional("type_var", info.typeVar);
+    io.mapRequired("operands", info.operands);
   }
 };
 
@@ -1060,11 +1055,12 @@ if ({0}Iter != attrs.end()) {{
                                         cppIdent, *expression.index));
           return cppIdent;
         }
-        if (expression.arithFn) {
+        if (expression.scalarFn &&
+            expression.scalarFn->kind == ScalarFnKind::Arith) {
           // Apply function.
           // Recursively generate operands.
           SmallVector<std::string> operandCppValues;
-          for (ScalarExpression &operand : expression.arithFn->operands) {
+          for (ScalarExpression &operand : expression.scalarFn->operands) {
             auto operandCppValue = generateExpression(operand);
             if (!operandCppValue)
               return None;
@@ -1073,28 +1069,30 @@ if ({0}Iter != attrs.end()) {{
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
           stmts.push_back(
               llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
-                            expression.arithFn->fnName,
+                            expression.scalarFn->fnName,
                             interleaveToString(operandCppValues, ", ")));
           return cppIdent;
         }
-        if (expression.typeFn) {
+        if (expression.scalarFn &&
+            expression.scalarFn->kind == ScalarFnKind::Type) {
           // Symbolic cast.
           // Operands must be arity 1.
-          if (expression.typeFn->operands.size() != 1) {
+          if (expression.scalarFn->operands.size() != 1) {
             emitError(genContext.getLoc())
                 << "type conversion operand arity must be 1";
             return None;
           }
           Optional<std::string> operandCppValue =
-              generateExpression(expression.typeFn->operands[0]);
+              generateExpression(expression.scalarFn->operands[0]);
           if (!operandCppValue)
             return None;
 
+          assert(expression.scalarFn->typeVar.hasValue());
           Optional<std::string> typeCppValue =
-              findTypeValue(expression.typeFn->typeVar, args);
+              findTypeValue(expression.scalarFn->typeVar.getValue(), args);
           if (!typeCppValue) {
             emitError(genContext.getLoc())
-                << "type variable " << expression.typeFn->typeVar
+                << "type variable " << expression.scalarFn->typeVar.getValue()
                 << ", used in a type conversion, must map to a predefined or "
                 << "an argument type but it does not";
             return None;
@@ -1102,17 +1100,17 @@ if ({0}Iter != attrs.end()) {{
 
           // Use the function name or the attribute to build the type function.
           std::string typeFunc = llvm::formatv(
-              "TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
-          if (expression.typeFn->attrName) {
+              "TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
+          if (expression.scalarFn->attrName) {
             if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
                   return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
-                         arg.name == expression.typeFn->attrName.getValue();
+                         arg.name == expression.scalarFn->attrName.getValue();
                 })) {
               emitError(genContext.getLoc())
                   << "missing type function attribute "
-                  << expression.typeFn->attrName.getValue();
+                  << expression.scalarFn->attrName.getValue();
             }
-            typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
+            typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
           }
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
           stmts.push_back(llvm::formatv(