#include <metal_stdlib>

using namespace metal;

float3 RGB_TO_SRGB(float3 rgb) {
    float3 A = rgb / 12.92;
    float3 B = pow((rgb + 0.055) / 1.055, 2.4);
    rgb.r = rgb.r <= 0.04045 ? A.r : B.r;
    rgb.g = rgb.g <= 0.04045 ? A.g : B.g;
    rgb.b = rgb.b <= 0.04045 ? A.b : B.b;
    return rgb;
}

// Function to apply the HLG OETF
float3 BT2020_TO_HLG2100(float3 rgb) {
    // HLG constants for the OETF
    float aHLG = 0.17883277;
    float bHLG = 0.28466892;
    float cHLG = 0.55991073;

    float3 hlgRGB = select(
        sqrt(3.0 * rgb),
        aHLG * log(12.0 * rgb - bHLG) + cHLG,
        rgb > 1.0 / 12.0
    );
    return hlgRGB;
}

// Inverse tonemap transfer function defined in ITU-R BT.2446-1, method C
float inverseTonemapTransferMethodC(float Y_SDR) {
  const float k1 = 0.83802;
  const float k2 = 15.09968;
  const float k3 = 0.74204;
  const float k4 = 78.99439;
  // inflection point
  const float Y_SDR_ip = 58.5;
  // const float Y_HDR_ip = 58.5 / k1; // # 69.80740316460228;

  if (Y_SDR < Y_SDR_ip) {
    return Y_SDR / k1;
  } else {
    return (exp((Y_SDR - k4) / k2) + k3) * Y_SDR_ip / k1;
  }
}

constant float peak_HDR = 332.383; // = inverseTonemapTransfer(100.0)
constant float peak_SDR = 100.0;

void ff_matrix_mul_3x3_vec(float dst[3],  float vec[3], float mat[3][3])
{
    int m;

    for (m = 0; m < 3; m++)
        dst[m] = vec[0] * mat[m][0] +
                 vec[1] * mat[m][1] +
                 vec[2] * mat[m][2];
}

#define FFMAX(a,b) ((a) > (b) ? (a) : (b))
#define FFMIN(a,b) ((a) > (b) ? (b) : (a))

float3 adjustBrightness(float3 linearSRGB, float itmBrightnessAdjustmentD65WhitePointMaxDistance, float itmBrightnessAdjustmentMaxBrightnessScale) {

    // Convert linearized RGB with BT.2020 primaries to CIE 1931 XYZ
    float kRGB2020toCIE1931XYZ[3][3] = {
        {0.6370, 0.1446, 0.1689},
        {0.2627, 0.6780, 0.0593},
        {0.0000, 0.0281, 1.0610},
    };

    // Convert CIE 1931 XYZ to linearized RGB with BT.2020 primaries
    float kCIE1931XYZtoRGB2020[3][3] = {
        {1.7167, -0.3557, -0.2534},
        {-0.6667, 1.6165, 0.0158},
        {0.0176, -0.0428, 0.9421},
    };

    // Convert to XYZ and Yxy
    float XYZ[3];
    float rgbCross[3];
    rgbCross[0] = linearSRGB.r;
    rgbCross[1] = linearSRGB.g;
    rgbCross[2] = linearSRGB.b;

    ff_matrix_mul_3x3_vec(XYZ, rgbCross, kRGB2020toCIE1931XYZ);

    float Yxy[3];
    Yxy[0] = XYZ[1];
    {
       auto denom = XYZ[0] + XYZ[1] + XYZ[2];
       Yxy[1] /* x */ = (denom > 0) ? (XYZ[0] / denom) : 0.0;
       Yxy[2] /* y */ = (denom > 0) ? (XYZ[1] / denom) : 0.0;
    }

    // Y_HDR = inverseTonemapTransfer(Y_SDR)
    Yxy[0] *= peak_SDR;
    Yxy[0] = inverseTonemapTransferMethodC(Yxy[0]);
    Yxy[0] /= peak_HDR;

    // This code scales colors near D65 white point on the xy chromaticity
    // coordinates. maxScale = 0.5 means that the brightness of white point is
    // reduced by 50%. The farther away to the white point, the less the brightness
    // is reduced. The scale factor rapidly goes to 1 as the distance to the white
    // point increases. Assumes linear light.

    const float kD65x = 0.3127;
    const float kD65y = 0.3290;
    const float K = 40;
 // reference values
 // const float maxScale = 0.8;
 // const float dist = 0.2;

    float maxScale = itmBrightnessAdjustmentMaxBrightnessScale;
    float dist =
        sqrt(pow(Yxy[1] - kD65x, 2) + pow(Yxy[2] - kD65y, 2));

    float XYZ_HDR[3];
    XYZ_HDR[1] = Yxy[0];
    // X_HDR = (x / y) * Y_HDR
    XYZ_HDR[0] = (Yxy[2] > 0) ? ((Yxy[1] / Yxy[2]) * XYZ_HDR[1]) : 0.0;
    // Z_HDR = ((1 - x - y) / y) * Y_HDR
    XYZ_HDR[2] =
      (Yxy[2] > 0) ? (((1 - Yxy[1] - Yxy[2]) / Yxy[2]) * XYZ_HDR[1]) : 0.0;

    float rgbCross_HDR[3];
    ff_matrix_mul_3x3_vec(rgbCross_HDR, XYZ_HDR, kCIE1931XYZtoRGB2020);
    for (int i = 0; i < 3; ++i) {
        rgbCross_HDR[i] = FFMIN(1.0, FFMAX(0, rgbCross_HDR[i]));
    }

    linearSRGB.r = rgbCross_HDR[0];
    linearSRGB.g = rgbCross_HDR[1];
    linearSRGB.b = rgbCross_HDR[2];

    return linearSRGB;
}

constant float4x4 MAT_RGBA_TO_XYZ_BT709 = float4x4(float4(0.4124, 0.3576, 0.1805, 0.0), float4(0.2126, 0.7152, 0.07218, 0.0), float4(0.01933, 0.1192, 0.9504, 0.0), float4(0.0, 0.0, 0.0, 1.0));
constant float4x4 MAT_XYZ_TO_RGBA_BT2020 = float4x4(float4(1.7167, -0.3557, -0.2534, 0.0), float4(-0.6667, 1.6165, 0.01577, 0.0), float4(0.01764, -0.04277, 0.9421, 0.0), float4(0.0, 0.0, 0.0, 1.0));
constant float4x4 MAT_RGBA_BT709_TO_BT2020 = MAT_RGBA_TO_XYZ_BT709 * MAT_XYZ_TO_RGBA_BT2020;

float3 BT709_TO_HLG_2100_INV(float3 sdrRGB, float peakDSLNits, float enableInverseTonemapBrightnessAdjustment, float itmBrightnessAdjustmentD65WhitePointMaxDistance, float itmBrightnessAdjustmentMaxBrightnessScale, float2 texCoords) {

    // Convert from gamma-corrected BT.709 to linear
    float3 linearSRGB = RGB_TO_SRGB(sdrRGB);

    float4 sdrRGB_linear{1.0};
    sdrRGB_linear.rgb = linearSRGB * 1.0;
    float4 hdrRGB2020_displayLinear = clamp(sdrRGB_linear * MAT_RGBA_BT709_TO_BT2020, 0.0, 1.0);

    linearSRGB = hdrRGB2020_displayLinear.rgb;

    if (enableInverseTonemapBrightnessAdjustment != 0.0) {
        linearSRGB = adjustBrightness(linearSRGB, itmBrightnessAdjustmentD65WhitePointMaxDistance, itmBrightnessAdjustmentMaxBrightnessScale);
    }

    // Apply inverse tonemap with shoulder curve to prevent overly bright whites
    // Use BT.709 luminance coefficients since input is BT.709 linear RGB
    float hdrLuminance = dot(float3(0.2126, 0.7152, 0.0722), linearSRGB);

    // Apply ITU-R BT.2446 Method C with shoulder curve
    float Y_SDR_norm = hdrLuminance * 100.0;  // Convert to nits scale
    float Y_HDR_norm = inverseTonemapTransferMethodC(Y_SDR_norm);

    // Add shoulder curve - limit peak expansion to ~83% as recommended by ITU-R BT.2446
    const float maxExpansionRatio = 2.8;  // Instead of full 3.32 (332/100)
    Y_HDR_norm = min(Y_HDR_norm, 100.0 * maxExpansionRatio);

    float expansionFactor = (hdrLuminance > 0) ? (Y_HDR_norm / 100.0) / hdrLuminance : 1.0;

    // Apply smooth roll-off for very bright regions to avoid artifacts
    if (expansionFactor > 2.0) {
        float rollOffFactor = 2.0 + (expansionFactor - 2.0) * 0.5;  // Gentler expansion above 2x
        expansionFactor = rollOffFactor;
    }

    float3 hdrRGB = linearSRGB * expansionFactor;

    // Add subtle dithering to prevent false contours (banding) in gradients
    // Generate pseudo-random noise using texture coordinates
    float2 screenPos = texCoords * 1024.0; // Scale texture coordinates
    float noise = fract(sin(dot(screenPos, float2(12.9898, 78.233))) * 43758.5453);
    noise = (noise - 0.5) * (1.0/255.0);  // Small noise amplitude

    // Apply minimal dithering only to prevent quantization artifacts
    hdrRGB += float3(noise);

    // Convert to HLG BT.2100
    float3 hlgRGB = BT2020_TO_HLG2100(hdrRGB);

    // Clamp to ensure valid output range [0, peakDSLNits]
    hlgRGB = clamp(hlgRGB, float3(0.0), float3(peakDSLNits));

    return hlgRGB;
}

struct Vertex {
    float4 position [[position]];
    float2 texCoords;
};

vertex Vertex vertexShader(device const Vertex *vertices [[buffer(0)]],
                           uint vid [[vertex_id]]) {
    Vertex out;

    out.position = float4(vertices[vid].position.xy, 0, 1);
    out.texCoords = vertices[vid].texCoords;

    return out;
}

fragment float4 yuv420ToRGBFragmentShader(Vertex in [[stage_in]],
                                          texture2d<float, access::sample> yTexture [[texture(0)]],
                                          texture2d<float, access::sample> uvTexture [[texture(1)]],
                                          sampler texSampler [[ sampler(0) ]],
                                          constant float &peakDSLNits  [[buffer(0)]],
                                          constant float &itmBrightnessAdjustmentD65WhitePointMaxDistance  [[buffer(1)]],
                                          constant float &itmBrightnessAdjustmentMaxBrightnessScale  [[buffer(2)]],
                                          constant float &enableInverseTonemapBrightnessAdjustment [[buffer(3)]]) {
    const float3x3 yuvToRgbMatrix = float3x3(
        1.164,  1.164,  1.164,
        0.0,   -0.213,  2.112,
        1.793, -0.533,  0.0
    );
    const float3 offset = float3(float(-16)/float(255), -0.5, -0.5);
    float3 yuv;
    yuv.x  = yTexture.sample(texSampler, in.texCoords).x;
    yuv.yz = uvTexture.sample(texSampler, in.texCoords).xy;
    float3 rgb = yuvToRgbMatrix * (yuv + offset);

    // inverseTonemap
    float3 hlgColor = BT709_TO_HLG_2100_INV(rgb, peakDSLNits, enableInverseTonemapBrightnessAdjustment, itmBrightnessAdjustmentD65WhitePointMaxDistance, itmBrightnessAdjustmentMaxBrightnessScale, in.texCoords);

    return float4(hlgColor, 1.0);
}
