[mlir][spirv] Add get() method to TargetEnvAttr taking raw values
authorLei Zhang <antiagainst@google.com>
Mon, 2 Mar 2020 22:27:05 +0000 (17:27 -0500)
committerLei Zhang <antiagainst@google.com>
Wed, 4 Mar 2020 19:01:26 +0000 (14:01 -0500)
Also make the getResourceLimits() method more explicit about its return type.

Differential Revision: https://reviews.llvm.org/D75484

mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
mlir/lib/Dialect/SPIRV/TargetAndABI.cpp

index a20fe6e..01b7758 100644 (file)
@@ -52,6 +52,9 @@ public:
   using Base::Base;
 
   /// Gets a TargetEnvAttr instance.
+  static TargetEnvAttr get(Version version, ArrayRef<Extension> extensions,
+                           ArrayRef<Capability> capabilities,
+                           DictionaryAttr limits);
   static TargetEnvAttr get(IntegerAttr version, ArrayAttr extensions,
                            ArrayAttr capabilities, DictionaryAttr limits);
 
@@ -86,7 +89,7 @@ public:
   ArrayAttr getCapabilitiesAttr();
 
   /// Returns the target resource limits.
-  DictionaryAttr getResourceLimits();
+  ResourceLimitsAttr getResourceLimits();
 
   static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; }
 
index ee5dcb1..f8c5900 100644 (file)
@@ -48,6 +48,27 @@ struct TargetEnvAttributeStorage : public AttributeStorage {
 } // namespace spirv
 } // namespace mlir
 
+spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
+    spirv::Version version, ArrayRef<spirv::Extension> extensions,
+    ArrayRef<spirv::Capability> capabilities, DictionaryAttr limits) {
+  Builder b(limits.getContext());
+
+  auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));
+
+  SmallVector<Attribute, 4> extAttrs;
+  extAttrs.reserve(extensions.size());
+  for (spirv::Extension ext : extensions)
+    extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
+
+  SmallVector<Attribute, 4> capAttrs;
+  capAttrs.reserve(capabilities.size());
+  for (spirv::Capability cap : capabilities)
+    capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));
+
+  return get(versionAttr, b.getArrayAttr(extAttrs), b.getArrayAttr(capAttrs),
+             limits);
+}
+
 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(IntegerAttr version,
                                                ArrayAttr extensions,
                                                ArrayAttr capabilities,
@@ -98,8 +119,8 @@ ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
   return getImpl()->capabilities.cast<ArrayAttr>();
 }
 
-DictionaryAttr spirv::TargetEnvAttr::getResourceLimits() {
-  return getImpl()->limits.cast<DictionaryAttr>();
+spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() {
+  return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
 }
 
 LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(