from __future__ import print_function
from __future__ import unicode_literals
+from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
def test_fp16momentum_sgd(self, n, nesterov, gc, dc):
assume(core.IsGPUDeviceType(gc.device_type))
gpuvers = workspace.GetDeviceProperties(0)["major"]
- if gpuvers < 6:
+ if gc.device_type == caffe2_pb2.CUDA and gpuvers < 6:
print("No FP16 support because major version {} < 6".format(gpuvers))
return
lr = np.random.rand(1).astype(np.float32)
param_momentum = np.random.rand(n).astype(np.float16)
momentum = 0.9
- nesterov = True
def momentum_sgd(grad, param_momentum, lr, param=None):
if not nesterov:
weight_decay=0.0,
)
+ threshold = 1e-3 if (gc.device_type == caffe2_pb2.HIP) else 1e-4
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[grad, param_momentum, lr, param],
- reference=momentum_sgd
+ reference=momentum_sgd,
+ threshold=threshold
)
bool nesterov,
const float wd,
half2* param) {
-#if __CUDA_ARCH__ >= 530
+#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
const float lr2 = lr[0];
const half2 LR = __float2half2_rn(lr2);
const half2 momentum = __float2half2_rn(mom);
__hfma(mi_new_half, __high2half(momentum), mi_new_half),
mom_mi_half);
if (param) {
- param_half[N - 1] = __hsub(param_half[i], ng_half[N - 1]);
+ param_half[N - 1] = __hsub(param_half[N - 1], ng_half[N - 1]);
}
}
}
bool nesterov,
const float wd,
half2* param) {
-#if __CUDA_ARCH__ >= 530
+#if __CUDA_ARCH__ >= 530 || defined(__HIP_PLATFORM_HCC__)
const float lr2 = lr[0];
const float LR = lr2;
const float momentum = mom;