add fbgemm fp16 (fbfcpacked) support, add global_init_net in predictor_export_meta...
authorWeiyi Zheng <wyz@fb.com>
Fri, 22 Mar 2019 07:08:50 +0000 (00:08 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 07:19:59 +0000 (00:19 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18257

support adding op in global_init_net. because pred_init_net is per thread, and just doesn't cut it.

Reviewed By: jspark1105

Differential Revision: D14552695

fbshipit-source-id: 53dd44c84ad019019ab9f35fc04d076b7f941ddc

caffe2/python/predictor/predictor_exporter.py
caffe2/python/predictor/predictor_exporter_test.py

index b4a1635..4cf5ded 100644 (file)
@@ -36,7 +36,7 @@ def get_predictor_exporter_helper(submodelNetName):
 class PredictorExportMeta(collections.namedtuple(
     'PredictorExportMeta',
         'predict_net, parameters, inputs, outputs, shapes, name, \
-        extra_init_net, net_type, num_workers, trainer_prefix')):
+        extra_init_net, global_init_net, net_type, num_workers, trainer_prefix')):
     """
     Metadata to be used for serializaing a net.
 
@@ -52,6 +52,13 @@ class PredictorExportMeta(collections.namedtuple(
     num_workers specifies for net type 'dag' how many threads should run ops
 
     trainer_prefix specifies the type of trainer.
+
+    extra_init_net gets appended to pred_init_net, useful for thread local init
+
+    global_init_net gets appended to global_init_net, useful for global init
+    on a shared across threads parameter workspace
+    (in a case of multi-threaded inference)
+
     """
     def __new__(
         cls,
@@ -62,6 +69,7 @@ class PredictorExportMeta(collections.namedtuple(
         shapes=None,
         name="",
         extra_init_net=None,
+        global_init_net=None,
         net_type=None,
         num_workers=None,
         trainer_prefix=None,
@@ -85,7 +93,7 @@ class PredictorExportMeta(collections.namedtuple(
         assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
         return super(PredictorExportMeta, cls).__new__(
             cls, predict_net, parameters, inputs, outputs, shapes, name,
-            extra_init_net, net_type, num_workers, trainer_prefix)
+            extra_init_net, global_init_net, net_type, num_workers, trainer_prefix)
 
     def inputs_name(self):
         return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
@@ -154,6 +162,9 @@ def _global_init_net(predictor_export_meta):
     net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
     net.Proto().external_output.extend(predictor_export_meta.parameters)
 
+    if predictor_export_meta.global_init_net:
+        net.AppendNet(predictor_export_meta.global_init_net)
+
     # Add the model_id in the predict_net to the global_init_net
     utils.AddModelIdArg(predictor_export_meta, net.Proto())
     return net.Proto()
index ef11246..b6b7806 100644 (file)
@@ -100,6 +100,15 @@ class PredictorExporterTest(unittest.TestCase):
 
         extra_init_net = core.Net('extra_init')
         extra_init_net.ConstantFill('data', 'data', value=1.0)
+
+        global_init_net = core.Net('global_init')
+        global_init_net.ConstantFill(
+            [],
+            'global_init_blob',
+            value=1.0,
+            shape=[1, 5],
+            dtype=core.DataType.FLOAT
+        )
         pem = pe.PredictorExportMeta(
             predict_net=self.predictor_export_meta.predict_net,
             parameters=self.predictor_export_meta.parameters,
@@ -107,6 +116,7 @@ class PredictorExporterTest(unittest.TestCase):
             outputs=self.predictor_export_meta.outputs,
             shapes=self.predictor_export_meta.shapes,
             extra_init_net=extra_init_net,
+            global_init_net=global_init_net,
             net_type='dag',
         )
 
@@ -142,11 +152,17 @@ class PredictorExporterTest(unittest.TestCase):
         np.testing.assert_array_equal(
             workspace.FetchBlob("y"), np.zeros(shape=(1, 10)))
 
+        self.assertTrue("global_init_blob" not in workspace.Blobs())
         # Load parameters from DB
         global_init_net = pred_utils.GetNet(meta_net_def,
                                             pc.GLOBAL_INIT_NET_TYPE)
         workspace.RunNetOnce(global_init_net)
 
+        # make sure the extra global_init_net is running
+        self.assertTrue(workspace.HasBlob('global_init_blob'))
+        np.testing.assert_array_equal(
+            workspace.FetchBlob("global_init_blob"), np.ones(shape=(1, 5)))
+
         # Run the net with a reshaped input and verify we are
         # producing good numbers (with our custom implementation)
         workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32))