- Also adopt variadic llvm::isa<> in more places.
- Fixes https://bugs.llvm.org/show_bug.cgi?id=46445
Differential Revision: https://reviews.llvm.org/D82769
return nullptr;
// Check that the type is either a TensorType or another StructType.
- if (!elementType.isa<mlir::TensorType>() &&
- !elementType.isa<StructType>()) {
+ if (!elementType.isa<mlir::TensorType, StructType>()) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
return nullptr;
// Check that the type is either a TensorType or another StructType.
- if (!elementType.isa<mlir::TensorType>() &&
- !elementType.isa<StructType>()) {
+ if (!elementType.isa<mlir::TensorType, StructType>()) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
: value(v), exprs(indexings.begin(), indexings.end()) {
- assert((v.getType().isa<MemRefType>() ||
- v.getType().isa<RankedTensorType>() ||
- v.getType().isa<VectorType>()) &&
+ assert((v.getType().isa<MemRefType, RankedTensorType, VectorType>()) &&
"MemRef, RankedTensor or Vector expected");
}
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
: type(t), exprs(indexings.begin(), indexings.end()) {
- assert((t.isa<MemRefType>() || t.isa<RankedTensorType>() ||
- t.isa<VectorType>()) &&
+ assert((t.isa<MemRefType, RankedTensorType, VectorType>()) &&
"MemRef, RankedTensor or Vector expected");
}
bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
+ template <typename First, typename Second, typename... Rest>
+ bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
assert(impl && "isa<> used on a null attribute.");
return U::classof(*this);
}
+
+template <typename First, typename Second, typename... Rest>
+bool Attribute::isa() const {
+ return isa<First>() || isa<Second, Rest...>();
+}
+
template <typename U> U Attribute::dyn_cast() const {
return isa<U>() ? U(impl) : U(nullptr);
}
return false;
auto type = op->getResult(0).getType();
- if (type.isa<IntegerType>() || type.isa<IndexType>())
+ if (type.isa<IntegerType, IndexType>())
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
- if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
+ if (type.isa<VectorType, RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getSplatValue());
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
static bool isValidElementType(Type t) {
- return t.isa<IntegerType>() || t.isa<FloatType>();
+ return t.isa<IntegerType, FloatType>();
}
ArrayRef<int64_t> getShape() const;
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
- return type.isa<ComplexType>() || type.isa<FloatType>() ||
- type.isa<IntegerType>() || type.isa<OpaqueType>() ||
- type.isa<VectorType>() || type.isa<IndexType>() ||
+ return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
+ IndexType>() ||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
}
bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
+ template <typename First, typename Second, typename... Rest>
+ bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
assert(impl && "isa<> used on a null type.");
return U::classof(*this);
}
+
+template <typename First, typename Second, typename... Rest>
+bool Type::isa() const {
+ return isa<First>() || isa<Second, Rest...>();
+}
+
template <typename U> U Type::dyn_cast() const {
return isa<U>() ? U(impl) : U(nullptr);
}
assert(*this && "isa<> used on a null type.");
return U::classof(*this);
}
+
+ template <typename First, typename Second, typename... Rest>
+ bool isa() const {
+ return isa<First>() || isa<Second, Rest...>();
+ }
+
template <typename U> U dyn_cast() const {
return isa<U>() ? U(ownerAndKind) : U(nullptr);
}
// Walk this 'affine.for' operation to gather all memory regions.
auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
- if (!isa<AffineReadOpInterface>(opInst) &&
- !isa<AffineWriteOpInterface>(opInst)) {
+ if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
// Neither load nor a store op.
return WalkResult::advance();
}
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Operation *, 8> loadAndStoreOpInsts;
auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
- if (isa<AffineReadOpInterface>(opInst) ||
- isa<AffineWriteOpInterface>(opInst))
+ if (isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst))
loadAndStoreOpInsts.push_back(opInst);
- else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
- !isa<AffineIfOp>(opInst) &&
+ else if (!isa<AffineForOp, AffineTerminatorOp, AffineIfOp>(opInst) &&
!MemoryEffectOpInterface::hasNoEffect(opInst))
return WalkResult::interrupt();
auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
if (!converted)
return {};
- if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>())
+ if (t.isa<MemRefType, UnrankedMemRefType>())
converted = converted.getPointerTo();
inputs.push_back(converted);
}
FunctionType type, SmallVectorImpl<UnsignedTypePair> &argsInfo) const {
argsInfo.reserve(type.getNumInputs());
for (auto en : llvm::enumerate(type.getInputs())) {
- if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>())
+ if (en.value().isa<MemRefType, UnrankedMemRefType>())
argsInfo.push_back({en.index(), en.value()});
}
}
return failure();
// std.constant should only have vector or tenor types.
- assert(srcType.isa<VectorType>() || srcType.isa<RankedTensorType>());
+ assert((srcType.isa<VectorType, RankedTensorType>()));
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return ValueBuilder<IOp>(lhs, rhs);
} else if (thisType.isa<FloatType>()) {
return ValueBuilder<FOp>(lhs, rhs);
- } else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
+ } else if (thisType.isa<VectorType, TensorType>()) {
auto aggregateType = thisType.cast<ShapedType>();
if (aggregateType.getElementType().isSignlessInteger())
return ValueBuilder<IOp>(lhs, rhs);
nest->walk([&](Operation *op) {
if (auto forOp = dyn_cast<AffineForOp>(op))
promoteIfSingleIteration(forOp);
- else if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+ else if (isa<AffineLoadOp, AffineStoreOp>(op))
copyOps.push_back(op);
});
// If the body of a predicated region has a for loop, we don't hoist the
// 'affine.if'.
return false;
- } else if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+ } else if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
// TODO(asabne): Support DMA ops.
return false;
} else if (!isa<ConstantOp>(op)) {
for (auto *user : memref.getUsers()) {
// If this memref has a user that is a DMA, give up because these
// operations write to this memref.
- if (isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) {
+ if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
return false;
}
// If the memref used by the load/store is used in a store elsewhere in
return nullptr;
// Fuse when consumer is GenericOp or IndexedGenericOp.
- if (isa<GenericOp>(consumer) || isa<IndexedGenericOp>(consumer)) {
+ if (isa<GenericOp, IndexedGenericOp>(consumer)) {
auto linalgOpConsumer = cast<LinalgOp>(consumer);
if (!linalgOpConsumer.hasTensorSemantics())
return nullptr;
- if (isa<GenericOp>(producer) || isa<IndexedGenericOp>(producer)) {
+ if (isa<GenericOp, IndexedGenericOp>(producer)) {
auto linalgOpProducer = cast<LinalgOp>(producer);
if (linalgOpProducer.hasTensorSemantics())
return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer,
static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
Type spec = typeAttr.getValue();
- if (spec.isa<TensorType>() || spec.isa<VectorType>())
+ if (spec.isa<TensorType, VectorType>())
return false;
// The spec should be either a quantized type which is compatible to the
}
// Is the constant value a type expressed in a way that we support?
- if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
- !value.isa<SparseElementsAttr>()) {
+ if (!value.isa<FloatAttr, DenseElementsAttr, SparseElementsAttr>()) {
return failure();
}
return failure();
Type type = value.getType();
- if (type.isa<NoneType>() || type.isa<TensorType>()) {
+ if (type.isa<NoneType, TensorType>()) {
if (parser.parseColonType(type))
return failure();
}
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
- if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) ||
- isa<spirv::SpecConstantOp>(initOp))) {
+ if (!initOp ||
+ !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
return varOp.emitOpError("initializer must be result of a "
"spv.specConstant or spv.globalVariable op");
}
static LogicalResult verify(spirv::MergeOp mergeOp) {
auto *parentOp = mergeOp.getParentOp();
- if (!parentOp ||
- (!isa<spirv::SelectionOp>(parentOp) && !isa<spirv::LoopOp>(parentOp)))
+ if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
return mergeOp.emitOpError(
"expected parent op to be 'spv.selection' or 'spv.loop'");
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
// a global (module scope) OpVariable instruction".
auto *initOp = varOp.getOperand(0).getDefiningOp();
- if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant
- isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
- isa<spirv::AddressOfOp>(initOp)))
+ if (!initOp || !isa<spirv::ConstantOp, // for normal constant
+ spirv::ReferenceOfOp, // for spec constant
+ spirv::AddressOfOp>(initOp))
return varOp.emitOpError("initializer must be the result of a "
"constant or spv.globalVariable op");
}
if (value.getType() != type)
return false;
// Finally, check that the attribute kind is handled.
- return value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
- value.isa<ElementsAttr>() || value.isa<UnitAttr>();
+ return value.isa<IntegerAttr, FloatAttr, ElementsAttr, UnitAttr>();
}
void ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
// If the result type is a vector or tensor, the type can be a mask with the
// same elements.
Type resultType = op.getType();
- if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>())
+ if (!resultType.isa<TensorType, VectorType>())
return op.emitOpError()
<< "expected condition to be a signless i1, but got "
<< conditionType;
assert(operands.size() == 1 && "splat takes one operand");
auto constOperand = operands.front();
- if (!constOperand ||
- (!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
+ if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
return {};
auto shapedType = getType().cast<ShapedType>();
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns llvm::None otherwise.
auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
- if (type.isa<VectorType>() || type.isa<RankedTensorType>())
+ if (type.isa<VectorType, RankedTensorType>())
return static_cast<StandardTypes::Kind>(type.getKind());
return llvm::None;
};
}
static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
- if (type.isa<IntegerType>() || type.isa<IndexType>())
+ if (type.isa<IntegerType, IndexType>())
return success();
return emitError(loc, "expected integer or index type");
}
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data,
bool isSplat) {
- assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+ assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
DenseElementsAttr values) {
assert(indices.getType().getElementType().isInteger(64) &&
"expected sparse indices to be 64-bit integer values");
- assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+ assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
}
bool Type::isSignlessIntOrIndex() {
- return isa<IndexType>() || isSignlessInteger();
+ return isSignlessInteger() || isa<IndexType>();
}
bool Type::isSignlessIntOrIndexOrFloat() {
- return isa<IndexType>() || isSignlessInteger() || isa<FloatType>();
+ return isSignlessInteger() || isa<IndexType, FloatType>();
}
bool Type::isSignlessIntOrFloat() {
bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
-bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
int64_t ShapedType::getRank() const { return getShape().size(); }
bool ShapedType::hasRank() const {
- return !isa<UnrankedMemRefType>() && !isa<UnrankedTensorType>();
+ return !isa<UnrankedMemRefType, UnrankedTensorType>();
}
int64_t ShapedType::getDimSize(unsigned idx) const {
// Tensors can have vectors and other tensors as elements, other shaped types
// cannot.
assert(isa<TensorType>() && "unsupported element type");
- assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
+ assert((elementType.isa<VectorType, TensorType>()) &&
"unsupported tensor element type");
return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
}
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
- !elementType.isa<ComplexType>())
+ if (!elementType.isIntOrFloat() &&
+ !elementType.isa<VectorType, ComplexType>())
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
- !elementType.isa<ComplexType>())
+ if (!elementType.isIntOrFloat() &&
+ !elementType.isa<VectorType, ComplexType>())
return emitError(loc, "invalid memref element type");
return success();
}
for (Attribute attr : llvm::drop_begin(attrRange, index)) {
/// Check for a nested container attribute, these will also need to be
/// walked.
- if (attr.isa<ArrayAttr>() || attr.isa<DictionaryAttr>()) {
+ if (attr.isa<ArrayAttr, DictionaryAttr>()) {
attrWorklist.push_back(attr);
curAccessChain.push_back(-1);
return WalkResult::advance();
return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
}
- if (!type.isa<IntegerType>() && !type.isa<IndexType>())
+ if (!type.isa<IntegerType, IndexType>())
return emitError(loc, "integer literal not valid for specified type"),
nullptr;
return nullptr;
}
- if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
+ if (!type.isa<RankedTensorType, VectorType>()) {
emitError("elements literal must be a ranked tensor or vector type");
return nullptr;
}
return nullptr;
// Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
- !elementType.isa<ComplexType>())
+ if (!elementType.isIntOrFloat() &&
+ !elementType.isa<VectorType, ComplexType>())
return emitError(typeLoc, "invalid memref element type"), nullptr;
// Parse semi-affine-map-composition.
LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
for (Operation &o : getModuleBody(m).getOperations())
- if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) &&
- !o.isKnownTerminator())
+ if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp>(&o) && !o.isKnownTerminator())
return o.emitOpError("unsupported module-level operation");
return success();
}
unsigned count = 0;
stats->opCountMap[childForOp] = 0;
for (auto &op : *forOp.getBody()) {
- if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
+ if (!isa<AffineForOp, AffineIfOp>(op))
++count;
}
stats->opCountMap[childForOp] = count;
// Collect the loads and stores within the function.
loadsAndStores.clear();
getFunction().walk([&](Operation *op) {
- if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+ if (isa<AffineLoadOp, AffineStoreOp>(op))
loadsAndStores.push_back(op);
});