[Metadata/V3] Implement metadata for npubinfmt v3 (TRIV2)
authorDongju Chae <dongju.chae@samsung.com>
Thu, 23 Apr 2020 10:57:51 +0000 (19:57 +0900)
committer송욱/On-Device Lab(SR)/Staff Engineer/삼성전자 <wook16.song@samsung.com>
Fri, 8 May 2020 03:40:20 +0000 (12:40 +0900)
This patch implements metadata for npubinfmt v3 (TRIV2).

Signed-off-by: Dongju Chae <dongju.chae@samsung.com>
src/core/ne-model.cc
src/core/ne-model.h
src/core/ne-segment-table.cc

index 35d60e6..feee9bf 100644 (file)
@@ -57,6 +57,9 @@ Metadata::extractMetadata (void *data)
   case 2:
     meta_ins = new Metadata_v2 (meta_data);
     break;
+  case 3:
+    meta_ins = new Metadata_v3 (meta_data);
+    break;
   default:
     logerr (TAG, "Invalid NPU binary format version: %d\n",
         NPUBIN_VERSION (meta_data->magiccode));
@@ -162,6 +165,75 @@ Metadata_v2::getOutputTensorSize (uint32_t idx, data_layout layout) const
   return tensor_size;
 }
 
+/** @brief constructor of npubinfmt v3 */
+Metadata_v3::Metadata_v3 (npubin_meta *meta)
+  : Metadata (meta)
+{
+}
+
+/** @brief sanity check for npubinfmt v3 */
+bool
+Metadata_v3::checkSanity () const
+{
+  if (getVersion () != 3)
+    return false;
+  if (getSegmentsNum () > MAX_SEGMENTS)
+    return false;
+  if (getInputNum () > getSegmentsNum () || getInputNum () > MAX_TENSORS)
+    return false;
+  if (getOutputNum () > getSegmentsNum () || getOutputNum () > MAX_TENSORS)
+    return false;
+
+  if (getWeightSegmentIndex () >= getSegmentsNum ())
+      return false;
+  for (uint32_t i = 0; i < getInputNum (); i++) {
+    if (getInputSegmentIndex (i) >= getSegmentsNum ())
+      return false;
+  }
+  for (uint32_t i = 0; i < getOutputNum (); i++) {
+    if (getOutputSegmentIndex (i) >= getSegmentsNum ())
+      return false;
+  }
+
+  return getSize () == getMetaSize () + getProgramSize () + getWeightSize ();
+}
+
+/**
+ * @brief calculate input tensor size depending on specified layout
+ * @todo need to fix this when the data manipulation in TRIV2 is finalized
+ */
+uint32_t
+Metadata_v3::getInputTensorSize (uint32_t idx, data_layout layout) const
+{
+  assert (idx < getInputNum ());
+
+  const uint32_t *dims = getInputDims (idx);
+  uint32_t tensor_size = 1;
+
+  for (uint32_t rank_idx = 0; rank_idx < MAX_RANK; rank_idx++)
+    tensor_size *= dims[rank_idx];
+
+  return tensor_size;
+}
+
+/**
+ * @brief calculate output tensor size depending on specified layout
+ * @todo need to fix this when the data manipulation in TRIV2 is finalized
+ */
+uint32_t
+Metadata_v3::getOutputTensorSize (uint32_t idx, data_layout layout) const
+{
+  assert (idx < getOutputNum ());
+
+  const uint32_t *dims = getOutputDims (idx);
+  uint32_t tensor_size = 1;
+
+  for (uint32_t rank_idx = 0; rank_idx < MAX_RANK; rank_idx++)
+    tensor_size *= dims[rank_idx];
+
+  return tensor_size;
+}
+
 /** @brief constructor of model class */
 Model::Model (const HWmemImpl* impl)
   : HWmem (impl), meta_ (nullptr)
index 1e0afb8..d9e705e 100644 (file)
@@ -67,6 +67,15 @@ class Metadata {
     virtual uint32_t getOutputQuantZero (uint32_t idx) const = 0;
     virtual float getOutputQuantScale (uint32_t idx) const = 0;
 
+    virtual data_type getInputQuantType (uint32_t idx) const { return DATA_TYPE_SRNPU; }
+    virtual data_type getOutputQuantType (uint32_t idx) const { return DATA_TYPE_SRNPU; }
+
+    virtual uint32_t getSegmentsNum () const { return 0; }
+    virtual uint32_t getSegmentSize (uint32_t idx) const { return 0; }
+    virtual uint32_t getWeightSegmentIndex () const { return 0; }
+    virtual uint32_t getInputSegmentIndex (uint32_t idx) const { return 0; }
+    virtual uint32_t getOutputSegmentIndex (uint32_t idx) const { return 0; }
+
     uint64_t getSize () const { return meta_->size; }
     uint64_t getProgramSize () const { return meta_->program_size; }
     uint64_t getWeightSize () const { return meta_->weight_size; }
@@ -190,6 +199,78 @@ class Metadata_v2 : public Metadata {
     }
 };
 
+/** @brief metadata version 3: support a segment table */
+class Metadata_v3 : public Metadata {
+  public:
+    Metadata_v3 (npubin_meta *meta);
+
+    bool checkSanity () const;
+
+    uint32_t getInputNum () const override { return meta_->input_seg_num; }
+    uint32_t getOutputNum () const override { return meta_->output_seg_num; }
+
+    uint32_t getInputOffset (uint32_t idx) const override { return 0; }
+    uint32_t getOutputOffset (uint32_t idx) const override { return 0; }
+
+    uint32_t getInputTensorSize (uint32_t idx, data_layout layout) const override;
+    uint32_t getOutputTensorSize (uint32_t idx, data_layout layout) const override;
+
+    uint32_t getInputElemSize (uint32_t idx) const override { return 1; }
+    uint32_t getOutputElemSize (uint32_t idx) const override { return 1; }
+
+    const uint32_t* getInputDims (uint32_t idx) const override {
+      assert (idx < getInputNum ());
+      return meta_->input_seg_dims[idx];
+    }
+    const uint32_t* getOutputDims (uint32_t idx) const override {
+      assert (idx < getOutputNum ());
+      return meta_->output_seg_dims[idx];
+    }
+
+    uint32_t getInputQuantZero (uint32_t idx) const override {
+      assert (idx < getInputNum ());
+      return meta_->input_seg_quant_z[idx];
+    }
+    float getInputQuantScale (uint32_t idx) const override {
+      assert (idx < getInputNum ());
+      return meta_->input_seg_quant_s[idx];
+    }
+    uint32_t getOutputQuantZero (uint32_t idx) const override {
+      assert (idx < getOutputNum ());
+      return meta_->output_seg_quant_z[idx];
+    }
+    float getOutputQuantScale (uint32_t idx) const override {
+      assert (idx < getOutputNum ());
+      return meta_->output_seg_quant_s[idx];
+    }
+
+    data_type getInputQuantType (uint32_t idx) const override {
+      assert (idx < getInputNum ());
+      return meta_->input_seg_quant_type[idx];
+    }
+    data_type getOutputQuantType (uint32_t idx) const override {
+      assert (idx < getOutputNum ());
+      return meta_->output_seg_quant_type[idx];
+    }
+
+    uint32_t getSegmentsNum () const override { return meta_->segment_num; }
+    uint32_t getSegmentSize (uint32_t idx) const override {
+      assert (idx < getSegmentsNum ());
+      return meta_->segment_size[idx];
+    }
+    uint32_t getWeightSegmentIndex () const override {
+      return meta_->weight_seg_idx;
+    }
+    uint32_t getInputSegmentIndex (uint32_t idx) const override {
+      assert (idx < getInputNum ());
+      return meta_->input_seg_idx[idx];
+    }
+    uint32_t getOutputSegmentIndex (uint32_t idx) const override {
+      assert (idx < getOutputNum ());
+      return meta_->output_seg_idx[idx];
+    }
+};
+
 /** @brief model class derived from hwmem */
 class Model : public HWmem {
   public:
index 3279d7f..25c7bd1 100644 (file)
@@ -47,6 +47,7 @@ SegmentTable::~SegmentTable ()
  * @brief create segments according to on metadata info
  * @param[in] meta the metadata
  * @return 0 if no error, otherwise a negative errno
+ * @note we assume that # weight segments is always 1. (fix impl when it's changed)
  */
 int
 SegmentTable::createSegments (const Metadata *meta)
@@ -73,20 +74,18 @@ SegmentTable::createSegments (const Metadata *meta)
     return -EINVAL;
   }
 
-  /** TODO: extract exact info from metadata. let's use dummy info for now */
-  uint32_t seg_num = 4;
-  uint32_t seg_sizes[] = { 0x1000, 0x2000, 0x3000, 0x4000 };
-
   HWmem * hwmem;
   int status;
-  for (uint32_t i = 0; i < seg_num; i++) {
+  for (uint32_t i = 0; i < meta->getSegmentsNum (); i++) {
+    uint32_t size = meta->getSegmentSize (i);
+
     hwmem = new HWmem (new HWmemDevice);
     hwmem->setDriverAPI (getDriverAPI ());
 
-    status = hwmem->alloc (seg_sizes [i]);
+    status = hwmem->alloc (size);
     if (status != 0) {
       logerr (TAG, "Failed to allocate %uth segment with size %u: %d\n",
-          i, seg_sizes [i], status);
+          i, size, status);
       segments_.clear ();
       return -EINVAL;
     }
@@ -105,19 +104,21 @@ SegmentTable::createSegments (const Metadata *meta)
     reinterpret_cast<uint64_t *>(getData())[i] = unsigned_dmabuf;
   }
 
-  /** dummy info */
-  num_total_segments_ = seg_num;
-  num_weight_segments_ = 1;
-  num_input_segments_ = 1;
-  num_output_segments_ = 1;
-
   weight_seg_idx_ = new uint32_t [num_weight_segments_];
   input_seg_idx_ = new uint32_t [num_input_segments_];
   output_seg_idx_ = new uint32_t [num_output_segments_];
 
-  weight_seg_idx_[0] = 1;
-  input_seg_idx_[0] = 2;
-  output_seg_idx_[0] = 3;
+  /** segment index validity is already checked in Metadata's checkSanity () */
+  num_total_segments_ = meta->getSegmentsNum ();
+  num_weight_segments_ = 1;
+  num_input_segments_ = meta->getInputNum ();
+  num_output_segments_ = meta->getOutputNum ();
+
+  weight_seg_idx_[0] = meta->getWeightSegmentIndex ();
+  for (uint32_t i = 0; i < num_input_segments_; i++)
+    input_seg_idx_[i] = meta->getInputSegmentIndex (i);
+  for (uint32_t i = 0; i < num_output_segments_; i++)
+    output_seg_idx_[i] = meta->getOutputSegmentIndex (i);
 
   return 0;
 }
@@ -130,24 +131,20 @@ SegmentTable::createSegments (const Metadata *meta)
 HWmem *
 SegmentTable::getWeightSegment (uint32_t idx)
 {
-  if (idx >= MAX_TENSORS) {
-    logerr (TAG, "Invalid weight segment index (%u). Should be less than %u\n",
-        idx, MAX_TENSORS);
-    return nullptr;
-  }
-
   if (weight_seg_idx_ == nullptr) {
     logerr (TAG, "No valid segments in this table, maybe uninitialized?\n");
     return nullptr;
   }
 
-  uint32_t seg_idx = weight_seg_idx_[idx];
-  if (seg_idx >= MAX_SEGMENTS) {
-    logerr (TAG, "Invalid segment index (%u). Should be less than %u\n",
-        seg_idx, MAX_SEGMENTS);
+  if (idx >= num_weight_segments_) {
+    logerr (TAG, "Invalid weight segment index (%u). Should be less than %u\n",
+        idx, num_weight_segments_);
     return nullptr;
   }
 
+  uint32_t seg_idx = weight_seg_idx_[idx];
+  assert (seg_idx < segments_.size ());  /** this is ensured in checkSanity() */
+
   return segments_ [seg_idx];
 }
 
@@ -159,24 +156,20 @@ SegmentTable::getWeightSegment (uint32_t idx)
 HWmem *
 SegmentTable::getInputSegment (uint32_t idx)
 {
-  if (idx >= MAX_TENSORS) {
-    logerr (TAG, "Invalid input segment index (%u). Should be less than %u\n",
-        idx, MAX_TENSORS);
-    return nullptr;
-  }
-
   if (input_seg_idx_ == nullptr) {
     logerr (TAG, "No valid segments in this table, maybe uninitialized?\n");
     return nullptr;
   }
 
-  uint32_t seg_idx = input_seg_idx_[idx];
-  if (seg_idx >= MAX_SEGMENTS) {
-    logerr (TAG, "Invalid segment index (%u). Should be less than %u\n",
-        seg_idx, MAX_SEGMENTS);
+  if (idx >= num_input_segments_) {
+    logerr (TAG, "Invalid input segment index (%u). Should be less than %u\n",
+        idx, num_input_segments_);
     return nullptr;
   }
 
+  uint32_t seg_idx = input_seg_idx_[idx];
+  assert (seg_idx < segments_.size ());  /** this is ensured in checkSanity() */
+
   return segments_ [seg_idx];
 }
 
@@ -188,23 +181,19 @@ SegmentTable::getInputSegment (uint32_t idx)
 HWmem *
 SegmentTable::getOutputSegment (uint32_t idx)
 {
-  if (idx >= MAX_TENSORS) {
-    logerr (TAG, "Invalid output segment index (%u). Should be less than %u\n",
-        idx, MAX_TENSORS);
-    return nullptr;
-  }
-
   if (output_seg_idx_ == nullptr) {
     logerr (TAG, "No valid segments in this table, maybe uninitialized?\n");
     return nullptr;
   }
 
-  uint32_t seg_idx = output_seg_idx_[idx];
-  if (seg_idx >= MAX_SEGMENTS) {
-    logerr (TAG, "Invalid segment index (%u). Should be less than %u\n",
-        seg_idx, MAX_SEGMENTS);
+  if (idx >= num_output_segments_) {
+    logerr (TAG, "Invalid output segment index (%u). Should be less than %u\n",
+        idx, num_output_segments_);
     return nullptr;
   }
 
+  uint32_t seg_idx = output_seg_idx_[idx];
+  assert (seg_idx < segments_.size ());  /** this is ensured in checkSanity() */
+
   return segments_ [seg_idx];
 }