auto const tensor_info = graph->tensors()->Get(tensor_id);
auto const tensor_name = tensor_info->name()->str();
auto const tensor_shape = as_tensor_shape(tensor_info->shape());
+ auto const tensor_type = tensor_info->type();
_name_ctx[tensor_id] = tensor_name;
_shape_ctx[tensor_id] = tensor_shape;
+ _type_ctx[tensor_id] = tensor_type;
}
}
{
/**
- * @brief Extracts and holds operand(tensor) information such as name and shape
+ * @brief Extracts and holds operand(tensor) information such as name, shape, and type
*/
class TensorContext
{
const std::string &name(uint32_t tensor_id) { return _name_ctx[tensor_id]; }
const tensor::Shape &shape(uint32_t tensor_id) { return _shape_ctx[tensor_id]; }
+ const tflite::TensorType &type(uint32_t tensor_id) { return _type_ctx[tensor_id]; }
private:
std::map<uint32_t, std::string> _name_ctx;
std::map<uint32_t, tensor::Shape> _shape_ctx;
+ std::map<uint32_t, tflite::TensorType> _type_ctx;
};
/**