Introduce runtime shape class.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 31 May 2018 17:15:59 +0000 (10:15 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 17:18:52 +0000 (10:18 -0700)
PiperOrigin-RevId: 198739017

tensorflow/contrib/lite/kernels/internal/types.h

index d5293ed..98ca21d 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -15,6 +15,9 @@ limitations under the License.
 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
 
+#include <cstring>
+#include <iterator>
+
 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
 
 namespace tflite {
@@ -44,6 +47,101 @@ struct Dims {
   int strides[N];
 };
 
+class RuntimeShape {
+ public:
+  // Shapes with dimensions up to 4 are stored directly in the structure, while
+  // larger shapes are separately allocated.
+  static constexpr int kMaxSmallSize = 4;
+
+  RuntimeShape() : size_(0) {}
+
+  explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
+    if (dimensions_count > kMaxSmallSize) {
+      dims_pointer_ = new int32[dimensions_count];
+    }
+  }
+
+  RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
+    ReplaceWith(dimensions_count, dims_data);
+  }
+
+  ~RuntimeShape() {
+    if (size_ > kMaxSmallSize) {
+      delete[] dims_pointer_;
+    }
+  }
+
+  inline const int32 DimensionsCount() const { return size_; }
+  inline const int32 Dims(int i) const {
+    TFLITE_DCHECK_GE(i, 0);
+    TFLITE_DCHECK_LT(i, size_);
+    return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
+  }
+  inline void SetDim(int i, int32 val) {
+    TFLITE_DCHECK_GE(i, 0);
+    TFLITE_DCHECK_LT(i, size_);
+    if (size_ > kMaxSmallSize) {
+      dims_pointer_[i] = val;
+    } else {
+      dims_[i] = val;
+    }
+  }
+  inline int32* DimsData() {
+    return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+  }
+  inline const int32* DimsData() const {
+    return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+  }
+
+  inline void Resize(int dimensions_count) {
+    if (size_ > kMaxSmallSize) {
+      delete[] dims_pointer_;
+    }
+    size_ = dimensions_count;
+    if (dimensions_count > kMaxSmallSize) {
+      dims_pointer_ = new int32[dimensions_count];
+    }
+  }
+
+  inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
+    Resize(dimensions_count);
+    int32* dst_dims = DimsData();
+    std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
+  }
+
+  template <typename T>
+  inline void BuildFrom(const T& src_iterable) {
+    const int dimensions_count =
+        std::distance(src_iterable.begin(), src_iterable.end());
+    Resize(dimensions_count);
+    int32* data = DimsData();
+    for (auto it : src_iterable) {
+      *data = it;
+      ++data;
+    }
+  }
+
+  // Returns the total count of elements, that is the size when flattened into a
+  // vector.
+  inline const int FlatSize() const {
+    int buffer_size = 1;
+    const int* dims_data = DimsData();
+    for (int i = 0; i < size_; i++) {
+      const int dim = dims_data[i];
+      TFLITE_DCHECK_GE(dim, 1);
+      buffer_size *= dim;
+    }
+    return buffer_size;
+  }
+
+ private:
+  int32 size_;
+  union {
+    int32 dims_[kMaxSmallSize];
+    int32* dims_pointer_;
+  };
+};
+
 // Gets next index to iterate through a multidimensional array.
 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
   TFLITE_DCHECK_GT(num_dims, 0);