From f3cf6ed789683db0f2e1a5da9926b82f935d9d3b Mon Sep 17 00:00:00 2001 From: Weiyi Zheng Date: Fri, 22 Mar 2019 00:08:50 -0700 Subject: [PATCH] add fbgemm fp16 (fbfcpacked) support, add global_init_net in predictor_export_meta (#18257) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18257 support adding op in global_init_net. because pred_init_net is per thread, and just doesn't cut it. Reviewed By: jspark1105 Differential Revision: D14552695 fbshipit-source-id: 53dd44c84ad019019ab9f35fc04d076b7f941ddc --- caffe2/python/predictor/predictor_exporter.py | 15 +++++++++++++-- caffe2/python/predictor/predictor_exporter_test.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/caffe2/python/predictor/predictor_exporter.py b/caffe2/python/predictor/predictor_exporter.py index b4a1635..4cf5ded 100644 --- a/caffe2/python/predictor/predictor_exporter.py +++ b/caffe2/python/predictor/predictor_exporter.py @@ -36,7 +36,7 @@ def get_predictor_exporter_helper(submodelNetName): class PredictorExportMeta(collections.namedtuple( 'PredictorExportMeta', 'predict_net, parameters, inputs, outputs, shapes, name, \ - extra_init_net, net_type, num_workers, trainer_prefix')): + extra_init_net, global_init_net, net_type, num_workers, trainer_prefix')): """ Metadata to be used for serializaing a net. @@ -52,6 +52,13 @@ class PredictorExportMeta(collections.namedtuple( num_workers specifies for net type 'dag' how many threads should run ops trainer_prefix specifies the type of trainer. + + extra_init_net gets appended to pred_init_net, useful for thread local init + + global_init_net gets appended to global_init_net, useful for global init + on a shared across threads parameter workspace + (in a case of multi-threaded inference) + """ def __new__( cls, @@ -62,6 +69,7 @@ class PredictorExportMeta(collections.namedtuple( shapes=None, name="", extra_init_net=None, + global_init_net=None, net_type=None, num_workers=None, trainer_prefix=None, @@ -85,7 +93,7 @@ class PredictorExportMeta(collections.namedtuple( assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef)) return super(PredictorExportMeta, cls).__new__( cls, predict_net, parameters, inputs, outputs, shapes, name, - extra_init_net, net_type, num_workers, trainer_prefix) + extra_init_net, global_init_net, net_type, num_workers, trainer_prefix) def inputs_name(self): return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE, @@ -154,6 +162,9 @@ def _global_init_net(predictor_export_meta): net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER]) net.Proto().external_output.extend(predictor_export_meta.parameters) + if predictor_export_meta.global_init_net: + net.AppendNet(predictor_export_meta.global_init_net) + # Add the model_id in the predict_net to the global_init_net utils.AddModelIdArg(predictor_export_meta, net.Proto()) return net.Proto() diff --git a/caffe2/python/predictor/predictor_exporter_test.py b/caffe2/python/predictor/predictor_exporter_test.py index ef11246..b6b7806 100644 --- a/caffe2/python/predictor/predictor_exporter_test.py +++ b/caffe2/python/predictor/predictor_exporter_test.py @@ -100,6 +100,15 @@ class PredictorExporterTest(unittest.TestCase): extra_init_net = core.Net('extra_init') extra_init_net.ConstantFill('data', 'data', value=1.0) + + global_init_net = core.Net('global_init') + global_init_net.ConstantFill( + [], + 'global_init_blob', + value=1.0, + shape=[1, 5], + dtype=core.DataType.FLOAT + ) pem = pe.PredictorExportMeta( predict_net=self.predictor_export_meta.predict_net, parameters=self.predictor_export_meta.parameters, @@ -107,6 +116,7 @@ class PredictorExporterTest(unittest.TestCase): outputs=self.predictor_export_meta.outputs, shapes=self.predictor_export_meta.shapes, extra_init_net=extra_init_net, + global_init_net=global_init_net, net_type='dag', ) @@ -142,11 +152,17 @@ class PredictorExporterTest(unittest.TestCase): np.testing.assert_array_equal( workspace.FetchBlob("y"), np.zeros(shape=(1, 10))) + self.assertTrue("global_init_blob" not in workspace.Blobs()) # Load parameters from DB global_init_net = pred_utils.GetNet(meta_net_def, pc.GLOBAL_INIT_NET_TYPE) workspace.RunNetOnce(global_init_net) + # make sure the extra global_init_net is running + self.assertTrue(workspace.HasBlob('global_init_blob')) + np.testing.assert_array_equal( + workspace.FetchBlob("global_init_blob"), np.ones(shape=(1, 5))) + # Run the net with a reshaped input and verify we are # producing good numbers (with our custom implementation) workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32)) -- 2.7.4