FP16MomentumSGDUpdate Op fix and enable for ROCm (#15150)
authorrohithkrn <rohith.nallamaddi@gmail.com>
Sat, 15 Dec 2018 00:31:34 +0000 (16:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 15 Dec 2018 00:33:45 +0000 (16:33 -0800)
Summary:
1. Fix a bug in FP16MomentumSGDUpdate operator
2. Enable operator for ROCm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15150

Differential Revision: D13473145

Pulled By: bddppq

fbshipit-source-id: 4c5c5f30cb9bba658e3639dbe193fa08a304d306

caffe2/python/operator_test/momentum_sgd_test.py
caffe2/sgd/fp16_momentum_sgd_op.cu

index 27dcb78..bcd0631 100644 (file)
@@ -3,6 +3,7 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
+from caffe2.proto import caffe2_pb2
 from caffe2.python import core, workspace
 import caffe2.python.hypothesis_test_util as hu
 import caffe2.python.serialized_test.serialized_test_util as serial
@@ -143,7 +144,7 @@ class TestMomentumSGD(serial.SerializedTestCase):
     def test_fp16momentum_sgd(self, n, nesterov, gc, dc):
         assume(core.IsGPUDeviceType(gc.device_type))
         gpuvers = workspace.GetDeviceProperties(0)["major"]
-        if gpuvers < 6:
+        if gc.device_type == caffe2_pb2.CUDA and gpuvers < 6:
             print("No FP16 support because major version {} < 6".format(gpuvers))
             return
 
@@ -152,7 +153,6 @@ class TestMomentumSGD(serial.SerializedTestCase):
         lr = np.random.rand(1).astype(np.float32)
         param_momentum = np.random.rand(n).astype(np.float16)
         momentum = 0.9
-        nesterov = True
 
         def momentum_sgd(grad, param_momentum, lr, param=None):
             if not nesterov:
@@ -174,11 +174,13 @@ class TestMomentumSGD(serial.SerializedTestCase):
             weight_decay=0.0,
         )
 
+        threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP) else 1e-4
         self.assertReferenceChecks(
             device_option=gc,
             op=op,
             inputs=[grad, param_momentum, lr, param],
-            reference=momentum_sgd
+            reference=momentum_sgd,
+            threshold=threshold
         )
 
 
index b7ac0a7..8ec1c85 100644 (file)
@@ -22,7 +22,7 @@ __global__ void FP16MomentumSGDKernel(
     bool nesterov,
     const float wd,
     half2* param) {
-#if __CUDA_ARCH__ >= 530
+#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
   const float lr2 = lr[0];
   const half2 LR = __float2half2_rn(lr2);
   const half2 momentum = __float2half2_rn(mom);
@@ -87,7 +87,7 @@ __global__ void FP16MomentumSGDKernel(
             __hfma(mi_new_half, __high2half(momentum), mi_new_half),
             mom_mi_half);
         if (param) {
-          param_half[N - 1] = __hsub(param_half[i], ng_half[N - 1]);
+          param_half[N - 1] = __hsub(param_half[N - 1], ng_half[N - 1]);
         }
       }
     }
@@ -109,7 +109,7 @@ __global__ void FP16MomentumSGDFP32Kernel(
     bool nesterov,
     const float wd,
     half2* param) {
-#if __CUDA_ARCH__ >= 530
+#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
   const float lr2 = lr[0];
   const float LR = lr2;
   const float momentum = mom;