Refactored enum_set
authorAndrey Tuganov <andreyt@google.com>
Thu, 9 Mar 2017 23:24:35 +0000 (18:24 -0500)
committerDavid Neto <dneto@google.com>
Fri, 10 Mar 2017 18:38:32 +0000 (13:38 -0500)
- removed forgotten file enum_set.cpp
- added IsEmpty and HasAnyOf
- hidden unsafe functions Add(uint32_t), Contains(uint32_t)
- added new tests

source/enum_set.cpp [deleted file]
source/enum_set.h
test/CMakeLists.txt
test/enum_set_test.cpp [moved from test/capability_set_test.cpp with 51% similarity]

diff --git a/source/enum_set.cpp b/source/enum_set.cpp
deleted file mode 100644 (file)
index d3e046c..0000000
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright (c) 2016 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "enum_set.h"
-
-#include "spirv/1.1/spirv.hpp"
-
-namespace {
-
-// Determines whether the given enum value can be represented
-// as a bit in a uint64_t mask. If so, then returns that mask bit.
-// Otherwise, returns 0.
-uint64_t AsMask(uint32_t word) {
-  if (word > 63) return 0;
-  return uint64_t(1) << word;
-}
-}
-
-namespace libspirv {
-
-template<typename EnumType>
-void EnumSet<EnumType>::Add(uint32_t word) {
-  if (auto new_bits = AsMask(word)) {
-    mask_ |= new_bits;
-  } else {
-    Overflow().insert(word);
-  }
-}
-
-template<typename EnumType>
-bool EnumSet<EnumType>::Contains(uint32_t word) const {
-  // We shouldn't call Overflow() since this is a const method.
-  if (auto bits = AsMask(word)) {
-    return mask_ & bits;
-  } else if (auto overflow = overflow_.get()) {
-    return overflow->find(word) != overflow->end();
-  }
-  // The word is large, but the set doesn't have large members, so
-  // it doesn't have an overflow set.
-  return false;
-}
-
-// Applies f to each capability in the set, in order from smallest enum
-// value to largest.
-void CapabilitySet::ForEach(std::function<void(SpvCapability)> f) const {
-  for (uint32_t i = 0; i < 64; ++i) {
-    if (mask_ & AsMask(i)) f(static_cast<SpvCapability>(i));
-  }
-  if (overflow_) {
-    for (uint32_t c : *overflow_) f(static_cast<SpvCapability>(c));
-  }
-}
-
-}  // namespace libspirv
index 6f9be32..0abc594 100644 (file)
@@ -39,7 +39,7 @@ class EnumSet {
 
  public:
   // Construct an empty set.
-  EnumSet() = default;
+  EnumSet() {}
   // Construct an set with just the given enum value.
   explicit EnumSet(EnumType c) { Add(c); }
   // Construct an set from an initializer list of enum values.
@@ -67,10 +67,52 @@ class EnumSet {
 
   // Adds the given enum value to the set.  This has no effect if the
   // enum value is already in the set.
-  void Add(EnumType c) { Add(ToWord(c)); }
+  void Add(EnumType c) { AddWord(ToWord(c)); }
+
+  // Returns true if this enum value is in the set.
+  bool Contains(EnumType c) const { return ContainsWord(ToWord(c)); }
+
+  // Applies f to each enum in the set, in order from smallest enum
+  // value to largest.
+  void ForEach(std::function<void(EnumType)> f) const {
+    for (uint32_t i = 0; i < 64; ++i) {
+      if (mask_ & AsMask(i)) f(static_cast<EnumType>(i));
+    }
+    if (overflow_) {
+      for (uint32_t c : *overflow_) f(static_cast<EnumType>(c));
+    }
+  }
+
+  // Returns true if the set is empty.
+  bool IsEmpty() const {
+    if (mask_) return false;
+    if (overflow_ && !overflow_->empty()) return false;
+    return true;
+  }
+
+  // Returns true if the set contains ANY of the elements of |in_set|,
+  // or if |in_set| is empty.
+  bool HasAnyOf(const EnumSet<EnumType>& in_set) const {
+    if (in_set.IsEmpty()) return true;
+
+    if (mask_ & in_set.mask_)
+      return true;
+
+    if (!overflow_ || !in_set.overflow_)
+      return false;
+
+    for (uint32_t item : *in_set.overflow_) {
+      if (overflow_->find(item) != overflow_->end())
+        return true;
+    }
+
+    return false;
+  }
+
+ private:
   // Adds the given enum value (as a 32-bit word) to the set.  This has no
   // effect if the enum value is already in the set.
-  void Add(uint32_t word) {
+  void AddWord(uint32_t word) {
     if (auto new_bits = AsMask(word)) {
       mask_ |= new_bits;
     } else {
@@ -78,10 +120,8 @@ class EnumSet {
     }
   }
 
-  // Returns true if this enum value is in the set.
-  bool Contains(EnumType c) const { return Contains(ToWord(c)); }
   // Returns true if the enum represented as a 32-bit word is in the set.
-  bool Contains(uint32_t word) const {
+  bool ContainsWord(uint32_t word) const {
     // We shouldn't call Overflow() since this is a const method.
     if (auto bits = AsMask(word)) {
       return (mask_ & bits) != 0;
@@ -93,18 +133,6 @@ class EnumSet {
     return false;
   }
 
-  // Applies f to each enum in the set, in order from smallest enum
-  // value to largest.
-  void ForEach(std::function<void(EnumType)> f) const {
-    for (uint32_t i = 0; i < 64; ++i) {
-      if (mask_ & AsMask(i)) f(static_cast<EnumType>(i));
-    }
-    if (overflow_) {
-      for (uint32_t c : *overflow_) f(static_cast<EnumType>(c));
-    }
-  }
-
- private:
   // Returns the enum value as a uint32_t.
   uint32_t ToWord(EnumType value) const {
     static_assert(sizeof(EnumType) <= sizeof(uint32_t),
index 6017793..3fc6719 100644 (file)
@@ -74,8 +74,8 @@ set(TEST_SOURCES
   binary_strnlen_s_test.cpp
   binary_to_text_test.cpp
   binary_to_text.literal_test.cpp
-  capability_set_test.cpp
   comment_test.cpp
+  enum_set_test.cpp
   ext_inst.glsl_test.cpp
   ext_inst.opencl_test.cpp
   fix_word_test.cpp
similarity index 51%
rename from test/capability_set_test.cpp
rename to test/enum_set_test.cpp
index 4a0111d..86207fe 100644 (file)
 
 namespace {
 
+using libspirv::EnumSet;
 using libspirv::CapabilitySet;
 using spvtest::ElementsIn;
 using ::testing::Eq;
 using ::testing::ValuesIn;
 
-TEST(CapabilitySet, DefaultIsEmpty) {
-  CapabilitySet c;
+TEST(EnumSet, IsEmpty1) {
+  EnumSet<uint32_t> set;
+  EXPECT_TRUE(set.IsEmpty());
+  set.Add(0);
+  EXPECT_FALSE(set.IsEmpty());
+}
+
+TEST(EnumSet, IsEmpty2) {
+  EnumSet<uint32_t> set;
+  EXPECT_TRUE(set.IsEmpty());
+  set.Add(150);
+  EXPECT_FALSE(set.IsEmpty());
+}
+
+TEST(EnumSet, IsEmpty3) {
+  EnumSet<uint32_t> set(4);
+  EXPECT_FALSE(set.IsEmpty());
+}
+
+TEST(EnumSet, IsEmpty4) {
+  EnumSet<uint32_t> set(300);
+  EXPECT_FALSE(set.IsEmpty());
+}
+
+TEST(EnumSetHasAnyOf, EmptySetEmptyQuery) {
+  const EnumSet<uint32_t> set;
+  const EnumSet<uint32_t> empty;
+  EXPECT_TRUE(set.HasAnyOf(empty));
+  EXPECT_TRUE(EnumSet<uint32_t>().HasAnyOf(EnumSet<uint32_t>()));
+}
+
+TEST(EnumSetHasAnyOf, MaskSetEmptyQuery) {
+  EnumSet<uint32_t> set;
+  const EnumSet<uint32_t> empty;
+  set.Add(5);
+  set.Add(8);
+  EXPECT_TRUE(set.HasAnyOf(empty));
+}
+
+TEST(EnumSetHasAnyOf, OverflowSetEmptyQuery) {
+  EnumSet<uint32_t> set;
+  const EnumSet<uint32_t> empty;
+  set.Add(200);
+  set.Add(300);
+  EXPECT_TRUE(set.HasAnyOf(empty));
+}
+
+TEST(EnumSetHasAnyOf, EmptyQuery) {
+  EnumSet<uint32_t> set;
+  const EnumSet<uint32_t> empty;
+  set.Add(5);
+  set.Add(8);
+  set.Add(200);
+  set.Add(300);
+  EXPECT_TRUE(set.HasAnyOf(empty));
+}
+
+TEST(EnumSetHasAnyOf, EmptyQueryAlwaysTrue) {
+  EnumSet<uint32_t> set;
+  const EnumSet<uint32_t> empty;
+  EXPECT_TRUE(set.HasAnyOf(empty));
+  set.Add(5);
+  EXPECT_TRUE(set.HasAnyOf(empty));
+
+  EXPECT_TRUE(EnumSet<uint32_t>(100).HasAnyOf(EnumSet<uint32_t>()));
+}
+
+TEST(EnumSetHasAnyOf, ReflexiveMask) {
+  EnumSet<uint32_t> set(3);
+  set.Add(24);
+  set.Add(30);
+  EXPECT_TRUE(set.HasAnyOf(set));
+}
+
+TEST(EnumSetHasAnyOf, ReflexiveOverflow) {
+  EnumSet<uint32_t> set(200);
+  set.Add(300);
+  set.Add(400);
+  EXPECT_TRUE(set.HasAnyOf(set));
+}
+
+TEST(EnumSetHasAnyOf, Reflexive) {
+  EnumSet<uint32_t> set(3);
+  set.Add(24);
+  set.Add(300);
+  set.Add(400);
+  EXPECT_TRUE(set.HasAnyOf(set));
+}
+
+TEST(EnumSetHasAnyOf, EmptySetHasNone) {
+  EnumSet<uint32_t> set;
+  EnumSet<uint32_t> items;
+  for (uint32_t i = 0; i < 200; ++i) {
+    items.Add(i);
+    EXPECT_FALSE(set.HasAnyOf(items));
+    EXPECT_FALSE(set.HasAnyOf(EnumSet<uint32_t>(i)));
+  }
+}
+
+TEST(EnumSetHasAnyOf, MaskSetMaskQuery) {
+  EnumSet<uint32_t> set(0);
+  EnumSet<uint32_t> items(1);
+  EXPECT_FALSE(set.HasAnyOf(items));
+  set.Add(2);
+  items.Add(3);
+  EXPECT_FALSE(set.HasAnyOf(items));
+  set.Add(3);
+  EXPECT_TRUE(set.HasAnyOf(items));
+  set.Add(4);
+  EXPECT_TRUE(set.HasAnyOf(items));
+}
+
+TEST(EnumSetHasAnyOf, OverflowSetOverflowQuery) {
+  EnumSet<uint32_t> set(100);
+  EnumSet<uint32_t> items(200);
+  EXPECT_FALSE(set.HasAnyOf(items));
+  set.Add(300);
+  items.Add(400);
+  EXPECT_FALSE(set.HasAnyOf(items));
+  set.Add(200);
+  EXPECT_TRUE(set.HasAnyOf(items));
+  set.Add(500);
+  EXPECT_TRUE(set.HasAnyOf(items));
+}
+
+TEST(EnumSetHasAnyOf, GeneralCase) {
+  EnumSet<uint32_t> set(0);
+  EnumSet<uint32_t> items(100);
+  EXPECT_FALSE(set.HasAnyOf(items));
+  set.Add(300);
+  items.Add(4);
+  EXPECT_FALSE(set.HasAnyOf(items));
+  set.Add(5);
+  items.Add(500);
+  EXPECT_FALSE(set.HasAnyOf(items));
+  set.Add(500);
+  EXPECT_TRUE(set.HasAnyOf(items));
+  EXPECT_FALSE(set.HasAnyOf(EnumSet<uint32_t>(20)));
+  EXPECT_FALSE(set.HasAnyOf(EnumSet<uint32_t>(600)));
+  EXPECT_TRUE(set.HasAnyOf(EnumSet<uint32_t>(5)));
+  EXPECT_TRUE(set.HasAnyOf(EnumSet<uint32_t>(300)));
+  EXPECT_TRUE(set.HasAnyOf(EnumSet<uint32_t>(0)));
+}
+
+TEST(EnumSet, DefaultIsEmpty) {
+  EnumSet<uint32_t> set;
   for (uint32_t i = 0; i < 1000; ++i) {
-    EXPECT_FALSE(c.Contains(i));
-    EXPECT_FALSE(c.Contains(static_cast<SpvCapability>(i)));
+    EXPECT_FALSE(set.Contains(i));
   }
 }
 
@@ -37,16 +181,16 @@ TEST(CapabilitySet, ConstructSingleMemberMatrix) {
   CapabilitySet s(SpvCapabilityMatrix);
   EXPECT_TRUE(s.Contains(SpvCapabilityMatrix));
   EXPECT_FALSE(s.Contains(SpvCapabilityShader));
-  EXPECT_FALSE(s.Contains(1000));
+  EXPECT_FALSE(s.Contains(static_cast<SpvCapability>(1000)));
 }
 
 TEST(CapabilitySet, ConstructSingleMemberMaxInMask) {
   CapabilitySet s(static_cast<SpvCapability>(63));
   EXPECT_FALSE(s.Contains(SpvCapabilityMatrix));
   EXPECT_FALSE(s.Contains(SpvCapabilityShader));
-  EXPECT_TRUE(s.Contains(63));
-  EXPECT_FALSE(s.Contains(64));
-  EXPECT_FALSE(s.Contains(1000));
+  EXPECT_TRUE(s.Contains(static_cast<SpvCapability>(63)));
+  EXPECT_FALSE(s.Contains(static_cast<SpvCapability>(64)));
+  EXPECT_FALSE(s.Contains(static_cast<SpvCapability>(1000)));
 }
 
 TEST(CapabilitySet, ConstructSingleMemberMinOverflow) {
@@ -54,41 +198,34 @@ TEST(CapabilitySet, ConstructSingleMemberMinOverflow) {
   CapabilitySet s(static_cast<SpvCapability>(64));
   EXPECT_FALSE(s.Contains(SpvCapabilityMatrix));
   EXPECT_FALSE(s.Contains(SpvCapabilityShader));
-  EXPECT_FALSE(s.Contains(63));
-  EXPECT_TRUE(s.Contains(64));
-  EXPECT_FALSE(s.Contains(1000));
+  EXPECT_FALSE(s.Contains(static_cast<SpvCapability>(63)));
+  EXPECT_TRUE(s.Contains(static_cast<SpvCapability>(64)));
+  EXPECT_FALSE(s.Contains(static_cast<SpvCapability>(1000)));
 }
 
 TEST(CapabilitySet, ConstructSingleMemberMaxOverflow) {
   // Check the max 32-bit signed int.
-  CapabilitySet s(SpvCapability(0x7fffffffu));
+  CapabilitySet s(static_cast<SpvCapability>(0x7fffffffu));
   EXPECT_FALSE(s.Contains(SpvCapabilityMatrix));
   EXPECT_FALSE(s.Contains(SpvCapabilityShader));
-  EXPECT_FALSE(s.Contains(1000));
-  EXPECT_TRUE(s.Contains(0x7fffffffu));
+  EXPECT_FALSE(s.Contains(static_cast<SpvCapability>(1000)));
+  EXPECT_TRUE(s.Contains(static_cast<SpvCapability>(0x7fffffffu)));
 }
 
 TEST(CapabilitySet, AddEnum) {
   CapabilitySet s(SpvCapabilityShader);
   s.Add(SpvCapabilityKernel);
+  s.Add(static_cast<SpvCapability>(42));
   EXPECT_FALSE(s.Contains(SpvCapabilityMatrix));
   EXPECT_TRUE(s.Contains(SpvCapabilityShader));
   EXPECT_TRUE(s.Contains(SpvCapabilityKernel));
-}
-
-TEST(CapabilitySet, AddInt) {
-  CapabilitySet s(SpvCapabilityShader);
-  s.Add(42);
-  EXPECT_FALSE(s.Contains(SpvCapabilityMatrix));
-  EXPECT_TRUE(s.Contains(SpvCapabilityShader));
-  EXPECT_TRUE(s.Contains(42));
   EXPECT_TRUE(s.Contains(static_cast<SpvCapability>(42)));
 }
 
 TEST(CapabilitySet, InitializerListEmpty) {
   CapabilitySet s{};
   for (uint32_t i = 0; i < 1000; i++) {
-    EXPECT_FALSE(s.Contains(i));
+    EXPECT_FALSE(s.Contains(static_cast<SpvCapability>(i)));
   }
 }