From: rohithkrn Date: Sat, 15 Dec 2018 00:31:34 +0000 (-0800) Subject: FP16MomentumSGDUpdate Op fix and enable for ROCm (#15150) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2224 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=763b9954f3fbe0b2058eee9fe8055dfdc58ee615;p=platform%2Fupstream%2Fpytorch.git FP16MomentumSGDUpdate Op fix and enable for ROCm (#15150) Summary: 1. Fix a bug in FP16MomentumSGDUpdate operator 2. Enable operator for ROCm Pull Request resolved: https://github.com/pytorch/pytorch/pull/15150 Differential Revision: D13473145 Pulled By: bddppq fbshipit-source-id: 4c5c5f30cb9bba658e3639dbe193fa08a304d306 --- diff --git a/caffe2/python/operator_test/momentum_sgd_test.py b/caffe2/python/operator_test/momentum_sgd_test.py index 27dcb78..bcd0631 100644 --- a/caffe2/python/operator_test/momentum_sgd_test.py +++ b/caffe2/python/operator_test/momentum_sgd_test.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial @@ -143,7 +144,7 @@ class TestMomentumSGD(serial.SerializedTestCase): def test_fp16momentum_sgd(self, n, nesterov, gc, dc): assume(core.IsGPUDeviceType(gc.device_type)) gpuvers = workspace.GetDeviceProperties(0)["major"] - if gpuvers < 6: + if gc.device_type == caffe2_pb2.CUDA and gpuvers < 6: print("No FP16 support because major version {} < 6".format(gpuvers)) return @@ -152,7 +153,6 @@ class TestMomentumSGD(serial.SerializedTestCase): lr = np.random.rand(1).astype(np.float32) param_momentum = np.random.rand(n).astype(np.float16) momentum = 0.9 - nesterov = True def momentum_sgd(grad, param_momentum, lr, param=None): if not nesterov: @@ -174,11 +174,13 @@ class TestMomentumSGD(serial.SerializedTestCase): weight_decay=0.0, ) + threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP) else 1e-4 self.assertReferenceChecks( device_option=gc, op=op, inputs=[grad, param_momentum, lr, param], - reference=momentum_sgd + reference=momentum_sgd, + threshold=threshold ) diff --git a/caffe2/sgd/fp16_momentum_sgd_op.cu b/caffe2/sgd/fp16_momentum_sgd_op.cu index b7ac0a7..8ec1c85 100644 --- a/caffe2/sgd/fp16_momentum_sgd_op.cu +++ b/caffe2/sgd/fp16_momentum_sgd_op.cu @@ -22,7 +22,7 @@ __global__ void FP16MomentumSGDKernel( bool nesterov, const float wd, half2* param) { -#if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__) const float lr2 = lr[0]; const half2 LR = __float2half2_rn(lr2); const half2 momentum = __float2half2_rn(mom); @@ -87,7 +87,7 @@ __global__ void FP16MomentumSGDKernel( __hfma(mi_new_half, __high2half(momentum), mi_new_half), mom_mi_half); if (param) { - param_half[N - 1] = __hsub(param_half[i], ng_half[N - 1]); + param_half[N - 1] = __hsub(param_half[N - 1], ng_half[N - 1]); } } } @@ -109,7 +109,7 @@ __global__ void FP16MomentumSGDFP32Kernel( bool nesterov, const float wd, half2* param) { -#if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__) const float lr2 = lr[0]; const float LR = lr2; const float momentum = mom;