Add meta-data to express dynamic shapes in ITensorInfo
authorGeorgios Pinitas <georgios.pinitas@arm.com>
Fri, 8 Jan 2021 03:14:31 +0000 (03:14 +0000)
committerGeorgios Pinitas <georgios.pinitas@arm.com>
Tue, 12 Jan 2021 03:50:44 +0000 (03:50 +0000)
Add `get_tensor_shape_state` and `set_tensor_shape_state` to inject
shape dynamism.
The state is represented by an array of integers which index maps to the
respective shape dimension index.
If -1 is passed as a dimension state then the corresponding dimension
is dynamic.

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I3a8a5ad109b90d4df8545b460a9f8dfcc13dfa0f
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4784
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>

arm_compute/core/ITensorInfo.h
arm_compute/core/SubTensorInfo.h
arm_compute/core/TensorInfo.h
src/core/SubTensorInfo.cpp
src/core/TensorInfo.cpp

index 3eb7239460f0350951f3fc34417007576fe977f0..9ddafce7c0b7309fa8133c0c780b095fd6190e59 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -39,6 +39,9 @@ namespace arm_compute
 /** Store the tensor's metadata */
 class ITensorInfo : public misc::ICloneable<ITensorInfo>
 {
+public:
+    using TensorDimsState = Coordinates;
+
 public:
     /** Default virtual destructor */
     virtual ~ITensorInfo() = default;
@@ -81,6 +84,17 @@ public:
      * @return Reference to this ITensorInfo object
      */
     virtual ITensorInfo &set_tensor_shape(const TensorShape &shape) = 0;
+    /** Set the state for each dimension of the tensor
+     *
+     * This sets the state of each dimension of the shape in terms of dynamic behavior using -1 where appropriate.
+     * The index in the state is a 1 to 1 mapping with the shape dimension index.
+     * For example if you want to express [?, 3, 3] as a dynamic input then [-1, 3, 3] has to be set as a state
+     *
+     * @param[in] state Tensor dimensions state
+     *
+     * @return Reference to this ITensorInfo object
+     */
+    virtual ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) = 0;
     /** Set the quantization settings (scale and offset) of the tensor.
      *
      * @param[in] quantization_info QuantizationInfo containing the scale and offset
@@ -170,6 +184,11 @@ public:
      * @return A vector with the size for each dimension of the tensor
      */
     virtual const TensorShape &tensor_shape() const = 0;
+    /** State of each dimension of the tensor shape
+     *
+     * @return A vector with the state for each dimension of the tensor, where -1 specifies dynamic dimension
+     */
+    virtual const TensorDimsState &tensor_dims_state() const = 0;
     /** Data type used for each element of the tensor
      *
      * @return Tensor data type
@@ -212,13 +231,6 @@ public:
      * @return Reference to this ITensorInfo object
      */
     virtual ITensorInfo &set_is_resizable(bool is_resizable) = 0;
-    /** Set the flag whether the tensor size is dynamic.
-     *
-     * @param[in] is_dynamic Flag that marks the tensor if it's dynamic.
-     *
-     * @return Reference to this ITensorInfo object
-     */
-    virtual ITensorInfo &set_is_dynamic(bool is_dynamic) = 0;
     /** Valid region of the tensor. All elements in the valid region have defined values, i.e. are not undefined.
      *
      * @return The valid region.
index 6654ccf00af22c4c99451a26c17202987fbc64d6..1b2278d99b13c2d2e4e1339d38f826d1c7771371 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -98,6 +98,7 @@ public:
         return *this;
     };
     ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
+    ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override;
     ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override
     {
         ARM_COMPUTE_ERROR_ON(_parent == nullptr);
@@ -155,6 +156,11 @@ public:
         ARM_COMPUTE_ERROR_ON(_parent == nullptr);
         return _tensor_shape;
     }
+    const TensorDimsState &tensor_dims_state() const override
+    {
+        ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+        return _dims_state;
+    }
     DataType data_type() const override
     {
         ARM_COMPUTE_ERROR_ON(_parent == nullptr);
@@ -196,12 +202,6 @@ public:
         _parent->set_is_resizable(is_resizable);
         return *this;
     }
-    ITensorInfo &set_is_dynamic(bool is_dynamic) override
-    {
-        ARM_COMPUTE_ERROR_ON(_parent == nullptr);
-        _parent->set_is_dynamic(is_dynamic);
-        return *this;
-    }
     ValidRegion valid_region() const override
     {
         return _valid_region;
@@ -228,11 +228,12 @@ public:
     }
 
 private:
-    ITensorInfo *_parent;
-    TensorShape  _tensor_shape;
-    Coordinates  _coords;
-    ValidRegion  _valid_region;
-    bool         _extend_parent;
+    ITensorInfo    *_parent;
+    TensorShape     _tensor_shape;
+    TensorDimsState _dims_state;
+    Coordinates     _coords;
+    ValidRegion     _valid_region;
+    bool            _extend_parent;
 };
 } // namespace arm_compute
 #endif /*ARM_COMPUTE_SUBTENSORINFO_H */
index 31f27328dd67d79ef5fdf622ce971e0a2717ace2..42a969e01b55b41e417dbe9e8d94b7cc592d70c0 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2019 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -224,6 +224,7 @@ public:
     ITensorInfo &set_num_channels(int num_channels) override;
     ITensorInfo &set_format(Format format) override;
     ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
+    ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override;
     ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override;
     ITensorInfo &set_data_layout(const DataLayout &data_layout) override;
     ITensorInfo &reset_padding() override;
@@ -262,6 +263,10 @@ public:
     {
         return _tensor_shape;
     }
+    const TensorDimsState &tensor_dims_state() const override
+    {
+        return _dims_state;
+    }
     DataType data_type() const override
     {
         return _data_type;
@@ -288,18 +293,13 @@ public:
     }
     bool is_dynamic() const override
     {
-        return _is_dynamic;
+        return std::find(std::cbegin(_dims_state), std::cend(_dims_state), -1) != std::cend(_dims_state);
     }
     ITensorInfo &set_is_resizable(bool is_resizable) override
     {
         _is_resizable = is_resizable;
         return *this;
     }
-    ITensorInfo &set_is_dynamic(bool is_dynamic) override
-    {
-        _is_dynamic = is_dynamic;
-        return *this;
-    }
     ValidRegion valid_region() const override
     {
         return _valid_region;
@@ -329,10 +329,10 @@ private:
     Strides          _strides_in_bytes;
     size_t           _num_channels;
     TensorShape      _tensor_shape;
+    TensorDimsState  _dims_state;
     DataType         _data_type;
     Format           _format;
     bool             _is_resizable;
-    bool             _is_dynamic;
     ValidRegion      _valid_region;
     PaddingSize      _padding;
     QuantizationInfo _quantization_info;
index bb8ecf60ea5eaf6201e9f3357991c0bbe6cd5496..6279992e89085f9e332e4517ebb1e9aa363a6e6e 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -56,12 +56,12 @@ TensorShape extend_parent_shape(TensorShape parent_shape, TensorShape shape, Coo
 } // namespace
 
 SubTensorInfo::SubTensorInfo()
-    : _parent(nullptr), _tensor_shape(), _coords(), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(false)
+    : _parent(nullptr), _tensor_shape(), _dims_state(), _coords(), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(false)
 {
 }
 
 SubTensorInfo::SubTensorInfo(ITensorInfo *parent, TensorShape tensor_shape, Coordinates coords, bool extend_parent)
-    : _parent(parent), _tensor_shape(tensor_shape), _coords(coords), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(extend_parent)
+    : _parent(parent), _tensor_shape(tensor_shape), _dims_state(), _coords(coords), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(extend_parent)
 {
     ARM_COMPUTE_ERROR_ON(parent == nullptr);
 
@@ -107,6 +107,13 @@ ITensorInfo &SubTensorInfo::set_tensor_shape(const TensorShape &shape)
     return *this;
 }
 
+ITensorInfo &SubTensorInfo::set_tensor_dims_state(const TensorDimsState &state)
+{
+    ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+    _dims_state = state;
+    return *this;
+}
+
 bool SubTensorInfo::extend_padding(const PaddingSize &padding)
 {
     ARM_COMPUTE_ERROR_ON(_parent == nullptr);
index 7b1f9c542a18c42b9aeb54106ac5dd23120f3821..bedfe147b0cd375fe58c48090626ff9b28c2df1f 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,7 +35,7 @@
 using namespace arm_compute;
 
 TensorInfo::TensorInfo()
-    : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true }, _is_dynamic{ false },
+    : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _dims_state(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true },
       _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW)
 {
 }
@@ -48,10 +48,10 @@ TensorInfo::TensorInfo(const ITensorInfo &info)
     _strides_in_bytes              = info.strides_in_bytes();
     _num_channels                  = info.num_channels();
     _tensor_shape                  = info.tensor_shape();
+    _dims_state                    = info.tensor_dims_state();
     _data_type                     = info.data_type();
     _format                        = info.format();
     _is_resizable                  = info.is_resizable();
-    _is_dynamic                    = info.is_dynamic();
     _valid_region                  = info.valid_region();
     _padding                       = info.padding();
     _quantization_info             = info.quantization_info();
@@ -371,6 +371,12 @@ ITensorInfo &TensorInfo::set_tensor_shape(const TensorShape &shape)
     return *this;
 }
 
+ITensorInfo &TensorInfo::set_tensor_dims_state(const TensorDimsState &state)
+{
+    _dims_state = state;
+    return *this;
+}
+
 ITensorInfo &TensorInfo::set_quantization_info(const QuantizationInfo &quantization_info)
 {
     _quantization_info = quantization_info;