class LoadOp;
class ReturnOp;
class StoreOp;
-namespace spirv {
-class SPIRVDialect;
-}
/// Type conversion from Standard Types to SPIR-V Types.
class SPIRVBasicTypeConverter : public TypeConverter {
public:
- explicit SPIRVBasicTypeConverter(MLIRContext *context);
-
/// Converts types to SPIR-V supported types.
virtual Type convertType(Type t);
-
-protected:
- spirv::SPIRVDialect *spirvDialect;
};
/// Converts a function type according to the requirements of a SPIR-V entry
static StringRef getDialectNamespace() { return "spv"; }
+ /// Checks if the given `type` is valid in SPIR-V dialect.
+ static bool isValidType(Type type);
+
/// Parses a type registered to this dialect.
Type parseType(llvm::StringRef spec, Location loc) const override;
/// Prints a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
- /// Checks if a type is valid in SPIR-V dialect.
- bool isValidSPIRVType(Type t) const;
+ /// Provides a hook for materializing a constant to this dialect.
+ Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+ Location loc) override;
};
} // end namespace spirv
SPV_Type:$constant
);
+ let hasFolder = 1;
+
+ let extraClassDeclaration = [{
+ // Returns true if a constant can be built for the given `type`.
+ static bool isBuildableWith(Type type);
+ }];
+
let hasOpcode = 0;
}
Pointer,
RuntimeArray,
Struct,
+ LAST_SPIRV_TYPE = Struct,
};
}
}
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
- SPIRVBasicTypeConverter basicTypeConverter(context);
+ SPIRVBasicTypeConverter basicTypeConverter;
SPIRVTypeConverter typeConverter(&basicTypeConverter);
OwningRewritePatternList patterns;
patterns.insert<
// Type Conversion
//===----------------------------------------------------------------------===//
-SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
- : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
-
Type SPIRVBasicTypeConverter::convertType(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
- if (spirvDialect->isValidSPIRVType(t)) {
+ if (spirv::SPIRVDialect::isValidType(t)) {
return t;
}
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
type.getNumElements() >= 2 && type.getNumElements() <= 4;
}
-bool SPIRVDialect::isValidSPIRVType(Type type) const {
+bool SPIRVDialect::isValidType(Type type) {
// Allow SPIR-V dialect types
- if (&type.getDialect() == this) {
+ if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+ type.getKind() <= TypeKind::LAST_SPIRV_TYPE) {
return true;
}
if (isValidSPIRVScalarType(type)) {
llvm_unreachable("unhandled SPIR-V type");
}
}
+
+//===----------------------------------------------------------------------===//
+// Constant
+//===----------------------------------------------------------------------===//
+
+Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ if (!ConstantOp::isBuildableWith(type))
+ return nullptr;
+
+ return builder.create<spirv::ConstantOp>(loc, type, value);
+}
auto elemType = arrayType.getElementType();
for (auto element : value.cast<ArrayAttr>().getValue()) {
if (element.getType() != elemType)
- return constOp.emitOpError(
- "has array element that are not of result array element type");
+ return constOp.emitOpError("has array element whose type (")
+ << element.getType()
+ << ") does not match the result element type (" << elemType
+ << ')';
}
} break;
default:
return success();
}
+OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return value();
+}
+
+bool spirv::ConstantOp::isBuildableWith(Type type) {
+ // Must be valid SPIR-V type first.
+ if (!SPIRVDialect::isValidType(type))
+ return false;
+
+ if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+ type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
+ // TODO(antiagainst): support contant struct
+ return type.isa<spirv::ArrayType>();
+ }
+
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// spv.EntryPoint
//===----------------------------------------------------------------------===//
case StandardAttributes::Integer:
case StandardAttributes::Float: {
// Make sure bitwidth is allowed.
- auto *dialect = static_cast<spirv::SPIRVDialect *>(constOp.getDialect());
- if (!dialect->isValidSPIRVType(value.getType()))
+ if (!spirv::SPIRVDialect::isValidType(value.getType()))
return constOp.emitOpError("default value bitwidth disallowed");
return success();
}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.constant
+//===----------------------------------------------------------------------===//
+
+// TODO(antiagainst): test constants in different blocks
+
+func @deduplicate_scalar_constant() -> (i32, i32) {
+ // CHECK: %[[CST:.*]] = spv.constant 42 : i32
+ %0 = spv.constant 42 : i32
+ %1 = spv.constant 42 : i32
+ // CHECK-NEXT: return %[[CST]], %[[CST]]
+ return %0, %1 : i32, i32
+}
+
+// -----
+
+func @deduplicate_vector_constant() -> (vector<3xi32>, vector<3xi32>) {
+ // CHECK: %[[CST:.*]] = spv.constant dense<[1, 2, 3]> : vector<3xi32>
+ %0 = spv.constant dense<[1, 2, 3]> : vector<3xi32>
+ %1 = spv.constant dense<[1, 2, 3]> : vector<3xi32>
+ // CHECK-NEXT: return %[[CST]], %[[CST]]
+ return %0, %1 : vector<3xi32>, vector<3xi32>
+}
+
+// -----
+
+func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.array<1 x vector<2xi32>>) {
+ // CHECK: %[[CST:.*]] = spv.constant [dense<5> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
+ %0 = spv.constant [dense<5> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
+ %1 = spv.constant [dense<5> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
+ // CHECK-NEXT: return %[[CST]], %[[CST]]
+ return %0, %1 : !spv.array<1 x vector<2xi32>>, !spv.array<1 x vector<2xi32>>
+}
// -----
func @array_constant() -> () {
- // expected-error @+1 {{has array element that are not of result array element type}}
+ // expected-error @+1 {{has array element whose type ('vector<2xi32>') does not match the result element type ('vector<2xf32>')}}
%0 = spv.constant [dense<3.0> : vector<2xf32>, dense<4> : vector<2xi32>] : !spv.array<2xvector<2xf32>>
return
}