matcaffe: allow destruction of individual networks and solvers
authorRok Mandeljc <rok.mandeljc@fe.uni-lj.si>
Mon, 29 Jun 2015 13:48:43 +0000 (15:48 +0200)
committerRok Mandeljc <rok.mandeljc@gmail.com>
Sun, 18 Sep 2016 22:31:56 +0000 (00:31 +0200)
matlab/+caffe/Net.m
matlab/+caffe/Solver.m
matlab/+caffe/private/caffe_.cpp

index e6295bb..349e060 100644 (file)
@@ -68,6 +68,9 @@ classdef Net < handle
       self.layer_names = self.attributes.layer_names;
       self.blob_names = self.attributes.blob_names;
     end
+    function delete (self)
+      caffe_('delete_net', self.hNet_self);
+    end
     function layer = layers(self, layer_name)
       CHECK(ischar(layer_name), 'layer_name must be a string');
       layer = self.layer_vec(self.name2layer_index(layer_name));
index f8bdc4e..2d3c98b 100644 (file)
@@ -36,6 +36,9 @@ classdef Solver < handle
         self.test_nets(n) = caffe.Net(self.attributes.hNet_test_nets(n));
       end
     end
+    function delete (self)
+      caffe_('delete_solver', self.hSolver_self);
+    end
     function iter = iter(self)
       iter = caffe_('solver_get_iter', self.hSolver_self);
     end
index 1b1b2bf..bc04f41 100644 (file)
@@ -197,6 +197,17 @@ static void get_solver(MEX_ARGS) {
   mxFree(solver_file);
 }
 
+// Usage: caffe_('delete_solver', hSolver)
+static void delete_solver(MEX_ARGS) {
+  mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
+      "Usage: caffe_('delete_solver', hSolver)");
+  Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
+  solvers_.erase(std::remove_if(solvers_.begin(), solvers_.end(),
+      [solver] (const shared_ptr< Solver<float> > &solverPtr) {
+      return solverPtr.get() == solver;
+  }), solvers_.end());
+}
+
 // Usage: caffe_('solver_get_attr', hSolver)
 static void solver_get_attr(MEX_ARGS) {
   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -271,6 +282,17 @@ static void get_net(MEX_ARGS) {
   mxFree(phase_name);
 }
 
+// Usage: caffe_('delete_solver', hSolver)
+static void delete_net(MEX_ARGS) {
+  mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
+      "Usage: caffe_('delete_solver', hNet)");
+  Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
+  nets_.erase(std::remove_if(nets_.begin(), nets_.end(),
+      [net] (const shared_ptr< Net<float> > &netPtr) {
+      return netPtr.get() == net;
+  }), nets_.end());
+}
+
 // Usage: caffe_('net_get_attr', hNet)
 static void net_get_attr(MEX_ARGS) {
   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -522,12 +544,14 @@ struct handler_registry {
 static handler_registry handlers[] = {
   // Public API functions
   { "get_solver",         get_solver      },
+  { "delete_solver",      delete_solver   },
   { "solver_get_attr",    solver_get_attr },
   { "solver_get_iter",    solver_get_iter },
   { "solver_restore",     solver_restore  },
   { "solver_solve",       solver_solve    },
   { "solver_step",        solver_step     },
   { "get_net",            get_net         },
+  { "delete_net",         delete_net      },
   { "net_get_attr",       net_get_attr    },
   { "net_forward",        net_forward     },
   { "net_backward",       net_backward    },