class PredictorExportMeta(collections.namedtuple(
'PredictorExportMeta',
'predict_net, parameters, inputs, outputs, shapes, name, \
- extra_init_net, net_type, num_workers, trainer_prefix')):
+ extra_init_net, global_init_net, net_type, num_workers, trainer_prefix')):
"""
Metadata to be used for serializaing a net.
num_workers specifies for net type 'dag' how many threads should run ops
trainer_prefix specifies the type of trainer.
+
+ extra_init_net gets appended to pred_init_net, useful for thread local init
+
+ global_init_net gets appended to global_init_net, useful for global init
+ on a shared across threads parameter workspace
+ (in a case of multi-threaded inference)
+
"""
def __new__(
cls,
shapes=None,
name="",
extra_init_net=None,
+ global_init_net=None,
net_type=None,
num_workers=None,
trainer_prefix=None,
assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
return super(PredictorExportMeta, cls).__new__(
cls, predict_net, parameters, inputs, outputs, shapes, name,
- extra_init_net, net_type, num_workers, trainer_prefix)
+ extra_init_net, global_init_net, net_type, num_workers, trainer_prefix)
def inputs_name(self):
return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
net.Proto().external_output.extend(predictor_export_meta.parameters)
+ if predictor_export_meta.global_init_net:
+ net.AppendNet(predictor_export_meta.global_init_net)
+
# Add the model_id in the predict_net to the global_init_net
utils.AddModelIdArg(predictor_export_meta, net.Proto())
return net.Proto()
extra_init_net = core.Net('extra_init')
extra_init_net.ConstantFill('data', 'data', value=1.0)
+
+ global_init_net = core.Net('global_init')
+ global_init_net.ConstantFill(
+ [],
+ 'global_init_blob',
+ value=1.0,
+ shape=[1, 5],
+ dtype=core.DataType.FLOAT
+ )
pem = pe.PredictorExportMeta(
predict_net=self.predictor_export_meta.predict_net,
parameters=self.predictor_export_meta.parameters,
outputs=self.predictor_export_meta.outputs,
shapes=self.predictor_export_meta.shapes,
extra_init_net=extra_init_net,
+ global_init_net=global_init_net,
net_type='dag',
)
np.testing.assert_array_equal(
workspace.FetchBlob("y"), np.zeros(shape=(1, 10)))
+ self.assertTrue("global_init_blob" not in workspace.Blobs())
# Load parameters from DB
global_init_net = pred_utils.GetNet(meta_net_def,
pc.GLOBAL_INIT_NET_TYPE)
workspace.RunNetOnce(global_init_net)
+ # make sure the extra global_init_net is running
+ self.assertTrue(workspace.HasBlob('global_init_blob'))
+ np.testing.assert_array_equal(
+ workspace.FetchBlob("global_init_blob"), np.ones(shape=(1, 5)))
+
# Run the net with a reshaped input and verify we are
# producing good numbers (with our custom implementation)
workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32))