ArmNN
 20.02
BaseIterator.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
9 #include <armnn/TypesUtils.hpp>
11 
12 #include <ResolveType.hpp>
13 
14 #include <boost/assert.hpp>
15 
16 namespace armnn
17 {
18 
20 {
21 public:
23 
24  virtual ~BaseIterator() {}
25 
26  virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0;
27 
28  virtual BaseIterator& operator++() = 0;
29 
30  virtual BaseIterator& operator+=(const unsigned int increment) = 0;
31 
32  virtual BaseIterator& operator-=(const unsigned int increment) = 0;
33 
34  virtual BaseIterator& operator[](const unsigned int index) = 0;
35 };
36 
37 template<typename IType>
38 class Decoder : public BaseIterator
39 {
40 public:
41  Decoder() {}
42 
43  virtual ~Decoder() {}
44 
45  virtual void Reset(void*) = 0;
46 
47  virtual IType Get() const = 0;
48 };
49 
50 template<typename IType>
51 class Encoder : public BaseIterator
52 {
53 public:
54  Encoder() {}
55 
56  virtual ~Encoder() {}
57 
58  virtual void Reset(void*) = 0;
59 
60  virtual void Set(IType right) = 0;
61 
62  virtual IType Get() const = 0;
63 };
64 
65 template<typename T, typename Base>
66 class TypedIterator : public Base
67 {
68 public:
69  TypedIterator(T* data = nullptr)
70  : m_Iterator(data), m_Start(data)
71  {}
72 
73  void Reset(void* data) override
74  {
75  m_Iterator = reinterpret_cast<T*>(data);
76  m_Start = m_Iterator;
77  }
78 
80  {
81  BOOST_ASSERT(m_Iterator);
82  ++m_Iterator;
83  return *this;
84  }
85 
86  TypedIterator& operator+=(const unsigned int increment) override
87  {
88  BOOST_ASSERT(m_Iterator);
89  m_Iterator += increment;
90  return *this;
91  }
92 
93  TypedIterator& operator-=(const unsigned int increment) override
94  {
95  BOOST_ASSERT(m_Iterator);
96  m_Iterator -= increment;
97  return *this;
98  }
99 
100  TypedIterator& operator[](const unsigned int index) override
101  {
102  BOOST_ASSERT(m_Iterator);
103  m_Iterator = m_Start + index;
104  return *this;
105  }
106 
107  TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override
108  {
109  IgnoreUnused(axisIndex);
110  BOOST_ASSERT(m_Iterator);
111  m_Iterator = m_Start + index;
112  return *this;
113  }
114 
115 protected:
118 };
119 
120 class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder<float>>
121 {
122 public:
123  QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
124  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
125 
126  QASymm8Decoder(const float scale, const int32_t offset)
127  : QASymm8Decoder(nullptr, scale, offset) {}
128 
129  float Get() const override
130  {
131  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
132  }
133 
134 private:
135  const float m_Scale;
136  const int32_t m_Offset;
137 };
138 
139 class QASymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
140 {
141 public:
142  QASymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
143  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
144 
145  QASymmS8Decoder(const float scale, const int32_t offset)
146  : QASymmS8Decoder(nullptr, scale, offset) {}
147 
148  float Get() const override
149  {
150  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
151  }
152 
153 private:
154  const float m_Scale;
155  const int32_t m_Offset;
156 };
157 
158 class QSymmS8Decoder : public TypedIterator<const int8_t, Decoder<float>>
159 {
160 public:
161  QSymmS8Decoder(const int8_t* data, const float scale, const int32_t offset)
162  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
163 
164  QSymmS8Decoder(const float scale, const int32_t offset)
165  : QSymmS8Decoder(nullptr, scale, offset) {}
166 
167  float Get() const override
168  {
169  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
170  }
171 
172 private:
173  const float m_Scale;
174  const int32_t m_Offset;
175 };
176 
177 class QSymm16Decoder : public TypedIterator<const int16_t, Decoder<float>>
178 {
179 public:
180  QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
181  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
182 
183  QSymm16Decoder(const float scale, const int32_t offset)
184  : QSymm16Decoder(nullptr, scale, offset) {}
185 
186  float Get() const override
187  {
188  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
189  }
190 
191 private:
192  const float m_Scale;
193  const int32_t m_Offset;
194 };
195 
196 class BFloat16Decoder : public TypedIterator<const BFloat16, Decoder<float>>
197 {
198 public:
200  : TypedIterator(data) {}
201 
203  : BFloat16Decoder(nullptr) {}
204 
205  float Get() const override
206  {
207  float val = 0.f;
209  return val;
210  }
211 };
212 
213 class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
214 {
215 public:
216  Float16Decoder(const Half* data)
217  : TypedIterator(data) {}
218 
220  : Float16Decoder(nullptr) {}
221 
222  float Get() const override
223  {
224  float val = 0.f;
226  return val;
227  }
228 };
229 
230 class Float32Decoder : public TypedIterator<const float, Decoder<float>>
231 {
232 public:
233  Float32Decoder(const float* data)
234  : TypedIterator(data) {}
235 
237  : Float32Decoder(nullptr) {}
238 
239  float Get() const override
240  {
241  return *m_Iterator;
242  }
243 };
244 
245 class ScaledInt32Decoder : public TypedIterator<const int32_t, Decoder<float>>
246 {
247 public:
248  ScaledInt32Decoder(const int32_t* data, const float scale)
249  : TypedIterator(data), m_Scale(scale) {}
250 
251  ScaledInt32Decoder(const float scale)
252  : ScaledInt32Decoder(nullptr, scale) {}
253 
254  float Get() const override
255  {
256  return static_cast<float>(*m_Iterator) * m_Scale;
257  }
258 
259 private:
260  const float m_Scale;
261 };
262 
263 class Int32Decoder : public TypedIterator<const int32_t, Decoder<float>>
264 {
265 public:
266  Int32Decoder(const int32_t* data)
267  : TypedIterator(data) {}
268 
270  : Int32Decoder(nullptr) {}
271 
272  float Get() const override
273  {
274  return static_cast<float>(*m_Iterator);
275  }
276 };
277 
278 class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
279 {
280 public:
281  QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
282  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
283 
284  QASymm8Encoder(const float scale, const int32_t offset)
285  : QASymm8Encoder(nullptr, scale, offset) {}
286 
287  void Set(float right) override
288  {
289  *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
290  }
291 
292  float Get() const override
293  {
294  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
295  }
296 
297 private:
298  const float m_Scale;
299  const int32_t m_Offset;
300 };
301 
302 class QASymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
303 {
304 public:
305  QASymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
306  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
307 
308  QASymmS8Encoder(const float scale, const int32_t offset)
309  : QASymmS8Encoder(nullptr, scale, offset) {}
310 
311  void Set(float right) override
312  {
313  *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
314  }
315 
316  float Get() const override
317  {
318  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
319  }
320 
321 private:
322  const float m_Scale;
323  const int32_t m_Offset;
324 };
325 
326 class QSymmS8Encoder : public TypedIterator<int8_t, Encoder<float>>
327 {
328 public:
329  QSymmS8Encoder(int8_t* data, const float scale, const int32_t offset)
330  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
331 
332  QSymmS8Encoder(const float scale, const int32_t offset)
333  : QSymmS8Encoder(nullptr, scale, offset) {}
334 
335  void Set(float right) override
336  {
337  *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale, m_Offset);
338  }
339 
340  float Get() const override
341  {
342  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
343  }
344 
345 private:
346  const float m_Scale;
347  const int32_t m_Offset;
348 };
349 
350 class QSymm16Encoder : public TypedIterator<int16_t, Encoder<float>>
351 {
352 public:
353  QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
354  : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
355 
356  QSymm16Encoder(const float scale, const int32_t offset)
357  : QSymm16Encoder(nullptr, scale, offset) {}
358 
359  void Set(float right) override
360  {
361  *m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
362  }
363 
364  float Get() const override
365  {
366  return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
367  }
368 
369 private:
370  const float m_Scale;
371  const int32_t m_Offset;
372 };
373 
374 class BFloat16Encoder : public TypedIterator<armnn::BFloat16, Encoder<float>>
375 {
376 public:
378  : TypedIterator(data) {}
379 
381  : BFloat16Encoder(nullptr) {}
382 
383  void Set(float right) override
384  {
386  }
387 
388  float Get() const override
389  {
390  float val = 0.f;
392  return val;
393  }
394 };
395 
396 class Float16Encoder : public TypedIterator<Half, Encoder<float>>
397 {
398 public:
400  : TypedIterator(data) {}
401 
403  : Float16Encoder(nullptr) {}
404 
405  void Set(float right) override
406  {
408  }
409 
410  float Get() const override
411  {
412  float val = 0.f;
414  return val;
415  }
416 };
417 
418 class Float32Encoder : public TypedIterator<float, Encoder<float>>
419 {
420 public:
421  Float32Encoder(float* data)
422  : TypedIterator(data) {}
423 
425  : Float32Encoder(nullptr) {}
426 
427  void Set(float right) override
428  {
429  *m_Iterator = right;
430  }
431 
432  float Get() const override
433  {
434  return *m_Iterator;
435  }
436 };
437 
438 class Int32Encoder : public TypedIterator<int32_t, Encoder<float>>
439 {
440 public:
441  Int32Encoder(int32_t* data)
442  : TypedIterator(data) {}
443 
445  : Int32Encoder(nullptr) {}
446 
447  void Set(float right) override
448  {
449  *m_Iterator = static_cast<int32_t>(right);
450  }
451 
452  float Get() const override
453  {
454  return static_cast<float>(*m_Iterator);
455  }
456 };
457 
458 class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
459 {
460 public:
461  BooleanEncoder(uint8_t* data)
462  : TypedIterator(data) {}
463 
465  : BooleanEncoder(nullptr) {}
466 
467  void Set(bool right) override
468  {
469  *m_Iterator = right;
470  }
471 
472  bool Get() const override
473  {
474  return *m_Iterator;
475  }
476 };
477 
478 // PerAxisIterator for per-axis quantization
479 template<typename T, typename Base>
480 class PerAxisIterator : public Base
481 {
482 public:
483  // axisFactor is used to calculate axisIndex
484  PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
485  : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
486  {}
487 
488  // This should be called to set index for per-axis Encoder/Decoder
489  PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override
490  {
491  BOOST_ASSERT(m_Iterator);
492  m_Iterator = m_Start + index;
493  m_AxisIndex = axisIndex;
494  return *this;
495  }
496 
497  void Reset(void* data) override
498  {
499  m_Iterator = reinterpret_cast<T*>(data);
500  m_Start = m_Iterator;
501  m_AxisIndex = 0;
502  }
503 
505  {
506  BOOST_ASSERT(m_Iterator);
507  ++m_Iterator;
508  m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
509  return *this;
510  }
511 
512  PerAxisIterator& operator+=(const unsigned int increment) override
513  {
514  BOOST_ASSERT(m_Iterator);
515  m_Iterator += increment;
516  m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
517  return *this;
518  }
519 
520  PerAxisIterator& operator-=(const unsigned int decrement) override
521  {
522  BOOST_ASSERT(m_Iterator);
523  m_Iterator -= decrement;
524  m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
525  return *this;
526  }
527 
528  PerAxisIterator& operator[](const unsigned int index) override
529  {
530  BOOST_ASSERT(m_Iterator);
531  m_Iterator = m_Start + index;
532  m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
533  return *this;
534  }
535 
536  protected:
539  unsigned int m_AxisIndex;
540  unsigned int m_AxisFactor;
541 };
542 
543 class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
544 {
545 public:
546  QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
547  : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
548 
549  float Get() const override
550  {
551  return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
552  }
553 
554  // Get scale of the current value
555  float GetScale() const
556  {
557  return m_Scale[m_AxisIndex];
558  }
559 
560 private:
561  std::vector<float> m_Scale;
562 };
563 
564 class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
565 {
566 public:
567  QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
568  : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
569 
570  void Set(float right)
571  {
572  *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
573  }
574 
575  float Get() const
576  {
577  return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
578  }
579 
580  // Get scale of the current value
581  float GetScale() const
582  {
583  return m_Scale[m_AxisIndex];
584  }
585 
586 private:
587  std::vector<float> m_Scale;
588 };
589 
590 class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>>
591 {
592 public:
593  ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor)
594  : PerAxisIterator(data, axisFactor), m_Scales(scales) {}
595 
596  float Get() const override
597  {
598  return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0);
599  }
600 
601  // Get scale of the current value
602  float GetScale() const
603  {
604  return m_Scales[m_AxisIndex];
605  }
606 
607 private:
608  std::vector<float> m_Scales;
609 };
610 
611 } // namespace armnn
PerAxisIterator & operator++() override
Float32Decoder(const float *data)
QSymm8PerAxisDecoder(const int8_t *data, const std::vector< float > &scale, unsigned int axisFactor)
void Set(float right) override
ScaledInt32Decoder(const float scale)
BFloat16Decoder(const BFloat16 *data)
PerAxisIterator(T *data=nullptr, unsigned int axisFactor=0)
virtual BaseIterator & operator-=(const unsigned int increment)=0
float Get() const override
PerAxisIterator & operator[](const unsigned int index) override
QSymmS8Decoder(const float scale, const int32_t offset)
float Get() const override
QSymm16Encoder(const float scale, const int32_t offset)
static void ConvertBFloat16ToFloat32(const void *srcBFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
void Reset(void *data) override
void Set(float right) override
ScaledInt32Decoder(const int32_t *data, const float scale)
QSymm16Decoder(const int16_t *data, const float scale, const int32_t offset)
void Set(bool right) override
float Get() const override
QSymmS8Encoder(int8_t *data, const float scale, const int32_t offset)
ScaledInt32PerAxisDecoder(const int32_t *data, const std::vector< float > &scales, unsigned int axisFactor)
BFloat16Encoder(armnn::BFloat16 *data)
void Set(float right) override
QSymmS8Decoder(const int8_t *data, const float scale, const int32_t offset)
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
float Get() const override
void Set(float right) override
float Get() const override
float Get() const override
QASymmS8Decoder(const int8_t *data, const float scale, const int32_t offset)
float Get() const override
Int32Decoder(const int32_t *data)
float Get() const override
virtual BaseIterator & operator[](const unsigned int index)=0
TypedIterator & operator[](const unsigned int index) override
QASymmS8Decoder(const float scale, const int32_t offset)
QASymm8Encoder(const float scale, const int32_t offset)
static void ConvertFloat32To16(const float *srcFloat32Buffer, size_t numElements, void *dstFloat16Buffer)
Converts a buffer of FP32 values to FP16, and stores in the given dstFloat16Buffer.
float Get() const override
QSymmS8Encoder(const float scale, const int32_t offset)
Int32Encoder(int32_t *data)
void Set(float right) override
void Reset(void *data) override
virtual ~Decoder()
virtual BaseIterator & operator++()=0
float Get() const override
QASymm8Encoder(uint8_t *data, const float scale, const int32_t offset)
float Get() const override
static void ConvertFloat16To32(const void *srcFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
BooleanEncoder(uint8_t *data)
float Get() const override
float Get() const override
float Get() const override
Float16Encoder(Half *data)
QSymm8PerAxisEncoder(int8_t *data, const std::vector< float > &scale, unsigned int axisFactor)
QASymm8Decoder(const float scale, const int32_t offset)
TypedIterator & operator++() override
Float16Decoder(const Half *data)
virtual BaseIterator & SetIndex(unsigned int index, unsigned int axisIndex=0)=0
float Get() const override
float Get() const override
PerAxisIterator & SetIndex(unsigned int index, unsigned int axisIndex) override
QASymmS8Encoder(const float scale, const int32_t offset)
PerAxisIterator & operator-=(const unsigned int decrement) override
virtual BaseIterator & operator+=(const unsigned int increment)=0
static void ConvertFloat32ToBFloat16(const float *srcFloat32Buffer, size_t numElements, void *dstBFloat16Buffer)
void Set(float right) override
QSymm16Decoder(const float scale, const int32_t offset)
Float32Encoder(float *data)
float Get() const override
QASymm8Decoder(const uint8_t *data, const float scale, const int32_t offset)
TypedIterator & operator+=(const unsigned int increment) override
virtual ~Encoder()
float Get() const override
QSymm16Encoder(int16_t *data, const float scale, const int32_t offset)
PerAxisIterator & operator+=(const unsigned int increment) override
void Set(float right) override
half_float::half Half
Definition: Half.hpp:16
QASymmS8Encoder(int8_t *data, const float scale, const int32_t offset)
TypedIterator & SetIndex(unsigned int index, unsigned int axisIndex=0) override
TypedIterator & operator-=(const unsigned int increment) override
TypedIterator(T *data=nullptr)
void Set(float right) override
bool Get() const override
float Get() const override