DATA_LAYOUT_SRNPU = DATA_LAYOUT_TRIV,
/**< alias for backward-compatibility */
DATA_LAYOUT_TRIV2, /**< customized layout for TRIV2 (based on NHWC) */
+ DATA_LAYOUT_MODEL, /**< use the same data layout specified in model metadata */
} data_layout;
/**
DATA_TYPE_INT64,
DATA_TYPE_UINT64,
DATA_TYPE_FLOAT64,
+ /* unknown */
+ DATA_TYPE_MODEL, /**< use the same data type specified in model metadata */
} data_type;
/**
return -ENOENT;
}
+ /* check the given model before running */
+ if (!model->finalize ()) {
+ logerr (TAG, "Failed to finalize the model. Please see the log messages\n");
+ return -EINVAL;
+ }
+
device_->setAsyncMode (mode);
return device_->run (NPUINPUT_HOST, model, input, cb, cb_data, sequence);
}
}
/**
+ * @brief finalize the model instance with user-provided configurations
+ * @return true if no error. otherwise false
+ */
+bool
+Model::finalize () {
+ uint32_t input_num = getInputTensorNum ();
+ uint32_t output_num = getOutputTensorNum ();
+
+ /** check tensors info */
+ if (input_num != in_.num_info) {
+ logerr (TAG, "The number of input tensors is different. Please set setNPU_dataInfo()\n");
+ return false;
+ }
+
+ if (output_num != out_.num_info) {
+ logerr (TAG, "The number of output tensors is different. Please set setNPU_dataInfo()\n");
+ return false;
+ }
+
+ /** evaluate data layout/type if required */
+ int version = getMetadata ()->getVersion ();
+ for (uint32_t idx = 0; idx < input_num; idx++) {
+ if (in_.info[idx].layout == DATA_LAYOUT_MODEL) {
+ if (version >= 3)
+ in_.info[idx].layout = DATA_LAYOUT_TRIV2;
+ else
+ in_.info[idx].layout = DATA_LAYOUT_SRNPU;
+ }
+ if (in_.info[idx].type == DATA_TYPE_MODEL) {
+ if (version >= 3)
+ in_.info[idx].type = getMetadata ()->getInputQuantType (idx);
+ else
+ in_.info[idx].type = DATA_TYPE_SRNPU;
+ }
+ }
+
+ for (uint32_t idx = 0; idx < output_num; idx++) {
+ if (out_.info[idx].layout == DATA_LAYOUT_MODEL) {
+ if (version >= 3)
+ out_.info[idx].layout = DATA_LAYOUT_TRIV2;
+ else
+ out_.info[idx].layout = DATA_LAYOUT_SRNPU;
+ }
+ if (out_.info[idx].type == DATA_TYPE_MODEL) {
+ if (version >= 3)
+ out_.info[idx].type = getMetadata ()->getOutputQuantType (idx);
+ else
+ out_.info[idx].type = DATA_TYPE_SRNPU;
+ }
+ }
+
+ return true;
+}
+
+/**
* @brief get the size of data type
* @return the data size
*/