From 8709bcacfb3a06847b47bb6b47e8556db43f3a43 Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Wed, 7 Dec 2022 22:49:27 -0500 Subject: [PATCH] clang: Add __builtin_elementwise_fma I didn't understand why the other builtins have promotion logic, or how it would apply for a ternary operation. Implicit conversions are evil to begin with, and even more so when the purpose is to get an exact IR intrinsic. This checks all the arguments have the same type. --- clang/docs/LanguageExtensions.rst | 1 + clang/include/clang/Basic/Builtins.def | 1 + clang/include/clang/Sema/Sema.h | 1 + clang/lib/CodeGen/CGBuiltin.cpp | 2 + clang/lib/Sema/SemaChecking.cpp | 52 ++++++++++++---- clang/test/CodeGen/builtins-elementwise-math.c | 78 ++++++++++++++++++++++++ clang/test/Sema/builtins-elementwise-math.c | 83 ++++++++++++++++++++++++++ 7 files changed, 207 insertions(+), 11 deletions(-) diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst index 2595b5d..c0ea8af 100644 --- a/clang/docs/LanguageExtensions.rst +++ b/clang/docs/LanguageExtensions.rst @@ -631,6 +631,7 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in =========================================== ================================================================ ========================================= T __builtin_elementwise_abs(T x) return the absolute value of a number x; the absolute value of signed integer and floating point types the most negative integer remains the most negative integer + T __builtin_elementwise_fma(T x, T y, T z) fused multiply add, (x * y) + z. floating point types T __builtin_elementwise_ceil(T x) return the smallest integral value greater than or equal to x floating point types T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def index 4128841..6db599a 100644 --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -671,6 +671,7 @@ BUILTIN(__builtin_elementwise_sin, "v.", "nct") BUILTIN(__builtin_elementwise_trunc, "v.", "nct") BUILTIN(__builtin_elementwise_canonicalize, "v.", "nct") BUILTIN(__builtin_elementwise_copysign, "v.", "nct") +BUILTIN(__builtin_elementwise_fma, "v.", "nct") BUILTIN(__builtin_elementwise_add_sat, "v.", "nct") BUILTIN(__builtin_elementwise_sub_sat, "v.", "nct") BUILTIN(__builtin_reduce_max, "v.", "nct") diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 41691ea..0c6a388 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -13531,6 +13531,7 @@ private: bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc); bool SemaBuiltinElementwiseMath(CallExpr *TheCall); + bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall); bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall); bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall); diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 52ec6e0..1535b14 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -3118,6 +3118,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, emitUnaryBuiltin(*this, E, llvm::Intrinsic::canonicalize, "elt.trunc")); case Builtin::BI__builtin_elementwise_copysign: return RValue::get(emitBinaryBuiltin(*this, E, llvm::Intrinsic::copysign)); + case Builtin::BI__builtin_elementwise_fma: + return RValue::get(emitTernaryBuiltin(*this, E, llvm::Intrinsic::fma)); case Builtin::BI__builtin_elementwise_add_sat: case Builtin::BI__builtin_elementwise_sub_sat: { Value *Op0 = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index eded606..485351f 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2626,20 +2626,16 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID, return ExprError(); QualType ArgTy = TheCall->getArg(0)->getType(); - QualType EltTy = ArgTy; - - if (auto *VecTy = EltTy->getAs()) - EltTy = VecTy->getElementType(); - if (!EltTy->isFloatingType()) { - Diag(TheCall->getArg(0)->getBeginLoc(), - diag::err_builtin_invalid_arg_type) - << 1 << /* float ty*/ 5 << ArgTy; - + if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(), + ArgTy, 1)) + return ExprError(); + break; + } + case Builtin::BI__builtin_elementwise_fma: { + if (SemaBuiltinElementwiseTernaryMath(TheCall)) return ExprError(); - } break; } - // These builtins restrict the element type to integer // types only. case Builtin::BI__builtin_elementwise_add_sat: @@ -17877,6 +17873,40 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { return false; } +bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) { + if (checkArgCount(*this, TheCall, 3)) + return true; + + Expr *Args[3]; + for (int I = 0; I < 3; ++I) { + ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I)); + if (Converted.isInvalid()) + return true; + Args[I] = Converted.get(); + } + + int ArgOrdinal = 1; + for (Expr *Arg : Args) { + if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(), + ArgOrdinal++)) + return true; + } + + for (int I = 1; I < 3; ++I) { + if (Args[0]->getType().getCanonicalType() != + Args[I]->getType().getCanonicalType()) { + return Diag(Args[0]->getBeginLoc(), + diag::err_typecheck_call_different_arg_types) + << Args[0]->getType() << Args[I]->getType(); + } + + TheCall->setArg(I, Args[I]); + } + + TheCall->setType(Args[0]->getType()); + return false; +} + bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) { if (checkArgCount(*this, TheCall, 1)) return true; diff --git a/clang/test/CodeGen/builtins-elementwise-math.c b/clang/test/CodeGen/builtins-elementwise-math.c index 1571d2b..1b48a12 100644 --- a/clang/test/CodeGen/builtins-elementwise-math.c +++ b/clang/test/CodeGen/builtins-elementwise-math.c @@ -1,5 +1,9 @@ // RUN: %clang_cc1 -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s +typedef _Float16 half; + +typedef half half2 __attribute__((ext_vector_type(2))); +typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); typedef short int si8 __attribute__((ext_vector_type(8))); typedef unsigned int u4 __attribute__((ext_vector_type(4))); @@ -525,3 +529,77 @@ void test_builtin_elementwise_copysign(float f1, float f2, double d1, double d2, // CHECK-NEXT: call <2 x double> @llvm.copysign.v2f64(<2 x double> , <2 x double> [[V2F64]]) v2f64 = __builtin_elementwise_copysign((double2)1.0, v2f64); } + +void test_builtin_elementwise_fma(float f32, double f64, + float2 v2f32, float4 v4f32, + double2 v2f64, double3 v3f64, + const float4 c_v4f32, + half f16, half2 v2f16) { + // CHECK-LABEL: define void @test_builtin_elementwise_fma( + // CHECK: [[F32_0:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_1:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_2:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: call float @llvm.fma.f32(float [[F32_0]], float [[F32_1]], float [[F32_2]]) + float f2 = __builtin_elementwise_fma(f32, f32, f32); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + double d2 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + float4 tmp_v4f32 = __builtin_elementwise_fma(v4f32, v4f32, v4f32); + + + // FIXME: Are we really still doing the 3 vector load workaround + // CHECK: [[V3F64_LOAD_0:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_0:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_1:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_1:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_2:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_2:%.+]] = shufflevector + // CHECK-NEXT: call <3 x double> @llvm.fma.v3f64(<3 x double> [[V3F64_0]], <3 x double> [[V3F64_1]], <3 x double> [[V3F64_2]]) + v3f64 = __builtin_elementwise_fma(v3f64, v3f64, v3f64); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + v2f64 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + v4f32 = __builtin_elementwise_fma(c_v4f32, c_v4f32, c_v4f32); + + // CHECK: [[F16_0:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_1:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: call half @llvm.fma.f16(half [[F16_0]], half [[F16_1]], half [[F16_2]]) + half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_2:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp0_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, v2f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[V2F16_2_INSERT:%.+]] = insertelement + // CHECK-NEXT: [[V2F16_2:%.+]] = shufflevector <2 x half> [[V2F16_2_INSERT]], <2 x half> poison, <2 x i32> zeroinitializer + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp1_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> ) + half2 tmp2_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)4.0); + +} diff --git a/clang/test/Sema/builtins-elementwise-math.c b/clang/test/Sema/builtins-elementwise-math.c index cb8b797..c803fce 100644 --- a/clang/test/Sema/builtins-elementwise-math.c +++ b/clang/test/Sema/builtins-elementwise-math.c @@ -4,6 +4,8 @@ typedef double double2 __attribute__((ext_vector_type(2))); typedef double double4 __attribute__((ext_vector_type(4))); typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); + +typedef int int2 __attribute__((ext_vector_type(2))); typedef int int3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned4 __attribute__((ext_vector_type(4))); @@ -572,3 +574,84 @@ void test_builtin_elementwise_copysign(int i, short s, double d, float f, float4 float2 tmp9 = __builtin_elementwise_copysign(v4f32, v4f32); // expected-error@-1 {{initializing 'float2' (vector of 2 'float' values) with an expression of incompatible type 'float4' (vector of 4 'float' values)}} } + +void test_builtin_elementwise_fma(int i32, int2 v2i32, short i16, + double f64, double2 v2f64, double2 v3f64, + float f32, float2 v2f32, float v3f32, float4 v4f32, + const float4 c_v4f32, + int3 v3i32, int *ptr) { + + f32 = __builtin_elementwise_fma(); + // expected-error@-1 {{too few arguments to function call, expected 3, have 0}} + + f32 = __builtin_elementwise_fma(f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 1}} + + f32 = __builtin_elementwise_fma(f32, f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 2}} + + f32 = __builtin_elementwise_fma(f32, f32, f32, f32); + // expected-error@-1 {{too many arguments to function call, expected 3, have 4}} + + f32 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f32 = __builtin_elementwise_fma(f32, f64, f32); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f64 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f64, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f32, f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(v2f32, v2f64, f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double2' (vector of 2 'double' values)}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, v2f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(f64, v2f32, v2f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'float2' (vector of 2 'float' values)}} + + v2f64 = __builtin_elementwise_fma(f64, v2f64, v2f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'double2' (vector of 2 'double' values)}} + + i32 = __builtin_elementwise_fma(i32, i32, i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int')}} + + v2i32 = __builtin_elementwise_fma(v2i32, v2i32, v2i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int2' (vector of 2 'int' values))}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, i32, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + + _Complex float c1, c2, c3; + c1 = __builtin_elementwise_fma(c1, f32, f32); + // expected-error@-1 {{1st argument must be a floating point type (was '_Complex float')}} + + c2 = __builtin_elementwise_fma(f32, c2, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was '_Complex float')}} + + c3 = __builtin_elementwise_fma(f32, f32, c3); + // expected-error@-1 {{3rd argument must be a floating point type (was '_Complex float')}} +} -- 2.7.4