// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <metal_stdlib>
using namespace metal;

// Use half precision for ~2x throughput on Apple GPUs
// Video processing doesn't require 32-bit precision

// Pre-multiplied matrix: kCrosstalk * kBt709ToBt2020
// Combines BT.709→BT.2020 conversion with crosstalk desaturation (alpha=0.05)
// This saves one matrix multiply per pixel
constant half3x3 kBt709ToBt2020WithCrosstalk = half3x3(
    half3(0.5689h, 0.0944h, 0.0496h),  // Column 0
    half3(0.3467h, 0.8484h, 0.1416h),  // Column 1
    half3(0.0843h, 0.0572h, 0.8088h)   // Column 2
);

// RGB BT.2020 to CIE 1931 XYZ (column-major for Metal)
constant half3x3 kRGB2020toXYZ = half3x3(
    half3(0.6370h, 0.2627h, 0.0000h),  // Column 0
    half3(0.1446h, 0.6780h, 0.0281h),  // Column 1
    half3(0.1689h, 0.0593h, 1.0610h)   // Column 2
);

// CIE 1931 XYZ to RGB BT.2020 (column-major for Metal)
constant half3x3 kXYZtoRGB2020 = half3x3(
    half3(1.7167h, -0.6667h, 0.0176h),   // Column 0
    half3(-0.3557h, 1.6165h, -0.0428h),  // Column 1
    half3(-0.2534h, 0.0158h, 0.9421h)    // Column 2
);

// Inverse crosstalk matrix (precomputed for alpha = 0.05, column-major)
constant half3x3 kInverseCrosstalk = half3x3(
    half3(1.1176h, -0.0588h, -0.0588h),   // Column 0
    half3(-0.0588h, 1.1176h, -0.0588h),   // Column 1
    half3(-0.0588h, -0.0588h, 1.1176h)    // Column 2
);

// Sampler for LUT with linear interpolation
constexpr sampler lutSampler(coord::normalized, address::clamp_to_edge, filter::linear);

// LUT-based BT.1886 EOTF (gamma 2.4)
// Using texture2d with height=1 instead of texture1d for iOS Simulator compatibility
inline half3 bt1886EOTF_LUT(half3 rgb, texture2d<float, access::sample> lut) {
    return half3(
        half(lut.sample(lutSampler, float2(float(rgb.r), 0.5)).r),
        half(lut.sample(lutSampler, float2(float(rgb.g), 0.5)).r),
        half(lut.sample(lutSampler, float2(float(rgb.b), 0.5)).r)
    );
}

// LUT-based inverse tonemap transfer
// Using texture2d with height=1 instead of texture1d for iOS Simulator compatibility
inline half inverseToneMapTransfer_LUT(half Y, texture2d<float, access::sample> lut) {
    return half(lut.sample(lutSampler, float2(float(Y), 0.5)).r);
}

// LUT-based HLG OETF
// Using texture2d with height=1 instead of texture1d for iOS Simulator compatibility
inline half3 hlgOETF_LUT(half3 rgb, texture2d<float, access::sample> lut) {
    return half3(
        half(lut.sample(lutSampler, float2(float(rgb.r), 0.5)).r),
        half(lut.sample(lutSampler, float2(float(rgb.g), 0.5)).r),
        half(lut.sample(lutSampler, float2(float(rgb.b), 0.5)).r)
    );
}

// HLG Inverse OOTF - converts display linear to scene linear
// This uses pow() but only once per pixel (on scalar Y), harder to LUT
inline half3 hlgInverseOOTF(half3 rgbDisplay) {
    const half exponent = -1.0h / 6.0h;  // Precomputed for 1000 nits display

    half Y_D = 0.2627h * rgbDisplay.r + 0.6780h * rgbDisplay.g + 0.0593h * rgbDisplay.b;

    if (Y_D == 0.0h) {
        return rgbDisplay;
    }
    return pow(Y_D, exponent) * rgbDisplay;
}

// Convert 8-bit YCbCr (video range) to normalized RGB
inline half3 ycbcrToRgb(half y, half cb, half cr) {
    // Video range to normalized
    half yNorm = (y * 255.0h - 16.0h) / 219.0h;
    half cbNorm = (cb * 255.0h - 128.0h) / 224.0h;
    half crNorm = (cr * 255.0h - 128.0h) / 224.0h;

    yNorm = clamp(yNorm, 0.0h, 1.0h);

    // BT.709 YCbCr to RGB
    half3 rgb;
    rgb.r = yNorm + 1.5748h * crNorm;
    rgb.g = yNorm - 0.1873h * cbNorm - 0.4681h * crNorm;
    rgb.b = yNorm + 1.8556h * cbNorm;

    return clamp(rgb, 0.0h, 1.0h);
}

// Convert normalized RGB to 10-bit YCbCr (video range) - returns values in [0, 1023]
inline void rgbToYcbcr10bit(half3 rgb, thread half& y, thread half& cb, thread half& cr) {
    // BT.2020 RGB to YCbCr
    half Y_val = 0.2627h * rgb.r + 0.6780h * rgb.g + 0.0593h * rgb.b;
    half Cb_val = (rgb.b - Y_val) / 1.8814h;
    half Cr_val = (rgb.r - Y_val) / 1.4746h;

    // Convert to 10-bit video range
    y = clamp(Y_val * 876.0h + 64.0h, 64.0h, 940.0h);
    cb = clamp((Cb_val + 0.5h) * 896.0h + 64.0h, 64.0h, 960.0h);
    cr = clamp((Cr_val + 0.5h) * 896.0h + 64.0h, 64.0h, 960.0h);
}

// Full BT.2446 Method C inverse tonemapping pipeline (LUT-based)
// Using texture2d with height=1 instead of texture1d for iOS Simulator compatibility
inline half3 inverseTonemapBT2446MethodC(
    half3 rgb709,
    texture2d<float, access::sample> bt1886LUT,
    texture2d<float, access::sample> inverseToneMapLUT,
    texture2d<float, access::sample> hlgOetfLUT
) {
    // 1. Linearize with BT.1886 EOTF (gamma 2.4) - LUT
    half3 rgb709Linear = bt1886EOTF_LUT(max(rgb709, 0.0h), bt1886LUT);

    // 2. Convert to BT.2020 primaries with crosstalk (combined matrix)
    half3 rgbCross = kBt709ToBt2020WithCrosstalk * rgb709Linear;
    rgbCross = clamp(rgbCross, 0.0h, 1.0h);

    // 3. Convert to XYZ
    half3 XYZ = kRGB2020toXYZ * rgbCross;

    // 4. Convert to Yxy
    half Y = XYZ.y;
    half denom = XYZ.x + XYZ.y + XYZ.z;
    half x = (denom > 0.0h) ? (XYZ.x / denom) : 0.0h;
    half y_chroma = (denom > 0.0h) ? (XYZ.y / denom) : 0.0h;

    // 5. Apply inverse tonemap transfer - LUT
    // LUT is pre-normalized: input Y [0,1] -> output Y_HDR_linear [0,1]
    half Y_HDR_linear = inverseToneMapTransfer_LUT(Y, inverseToneMapLUT);

    // 6. Convert back to XYZ
    half3 XYZ_HDR;
    if (y_chroma > 0.0h) {
        XYZ_HDR.x = (Y_HDR_linear / y_chroma) * x;
        XYZ_HDR.y = Y_HDR_linear;
        XYZ_HDR.z = (Y_HDR_linear / y_chroma) * (1.0h - x - y_chroma);
    } else {
        XYZ_HDR = half3(Y_HDR_linear);
    }

    // 7. Convert to RGB BT.2020
    half3 rgbCross_HDR = kXYZtoRGB2020 * XYZ_HDR;
    rgbCross_HDR = clamp(rgbCross_HDR, 0.0h, 1.0h);

    // 8. Apply inverse crosstalk
    half3 rgb2020_HDR = kInverseCrosstalk * rgbCross_HDR;
    rgb2020_HDR = clamp(rgb2020_HDR, 0.0h, 1.0h);

    // 9. Apply HLG inverse OOTF and OETF - LUT for OETF
    half3 rgbScene = hlgInverseOOTF(rgb2020_HDR);
    half3 hlgOutput = hlgOETF_LUT(clamp(rgbScene, 0.0h, 1.0h), hlgOetfLUT);

    return clamp(hlgOutput, 0.0h, 1.0h);
}

// Main compute kernel for inverse tonemapping
// Input: 8-bit biplanar YCbCr (NV12)
// Output: 10-bit biplanar YCbCr (P010)
kernel void inverseTonemapKernel(
    texture2d<half, access::read> inputY [[texture(0)]],
    texture2d<half, access::read> inputCbCr [[texture(1)]],
    texture2d<half, access::write> outputY [[texture(2)]],
    texture2d<half, access::write> outputCbCr [[texture(3)]],
    texture2d<float, access::sample> bt1886LUT [[texture(4)]],
    texture2d<float, access::sample> inverseToneMapLUT [[texture(5)]],
    texture2d<float, access::sample> hlgOetfLUT [[texture(6)]],
    uint2 gid [[thread_position_in_grid]]
) {
    // Each thread processes a 2x2 block for proper chroma subsampling
    uint2 basePos = gid * 2;

    // Check bounds
    if (basePos.x >= inputY.get_width() || basePos.y >= inputY.get_height()) {
        return;
    }

    // Read shared chroma for the 2x2 block
    half2 cbcr = inputCbCr.read(gid).rg;
    half cb = cbcr.r;
    half cr = cbcr.g;

    // Process 4 pixels in the 2x2 block
    half3 hlg00, hlg01, hlg10, hlg11;

    // Pixel (0, 0)
    half y00 = inputY.read(basePos).r;
    half3 rgb00 = ycbcrToRgb(y00, cb, cr);
    hlg00 = inverseTonemapBT2446MethodC(rgb00, bt1886LUT, inverseToneMapLUT, hlgOetfLUT);

    // Pixel (1, 0)
    half y01 = inputY.read(basePos + uint2(1, 0)).r;
    half3 rgb01 = ycbcrToRgb(y01, cb, cr);
    hlg01 = inverseTonemapBT2446MethodC(rgb01, bt1886LUT, inverseToneMapLUT, hlgOetfLUT);

    // Pixel (0, 1)
    half y10 = inputY.read(basePos + uint2(0, 1)).r;
    half3 rgb10 = ycbcrToRgb(y10, cb, cr);
    hlg10 = inverseTonemapBT2446MethodC(rgb10, bt1886LUT, inverseToneMapLUT, hlgOetfLUT);

    // Pixel (1, 1)
    half y11 = inputY.read(basePos + uint2(1, 1)).r;
    half3 rgb11 = ycbcrToRgb(y11, cb, cr);
    hlg11 = inverseTonemapBT2446MethodC(rgb11, bt1886LUT, inverseToneMapLUT, hlgOetfLUT);

    // Convert to YCbCr and write output
    half outY00, outY01, outY10, outY11;
    half outCb00, outCr00, outCb01, outCr01, outCb10, outCr10, outCb11, outCr11;

    rgbToYcbcr10bit(hlg00, outY00, outCb00, outCr00);
    rgbToYcbcr10bit(hlg01, outY01, outCb01, outCr01);
    rgbToYcbcr10bit(hlg10, outY10, outCb10, outCr10);
    rgbToYcbcr10bit(hlg11, outY11, outCb11, outCr11);

    // Write Y values (normalized to [0, 1] for 10-bit texture)
    outputY.write(half4(outY00 / 1023.0h, 0.0h, 0.0h, 1.0h), basePos);
    outputY.write(half4(outY01 / 1023.0h, 0.0h, 0.0h, 1.0h), basePos + uint2(1, 0));
    outputY.write(half4(outY10 / 1023.0h, 0.0h, 0.0h, 1.0h), basePos + uint2(0, 1));
    outputY.write(half4(outY11 / 1023.0h, 0.0h, 0.0h, 1.0h), basePos + uint2(1, 1));

    // Average chroma for the 2x2 block and write
    half avgCb = (outCb00 + outCb01 + outCb10 + outCb11) / 4.0h;
    half avgCr = (outCr00 + outCr01 + outCr10 + outCr11) / 4.0h;
    outputCbCr.write(half4(avgCb / 1023.0h, avgCr / 1023.0h, 0.0h, 1.0h), gid);
}
