diff --git a/amd/device-libs/ocml/src/rsqrtD.cl b/amd/device-libs/ocml/src/rsqrtD.cl index 0430645e542e4..69e13ad99d6ca 100644 --- a/amd/device-libs/ocml/src/rsqrtD.cl +++ b/amd/device-libs/ocml/src/rsqrtD.cl @@ -11,8 +11,7 @@ CONSTATTR double MATH_MANGLE(rsqrt)(double x) { double y0 = BUILTIN_AMDGPU_RSQRT_F64(x); - double e = MATH_MAD(-x*y0, y0, 1.0); - double y1 = MATH_MAD(y0*e, MATH_MAD(e, 0.375, 0.5), y0); - return BUILTIN_CLASS_F64(y0, CLASS_PSUB|CLASS_PNOR) ? y1 : y0; + double e = MATH_MAD(-y0 * (x == PINF_F64 || x == 0.0 ? y0 : x), y0, 1.0); + return MATH_MAD(y0*e, MATH_MAD(e, 0.375, 0.5), y0); } diff --git a/amd/device-libs/test/compile/rsqrt.cl b/amd/device-libs/test/compile/rsqrt.cl index 1a44c1ad539ab..11c21c5fde61f 100644 --- a/amd/device-libs/test/compile/rsqrt.cl +++ b/amd/device-libs/test/compile/rsqrt.cl @@ -26,13 +26,14 @@ float test_rsqrt_f32(float x) { // CHECK-LABEL: {{^}}test_rsqrt_f64: // CHECK: v_rsq_f64 +// CHECK: v_cmp_class_f64 +// CHECK: v_cndmask_b32 +// CHECK: v_cndmask_b32 // CHECK: v_mul_f64 // CHECK: v_fma_f64 // CHECK: v_mul_f64 // CHECK: v_fma_f64 // CHECK: v_fma_f64 -// CHECK: v_cndmask_b32 -// CHECK: v_cndmask_b32 double test_rsqrt_f64(double x) { return rsqrt(x); }