Added Huffman codec to utils
authorAndrey Tuganov <andreyt@google.com>
Thu, 25 May 2017 15:21:12 +0000 (11:21 -0400)
committerDavid Neto <dneto@google.com>
Thu, 29 Jun 2017 18:51:01 +0000 (14:51 -0400)
Attached ids to Huffman nodes for deterministic internal node
comparison.

source/util/bit_stream.h
source/util/huffman_codec.h [new file with mode: 0644]
test/CMakeLists.txt
test/huffman_codec.cpp [new file with mode: 0644]

index a139b63..14c3a66 100644 (file)
@@ -17,6 +17,7 @@
 #ifndef LIBSPIRV_UTIL_BIT_STREAM_H_
 #define LIBSPIRV_UTIL_BIT_STREAM_H_
 
+#include <algorithm>
 #include <bitset>
 #include <cstdint>
 #include <string>
diff --git a/source/util/huffman_codec.h b/source/util/huffman_codec.h
new file mode 100644 (file)
index 0000000..2e74d6b
--- /dev/null
@@ -0,0 +1,299 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Contains utils for reading, writing and debug printing bit streams.
+
+#ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
+#define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
+
+#include <algorithm>
+#include <cassert>
+#include <functional>
+#include <queue>
+#include <iomanip>
+#include <map>
+#include <memory>
+#include <ostream>
+#include <sstream>
+#include <stack>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
+
+namespace spvutils {
+
+// Used to generate and apply a Huffman coding scheme.
+// |Val| is the type of variable being encoded (for example a string or a
+// literal).
+template <class Val>
+class HuffmanCodec {
+  struct Node;
+
+ public:
+  // Creates Huffman codec from a histogramm.
+  // Histogramm counts must not be zero.
+  explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
+    if (hist.empty()) return;
+
+    // Heuristic estimate.
+    all_nodes_.reserve(3 * hist.size());
+
+    // The queue is sorted in ascending order by weight (or by node id if
+    // weights are equal).
+    std::vector<Node*> queue_vector;
+    queue_vector.reserve(hist.size());
+    std::priority_queue<Node*, std::vector<Node*>,
+        std::function<bool(const Node*, const Node*)>>
+            queue(LeftIsBigger, std::move(queue_vector));
+
+    // Put all leaves in the queue.
+    for (const auto& pair : hist) {
+      Node* node = CreateNode();
+      node->val = pair.first;
+      node->weight = pair.second;
+      assert(node->weight);
+      queue.push(node);
+    }
+
+    // Form the tree by combining two subtrees with the least weight,
+    // and pushing the root of the new tree in the queue.
+    while (true) {
+      // We push a node at the end of each iteration, so the queue is never
+      // supposed to be empty at this point, unless there are no leaves, but
+      // that case was already handled.
+      assert(!queue.empty());
+      Node* right = queue.top();
+      queue.pop();
+
+      // If the queue is empty at this point, then the last node is
+      // the root of the complete Huffman tree.
+      if (queue.empty()) {
+        root_ = right;
+        break;
+      }
+
+      Node* left = queue.top();
+      queue.pop();
+
+      // Combine left and right into a new tree and push it into the queue.
+      Node* parent = CreateNode();
+      parent->weight = right->weight + left->weight;
+      parent->left = left;
+      parent->right = right;
+      queue.push(parent);
+    }
+
+    // Traverse the tree and form encoding table.
+    CreateEncodingTable();
+  }
+
+  // Prints the Huffman tree in the following format:
+  // w------w------'x'
+  //        w------'y'
+  // Where w stands for the weight of the node.
+  // Right tree branches appear above left branches. Taking the right path
+  // adds 1 to the code, taking the left adds 0.
+  void PrintTree(std::ostream& out) {
+    PrintTreeInternal(out, root_, 0);
+  }
+
+  // Traverses the tree and prints the Huffman table: value, code
+  // and optionally node weight for every leaf.
+  void PrintTable(std::ostream& out, bool print_weights = true) {
+    std::queue<std::pair<Node*, std::string>> queue;
+    queue.emplace(root_, "");
+
+    while (!queue.empty()) {
+      const Node* node = queue.front().first;
+      const std::string code = queue.front().second;
+      queue.pop();
+      if (!node->right && !node->left) {
+        out << node->val;
+        if (print_weights)
+            out << " " << node->weight;
+        out << " " << code << std::endl;
+      } else {
+        if (node->left)
+          queue.emplace(node->left, code + "0");
+
+        if (node->right)
+          queue.emplace(node->right, code + "1");
+      }
+    }
+  }
+
+  // Returns the Huffman table. The table was built at at construction time,
+  // this function just returns a const reference.
+  const std::unordered_map<Val, std::pair<uint64_t, size_t>>&
+      GetEncodingTable() const {
+    return encoding_table_;
+  }
+
+  // Encodes |val| and stores its Huffman code in the lower |num_bits| of
+  // |bits|. Returns false of |val| is not in the Huffman table.
+  bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) {
+    auto it = encoding_table_.find(val);
+    if (it == encoding_table_.end())
+      return false;
+    *bits = it->second.first;
+    *num_bits = it->second.second;
+    return true;
+  }
+
+  // Reads bits one-by-one using callback |read_bit| until a match is found.
+  // Matching value is stored in |val|. Returns false if |read_bit| terminates
+  // before a code was mathced.
+  // |read_bit| has type bool func(bool* bit). When called, the next bit is
+  // stored in |bit|. |read_bit| returns false if the stream terminates
+  // prematurely.
+  bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) {
+    Node* node = root_;
+    while (true) {
+      assert(node);
+
+      if (node->left == nullptr && node->right == nullptr) {
+        *val = node->val;
+        return true;
+      }
+
+      bool go_right;
+      if (!read_bit(&go_right))
+        return false;
+
+      if (go_right)
+        node = node->right;
+      else
+        node = node->left;
+    }
+
+    assert (0);
+    return false;
+  }
+
+ private:
+  // Huffman tree node.
+  struct Node {
+    Val val = Val();
+    uint32_t weight = 0;
+    // Ids are issued sequentially starting from 1. Ids are used as an ordering
+    // tie-breaker, to make sure that the ordering (and resulting coding scheme)
+    // are consistent accross multiple platforms.
+    uint32_t id = 0;
+    Node* left = nullptr;
+    Node* right = nullptr;
+  };
+
+  // Returns true if |left| has bigger weight than |right|. Node ids are
+  // used as tie-breaker.
+  static bool LeftIsBigger(const Node* left, const Node* right) {
+    if (left->weight == right->weight) {
+      assert (left->id != right->id);
+      return left->id > right->id;
+    }
+    return left->weight > right->weight;
+  }
+
+  // Prints subtree (helper function used by PrintTree).
+  static void PrintTreeInternal(std::ostream& out, Node* node, size_t depth) {
+    if (!node)
+      return;
+
+    const size_t kTextFieldWidth = 7;
+
+    if (!node->right && !node->left) {
+      out << node->val << std::endl;
+    } else {
+      if (node->right) {
+        std::stringstream label;
+        label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
+              << node->right->weight;
+        out << label.str();
+        PrintTreeInternal(out, node->right, depth + 1);
+      }
+
+      if (node->left) {
+        out << std::string(depth * kTextFieldWidth, ' ');
+        std::stringstream label;
+        label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
+              << node->left->weight;
+        out << label.str();
+        PrintTreeInternal(out, node->left, depth + 1);
+      }
+    }
+  }
+
+  // Traverses the Huffman tree and saves paths to the leaves as bit
+  // sequences to encoding_table_.
+  void CreateEncodingTable() {
+    struct Context {
+      Context(Node* in_node, uint64_t in_bits, size_t in_depth)
+          :  node(in_node), bits(in_bits), depth(in_depth) {}
+      Node* node;
+      // Huffman tree depth cannot exceed 64 as histogramm counts are expected
+      // to be positive and limited by numeric_limits<uint32_t>::max().
+      // For practical applications tree depth would be much smaller than 64.
+      uint64_t bits;
+      size_t depth;
+    };
+
+    std::queue<Context> queue;
+    queue.emplace(root_, 0, 0);
+
+    while (!queue.empty()) {
+      const Context& context = queue.front();
+      const Node* node = context.node;
+      const uint64_t bits = context.bits;
+      const size_t depth = context.depth;
+      queue.pop();
+
+      if (!node->right && !node->left) {
+        auto insertion_result = encoding_table_.emplace(
+            node->val, std::pair<uint64_t, size_t>(bits, depth));
+        assert(insertion_result.second);
+        (void)insertion_result;
+      } else {
+        if (node->left)
+          queue.emplace(node->left, bits, depth + 1);
+
+        if (node->right)
+          queue.emplace(node->right, bits | (1ULL << depth), depth + 1);
+      }
+    }
+  }
+
+  // Creates new Huffman tree node and stores it in the deleter array.
+  Node* CreateNode() {
+    all_nodes_.emplace_back(new Node());
+    all_nodes_.back()->id = next_node_id_++;
+    return all_nodes_.back().get();
+  }
+
+  // Huffman tree root.
+  Node* root_ = nullptr;
+
+  // Huffman tree deleter.
+  std::vector<std::unique_ptr<Node>> all_nodes_;
+
+  // Encoding table value -> {bits, num_bits}.
+  // Huffman codes are expected to never exceed 64 bit length (this is in fact
+  // impossible if frequencies are stored as uint32_t).
+  std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
+
+  // Next node id issued by CreateNode();
+  uint32_t next_node_id_ = 1;
+};
+
+}  // namespace spvutils
+
+#endif  // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
index 926dade..c26de52 100644 (file)
@@ -169,6 +169,11 @@ add_spvtools_unittest(
   SRCS bit_stream.cpp
   LIBS ${SPIRV_TOOLS})
 
+add_spvtools_unittest(
+  TARGET huffman_codec
+  SRCS huffman_codec.cpp
+  LIBS ${SPIRV_TOOLS})
+
 add_subdirectory(opt)
 add_subdirectory(val)
 add_subdirectory(stats)
diff --git a/test/huffman_codec.cpp b/test/huffman_codec.cpp
new file mode 100644 (file)
index 0000000..80f7d8f
--- /dev/null
@@ -0,0 +1,220 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Contains utils for reading, writing and debug printing bit streams.
+
+#include <map>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+
+#include "util/huffman_codec.h"
+#include "util/bit_stream.h"
+#include "gmock/gmock.h"
+
+namespace {
+
+using spvutils::HuffmanCodec;
+using spvutils::BitsToStream;
+
+const std::map<std::string, uint32_t>& GetTestSet() {
+  static const std::map<std::string, uint32_t> hist = {
+    {"a", 4},
+    {"e", 7},
+    {"f", 3},
+    {"h", 2},
+    {"i", 3},
+    {"m", 2},
+    {"n", 2},
+    {"s", 2},
+    {"t", 2},
+    {"l", 1},
+    {"o", 2},
+    {"p", 1},
+    {"r", 1},
+    {"u", 1},
+    {"x", 1},
+  };
+
+  return hist;
+}
+
+class TestBitReader {
+ public:
+  TestBitReader(const std::string& bits) : bits_(bits) {}
+
+  bool ReadBit(bool* bit) {
+    if (pos_ < bits_.length()) {
+      *bit = bits_[pos_++] == '0' ? false : true;
+      return true;
+    }
+    return false;
+  }
+
+ private:
+  std::string bits_;
+  size_t pos_ = 0;
+};
+
+TEST(Huffman, PrintTree) {
+  HuffmanCodec<std::string> huffman(GetTestSet());
+  std::stringstream ss;
+  huffman.PrintTree(ss);
+
+  const std::string expected = std::string(R"(
+15-----7------e
+       8------4------a
+              4------2------m
+                     2------n
+19-----8------4------2------o
+                     2------s
+              4------2------t
+                     2------1------l
+                            1------p
+       11-----5------2------1------r
+                            1------u
+                     3------f
+              6------3------i
+                     3------1------x
+                            2------h
+)").substr(1);
+
+  EXPECT_EQ(expected, ss.str());
+}
+
+TEST(Huffman, PrintTable) {
+  HuffmanCodec<std::string> huffman(GetTestSet());
+  std::stringstream ss;
+  huffman.PrintTable(ss);
+
+  const std::string expected = std::string(R"(
+e 7 11
+a 4 101
+i 3 0001
+f 3 0010
+t 2 0101
+s 2 0110
+o 2 0111
+n 2 1000
+m 2 1001
+h 2 00000
+x 1 00001
+u 1 00110
+r 1 00111
+p 1 01000
+l 1 01001
+)").substr(1);
+
+  EXPECT_EQ(expected, ss.str());
+}
+
+TEST(Huffman, TestValidity) {
+  HuffmanCodec<std::string> huffman(GetTestSet());
+  const auto& encoding_table = huffman.GetEncodingTable();
+  std::vector<std::string> codes;
+  for (const auto& entry : encoding_table) {
+    codes.push_back(BitsToStream(entry.second.first, entry.second.second));
+  }
+
+  std::sort(codes.begin(), codes.end());
+
+  ASSERT_LT(codes.size(), 20u) << "Inefficient test ahead";
+
+  for (size_t i = 0; i < codes.size(); ++i) {
+    for (size_t j = i + 1; j < codes.size(); ++j) {
+      ASSERT_FALSE(codes[i] == codes[j].substr(0, codes[i].length()))
+          << codes[i] << " is prefix of " << codes[j];
+    }
+  }
+}
+
+TEST(Huffman, TestEncode) {
+  HuffmanCodec<std::string> huffman(GetTestSet());
+
+  uint64_t bits = 0;
+  size_t num_bits = 0;
+
+  EXPECT_TRUE(huffman.Encode("e", &bits, &num_bits));
+  EXPECT_EQ(2u, num_bits);
+  EXPECT_EQ("11", BitsToStream(bits, num_bits));
+
+  EXPECT_TRUE(huffman.Encode("a", &bits, &num_bits));
+  EXPECT_EQ(3u, num_bits);
+  EXPECT_EQ("101", BitsToStream(bits, num_bits));
+
+  EXPECT_TRUE(huffman.Encode("x", &bits, &num_bits));
+  EXPECT_EQ(5u, num_bits);
+  EXPECT_EQ("00001", BitsToStream(bits, num_bits));
+
+  EXPECT_FALSE(huffman.Encode("y", &bits, &num_bits));
+}
+
+TEST(Huffman, TestDecode) {
+  HuffmanCodec<std::string> huffman(GetTestSet());
+  TestBitReader bit_reader("01001""0001""1000""00110""00001""00");
+  auto read_bit = [&bit_reader](bool* bit) {
+    return bit_reader.ReadBit(bit);
+  };
+
+  std::string decoded;
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ("l", decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ("i", decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ("n", decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ("u", decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ("x", decoded);
+
+  ASSERT_FALSE(huffman.DecodeFromStream(read_bit, &decoded));
+}
+
+TEST(Huffman, TestDecodeNumbers) {
+  const std::map<uint32_t, uint32_t> hist = { {1, 10}, {2, 5}, {3, 15} };
+  HuffmanCodec<uint32_t> huffman(hist);
+
+  TestBitReader bit_reader("1""1""01""00""01""1");
+  auto read_bit = [&bit_reader](bool* bit) {
+    return bit_reader.ReadBit(bit);
+  };
+
+  uint32_t decoded;
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ(3u, decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ(3u, decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ(2u, decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ(1u, decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ(2u, decoded);
+
+  ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+  EXPECT_EQ(3u, decoded);
+}
+
+}  // anonymous namespace