ods_def<MatmulOp>
implements_interface<LinalgContractionOpInterface> :
def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
- C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
+ C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
}
ods_def<MatmulColumnMajorOp>
implements_interface<LinalgContractionOpInterface> :
def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
- C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
+ C(n, m) = std_addf<k>(C(n, m), std_mulf(A(k, m), B(n, k)));
}
ods_def<MatmulI8I8I32Op>
def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
// TODO: ideally something closer to
// C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
- C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
}
ods_def<MatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) {
- C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
+ C(m, n) = std_addi<k>(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n))));
}
ods_def<MatmulI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
- C(m, n) = std_addi<k>(std_muli(A(m, k), B(k, n)));
+ C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
}
ods_def<MatvecOp>
implements_interface<LinalgContractionOpInterface> :
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
- x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
+ x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
}
ods_def<MatvecI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
}
ods_def<MatvecI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(std_sexti32(std_muli(A(m, n), y(n))));
+ x(m) = std_addi<n>(x(m), std_sexti32(std_muli(A(m, n), y(n))));
}
ods_def<MatvecI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
- x(m) = std_addi<n>(std_muli(A(m, n), y(n)));
+ x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
}
ods_def<VecmatOp>
implements_interface<LinalgContractionOpInterface> :
def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
- x(n) = std_addf<m>(std_mulf(y(m), A(m, n)));
+ x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
}
ods_def<VecmatI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
}
ods_def<VecmatI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(std_sexti32(std_muli(y(m), A(m, n))));
+ x(n) = std_addi<m>(x(n), std_sexti32(std_muli(y(m), A(m, n))));
}
-
ods_def<VecmatI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
- x(n) = std_addi<m>(std_muli(y(m), A(m, n)));
+ x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
}
ods_def<DotOp>
implements_interface<LinalgContractionOpInterface> :
def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
- C() = std_addf<m>(std_mulf(A(m), B(m)));
+ C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
}
ods_def<DotI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
- C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
}
ods_def<DotI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) {
- C() = std_addi<m>(std_sexti32(std_muli(A(m), B(m))));
+ C() = std_addi<m>(C(), std_sexti32(std_muli(A(m), B(m))));
}
-
ods_def<DotI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
- C() = std_addi<m>(std_muli(A(m), B(m)));
+ C() = std_addi<m>(C(), std_muli(A(m), B(m)));
}
-
ods_def<BatchMatmulOp>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
- C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
+ C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
}
ods_def<BatchMatmulI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
- C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ C(b, m, n) =
+ std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
}
ods_def<BatchMatmulI16I16I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) {
- C(b, m, n) = std_addi<k>(std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
+ C(b, m, n) =
+ std_addi<k>(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n))));
}
ods_def<BatchMatmulI32I32I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i32_i32_i32(A: i32(Batch, M, K), B: i32(Batch, K, N)) -> (C: i32(Batch, M, N)) {
- C(b, m, n) = std_addi<k>(std_muli(A(b, m, k), B(b, k, n)));
+ C(b, m, n) = std_addi<k>(C(b, m, n), std_muli(A(b, m, k), B(b, k, n)));
}
ods_def<ConvWOp>:
def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
- O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));
+ O(w) = std_addf<kw>(O(w), std_mulf(I(w + kw), K(kw)));
}
ods_def<ConvNWCOp>:
def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
- O(n, w, f) = std_addf<kw>(std_mulf(I(n, w + kw, c), K(f, kw, c)));
+ O(n, w, f) = std_addf<kw>(O(n, w, f), std_mulf(I(n, w + kw, c), K(f, kw, c)));
}
ods_def<ConvNCWOp>:
def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
- O(n, f, w) = std_addf<kw>(std_mulf(I(n, c, w + kw), K(f, c, kw)));
+ O(n, f, w) = std_addf<kw>(O(n, f, w), std_mulf(I(n, c, w + kw), K(f, c, kw)));
}
ods_def<ConvHWOp>:
def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
- O(h, w) = std_addf<kh, kw>(std_mulf(I(h + kh, w + kw), K(kh, kw)));
+ O(h, w) = std_addf<kh, kw>(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
}
ods_def<ConvNHWCOp>:
def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
- O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
- I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
+ O(n, h, w, f) = std_addf<kh, kw>(
+ O(n, h, w, f), std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
}
ods_def<ConvNCHWOp>:
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
- O(n, f, h, w) = std_addf<kh, kw>(std_mulf(
- I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
+ O(n, f, h, w) = std_addf<kh, kw>(
+ O(n, f, h, w), std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
}
ods_def<ConvDHWOp>:
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
- O(d, h, w) = std_addf<kd, kh, kw>(std_mulf(
- I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
+ O(d, h, w) = std_addf<kd, kh, kw>(
+ O(d, h, w), std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
}
ods_def<ConvNDHWCOp>:
def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
- O(n, d, h, w, f) = std_addf<kd, kh, kw>(std_mulf(
- I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
+ O(n, d, h, w, f) = std_addf<kd, kh, kw>(
+ O(n, d, h, w, f),
+ std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
}
ods_def<ConvNCDHWOp>:
def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
- O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
- I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
+ O(n, f, d, h, w) = std_addf<kd, kh, kw>(
+ O(n, f, d, h, w),
+ std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
}
ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
Note: this op only supports channel multiplier == 1.
"""
{
- O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
- I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)));
+ O(n, oh, ow, c) = std_addf<kh, kw>(
+ O(n, oh, ow, c),
+ std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
+ K(kh, kw, c)));
}
ods_def<ConvInputNWCFilterWCFOp>:
"""
{
O(n, w, f) = std_addf<kw>(
+ O(n, w, f),
std_mulf(I(n, w * strides[0] + kw * dilations[0], c), K(kw, c, f)));
}
"""
{
O(n, f, w) = std_addf<kw>(
+ O(n, f, w),
std_mulf(I(n, c, w * strides[0] + kw * dilations[0]), K(kw, c, f)));
}
order of (`N`, `H`, `W`, `F`, `KH`, `KW`, `C`).
"""
{
- O(n, h, w, f) =
- std_addf<kh, kw>(std_mulf(I(n, h * strides[0] + kh * dilations[0],
- w * strides[1] + kw * dilations[1], c),
- K(kh, kw, c, f)));
+ O(n, h, w, f) = std_addf<kh, kw>(
+ O(n, h, w, f), std_mulf(I(n, h * strides[0] + kh * dilations[0],
+ w * strides[1] + kw * dilations[1], c),
+ K(kh, kw, c, f)));
}
ods_def<ConvInputNCHWFilterHWCFOp>:
order of (`N`, `F`, `H`, `W`, `KH`, `KW`, `C`).
"""
{
- O(n, f, h, w) =
- std_addf<kh, kw>(std_mulf(I(n, c, h * strides[0] + kh * dilations[0],
- w * strides[1] + kw * dilations[1]),
- K(kh, kw, c, f)));
+ O(n, f, h, w) = std_addf<kh, kw>(
+ O(n, f, h, w), std_mulf(I(n, c, h * strides[0] + kh * dilations[0],
+ w * strides[1] + kw * dilations[1]),
+ K(kh, kw, c, f)));
}
ods_def<ConvInputNDHWCFilterDHWCFOp>:
order of (`N`, `D`, `H`, `W`, `F`, `KD`, `KH`, `KW`, `C`).
"""
{
- O(n, d, h, w, f) =
- std_addf<kd, kh, kw>(std_mulf(I(n, d * strides[0] + kd * dilations[0],
- h * strides[1] + kh * dilations[1],
- w * strides[2] + kw * dilations[2], c),
- K(kd, kh, kw, c, f)));
+ O(n, d, h, w, f) = std_addf<kd, kh, kw>(
+ O(n, d, h, w, f), std_mulf(I(n, d * strides[0] + kd * dilations[0],
+ h * strides[1] + kh * dilations[1],
+ w * strides[2] + kw * dilations[2], c),
+ K(kd, kh, kw, c, f)));
}
ods_def<ConvInputNCDHWFilterDHWCFOp>:
order of (`N`, `F`, `D`, `H`, `W`, `KD`, `KH`, `KW`, `C`).
"""
{
- O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
- I(n, c, d * strides[0] + kd * dilations[0],
- h * strides[1] + kh * dilations[1], w * strides[2] + kw * dilations[2]),
- K(kd, kh, kw, c, f)));
+ O(n, f, d, h, w) = std_addf<kd, kh, kw>(
+ O(n, f, d, h, w), std_mulf(I(n, c, d * strides[0] + kd * dilations[0],
+ h * strides[1] + kh * dilations[1],
+ w * strides[2] + kw * dilations[2]),
+ K(kd, kh, kw, c, f)));
}