From 0199d59d3a3ba69c62480ddb6680c9c5f55667b6 Mon Sep 17 00:00:00 2001 From: Jiyan Yang Date: Tue, 27 Nov 2018 14:49:28 -0800 Subject: [PATCH] Resubmit: Set the correct engine name for position weighted pooling when fp16 is used for training Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13768 Reviewed By: xianjiec Differential Revision: D12996103 fbshipit-source-id: 5ca4cda4210f68ece2b5d6eced8cf52ee91fb36f --- caffe2/python/layers/sparse_lookup.py | 41 +++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py index c4de612..5524785 100644 --- a/caffe2/python/layers/sparse_lookup.py +++ b/caffe2/python/layers/sparse_lookup.py @@ -173,6 +173,8 @@ class SparseLookup(ModelLayer): "Train version {} is not currently supported".format(trainer_version) ) + self.trainer_version = trainer_version + return default_weight_init def _gather_wrapper(self, net, version, in_indices, out): @@ -215,11 +217,22 @@ class SparseLookup(ModelLayer): if version in ['fp32', 'fp16']: # SparseLengths* Ops will accept either fp16 or fp32 embedding # matrix and output fp32 pooled embedding - net.__getattr__(layer_name)( - op_input, - self.output_schema.field_blobs(), - grad_on_weights=grad_on_weights, - ) + # A special case here is that we need FP16 engine for + # SparseLengthsWeightedSum when FP16 embeedings are used for + # correct backward updates + if reducer == "WeightedSum" and version == "fp16": + net.SparseLengthsWeightedSum( + op_input, + self.output_schema.field_blobs(), + grad_on_weights=grad_on_weights, + engine='FP16', + ) + else: + net.__getattr__(layer_name)( + op_input, + self.output_schema.field_blobs(), + grad_on_weights=grad_on_weights, + ) elif version == 'uint8rowwise': op_input.insert(len(op_input), self.scale_bias) net.__getattr__(layer_name + '8BitsRowwise')( @@ -345,6 +358,17 @@ class SparseLookup(ModelLayer): raise "Only Sum, Mean, None are supported for IdScoreList input." +\ "Trying to create with {}".format(self.reducer) + def _add_ops(self, net, version='fp32'): + if _is_id_list(self.input_record): + self._add_ops_id_list(net, version=version) + elif _is_id_score_list(self.input_record): + self._add_ops_id_score_list(net, version=version) + else: + raise "Unsupported input type {0}".format(self.input_record) + + def add_train_ops(self, net): + self._add_ops(net, self.trainer_version) + def add_ops(self, net): cur_scope = get_current_scope() version = get_sparse_lookup_predictor_version( @@ -357,9 +381,4 @@ class SparseLookup(ModelLayer): 'fused_uint8rowwise'}: version = 'fp32' - if _is_id_list(self.input_record): - self._add_ops_id_list(net, version=version) - elif _is_id_score_list(self.input_record): - self._add_ops_id_score_list(net, version=version) - else: - raise "Unsupported input type {0}".format(self.input_record) + self._add_ops(net, version) -- 2.7.4