-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#include <cstring>
+#include <iterator>
+
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
int strides[N];
};
+class RuntimeShape {
+ public:
+ // Shapes with dimensions up to 4 are stored directly in the structure, while
+ // larger shapes are separately allocated.
+ static constexpr int kMaxSmallSize = 4;
+
+ RuntimeShape() : size_(0) {}
+
+ explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
+ if (dimensions_count > kMaxSmallSize) {
+ dims_pointer_ = new int32[dimensions_count];
+ }
+ }
+
+ RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
+ ReplaceWith(dimensions_count, dims_data);
+ }
+
+ ~RuntimeShape() {
+ if (size_ > kMaxSmallSize) {
+ delete[] dims_pointer_;
+ }
+ }
+
+ inline const int32 DimensionsCount() const { return size_; }
+ inline const int32 Dims(int i) const {
+ TFLITE_DCHECK_GE(i, 0);
+ TFLITE_DCHECK_LT(i, size_);
+ return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
+ }
+ inline void SetDim(int i, int32 val) {
+ TFLITE_DCHECK_GE(i, 0);
+ TFLITE_DCHECK_LT(i, size_);
+ if (size_ > kMaxSmallSize) {
+ dims_pointer_[i] = val;
+ } else {
+ dims_[i] = val;
+ }
+ }
+ inline int32* DimsData() {
+ return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+ }
+ inline const int32* DimsData() const {
+ return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+ }
+
+ inline void Resize(int dimensions_count) {
+ if (size_ > kMaxSmallSize) {
+ delete[] dims_pointer_;
+ }
+ size_ = dimensions_count;
+ if (dimensions_count > kMaxSmallSize) {
+ dims_pointer_ = new int32[dimensions_count];
+ }
+ }
+
+ inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
+ Resize(dimensions_count);
+ int32* dst_dims = DimsData();
+ std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
+ }
+
+ template <typename T>
+ inline void BuildFrom(const T& src_iterable) {
+ const int dimensions_count =
+ std::distance(src_iterable.begin(), src_iterable.end());
+ Resize(dimensions_count);
+ int32* data = DimsData();
+ for (auto it : src_iterable) {
+ *data = it;
+ ++data;
+ }
+ }
+
+ // Returns the total count of elements, that is the size when flattened into a
+ // vector.
+ inline const int FlatSize() const {
+ int buffer_size = 1;
+ const int* dims_data = DimsData();
+ for (int i = 0; i < size_; i++) {
+ const int dim = dims_data[i];
+ TFLITE_DCHECK_GE(dim, 1);
+ buffer_size *= dim;
+ }
+ return buffer_size;
+ }
+
+ private:
+ int32 size_;
+ union {
+ int32 dims_[kMaxSmallSize];
+ int32* dims_pointer_;
+ };
+};
+
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
TFLITE_DCHECK_GT(num_dims, 0);