From b5ecbb7fd60da7a7dbbc633b81ecc938ec38ea4f Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 17 May 2019 19:45:45 -0700 Subject: [PATCH] Clean up tablegen vector and tensor types There was a weird mix of names, styles, and inheritance here. I think this makes it cleaner and more consistent. We can also have a more principled and far-reaching refactor of some of this naming, but this seems like a good improvement regardless -- PiperOrigin-RevId: 248827005 --- .../mlir/Dialect/QuantOps/QuantPredicates.td | 4 +- mlir/include/mlir/IR/OpBase.td | 92 ++++++++-------------- mlir/include/mlir/StandardOps/Ops.td | 8 +- mlir/test/mlir-tblgen/op-operand.td | 6 +- mlir/test/mlir-tblgen/op-result.td | 14 ++-- mlir/test/mlir-tblgen/predicate.td | 20 ++--- mlir/test/mlir-tblgen/reference-impl.td | 4 +- 7 files changed, 63 insertions(+), 85 deletions(-) diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td index 028e0df..e26d12d 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td @@ -28,8 +28,8 @@ class quant_TypedPrimitiveOrContainer : Type.predicate, - TypedVector.predicate]>, + Tensor.predicate, + Vector.predicate]>, "primitive/tensor/vector of " # etype.description>; // An implementation of QuantizedType. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 119f1db..861d9d5 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -195,13 +195,12 @@ def IsMemRefTypePred : CPred<"$_self.isa()">; // Whether a type is a ShapedType. def IsShapedTypePred : CPred<"$_self.isa()">; +// For a ShapedType, verify that it has a static shape. +def HasStaticShapePred : CPred<"$_self.cast().hasStaticShape()">; + // Whether a type is a TupleType. def IsTupleTypePred : CPred<"$_self.isa()">; -// For a TensorType, verify that it is a statically shaped tensor. -def IsStaticShapeTensorTypePred : - CPred<"$_self.cast().hasStaticShape()">; - //===----------------------------------------------------------------------===// // Dialect definitions //===----------------------------------------------------------------------===// @@ -230,7 +229,7 @@ class Type : // more than one variadic operand/result, and that operand/result must be the // last one in the operand/result list. class Variadic - // TODO: support variadic type conditions + // TODO(b/132908002): support variadic type conditions : TypeConstraint, descr> { Type baseType = type; } @@ -254,7 +253,7 @@ def AnyType : Type, "any type">; def NoneType : Type()">, "none type">; // Any type from the given list -class AnyTypeOf allowedTypes, string description> : Type< +class AnyTypeOf allowedTypes, string description = ""> : Type< // Satisfy any of the allowed type's condition AnyOf, !if(!eq(description, ""), @@ -330,70 +329,49 @@ class ContainerType : + ContainerType().getElementType()", descr>; + // Vector types. -class TypedVector : ContainerType().getElementType()", "vector">; - -class Vector dims> : ContainerType().getShape() == ArrayRef{{" # - Stringify.result # "}">]>, - "$_self.cast().getElementType()", - "vector"> { - list dimensions = dims; -} -// Tensor type. +class Vector : ShapedContainerType; -// This represents a generic tensor without constraints on elemental type, -// rank, size. As there is no constraint on elemental type, derive from Type -// directly instead of ContainerType. -def Tensor : Type; +def AnyVector : Vector; -// A tensor with static shape but no other constraints. Note: as -// Tensor is a def this doesn't derive from it, but reuses the predicate -// that must hold for it to be a tensor. -def StaticShapeTensor - : Type, - "statically shaped tensor">; +// Tensor types. -// For typed tensors. -class TypedTensor - : ContainerType().getElementType()", - "tensor">; +class Tensor : ShapedContainerType; -class TypedStaticShapeTensor - : Type.predicate, IsStaticShapeTensorTypePred ]>, - "statically shaped tensor">; +def AnyTensor : Tensor; -def I1Tensor : TypedTensor; -def I8Tensor : TypedTensor; -def I16Tensor : TypedTensor; -def I32Tensor : TypedTensor; -def I64Tensor : TypedTensor; +// Any tensor type whose element type is from the given `allowedTypes` list +class AnyTensorOf allowedTypes, string elementDescription = ""> : + Tensor>; -def BF16Tensor : TypedTensor; -def F16Tensor : TypedTensor; -def F32Tensor : TypedTensor; -def F64Tensor : TypedTensor; +// TODO(b/130807343) Fix description to contain element information. +class StaticShapeTensor + : Type.predicate, HasStaticShapePred ]>, + "statically shaped tensor">; -// Any tensor type whose element type is from the given -// `allowedTypes` list -class AnyTensorOf allowedTypes, string elementDescription = ""> : - TypedTensor>; +def AnyStaticShapeTensor : StaticShapeTensor; -def VectorOrTensor : - AnyTypeOf<[TypedVector, Tensor], "vector or tensor">; +def I1Tensor : Tensor; +def I8Tensor : Tensor; +def I16Tensor : Tensor; +def I32Tensor : Tensor; +def I64Tensor : Tensor; + +def BF16Tensor : Tensor; +def F16Tensor : Tensor; +def F32Tensor : Tensor; +def F64Tensor : Tensor; // This represents a generic tuple without any constraints on elemental type, // ranks, or size. As Tuples can contain tensors, vectors, or scalar values // there is not only a single elemental type. def Tuple : Type; - // A Tuple that only holds elements of a certain type. This cannot inherit from // ContainerType because tuples do not always have a single element type that // could be retrieved with elementTypeCall. @@ -438,12 +416,12 @@ def F64MemRef : MemRef; // Type constraint for integer-like types: integers, indices, vectors of // integers, tensors of integers. def IntegerLike : TypeConstraint.predicate, TypedTensor.predicate]>, + Vector.predicate, Tensor.predicate]>, "integer-like">; // Type constraint for float-like types: floats, vectors or tensors thereof. def FloatLike : TypeConstraint.predicate, TypedTensor.predicate]>, + Vector.predicate, Tensor.predicate]>, "floating-point-like">; @@ -1037,7 +1015,7 @@ def addBenefit; // def : Pattern<(OneResultOp1:$op1 $arg0, $arg1), // [(OneResultOp2:$op2 $arg0, $arg1), // (OneResultOp3 $op2 (OneResultOp4))], -// [(IsStaticShapeTensorTypePred $op1)]>; +// [(HasStaticShapePred $op1)]>; // ``` // // `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index e3b521a..9ba3ab6 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -317,7 +317,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> { %1 = dim %0, 2 : tensor }]; - let arguments = (ins AnyTypeOf<[MemRef, Tensor], + let arguments = (ins AnyTypeOf<[MemRef, AnyTensor], "any tensor or memref type">:$memrefOrTensor, APIntAttr:$index); let results = (outs Index); @@ -366,7 +366,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { %0 = extract_element %0[%1, %2] : vector<4x4xi32> }]; - let arguments = (ins VectorOrTensor:$aggregate, + let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate, Variadic:$indices); let results = (outs AnyType); @@ -498,8 +498,8 @@ def TensorCastOp : CastOp<"tensor_cast"> { %2 = tensor_cast %1 : tensor to tensor }]; - let arguments = (ins Tensor); - let results = (outs Tensor); + let arguments = (ins AnyTensor); + let results = (outs AnyTensor); let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 89f260f..633d4e7 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -37,7 +37,7 @@ def OpB : NS_Op<"one_variadic_operand_op", []> { // CHECK: tblgen_state->addOperands(input); def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> { - let arguments = (ins Variadic:$input1, Variadic:$input2); + let arguments = (ins Variadic:$input1, Variadic:$input2); } // CHECK-LABEL: Operation::operand_range OpC::input1() @@ -55,7 +55,7 @@ def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> { // CHECK-NEXT: tblgen_state->addOperands(input2); def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> { - let arguments = (ins Variadic:$input1, Tensor:$input2, Variadic:$input3); + let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); } // CHECK-LABEL: Operation::operand_range OpD::input1() @@ -79,7 +79,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> // CHECK-NEXT: tblgen_state->addOperands(input3); def OpE : NS_Op<"one_variadic_among_multi_normal_inputs_op", []> { - let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic:$input3, Tensor:$input4, Tensor:$input5); + let arguments = (ins AnyTensor:$input1, AnyTensor:$input2, Variadic:$input3, AnyTensor:$input4, AnyTensor:$input5); } // CHECK-LABEL: Value *OpE::input1() diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 268f0c0..4bce522 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -48,7 +48,7 @@ def OpC : NS_Op<"three_normal_result_op", []> { def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">; def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32); - let results = (outs Tensor:$y); + let results = (outs AnyTensor:$y); } // CHECK-LABEL: OpD definitions @@ -57,7 +57,7 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, F32Attr:$attr); - let results = (outs Tensor:$y); + let results = (outs AnyTensor:$y); } // CHECK-LABEL: OpE definitions @@ -94,7 +94,7 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> { def OpH : NS_Op<"all_variadic_results_op", [SameVariadicResultSize]> { - let results = (outs Variadic:$output1, Variadic:$output2); + let results = (outs Variadic:$output1, Variadic:$output2); } // CHECK-LABEL: Operation::result_range OpH::output1() @@ -113,7 +113,7 @@ def OpH : NS_Op<"all_variadic_results_op", [SameVariadicResultSize]> { // CHECK-NEXT: tblgen_state->addTypes(output2); def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> { - let results = (outs Variadic:$output1, Tensor:$output2, Variadic:$output3); + let results = (outs Variadic:$output1, AnyTensor:$output2, Variadic:$output3); } // CHECK-LABEL: Operation::result_range OpI::output1() @@ -137,7 +137,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> // CHECK-NEXT: tblgen_state->addTypes(output3); def OpJ : NS_Op<"one_variadic_among_multi_normal_results_op", []> { - let results = (outs Tensor:$output1, Tensor:$output2, Variadic:$output3, Tensor:$output4, Tensor:$output5); + let results = (outs AnyTensor:$output1, AnyTensor:$output2, Variadic:$output3, AnyTensor:$output4, AnyTensor:$output5); } // CHECK-LABEL: Value *OpJ::output1() @@ -159,8 +159,8 @@ def OpJ : NS_Op<"one_variadic_among_multi_normal_results_op", []> { // pack to set result type // --- def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> { - let arguments = (ins Variadic:$input); - let results = (outs Tensor:$result); + let arguments = (ins Variadic:$input); + let results = (outs AnyTensor:$result); } // CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef input) diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 4e0da20..58b205a 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -29,7 +29,7 @@ def OpB : NS_Op<"op_for_AllOf_PredOpTrait", [ def OpC : NS_Op<"op_for_TCopVTEtIs", [ PredOpTrait<"first operand has i32 element type", TCopVTEtIs<0, I32>>]> { - let arguments = (ins Tensor:$x); + let arguments = (ins AnyTensor:$x); } // CHECK-LABEL: OpC::verify @@ -40,7 +40,7 @@ def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [ PredOpTrait<"first operand is a vector or tensor with the same " "elemental type as itself", TCopVTEtIsSameAs<0, 0>>]> { - let arguments = (ins Tensor:$x); + let arguments = (ins AnyTensor:$x); } // CHECK-LABEL: OpD::verify @@ -52,8 +52,8 @@ def OpE : NS_Op<"op_for_TCresVTEtIsSameAsOp", [ PredOpTrait<"first operand is a vector or tensor with the same " "elemental type as first result", TCresVTEtIsSameAsOp<0, 0>>]> { - let arguments = (ins Tensor:$x); - let results = (outs Tensor:$y); + let arguments = (ins AnyTensor:$x); + let results = (outs AnyTensor:$y); } // CHECK-LABEL: OpE::verify @@ -97,11 +97,11 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [ PredOpTrait<"operands indexed at 0, 2, 3 should all have " "the same type", TCopVTEtAreSameAt<[0, 2, 3]>>]> { let arguments = (ins - Tensor:$a, - Tensor:$b, - Tensor:$c, - Tensor:$d, - Tensor:$e + AnyTensor:$a, + AnyTensor:$b, + AnyTensor:$c, + AnyTensor:$d, + AnyTensor:$e ); } @@ -116,4 +116,4 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> { } // CHECK-LABEL: OpK::verify -// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa())) && (((this->getOperation()->getOperand(0)->getType().cast().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast().getElementType().isInteger(32)))))) +// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa())) && (((this->getOperation()->getOperand(0)->getType().cast().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast().getElementType().isInteger(32)))))) diff --git a/mlir/test/mlir-tblgen/reference-impl.td b/mlir/test/mlir-tblgen/reference-impl.td index 8bcab6e..69b1787 100644 --- a/mlir/test/mlir-tblgen/reference-impl.td +++ b/mlir/test/mlir-tblgen/reference-impl.td @@ -12,8 +12,8 @@ class X_Op traits = []> : Op; def X_AddOp : X_Op<"add">, - Arguments<(ins Tensor:$A, Tensor:$B)>, - Results<(outs Tensor: $C)> { + Arguments<(ins AnyTensor:$A, AnyTensor:$B)>, + Results<(outs AnyTensor: $C)> { // TODO: extract referenceImplementation to Op. code referenceImplementation = [{ auto ivs = IndexHandle::makeIndexHandles(view_A.rank()); -- 2.7.4