Extend the OpDSL syntax with an optional `domain` function to specify an explicit dimension order. The extension is needed to provide more control over the dimension order instead of deducing it implicitly depending on the formulation of the tensor comprehension. Additionally, the patch also ensures the symbols are ordered according to the operand definitions of the operation.
Differential Revision: https://reviews.llvm.org/D105117
+
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
name: A
usage: InputOperand
type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- !LinalgOperandDefConfig
name: B
usage: InputOperand
type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+ shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- !LinalgOperandDefConfig
name: C
usage: OutputOperand
type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
name: A
usage: InputOperand
type_var: T1
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- !LinalgOperandDefConfig
name: B
usage: InputOperand
type_var: T2
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- !LinalgOperandDefConfig
name: C
usage: OutputOperand
type_var: U
- shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+ shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
name: y
usage: InputOperand
type_var: T1
- shape_map: affine_map<()[s0, s1] -> (s1)>
+ shape_map: affine_map<()[s0, s1] -> (s0)>
- !LinalgOperandDefConfig
name: A
usage: InputOperand
type_var: T2
- shape_map: affine_map<()[s0, s1] -> (s1, s0)>
+ shape_map: affine_map<()[s0, s1] -> (s0, s1)>
- !LinalgOperandDefConfig
name: x
usage: OutputOperand
type_var: U
- shape_map: affine_map<()[s0, s1] -> (s0)>
+ shape_map: affine_map<()[s0, s1] -> (s1)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d1)>
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s0, s4, s5, s3)>
+ (s0, s1, s2, s3)>
- !LinalgOperandDefConfig
name: K
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s6, s7, s3)>
+ (s4, s5, s3)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s0, s1, s2, s3)>
+ (s0, s6, s7, s3)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
- s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)>
+ s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
- s10, s11] -> (d4, d5, d3)>
+ s10, s11] -> (d3, d4, d5)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
- s10, s11] -> (d0, d1, d2, d3)>
+ s10, s11] -> (d0, d1, d2, d5)>
iterator_types:
- parallel
- parallel
- parallel
- - parallel
- reduction
- reduction
+ - parallel
assignments:
- !ScalarAssign
arg: O
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s0, s4, s5, s3)>
+ (s0, s1, s2, s3)>
- !LinalgOperandDefConfig
name: K
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s10, s11)>
+ (s4, s5)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s0, s1, s2, s3)>
+ (s0, s6, s7, s3)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
- -> (s6, s7)>
+ -> (s8, s9)>
- !LinalgOperandDefConfig
name: dilations
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
- -> (s8, s9)>
+ -> (s10, s11)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
- s10, s11] -> (d2, d3 * s6 + d0 * s8, d4 * s7 + d1 * s9, d5)>
+ s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
- s10, s11] -> (d0, d1)>
+ s10, s11] -> (d3, d4)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
- s10, s11] -> (d2, d3, d4, d5)>
+ s10, s11] -> (d0, d1, d2, d5)>
iterator_types:
- - reduction
- - reduction
- parallel
- parallel
- parallel
+ - reduction
+ - reduction
- parallel
assignments:
- !ScalarAssign
"""Visits all tensor expression reachable by the expression."""
callback(self)
- def _get_all_dim_defs(self) -> Set[DimDef]:
- """Recursively gets all DimDef affine expressions that are referenced."""
+ def collect_dim_uses(self, uses: Set["DimDef"]):
+ """Collects all DimDefs reachable through this expression."""
results = set()
def visit_dim_def(dim_def):
if isinstance(dim_def, DimDef):
- results.add(dim_def)
+ uses.add(dim_def)
def visit_affine_exprs(expr):
if isinstance(expr, TensorUse):
ind.visit_affine_exprs(visit_dim_def)
self.visit_tensor_exprs(visit_affine_exprs)
- return results
def collect_tensor_uses(self, uses: Set["TensorUse"]):
"""Collects all TensorUses reachable through this expression."""
reduced into. Any indices referenced on the rhs and not in self are
considered reduction dims and will be ordered as encountered on the rhs.
"""
- rhs_dims = rhs._get_all_dim_defs()
- lhs_dims = self._get_all_dim_defs()
+ rhs_dims = set()
+ lhs_dims = set()
+ rhs.collect_dim_uses(rhs_dims)
+ self.collect_dim_uses(lhs_dims)
return rhs_dims - lhs_dims
def __repr__(self):
f"number of index_dims {len(index_dims)}")
if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
raise ValueError(f"TensorDef requires index dims of type DimDef but "
- f"got {type(index_dims)}")
+ f"got {index_dims}")
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
self.operand_def = OperandDef(
kind, type_var, size_exprs=shape, index_dims=index_dims)
def __init__(self, *sizes: SymbolDef):
if any(not isinstance(size, SymbolDef) for size in sizes):
raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got "
- f"{type(sizes)}")
+ f"{sizes}")
self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
self.metadata = OpMetadataDef(
name=name, cpp_class_name=cpp_class_name, doc=doc)
self.registered_operands = dict() # type: Dict[str, OperandDef]
+ self.domain = list() # type: List[DimDef]
self.comprehensions = list() # type: List[Comprehension]
self._affine_state = AffineBuildState()
def __init__(self,
comprehension: Comprehension,
+ domain: Sequence[DimDef],
registered_operands: Sequence[OperandDef],
context: Optional[_ir.Context] = None):
self.context = context if context is not None else _ir.Context()
self.operands = dict() # type: Dict[OperandDef, OperandDefConfig]
self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
- # Compute the ordered set of writes and collect the tensor, capture, and
- # index uses.
+ # Compute the ordered set of writes and collect the tensor, capture, dims,
+ # and index uses.
collected_tensor_uses = set()
collected_scalar_uses = set()
+ collected_dim_uses = set()
collected_indices = set()
for write_use, read_use in zip(comprehension.definitions,
comprehension.values):
collected_tensor_uses.add(write_use)
read_use.collect_tensor_uses(collected_tensor_uses)
read_use.collect_scalar_uses(collected_scalar_uses)
+ read_use.collect_dim_uses(collected_dim_uses)
+ write_use.collect_dim_uses(collected_dim_uses)
read_use.collect_indices(collected_indices)
+ # Set domain to the sorted list of uses if no domain annotation is given.
+ if not domain:
+ domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
+
+ # Verify the domain dimensions match the used dimensions.
+ if (len(domain) != len(collected_dim_uses) or
+ any(dim not in collected_dim_uses for dim in domain)):
+ raise ValueError(f"Expected the annotated domain dimensions {domain} to "
+ f"match the set of dimension used by the tensor "
+ f"comprehension {collected_dim_uses}")
+
+ # Instantiate the dimensions in the given order.
+ with self.context:
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_symbols=False)
+ for dim in domain:
+ dim.build(state=local_state)
+
# Collect all attribute definitions.
collected_attr_defs = list()
for operand in registered_operands:
collected_index_defs = list()
for operand in registered_operands:
if operand.index_dims:
+ if any(dim not in collected_dim_uses for dim in operand.index_dims):
+ raise ValueError(f"Expected all index dims {operand.index_dims} of "
+ f"operand {operand.name} to have uses.")
collected_index_defs.append(operand)
- # Add all definitions before uses, so process twice.
+ # Collect the operand definitions of all tensor/scalar uses, attributes, and
+ # shape-only tensors.
+ all_operand_defs = list()
for use in collected_tensor_uses:
- self.add_operand(use.operand_def)
+ all_operand_defs.append(use.operand_def)
for use in collected_scalar_uses:
- self.add_operand(use.operand_def)
+ all_operand_defs.append(use.operand_def)
for definition in collected_attr_defs:
- self.add_operand(definition)
+ all_operand_defs.append(definition)
+ for definition in collected_index_defs:
+ all_operand_defs.append(definition)
+
+ # Add all operands in registration order to ensure the symbols are
+ # registered in the order they appear.
+ all_operand_defs = sorted(
+ all_operand_defs, key=lambda operand_def: operand_def.registered_index)
+ for operand_def in all_operand_defs:
+ self.add_operand(operand_def)
+
+ # Add all shape-only tensor index_dim annotations and all tensor uses.
for definition in collected_index_defs:
- if definition not in self.operands:
- self.add_operand(definition)
self.add_indexed_operand(definition)
for use in collected_tensor_uses:
self.add_tensor_use(use)
LinalgOpConfig(
tc_op_def.metadata,
structured_op=LinalgStructuredOpConfig(
- tc_op_def.comprehensions[0],
+ tc_op_def.comprehensions[0], tc_op_def.domain,
tc_op_def.registered_operands.values(), context)),
]
def implements(*interfaces: OpInterfaceDef):
current_op_def().metadata.implements.extend(interfaces)
+
+
+def domain(*dimensions: DimDef):
+ if current_op_def().domain:
+ raise ValueError(f"Expected only one set of domain dimensions per operator")
+ if any(not isinstance(dim, DimDef) for dim in dimensions):
+ raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
+ current_op_def().domain.extend(dimensions)
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ domain(D.m, D.n)
implements(ContractionOpInterface)
x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ domain(D.n, D.m)
implements(ContractionOpInterface)
x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]) * cast(U, K[D.kh, D.kw, D.c])
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
element seed the random number generation. The min and max operands limit
the range of the generated random numbers.
"""
+ domain(D.m, D.n)
multiplier = cast(I32, const(1103515245))
increment = cast(I32, const(12345))
rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
# CHECK: name: A
# CHECK: usage: InputOperand
# CHECK: type_var: T
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
# CHECK: name: B
# CHECK: usage: InputOperand
# CHECK: type_var: T
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
# CHECK: name: C
# CHECK: usage: OutputOperand
# CHECK: type_var: U
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
# CHECK: name: I
# CHECK: usage: InputOperand
# CHECK: type_var: T
-# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
# CHECK: name: O
# CHECK: usage: OutputOperand
# CHECK: type_var: T
-# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
# CHECK: name: strides
# CHECK: usage: IndexAttribute
# CHECK: type_var: I64
I=TensorDef(T, S.IH, S.IW),
O=TensorDef(T, S.OH, S.OW, output=True),
strides=AttributeDef(S.SH, S.SW)):
- O[D.oh, D.ow] = I[D.h * S.SH, D.w * S.SW]
+ O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(T, S.M, S.N, output=True)):
+ domain(D.m, D.n, D.k)
C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
+ domain(D.m, D.n, D.k)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
+ domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]) * cast(U, K[D.kh, D.kw, D.c])
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
+ domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
# CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
# Convolution indexing maps.
- # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
- # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
- # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+ # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
# Pooling indexing maps.
- # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3 * 2 + d0, d4 * 4 + d1 * 2, d5)>
- # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
- # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+ # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
# CHECK-LABEL: func @test_matmul_mono
# CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
# CHECK-LABEL: @test_f32i32_conv
# CHECK: linalg.generic
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
- # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
# CHECK-NEXT: %[[FILTER_CAST:.+]] = fptosi %[[FILTER:.+]] : f32 to i32
# CHECK-LABEL: @test_f32i32_pooling
# CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
- # CHECK-SAME: iterator_types = ["reduction", "reduction", "parallel", "parallel", "parallel", "parallel"]
+ # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
# CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32
from mlir.dialects.linalg.opdsl.lang import *
+
# CHECK: ---
# CHECK-LABEL: matmul
# CHECK: implements:
# CHECK-NEXT: - LinalgContractionOpInterface
@linalg_structured_op
-def matmul(A=TensorDef(T, S.M, S.K),
- B=TensorDef(T, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+ A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
# dims auto discovered emits the right shape, indexing maps and iterator types.
# CHECK: ---
# CHECK-LABEL: matmul
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
# CHECK: static_indexing_maps:
# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
+ domain(D.m, D.n, D.k)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
+
# Verifies that the index_dims of shape-only operands translate to correct
# indexing maps.
# CHECK: ---
# CHECK-LABEL: pool
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)>
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1)>
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2)>
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)>
# CHECK: static_indexing_maps:
-# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1 * 2 + d0)>
-# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0 * 2 + d1)>
# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)>
# CHECK: iterator_types:
-# CHECK-NEXT: - reduction
# CHECK-NEXT: - parallel
+# CHECK-NEXT: - reduction
@linalg_structured_op
-def pool(I=TensorDef(T, S.I),
- K=TensorDef(T, S.K, index_dims=[D.k]),
- O=TensorDef(U, S.O, output=True)):
+def pool(
+ I=TensorDef(T, S.I),
+ K=TensorDef(T, S.K, index_dims=[D.k]),
+ O=TensorDef(U, S.O, output=True)):
+ domain(D.o, D.k)
O[D.o] += cast(U, I[D.o * 2 + D.k])