Convert Metal shaders to FP16

This commit is contained in:
Cameron Gutman
2025-12-29 01:45:49 -06:00
parent 8ee82421b3
commit 93dc6d6b60
2 changed files with 44 additions and 55 deletions

View File

@@ -1,3 +1,4 @@
#include <metal_stdlib>
using namespace metal;
struct Vertex
@@ -8,10 +9,10 @@ struct Vertex
struct CscParams
{
float3 matrix[3];
float3 offsets;
float2 chromaOffset;
float bitnessScaleFactor;
half3x3 matrix;
half3 offsets;
half2 chromaOffset;
half bitnessScaleFactor;
};
constexpr sampler s(coord::normalized, address::clamp_to_edge, filter::linear);
@@ -21,48 +22,40 @@ vertex Vertex vs_draw(constant Vertex *vertices [[ buffer(0) ]], uint id [[ vert
return vertices[id];
}
fragment float4 ps_draw_biplanar(Vertex v [[ stage_in ]],
fragment half4 ps_draw_biplanar(Vertex v [[ stage_in ]],
constant CscParams &cscParams [[ buffer(0) ]],
texture2d<half> luminancePlane [[ texture(0) ]],
texture2d<half> chrominancePlane [[ texture(1) ]])
{
float2 chromaOffset = float2(cscParams.chromaOffset) / float2(luminancePlane.get_width(),
luminancePlane.get_height());
half3 yuv = half3(luminancePlane.sample(s, v.texCoords).r,
chrominancePlane.sample(s, v.texCoords + chromaOffset).rg);
yuv *= cscParams.bitnessScaleFactor;
yuv -= cscParams.offsets;
return half4(yuv * cscParams.matrix, 1.0h);
}
fragment half4 ps_draw_triplanar(Vertex v [[ stage_in ]],
constant CscParams &cscParams [[ buffer(0) ]],
texture2d<float> luminancePlane [[ texture(0) ]],
texture2d<float> chrominancePlane [[ texture(1) ]])
texture2d<half> luminancePlane [[ texture(0) ]],
texture2d<half> chrominancePlaneU [[ texture(1) ]],
texture2d<half> chrominancePlaneV [[ texture(2) ]])
{
float2 chromaOffset = float2(cscParams.chromaOffset.x / luminancePlane.get_width(),
cscParams.chromaOffset.y / luminancePlane.get_height());
float3 yuv = float3(luminancePlane.sample(s, v.texCoords).r,
chrominancePlane.sample(s, v.texCoords + chromaOffset).rg);
float2 chromaOffset = float2(cscParams.chromaOffset) / float2(luminancePlane.get_width(),
luminancePlane.get_height());
half3 yuv = half3(luminancePlane.sample(s, v.texCoords).r,
chrominancePlaneU.sample(s, v.texCoords + chromaOffset).r,
chrominancePlaneV.sample(s, v.texCoords + chromaOffset).r);
yuv *= cscParams.bitnessScaleFactor;
yuv -= cscParams.offsets;
float3 rgb;
rgb.r = dot(yuv, cscParams.matrix[0]);
rgb.g = dot(yuv, cscParams.matrix[1]);
rgb.b = dot(yuv, cscParams.matrix[2]);
return float4(rgb, 1.0f);
return half4(yuv * cscParams.matrix, 1.0h);
}
fragment float4 ps_draw_triplanar(Vertex v [[ stage_in ]],
constant CscParams &cscParams [[ buffer(0) ]],
texture2d<float> luminancePlane [[ texture(0) ]],
texture2d<float> chrominancePlaneU [[ texture(1) ]],
texture2d<float> chrominancePlaneV [[ texture(2) ]])
{
float2 chromaOffset = float2(cscParams.chromaOffset.x / luminancePlane.get_width(),
cscParams.chromaOffset.y / luminancePlane.get_height());
float3 yuv = float3(luminancePlane.sample(s, v.texCoords).r,
chrominancePlaneU.sample(s, v.texCoords + chromaOffset).r,
chrominancePlaneV.sample(s, v.texCoords + chromaOffset).r);
yuv *= cscParams.bitnessScaleFactor;
yuv -= cscParams.offsets;
float3 rgb;
rgb.r = dot(yuv, cscParams.matrix[0]);
rgb.g = dot(yuv, cscParams.matrix[1]);
rgb.b = dot(yuv, cscParams.matrix[2]);
return float4(rgb, 1.0f);
}
fragment float4 ps_draw_rgb(Vertex v [[ stage_in ]],
texture2d<float> rgbTexture [[ texture(0) ]])
fragment half4 ps_draw_rgb(Vertex v [[ stage_in ]],
texture2d<half> rgbTexture [[ texture(0) ]])
{
return rgbTexture.sample(s, v.texCoords);
}

View File

@@ -24,15 +24,15 @@ extern "C" {
struct CscParams
{
simd_float3 matrix[3];
simd_float3 offsets;
simd_half3x3 matrix;
simd_half3 offsets;
};
struct ParamBuffer
{
CscParams cscParams;
simd_float2 chromaOffset;
float bitnessScaleFactor;
simd_half2 chromaOffset;
simd_half1 bitnessScaleFactor;
};
struct Vertex
@@ -262,18 +262,14 @@ public:
getFramePremultipliedCscConstants(frame, cscMatrix, yuvOffsets);
getFrameChromaCositingOffsets(frame, chromaOffset);
// Copy the row-major CSC matrix into column-major for Metal
for (int i = 0; i < 3; i++) {
paramBuffer.cscParams.matrix[i] = simd_make_float3(cscMatrix[0 + i],
cscMatrix[3 + i],
cscMatrix[6 + i]);
}
paramBuffer.cscParams.offsets = simd_make_float3(yuvOffsets[0],
yuvOffsets[1],
yuvOffsets[2]);
paramBuffer.chromaOffset = simd_make_float2(chromaOffset[0],
chromaOffset[1]);
paramBuffer.cscParams.matrix = simd_matrix(simd_make_half3(cscMatrix[0], cscMatrix[3], cscMatrix[6]),
simd_make_half3(cscMatrix[1], cscMatrix[4], cscMatrix[7]),
simd_make_half3(cscMatrix[2], cscMatrix[5], cscMatrix[8]));
paramBuffer.cscParams.offsets = simd_make_half3(yuvOffsets[0],
yuvOffsets[1],
yuvOffsets[2]);
paramBuffer.chromaOffset = simd_make_half2(chromaOffset[0],
chromaOffset[1]);
// Set the EDR metadata for HDR10 to enable OS tonemapping
if (frame->color_trc == AVCOL_TRC_SMPTE2084 && m_MasteringDisplayColorVolume != nullptr) {