Prepare the OpDSL function handling to introduce more function classes. A follow up commit will split ArithFn into UnaryFn and BinaryFn. This revision prepares the split by adding a function kind enum to handle different function types using a single class on the various levels of the stack (for example, there is now one TensorFn and one ScalarFn).
Depends On D119718
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D120108
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- attr_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- attr_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul_unsigned
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast_unsigned
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast_unsigned
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: AZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: BZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
- !ScalarAssign
arg: accum
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: accum
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: AccumType
operands:
- !ScalarExpression
scalar_arg: lhs
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: AccumType
operands:
- !ScalarExpression
scalar_arg: rhs
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_batch_matmul
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: AZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: BZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
- !ScalarAssign
arg: x
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: y
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: vecmat
- !ScalarAssign
arg: x
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: y
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matvec
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
- !ScalarAssign
arg: C
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_1d
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_3d
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_1d_nwc_wcf
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_hwcf
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_hwcf_q
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nchw_fchw
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_3d_ndhwc_dhwcf
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_1d_nwc_wc
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwc
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwc_q
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwcm
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_nhwc_hwcm_q
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_sum
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_max
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_max_unsigned
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max_unsigned
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast_unsigned
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nchw_max
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_min
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: min
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_min_unsigned
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: min_unsigned
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast_unsigned
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast_unsigned
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_sum
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_max
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: max
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_ndhwc_min
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: min
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_tensor
- !ScalarAssign
arg: O
value: !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: value
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
- !ScalarAssign
arg: O
value: !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: T
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: F64
operands:
- !ScalarExpression
scalar_const: '2147483647 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: F64
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_index: 1
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_index: 0
- fn_name: cast
- !ScalarExpression
scalar_arg: seed
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '1103515245 : i64'
- fn_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: I32
operands:
- !ScalarExpression
scalar_const: '12345 : i64'
- fn_name: cast
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: mul
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: sub
operands:
- !ScalarExpression
- !ScalarExpression
scalar_arg: min
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: F64
operands:
- !ScalarExpression
scalar_const: '2.3283063999999999E-10 : f64'
- fn_name: cast
- !ScalarExpression
scalar_arg: min
- fn_name: cast
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: soft_plus_2d
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: log
operands:
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_const: '1.000000e+00 : f64'
- fn_name: cast
- !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: exp
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- fn_name: cast
f"[{', '.join([repr(i) for i in self.indices])}]")
-class TensorArithFn(TensorExpression):
- """Application of an arithmetic function."""
+class TensorFn(TensorExpression):
+ """Application of a tensor function."""
- def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]):
- self.arith_fn = arith_fn
- self.args = tuple(args)
-
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarArithFn(self.arith_fn.fn_name,
- *[arg.to_scalar_expression() for arg in self.args
- ]).expr()
-
- def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
- super().visit_tensor_exprs(callback)
- for arg in self.args:
- arg.visit_tensor_exprs(callback)
-
- def __repr__(self):
- return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
-
-
-class TensorTypeFn(TensorExpression):
- """Application of a type conversion function."""
-
- def __init__(self, type_fn: Optional["TypeFn"],
- operand_def: Optional["OperandDef"], type_var: TypeVar,
- arg: TensorExpression):
- if bool(type_fn) + bool(operand_def) != 1:
- raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
- self.type_fn = type_fn
+ def __init__(self, kind: "FunctionKind", name: Optional[str],
+ operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
+ args: Sequence[TensorExpression]):
+ if bool(name) + bool(operand_def) != 1:
+ raise ValueError("One of 'name', 'operand_def' must be specified")
+ self.name = name
+ self.kind = kind
self.operand_def = operand_def
self.type_var = type_var
- self.arg = arg
+ self.args = args
def to_scalar_expression(self) -> ScalarExpression:
if self.operand_def:
- assert self.operand_def.name, "TypeFnAttr not registered with an op"
- fn_name = self.type_fn.fn_name if self.type_fn else None
+ assert self.operand_def.name, "TensorFn not registered with an op"
attr_name = self.operand_def.name if self.operand_def else None
- return ScalarTypeFn(fn_name, attr_name, self.type_var,
- self.arg.to_scalar_expression()).expr()
+ args = [arg.to_scalar_expression() for arg in self.args]
+ return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
super().visit_tensor_exprs(callback)
- self.arg.visit_tensor_exprs(callback)
+ for arg in self.args:
+ arg.visit_tensor_exprs(callback)
def __repr__(self):
- return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
- f"({self.type_var}, {self.arg})")
+ name = self.operand_def.name if self.operand_def else self.name
+ return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
+ f"args={', '.join(repr(a) for a in self.args)})")
class TensorReduceFn(TensorExpression):
args: Sequence[TensorExpression]):
self.reduce_use = reduce_use
self.lhs = None # type: Optional[TensorUse]
- self.args = tuple(args)
+ self.args = args
def to_scalar_expression(self) -> ScalarExpression:
if self.lhs is None:
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
- return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
+ return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None,
+ None, full_args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
for arg in self.args:
###############################################################################
+class FunctionKind(Enum):
+ ARITH = 0
+ TYPE = 1
+
+
class TypeFnType:
"""Type conversion function.
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
- return TensorTypeFn(self, None, type_var, arg)
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+ return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
def __repr__(self):
return f"{self.fn_name}"
def __init__(self, fn_name: str):
self.fn_name = fn_name
- def __call__(self, *args) -> "TensorArithFn":
- return TensorArithFn(self, args)
+ def __call__(self, *args) -> "TensorFn":
+ return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args)
def __repr__(self):
return f"{self.fn_name}"
self.operand_def = OperandDef(
OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
- def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
- return TensorTypeFn(None, self.operand_def, type_var, arg)
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
+ return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
###############################################################################
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
- elif expr.arith_fn:
- fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
+ elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH:
+ fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}")
operand_values = [
- self.expression(operand) for operand in expr.arith_fn.operands
+ self.expression(operand) for operand in expr.scalar_fn.operands
]
return fn(*operand_values)
- elif expr.type_fn:
- fn_name = expr.type_fn.fn_name
- if expr.type_fn.attr_name:
- fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
+ elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE:
+ fn_name = expr.scalar_fn.fn_name
+ if expr.scalar_fn.attr_name:
+ fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
fn = self._get_function(f"_typefn_{fn_name}")
- operand = self.expression(expr.type_fn.operand)
- return fn(expr.type_fn.type_var.name, operand)
+ operand_value = self.expression(expr.scalar_fn.operands[0])
+ return fn(expr.scalar_fn.type_var.name, operand_value)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
def yield_outputs(self, *output_names: str):
from typing import Optional, Sequence
-from .yaml_helper import *
+from .comprehension import *
from .types import *
+from .yaml_helper import *
__all__ = [
"ScalarAssign",
- "ScalarArithFn",
- "ScalarTypeFn",
+ "ScalarFn",
"ScalarArg",
"ScalarConst",
"ScalarIndex",
]
-class ScalarArithFn:
- """A type of ScalarExpression that applies an arithmetic function."""
-
- def __init__(self, fn_name: str, *operands: "ScalarExpression"):
- self.fn_name = fn_name
- self.operands = operands
-
- def expr(self) -> "ScalarExpression":
- return ScalarExpression(arith_fn=self)
-
- def __repr__(self):
- return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
-
-
-class ScalarTypeFn:
- """A type of ScalarExpression that applies a type conversion function."""
+class ScalarFn:
+ """A type of ScalarExpression that applies a function."""
- def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
- type_var: TypeVar, operand: "ScalarExpression"):
+ def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
+ attr_name: Optional[str], type_var: Optional["TypeVar"],
+ operands: Sequence["ScalarExpression"]):
+ if bool(fn_name) + bool(attr_name) != 1:
+ raise ValueError("One of 'fn_name', 'attr_name' must be specified")
+ self.kind = kind
self.fn_name = fn_name
self.attr_name = attr_name
self.type_var = type_var
- self.operand = operand
+ self.operands = operands
def expr(self) -> "ScalarExpression":
- return ScalarExpression(type_fn=self)
+ return ScalarExpression(scalar_fn=self)
def __repr__(self):
- return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
- f"({self.type_var}, {self.operand})")
+ name = self.fn_name if self.fn_name else self.attr_name
+ return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
+ f"operands=[{', '.join(self.operands)}])")
class ScalarArg:
"""An expression on scalar values.
Can be one of:
- - ScalarArithFn
- - ScalarTypeFn
+ - ScalarFn
- ScalarArg
- ScalarConst
- ScalarIndex
- - ScalarSymbolicCast
"""
yaml_tag = "!ScalarExpression"
def __init__(self,
- arith_fn: Optional[ScalarArithFn] = None,
- type_fn: Optional[ScalarTypeFn] = None,
+ scalar_fn: Optional[ScalarFn] = None,
scalar_arg: Optional[ScalarArg] = None,
scalar_const: Optional[ScalarConst] = None,
scalar_index: Optional[ScalarIndex] = None):
- if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
+ if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
bool(scalar_index)) != 1:
- raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
- "'scalar_const', 'scalar_index', must be specified")
- self.arith_fn = arith_fn
- self.type_fn = type_fn
+ raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
+ "'scalar_index' must be specified")
+ self.scalar_fn = scalar_fn
self.scalar_arg = scalar_arg
self.scalar_const = scalar_const
self.scalar_index = scalar_index
def to_yaml_custom_dict(self):
- if self.arith_fn:
- return dict(
- arith_fn=dict(
- fn_name=self.arith_fn.fn_name,
- operands=list(self.arith_fn.operands),
- ))
- if self.type_fn:
- # Note that even though operands must be arity 1, we write it the
- # same way as for apply because it allows handling code to be more
- # generic vs having a special form.
- type_fn_dict = dict(
- type_var=self.type_fn.type_var.name,
- operands=[self.type_fn.operand],
- )
- if self.type_fn.fn_name:
- type_fn_dict["fn_name"] = self.type_fn.fn_name
- if self.type_fn.attr_name:
- type_fn_dict["attr_name"] = self.type_fn.attr_name
- return dict(type_fn=type_fn_dict)
+ if self.scalar_fn:
+ scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
+ if self.scalar_fn.fn_name:
+ scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
+ if self.scalar_fn.attr_name:
+ scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
+ if self.scalar_fn.type_var:
+ scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
+ scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
+ return dict(scalar_fn=scalar_fn_dict)
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
elif self.scalar_const:
- !ScalarAssign
arg: O
value: !ScalarExpression
- arith_fn:
+ scalar_fn:
+ kind: arith
fn_name: add
operands:
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: T
operands:
- !ScalarExpression
scalar_const: '42 : i64'
- attr_name: cast
- !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
+ attr_name: cast
type_var: T
operands:
- !ScalarExpression
scalar_index: 1
- attr_name: cast
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
- !ScalarAssign
arg: O
value: !ScalarExpression
- type_fn:
+ scalar_fn:
+ kind: type
fn_name: cast
type_var: U
operands:
# CHECK: -
# CHECK: arg: C
# CHECK: value:
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
# CHECK: fn_name: add
# CHECK: operands:
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
# CHECK: fn_name: mul
# CHECK: operands:
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
+# CHECK: attr_name: cast
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: A
+# CHECK: scalar_fn:
+# CHECK: kind: type
# CHECK: attr_name: cast
-# CHECK: type_fn:
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: B
-# CHECK: attr_name: cast
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: arith
# CHECK: fn_name: sub
# CHECK: operands:
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: arith
# CHECK: fn_name: add
# CHECK: operands:
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '3.1415926535897931 : f64'
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
+# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '42 : i64'
-# CHECK: type_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: type
+# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
-# CHECK: arith_fn:
+# CHECK: scalar_fn:
+# CHECK: kind: arith
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_index: 1
struct ScalarExpression;
-struct ScalarArithFn {
- std::string fnName;
- // NOTE: Must be pure heap allocated container (not SmallVector)
- // due to recursive data type.
- std::vector<ScalarExpression> operands;
-};
+enum class ScalarFnKind { Arith, Type };
-struct ScalarTypeFn {
- std::string typeVar;
+struct ScalarFn {
+ ScalarFnKind kind;
+ Optional<std::string> fnName;
+ Optional<std::string> attrName;
+ Optional<std::string> typeVar;
// NOTE: This must be of arity 1, but to break the self-referential cycle,
// we use a heap allocated vector.
std::vector<ScalarExpression> operands;
- Optional<std::string> fnName;
- Optional<std::string> attrName;
};
struct ScalarExpression {
Optional<std::string> arg;
Optional<std::string> constant;
Optional<int64_t> index;
- Optional<ScalarArithFn> arithFn;
- Optional<ScalarTypeFn> typeFn;
+ Optional<ScalarFn> scalarFn;
};
struct ScalarAssign {
/// - `scalar_arg`: An operation argument.
/// - `scalar_const`: A constant definition.
/// - `scalar_index`: An iteration index.
-/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
-/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
+/// - `scalar_fn`: A named function (see `ScalarFn`).
template <>
struct MappingTraits<ScalarExpression> {
static void mapping(IO &io, ScalarExpression &info) {
io.mapOptional("scalar_arg", info.arg);
io.mapOptional("scalar_const", info.constant);
io.mapOptional("scalar_index", info.index);
- io.mapOptional("arith_fn", info.arithFn);
- io.mapOptional("type_fn", info.typeFn);
+ io.mapOptional("scalar_fn", info.scalarFn);
+ }
+};
+
+/// Scalar function kind enum.
+template <>
+struct ScalarEnumerationTraits<ScalarFnKind> {
+ static void enumeration(IO &io, ScalarFnKind &value) {
+ io.enumCase(value, "arith", ScalarFnKind::Arith);
+ io.enumCase(value, "type", ScalarFnKind::Type);
}
};
/// - `add(lhs, rhs)`
/// - `mul(lhs, rhs)`
template <>
-struct MappingTraits<ScalarArithFn> {
- static void mapping(IO &io, ScalarArithFn &info) {
- io.mapRequired("fn_name", info.fnName);
- io.mapRequired("operands", info.operands);
- }
-};
-
-template <>
-struct MappingTraits<ScalarTypeFn> {
- static void mapping(IO &io, ScalarTypeFn &info) {
- io.mapRequired("type_var", info.typeVar);
- io.mapRequired("operands", info.operands);
+struct MappingTraits<ScalarFn> {
+ static void mapping(IO &io, ScalarFn &info) {
+ io.mapRequired("kind", info.kind);
io.mapOptional("fn_name", info.fnName);
io.mapOptional("attr_name", info.attrName);
+ io.mapOptional("type_var", info.typeVar);
+ io.mapRequired("operands", info.operands);
}
};
cppIdent, *expression.index));
return cppIdent;
}
- if (expression.arithFn) {
+ if (expression.scalarFn &&
+ expression.scalarFn->kind == ScalarFnKind::Arith) {
// Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
- for (ScalarExpression &operand : expression.arithFn->operands) {
+ for (ScalarExpression &operand : expression.scalarFn->operands) {
auto operandCppValue = generateExpression(operand);
if (!operandCppValue)
return None;
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
- expression.arithFn->fnName,
+ expression.scalarFn->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
}
- if (expression.typeFn) {
+ if (expression.scalarFn &&
+ expression.scalarFn->kind == ScalarFnKind::Type) {
// Symbolic cast.
// Operands must be arity 1.
- if (expression.typeFn->operands.size() != 1) {
+ if (expression.scalarFn->operands.size() != 1) {
emitError(genContext.getLoc())
<< "type conversion operand arity must be 1";
return None;
}
Optional<std::string> operandCppValue =
- generateExpression(expression.typeFn->operands[0]);
+ generateExpression(expression.scalarFn->operands[0]);
if (!operandCppValue)
return None;
+ assert(expression.scalarFn->typeVar.hasValue());
Optional<std::string> typeCppValue =
- findTypeValue(expression.typeFn->typeVar, args);
+ findTypeValue(expression.scalarFn->typeVar.getValue(), args);
if (!typeCppValue) {
emitError(genContext.getLoc())
- << "type variable " << expression.typeFn->typeVar
+ << "type variable " << expression.scalarFn->typeVar.getValue()
<< ", used in a type conversion, must map to a predefined or "
<< "an argument type but it does not";
return None;
// Use the function name or the attribute to build the type function.
std::string typeFunc = llvm::formatv(
- "TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
- if (expression.typeFn->attrName) {
+ "TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
+ if (expression.scalarFn->attrName) {
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
- arg.name == expression.typeFn->attrName.getValue();
+ arg.name == expression.scalarFn->attrName.getValue();
})) {
emitError(genContext.getLoc())
<< "missing type function attribute "
- << expression.typeFn->attrName.getValue();
+ << expression.scalarFn->attrName.getValue();
}
- typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
+ typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
}
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(llvm::formatv(