Merge pull request #4737 from rokm/matcaffe-individual-destruct
[platform/upstream/caffeonacl.git] / matlab / +caffe / private / caffe_.cpp
index 4e466e6..a32bd5e 100644 (file)
@@ -197,6 +197,17 @@ static void get_solver(MEX_ARGS) {
   mxFree(solver_file);
 }
 
+// Usage: caffe_('delete_solver', hSolver)
+static void delete_solver(MEX_ARGS) {
+  mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
+      "Usage: caffe_('delete_solver', hSolver)");
+  Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
+  solvers_.erase(std::remove_if(solvers_.begin(), solvers_.end(),
+      [solver] (const shared_ptr< Solver<float> > &solverPtr) {
+      return solverPtr.get() == solver;
+  }), solvers_.end());
+}
+
 // Usage: caffe_('solver_get_attr', hSolver)
 static void solver_get_attr(MEX_ARGS) {
   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -271,6 +282,17 @@ static void get_net(MEX_ARGS) {
   mxFree(phase_name);
 }
 
+// Usage: caffe_('delete_solver', hSolver)
+static void delete_net(MEX_ARGS) {
+  mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
+      "Usage: caffe_('delete_solver', hNet)");
+  Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
+  nets_.erase(std::remove_if(nets_.begin(), nets_.end(),
+      [net] (const shared_ptr< Net<float> > &netPtr) {
+      return netPtr.get() == net;
+  }), nets_.end());
+}
+
 // Usage: caffe_('net_get_attr', hNet)
 static void net_get_attr(MEX_ARGS) {
   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -522,12 +544,14 @@ struct handler_registry {
 static handler_registry handlers[] = {
   // Public API functions
   { "get_solver",         get_solver      },
+  { "delete_solver",      delete_solver   },
   { "solver_get_attr",    solver_get_attr },
   { "solver_get_iter",    solver_get_iter },
   { "solver_restore",     solver_restore  },
   { "solver_solve",       solver_solve    },
   { "solver_step",        solver_step     },
   { "get_net",            get_net         },
+  { "delete_net",         delete_net      },
   { "net_get_attr",       net_get_attr    },
   { "net_forward",        net_forward     },
   { "net_backward",       net_backward    },