Merge pull request #18419 from TolyaTalamanov:at/generic-inference
[platform/upstream/opencv.git] / modules / gapi / include / opencv2 / gapi / infer.hpp
index 50086dd..4fdd2df 100644 (file)
@@ -121,6 +121,45 @@ struct GInferBase {
     }
 };
 
+// Struct stores network input/output names.
+// Used by infer<Generic>
+struct InOutInfo
+{
+    std::vector<std::string> in_names;
+    std::vector<std::string> out_names;
+};
+
+/**
+ * @{
+ * @brief G-API object used to collect network inputs
+ */
+class GAPI_EXPORTS GInferInputs
+{
+public:
+    cv::GMat& operator[](const std::string& name);
+    const std::unordered_map<std::string, cv::GMat>& getBlobs() const;
+
+private:
+    std::unordered_map<std::string, cv::GMat> in_blobs;
+};
+/** @} */
+
+/**
+ * @{
+ * @brief G-API object used to collect network outputs
+ */
+struct GAPI_EXPORTS GInferOutputs
+{
+public:
+    GInferOutputs(std::shared_ptr<cv::GCall> call);
+    cv::GMat at(const std::string& name);
+
+private:
+    std::shared_ptr<cv::GCall> m_call;
+    InOutInfo* m_info = nullptr;
+    std::unordered_map<std::string, cv::GMat> out_blobs;
+};
+/** @} */
 
 // Base "Infer list" kernel.
 // All notes from "Infer" kernel apply here as well.
@@ -254,6 +293,45 @@ typename Net::Result infer(Args&&... args) {
     return GInfer<Net>::on(std::forward<Args>(args)...);
 }
 
+/**
+ * @brief Special network type
+ */
+struct Generic { };
+
+/**
+ * @brief Calculates response for generic network
+ *
+ * @param tag a network tag
+ * @param inputs networks's inputs
+ * @return a GInferOutputs
+ */
+template<typename T = Generic> GInferOutputs
+infer(const std::string& tag, const GInferInputs& inputs)
+{
+    std::vector<GArg> input_args;
+    std::vector<std::string> input_names;
+
+    const auto& blobs = inputs.getBlobs();
+    for (auto&& p : blobs)
+    {
+        input_names.push_back(p.first);
+        input_args.emplace_back(p.second);
+    }
+
+    GKinds kinds(blobs.size(), cv::detail::OpaqueKind::CV_MAT);
+    auto call = std::make_shared<cv::GCall>(GKernel{
+                GInferBase::id(),
+                tag,
+                GInferBase::getOutMeta,
+                {}, // outShape will be filled later
+                std::move(kinds)
+            });
+
+    call->setArgs(std::move(input_args));
+    call->params() = InOutInfo{input_names, {}};
+
+    return GInferOutputs{std::move(call)};
+}
 
 } // namespace gapi
 } // namespace cv