clang 23.0.0git
hlsl_intrinsic_helpers.h
Go to the documentation of this file.
1//===----- hlsl_intrinsic_helpers.h - HLSL helpers intrinsics -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef _HLSL_HLSL_INTRINSIC_HELPERS_H_
10#define _HLSL_HLSL_INTRINSIC_HELPERS_H_
11
12namespace hlsl {
13namespace __detail {
14
16 // Use the same scaling factor used by FXC, and DXC for DXIL
17 // (i.e., 255.001953)
18 // https://github.com/microsoft/DirectXShaderCompiler/blob/070d0d5a2beacef9eeb51037a9b04665716fd6f3/lib/HLSL/HLOperationLower.cpp#L666C1-L697C2
19 // The DXC implementation refers to a comment on the following stackoverflow
20 // discussion to justify the scaling factor: "Built-in rounding, necessary
21 // because of truncation. 0.001953 * 256 = 0.5"
22 // https://stackoverflow.com/questions/52103720/why-does-d3dcolortoubyte4-multiplies-components-by-255-001953f
23 return V.zyxw * 255.001953f;
24}
25
26template <typename T> constexpr T length_impl(T X) { return abs(X); }
27
28template <typename T, int N>
29constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
30length_vec_impl(vector<T, N> X) {
31#if (__has_builtin(__builtin_spirv_length))
32 return __builtin_spirv_length(X);
33#else
34 return sqrt(dot(X, X));
35#endif
36}
37
38template <typename T>
39constexpr vector<T, 4> dst_impl(vector<T, 4> Src0, vector<T, 4> Src1) {
40 return {1, Src0[1] * Src1[1], Src0[2], Src1[3]};
41}
42
43template <typename T> constexpr T distance_impl(T X, T Y) {
44 return length_impl(X - Y);
45}
46
47template <typename T, int N>
48constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
49distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
50 return length_vec_impl(X - Y);
51}
52
53constexpr float dot2add_impl(half2 a, half2 b, float c) {
54#if (__has_builtin(__builtin_dx_dot2add))
55 return __builtin_dx_dot2add(a, b, c);
56#else
57 return dot(a, b) + c;
58#endif
59}
60
61template <typename T, int N>
62constexpr enable_if_t<!is_same<double, T>::value, T>
63mul_vec_impl(vector<T, N> x, vector<T, N> y) {
64 return dot(x, y);
65}
66
67// Double vectors do not have a dot intrinsic, so expand manually.
68template <typename T, int N>
70 vector<T, N> y) {
71 T sum = x[0] * y[0];
72 [unroll] for (int i = 1; i < N; ++i) sum = mad(x[i], y[i], sum);
73 return sum;
74}
75
76template <typename T> constexpr T reflect_impl(T I, T N) {
77 return I - 2 * N * I * N;
78}
79
80template <typename T, int L>
81constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
82#if (__has_builtin(__builtin_spirv_reflect))
83 return __builtin_spirv_reflect(I, N);
84#else
85 return I - 2 * N * dot(I, N);
86#endif
87}
88
89template <typename T, typename U> constexpr T refract_impl(T I, T N, U Eta) {
90#if (__has_builtin(__builtin_spirv_refract))
91 return __builtin_spirv_refract(I, N, Eta);
92#endif
93 T Mul = dot(N, I);
94 T K = 1 - Eta * Eta * (1 - Mul * Mul);
95 T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
96 return select<T>(K < 0, static_cast<T>(0), Result);
97}
98
99template <typename T> constexpr T fmod_impl(T X, T Y) {
100#if !defined(__DIRECTX__)
101 return __builtin_elementwise_fmod(X, Y);
102#else
103 T div = X / Y;
104 bool ge = div >= 0;
105 T frc = frac(abs(div));
106 return select<T>(ge, frc, -frc) * Y;
107#endif
108}
109
110template <typename T, int N>
111constexpr vector<T, N> fmod_vec_impl(vector<T, N> X, vector<T, N> Y) {
112#if !defined(__DIRECTX__)
113 return __builtin_elementwise_fmod(X, Y);
114#else
115 vector<T, N> div = X / Y;
116 vector<bool, N> ge = div >= 0;
117 vector<T, N> frc = frac(abs(div));
118 return select<T>(ge, frc, -frc) * Y;
119#endif
120}
121
122template <typename T> constexpr T smoothstep_impl(T Min, T Max, T X) {
123#if (__has_builtin(__builtin_spirv_smoothstep))
124 return __builtin_spirv_smoothstep(Min, Max, X);
125#else
126 T S = saturate((X - Min) / (Max - Min));
127 return (3 - 2 * S) * S * S;
128#endif
129}
130
131template <typename T, int N>
132constexpr vector<T, N> smoothstep_vec_impl(vector<T, N> Min, vector<T, N> Max,
133 vector<T, N> X) {
134#if (__has_builtin(__builtin_spirv_smoothstep))
135 return __builtin_spirv_smoothstep(Min, Max, X);
136#else
137 vector<T, N> S = saturate((X - Min) / (Max - Min));
138 return (3 - 2 * S) * S * S;
139#endif
140}
141
142template <typename T> constexpr vector<T, 4> lit_impl(T NDotL, T NDotH, T M) {
143 bool DiffuseCond = NDotL < 0;
144 T Diffuse = select<T>(DiffuseCond, 0, NDotL);
145 vector<T, 4> Result = {1, Diffuse, 0, 1};
146 // clang-format off
147 bool SpecularCond = or(DiffuseCond, (NDotH < 0));
148 // clang-format on
149 T SpecularExp = exp(log(NDotH) * M);
150 Result[2] = select<T>(SpecularCond, 0, SpecularExp);
151 return Result;
152}
153
154template <typename T> constexpr T faceforward_impl(T N, T I, T Ng) {
155 return select(dot(I, Ng) < 0, N, -N);
156}
157
158template <typename T> constexpr T ldexp_impl(T X, T Exp) {
159 return exp2(Exp) * X;
160}
161
162template <typename K, typename T, int BitWidth>
163constexpr K firstbithigh_impl(T X) {
164 K FBH = __builtin_hlsl_elementwise_firstbithigh(X);
165#if defined(__DIRECTX__)
166 // The firstbithigh DXIL ops count bits from the wrong side, so we need to
167 // invert it for DirectX.
168 K Inversion = (BitWidth - 1) - FBH;
169 FBH = select(FBH == -1, FBH, Inversion);
170#endif
171 return FBH;
172}
173
174template <typename T> constexpr T ddx_impl(T input) {
175#if (__has_builtin(__builtin_spirv_ddx))
176 return __builtin_spirv_ddx(input);
177#else
178 return __builtin_hlsl_elementwise_ddx_coarse(input);
179#endif
180}
181
182template <typename T> constexpr T ddy_impl(T input) {
183#if (__has_builtin(__builtin_spirv_ddy))
184 return __builtin_spirv_ddy(input);
185#else
186 return __builtin_hlsl_elementwise_ddy_coarse(input);
187#endif
188}
189
190template <typename T> constexpr T fwidth_impl(T input) {
191#if (__has_builtin(__builtin_spirv_fwidth))
192 return __builtin_spirv_fwidth(input);
193#else
194 T derivCoarseX = ddx_coarse(input);
195 derivCoarseX = abs(derivCoarseX);
196 T derivCoarseY = ddy_coarse(input);
197 derivCoarseY = abs(derivCoarseY);
198 return derivCoarseX + derivCoarseY;
199#endif
200}
201
202} // namespace __detail
203} // namespace hlsl
204
205#endif // _HLSL_HLSL_INTRINSIC_HELPERS_H_
#define V(N, I)
#define X(type, name)
Definition Value.h:97
__device__ __2f16 b
__device__ __2f16 float c
#define or
Definition iso646.h:24
constexpr vector< T, N > smoothstep_vec_impl(vector< T, N > Min, vector< T, N > Max, vector< T, N > X)
constexpr T length_impl(T X)
constexpr vector< T, 4 > dst_impl(vector< T, 4 > Src0, vector< T, 4 > Src1)
constexpr T faceforward_impl(T N, T I, T Ng)
constexpr T fwidth_impl(T input)
constexpr vector< T, L > reflect_vec_impl(vector< T, L > I, vector< T, L > N)
constexpr T distance_impl(T X, T Y)
constexpr K firstbithigh_impl(T X)
constexpr T reflect_impl(T I, T N)
constexpr T ldexp_impl(T X, T Exp)
constexpr enable_if_t< is_same< float, T >::value||is_same< half, T >::value, T > distance_vec_impl(vector< T, N > X, vector< T, N > Y)
constexpr T fmod_impl(T X, T Y)
typename enable_if< B, T >::Type enable_if_t
Definition hlsl_detail.h:31
constexpr T ddx_impl(T input)
constexpr int4 d3d_color_to_ubyte4_impl(float4 V)
constexpr T smoothstep_impl(T Min, T Max, T X)
constexpr enable_if_t< is_same< float, T >::value||is_same< half, T >::value, T > length_vec_impl(vector< T, N > X)
constexpr T refract_impl(T I, T N, U Eta)
constexpr float dot2add_impl(half2 a, half2 b, float c)
constexpr vector< T, N > fmod_vec_impl(vector< T, N > X, vector< T, N > Y)
constexpr T ddy_impl(T input)
constexpr vector< T, 4 > lit_impl(T NDotL, T NDotH, T M)
constexpr enable_if_t<!is_same< double, T >::value, T > mul_vec_impl(vector< T, N > x, vector< T, N > y)
half ddx_coarse(half)
half mad(half, half, half)
T select(bool, T, T)
ternary operator.
vector< half, 2 > half2
half saturate(half)
half ddy_coarse(half)
half abs(half)
vector< float, 4 > float4
half dot(half, half)
vector< int, 4 > int4
half frac(half)
static const bool value
Definition hlsl_detail.h:17
#define sqrt(__x)
Definition tgmath.h:520
#define exp(__x)
Definition tgmath.h:431
#define exp2(__x)
Definition tgmath.h:670
#define log(__x)
Definition tgmath.h:460