[FFmpeg-devel] [PATCH v2 3/4] libavcodec/vulkan: Add vulkan vc2 shaders

IndecisiveTurtle geoster3d at gmail.com
Sat Mar 8 17:33:57 EET 2025


---
 libavcodec/vulkan/vc2_dwt_haar.comp          |  66 +++++++
 libavcodec/vulkan/vc2_dwt_haar_subgroup.comp |  89 +++++++++
 libavcodec/vulkan/vc2_dwt_hor_legall.comp    |  66 +++++++
 libavcodec/vulkan/vc2_dwt_upload.comp        |  29 +++
 libavcodec/vulkan/vc2_dwt_ver_legall.comp    |  62 +++++++
 libavcodec/vulkan/vc2_encode.comp            | 173 ++++++++++++++++++
 libavcodec/vulkan/vc2_slice_sizes.comp       | 183 +++++++++++++++++++
 7 files changed, 668 insertions(+)
 create mode 100644 libavcodec/vulkan/vc2_dwt_haar.comp
 create mode 100644 libavcodec/vulkan/vc2_dwt_haar_subgroup.comp
 create mode 100644 libavcodec/vulkan/vc2_dwt_hor_legall.comp
 create mode 100644 libavcodec/vulkan/vc2_dwt_upload.comp
 create mode 100644 libavcodec/vulkan/vc2_dwt_ver_legall.comp
 create mode 100644 libavcodec/vulkan/vc2_encode.comp
 create mode 100644 libavcodec/vulkan/vc2_slice_sizes.comp

diff --git a/libavcodec/vulkan/vc2_dwt_haar.comp b/libavcodec/vulkan/vc2_dwt_haar.comp
new file mode 100644
index 0000000000..793053b6c1
--- /dev/null
+++ b/libavcodec/vulkan/vc2_dwt_haar.comp
@@ -0,0 +1,66 @@
+#extension GL_EXT_scalar_block_layout : require
+#extension GL_EXT_buffer_reference : require
+
+#define LOCAL_X 256
+
+struct Plane {
+    ivec2 dim;
+    ivec2 dwt_dim;
+};
+
+layout(push_constant, scalar) uniform ComputeInfo {
+    int s;
+    int plane_idx;
+    int wavelet_depth;
+    Plane planes[3];
+};
+
+shared int local_coef[LOCAL_X];
+
+void main()
+{
+    ivec2 coord = ivec2(gl_GlobalInvocationID.xy);
+    ivec2 dwt_dim = planes[plane_idx].dwt_dim;
+    int value = imageLoad(planes0[plane_idx], coord).x;
+
+    /* Perform Haar wavelet on the 16x16 local workgroup with shared memory */
+    for (int i = 0; i < wavelet_depth; i++)
+    {
+        ivec2 mask = ivec2((1 << i) - 1);
+        if (any(notEqual(coord & mask, ivec2(0))))
+            break;
+        
+        /* Offset between valid hor pixels for each level, +1, +2, +4 etc */
+        int dist = (1 << i);
+        
+        local_coef[gl_LocalInvocationIndex] = value;
+        barrier();
+
+        /* Horizontal haar wavelet */
+        uint other_id = gl_LocalInvocationIndex ^ dist;
+        int other = local_coef[other_id];
+        int a = gl_LocalInvocationIndex < other_id ? value : other;
+        int b = gl_LocalInvocationIndex < other_id ? other : value;
+        int dst_b = (b - a) * (1 << s);
+        int dst_a = a * (1 << s) + ((dst_b + 1) >> 1);
+        value = gl_LocalInvocationIndex < other_id ? dst_a : dst_b;
+
+        /* Offset between valid ver pixels for each level, +1, +2, +4 etc */
+        dist <<= 4;
+
+        local_coef[gl_LocalInvocationIndex] = value;
+        barrier();
+
+        /* Vertical haar wavelet */
+        other_id = gl_LocalInvocationIndex ^ dist;
+        other = local_coef[other_id];
+        a = gl_LocalInvocationIndex < other_id ? value : other;
+        b = gl_LocalInvocationIndex < other_id ? other : value;
+        dst_b = b - a;
+        dst_a = a + ((dst_b + 1) >> 1);
+        value = gl_LocalInvocationIndex < other_id ? dst_a : dst_b;
+    }
+
+    /* Store value */
+    imageStore(planes0[plane_idx], coord, ivec4(value));
+}
diff --git a/libavcodec/vulkan/vc2_dwt_haar_subgroup.comp b/libavcodec/vulkan/vc2_dwt_haar_subgroup.comp
new file mode 100644
index 0000000000..da0b9f72ca
--- /dev/null
+++ b/libavcodec/vulkan/vc2_dwt_haar_subgroup.comp
@@ -0,0 +1,89 @@
+#extension GL_EXT_scalar_block_layout : require
+#extension GL_KHR_shader_subgroup_basic : require
+#extension GL_KHR_shader_subgroup_shuffle : require
+#extension GL_EXT_buffer_reference : require
+
+#define TILE_DIM 8
+
+layout(scalar, buffer_reference, buffer_reference_align = 4) buffer DwtCoef {
+    int coef_buf[];
+};
+
+struct Plane {
+    ivec2 dim;
+    ivec2 dwt_dim;
+};
+
+layout(push_constant, scalar) uniform ComputeInfo {
+    int s;
+    int plane_idx;
+    int wavelet_depth;
+    Plane planes[3];
+    DwtCoef pbuf[3];
+};
+
+int dwt_haar_subgroup(int value, int i)
+{
+    /* Offset between valid hor pixels for each level, +1, +2, +4 etc */
+    int dist = (1 << i);
+
+    /* Horizontal haar wavelet */
+    uint other_sub_id = gl_SubgroupInvocationID ^ dist;
+    int other = subgroupShuffle(value, other_sub_id);
+    int a = gl_SubgroupInvocationID < other_sub_id ? value : other;
+    int b = gl_SubgroupInvocationID < other_sub_id ? other : value;
+    int dst_b = (b - a) * (1 << s);
+    int dst_a = a * (1 << s) + ((dst_b + 1) >> 1);
+    value = gl_SubgroupInvocationID < other_sub_id ? dst_a : dst_b;
+
+    /* Offset between valid ver pixels for each level, +1, +2, +4 etc */
+    dist <<= 3;
+
+    /* Vertical haar wavelet */
+    other_sub_id = gl_SubgroupInvocationID ^ dist;
+    other = subgroupShuffle(value, other_sub_id);
+    a = gl_SubgroupInvocationID < other_sub_id ? value : other;
+    b = gl_SubgroupInvocationID < other_sub_id ? other : value;
+    dst_b = b - a;
+    dst_a = a + ((dst_b + 1) >> 1);
+    return gl_SubgroupInvocationID < other_sub_id ? dst_a : dst_b;
+}
+
+bool is_thread_active(int i, ivec2 coord)
+{
+    if (i > wavelet_depth - 1)
+        return false;
+    ivec2 mask = ivec2((1 << i) - 1);
+    if (any(notEqual(coord & mask, ivec2(0))))
+        return false;
+    return true;
+}
+
+void main() {
+    ivec2 tile_coord = ivec2(gl_WorkGroupID.xy);
+    ivec2 local_coord = ivec2(gl_LocalInvocationIndex & 7, gl_LocalInvocationIndex >> 3);
+    ivec2 coord = tile_coord * ivec2(TILE_DIM) + local_coord;
+    ivec2 dwt_dim = planes[plane_idx].dwt_dim;
+    if (any(greaterThanEqual(coord, dwt_dim))) {
+        return;
+    }
+    int index = dwt_dim.x * coord.y + coord.x;
+    int value = pbuf[plane_idx].coef_buf[index];
+
+    if (gl_SubgroupSize == 64) {
+        for (int i = 0; i < 3; i++) {
+            if (!is_thread_active(i, local_coord))
+                break;
+            value = dwt_haar_subgroup(value, i);
+        }
+    } else {
+        for (int i = 0; i < 2; i++) {
+            if (!is_thread_active(i, local_coord))
+                break;
+            value = dwt_haar_subgroup(value, i);
+        }
+    }
+ 
+    // Store value
+    pbuf[plane_idx].coef_buf[index] = value;
+}
diff --git a/libavcodec/vulkan/vc2_dwt_hor_legall.comp b/libavcodec/vulkan/vc2_dwt_hor_legall.comp
new file mode 100644
index 0000000000..3eece4ab48
--- /dev/null
+++ b/libavcodec/vulkan/vc2_dwt_hor_legall.comp
@@ -0,0 +1,66 @@
+#extension GL_EXT_scalar_block_layout : require
+#extension GL_EXT_buffer_reference : require
+
+struct Plane {
+    ivec2 dim;
+    ivec2 dwt_dim;
+};
+
+layout(push_constant, scalar) uniform ComputeInfo {
+    int s;
+    int diff_offset;
+    int level;
+    Plane planes[3];
+};
+
+int image_load(int coord_x)
+{
+    int coord_y = int(gl_GlobalInvocationID.x);
+    return imageLoad(planes0[gl_GlobalInvocationID.z], ivec2(coord_x, coord_y)).x;
+}
+
+void image_store(int coord_x, int value)
+{
+    int coord_y = int(gl_GlobalInvocationID.x);
+    imageStore(planes0[gl_GlobalInvocationID.z], ivec2(coord_x, coord_y), ivec4(value));
+}
+
+void main()
+{
+    int coord_y = int(gl_GlobalInvocationID.x);
+    uint plane_idx = gl_GlobalInvocationID.z;
+    ivec2 work_area = planes[plane_idx].dwt_dim;
+    int dist = 1 << level;
+    if (coord_y >= work_area.y || (coord_y & (dist - 1)) != 0)
+        return;
+    
+    // Shift in one bit that is used for additional precision
+    for (int x = 0; x < work_area.x; x += dist)
+        image_store(x, image_load(x) << 1);
+
+    // Lifting stage 2
+    for (int x = 0; x < work_area.x - 2 * dist; x += 2 * dist) {
+        int lhs = image_load(x);
+        int rhs = image_load(x + 2 * dist);
+        int value = image_load(x + dist);
+        value -= (lhs + rhs + 1) >> 1;
+        image_store(x + dist, value);
+    }
+    int lhs = image_load(work_area.x - 2 * dist);
+    int value = image_load(work_area.x - dist);
+    value -= (2 * lhs + 1) >> 1;
+    image_store(work_area.x - dist, value);
+
+    // Lifting stage 1
+    lhs = image_load(dist);
+    value = image_load(0);
+    value += (2 * lhs + 2) >> 2;
+    image_store(0, value);
+    for (int x = 2 * dist; x <= work_area.x - 2 * dist; x += 2 * dist) {
+        int lhs = image_load(x - dist);
+        int rhs = image_load(x + dist);
+        int value = image_load(x);
+        value += (lhs + rhs + 2) >> 2;
+        image_store(x, value);
+    }
+}
diff --git a/libavcodec/vulkan/vc2_dwt_upload.comp b/libavcodec/vulkan/vc2_dwt_upload.comp
new file mode 100644
index 0000000000..6de3721d3b
--- /dev/null
+++ b/libavcodec/vulkan/vc2_dwt_upload.comp
@@ -0,0 +1,29 @@
+#extension GL_EXT_scalar_block_layout : require
+#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_buffer_reference : require
+
+layout(scalar, buffer_reference, buffer_reference_align = 1) buffer PlaneBuf {
+    uint8_t data[];
+};
+
+struct Plane {
+    ivec2 dim;
+    ivec2 dwt_dim;
+};
+
+layout(push_constant, scalar) uniform ComputeInfo {
+    int s;
+    int diff_offset;
+    int level;
+    Plane planes[3];
+};
+
+void main()
+{
+    ivec2 coord = ivec2(gl_GlobalInvocationID.xy);
+    uint plane_idx = gl_GlobalInvocationID.z;
+    ivec2 coord_i = clamp(coord, ivec2(0), planes[plane_idx].dim);
+    uint texel = imageLoad(planes1[plane_idx], coord_i).x;
+    int result = int(texel - diff_offset);
+    imageStore(planes0[plane_idx], coord, ivec4(result));
+}
diff --git a/libavcodec/vulkan/vc2_dwt_ver_legall.comp b/libavcodec/vulkan/vc2_dwt_ver_legall.comp
new file mode 100644
index 0000000000..28cfb97a7a
--- /dev/null
+++ b/libavcodec/vulkan/vc2_dwt_ver_legall.comp
@@ -0,0 +1,62 @@
+#extension GL_EXT_scalar_block_layout : require
+#extension GL_EXT_buffer_reference : require
+
+struct Plane {
+    ivec2 dim;
+    ivec2 dwt_dim;
+};
+
+layout(push_constant, scalar) uniform ComputeInfo {
+    int s;
+    int diff_offset;
+    int level;
+    Plane planes[3];
+};
+
+int image_load(int coord_y)
+{
+    int coord_x = int(gl_GlobalInvocationID.x);
+    return imageLoad(planes0[gl_GlobalInvocationID.z], ivec2(coord_x, coord_y)).x;
+}
+
+void image_store(int coord_y, int value)
+{
+    int coord_x = int(gl_GlobalInvocationID.x);
+    imageStore(planes0[gl_GlobalInvocationID.z], ivec2(coord_x, coord_y), ivec4(value));
+}
+
+void main()
+{
+    int coord_x = int(gl_GlobalInvocationID.x);
+    uint plane_idx = gl_GlobalInvocationID.z;
+    ivec2 work_area = planes[plane_idx].dwt_dim;
+    int dist = 1 << level;
+    if (coord_x >= work_area.x || (coord_x & (dist - 1)) != 0)
+        return;
+
+    // Lifting stage 2
+    for (int y = dist; y < work_area.y - 2 * dist; y += 2 * dist) {
+        int lhs = image_load(y - dist);
+        int rhs = image_load(y + dist);
+        int value = image_load(y);
+        value -= (lhs + rhs + 1) >> 1;
+        image_store(y, value);
+    }
+    int lhs = image_load(work_area.y - 2 * dist);
+    int value = image_load(work_area.y - dist);
+    value -= (2 * lhs + 1) >> 1;
+    image_store(work_area.y - dist, value);
+
+    // Lifting stage 1
+    lhs = image_load(dist);
+    value = image_load(0);
+    value += (2 * lhs + 2) >> 2;
+    image_store(0, value);
+    for (int y = 2 * dist; y <= work_area.y - 2 * dist; y += 2 * dist) {
+        int lhs = image_load(y + dist);
+        int rhs = image_load(y - dist);
+        int value = image_load(y);
+        value += (lhs + rhs + 2) >> 2;
+        image_store(y, value);
+    }
+}
diff --git a/libavcodec/vulkan/vc2_encode.comp b/libavcodec/vulkan/vc2_encode.comp
new file mode 100644
index 0000000000..da64d9c6d8
--- /dev/null
+++ b/libavcodec/vulkan/vc2_encode.comp
@@ -0,0 +1,173 @@
+#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_scalar_block_layout : require
+#extension GL_EXT_buffer_reference : require
+#extension GL_EXT_debug_printf : require
+
+#define MAX_DWT_LEVELS (5)
+
+struct SliceArgs {
+    int quant_idx;
+    int bytes;
+    int pb_start;
+    int pad;
+};
+
+struct Plane {
+    ivec2 dim;
+    ivec2 dwt_dim;
+};
+
+layout(std430, buffer_reference, buffer_reference_align = 16) buffer SliceArgBuf {
+    SliceArgs args[];
+};
+layout(scalar, buffer_reference, buffer_reference_align = 1) buffer BitBuf {
+    uint data[];
+};
+layout(scalar, buffer_reference, buffer_reference_align = 4) buffer QuantLuts {
+    int quant[5][4];
+    int ff_dirac_qscale_tab[116];
+};
+
+layout(push_constant, scalar) uniform ComputeInfo {
+    BitBuf bytestream;
+    QuantLuts luts;
+    SliceArgBuf slice;
+    ivec2 num_slices;
+    Plane planes[3];
+    int wavelet_depth;
+    int size_scaler;
+    int prefix_bytes;
+};
+
+void put_vc2_ue_uint(inout PutBitContext pb, uint val)
+{
+    int pbits = 0, topbit = 1, maxval = 1, bits = 0;
+    if (val == 0)
+    {
+        put_bits(pb, 1, 1);
+        return;
+    }
+    val++;
+
+    while (val > maxval)
+    {
+        topbit <<= 1;
+        bits++;
+        maxval <<= 1;
+        maxval |=  1;
+    }
+
+    for (int i = 0; i < bits; i++)
+    {
+        topbit >>= 1;
+        pbits <<= 2;
+        if ((val & topbit) != 0)
+            pbits |= 1;
+    }
+
+    put_bits(pb, bits * 2 + 1, (pbits << 1) | 1);
+}
+
+int quants[MAX_DWT_LEVELS][4];
+
+int subband_coord(int index, int h, int lvl)
+{
+    int coord = index;
+    coord <<= 1;
+    coord |= h;
+    coord <<= (wavelet_depth-lvl-1);
+    return coord;
+}
+
+void main()
+{
+    int slice_index = int(gl_GlobalInvocationID.x);
+    int max_index = num_slices.x * num_slices.y;
+    if (slice_index >= max_index)
+        return;
+
+    /* Step 2. Quantize and encode */
+    int pb_start = slice.args[slice_index].pb_start;
+    int workgroup_x = int(gl_WorkGroupSize.x);
+    for (int i = 0, index = workgroup_x - 1; i < gl_WorkGroupID.x; i++) {
+        pb_start += slice.args[index].pb_start + slice.args[index].bytes;
+        index += workgroup_x;
+    }
+    ivec2 slice_coord = ivec2(slice_index % num_slices.x, slice_index / num_slices.x);
+    int slice_bytes_max = slice.args[slice_index].bytes;
+    int quant_index = slice.args[slice_index].quant_idx;
+
+    PutBitContext pb;
+    init_put_bits(pb, OFFBUF(u8buf, bytestream, pb_start), slice_bytes_max);
+
+    for (int level = 0; level < wavelet_depth; level++)
+        for (int orientation = int(level > 0); orientation < 4; orientation++)
+            quants[level][orientation] = max(quant_index - luts.quant[level][orientation], 0);
+
+    /* Write quant index for this slice */
+    put_bits(pb, 8, quant_index);
+
+    /* Luma + 2 Chroma planes */
+    for (int p = 0; p < 3; p++)
+    {
+        int pad_s, pad_c;
+        int bytes_start = put_bytes_count(pb);
+
+        /* Save current location and write a zero value */
+        uint64_t write_ptr_start = pb.buf;
+        int bit_left_start = pb.bit_left;
+        put_bits(pb, 8, 0);
+
+        int stride = align(planes[p].dwt_dim.x, 32);
+        for (int level = 0; level < wavelet_depth; level++)
+        {
+            ivec2 band_size = planes[p].dwt_dim >> (wavelet_depth - level);
+            for (int o = int(level > 0); o < 4; o++)
+            {
+                /* Encode subband */
+                int left = band_size.x * (slice_coord.x) / num_slices.x;
+                int right = band_size.x * (slice_coord.x+1) / num_slices.x;
+                int top = band_size.y * (slice_coord.y) / num_slices.y;
+                int bottom = band_size.y * (slice_coord.y+1) / num_slices.y;
+
+                const int q_idx = quants[level][o];
+                const int qfactor = luts.ff_dirac_qscale_tab[q_idx];
+
+                const int yh = o >> 1;
+                const int xh = o & 1;
+
+                for (int y = top; y < bottom; y++)
+                {
+                    for (int x = left; x < right; x++)
+                    {
+                        int sx = subband_coord(x, xh, level);
+                        int sy = subband_coord(y, yh, level);
+                        int coef = imageLoad(planes0[p], ivec2(sx, sy)).x;
+                        uint c_abs = uint(abs(coef));
+                        c_abs = (c_abs << 2) / qfactor;
+                        put_vc2_ue_uint(pb, c_abs);
+                        if (c_abs != 0)
+                            put_bits(pb, 1, int(coef < 0));
+                    }
+                }
+            }
+        }
+        flush_put_bits(pb);
+        int bytes_len = put_bytes_count(pb) - bytes_start - 1;
+        if (p == 2)
+        {
+            int len_diff = slice_bytes_max - put_bytes_count(pb);
+            pad_s = align((bytes_len + len_diff), size_scaler)/size_scaler;
+            pad_c = (pad_s*size_scaler) - bytes_len;
+        }
+        else
+        {
+            pad_s = align(bytes_len, size_scaler)/size_scaler;
+            pad_c = (pad_s*size_scaler) - bytes_len;
+        }
+        uint64_t start_ptr = write_ptr_start + ((BUF_BITS - bit_left_start) >> 3);
+        u8buf(start_ptr).v = uint8_t(pad_s);
+        /* vc2-reference uses that padding that decodes to '0' coeffs */
+        skip_put_bytes(pb, pad_c);
+    }
+}
diff --git a/libavcodec/vulkan/vc2_slice_sizes.comp b/libavcodec/vulkan/vc2_slice_sizes.comp
new file mode 100644
index 0000000000..9c048f3664
--- /dev/null
+++ b/libavcodec/vulkan/vc2_slice_sizes.comp
@@ -0,0 +1,183 @@
+#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_scalar_block_layout : require
+#extension GL_EXT_buffer_reference : require
+
+#define DIRAC_MAX_QUANT_INDEX 116
+#define MAX_DWT_LEVELS 5
+
+struct SliceArgs {
+    int quant_idx;
+    int bytes;
+    int pb_start;
+    int pad;
+};
+
+struct Plane {
+    ivec2 dim;
+    ivec2 dwt_dim;
+};
+
+layout(std430, buffer_reference) buffer SliceArgBuf {
+    SliceArgs args[];
+};
+layout(scalar, buffer_reference, buffer_reference_align = 4) buffer QuantLuts {
+    int quant[5][4];
+    int ff_dirac_qscale_tab[116];
+};
+
+layout(push_constant, scalar) uniform ComputeInfo {
+    QuantLuts luts;
+    SliceArgBuf slice;
+    ivec2 num_slices;
+    Plane planes[3];
+    int wavelet_depth;
+    int size_scaler;
+    int prefix_bytes;
+    int bits_ceil;
+    int bits_floor;
+};
+
+int count_vc2_ue_uint(uint val)
+{
+    uint topbit = 1, maxval = 1;
+    int bits = 0;
+    if (val == 0)
+        return 1;
+    val++;
+    while (val > maxval)
+    {
+        bits++;
+        topbit <<= 1;
+        maxval <<= 1;
+        maxval |=  1;
+    }
+    return bits * 2 + 1;
+}
+
+int cache[DIRAC_MAX_QUANT_INDEX];
+int quants[MAX_DWT_LEVELS][4];
+shared int slice_sizes[gl_WorkGroupSize.x];
+
+int subband_coord(int index, int h, int lvl)
+{
+    int coord = index;
+    coord <<= 1;
+    coord |= h;
+    coord <<= (wavelet_depth-lvl-1);
+    return coord;
+}
+
+int count_hq_slice(int quant_index)
+{
+    int bits = 0;
+    if (cache[quant_index] != 0)
+        return cache[quant_index];
+
+    bits += 8*prefix_bytes;
+    bits += 8; /* quant_idx */
+
+    for (int level = 0; level < wavelet_depth; level++)
+        for (int orientation = int(level > 0); orientation < 4; orientation++)
+            quants[level][orientation] = max(quant_index - luts.quant[level][orientation], 0);
+
+    int slice_index = int(gl_GlobalInvocationID.x);
+    ivec2 slice_coord = ivec2(slice_index % num_slices.x, slice_index / num_slices.x);
+    for (int p = 0; p < 3; p++)
+    {
+        int bytes_start = bits >> 3;
+        bits += 8;
+
+        const int stride = align(planes[p].dwt_dim.x, 32);
+        for (int level = 0; level < wavelet_depth; level++)
+        {
+            ivec2 band_dim = planes[p].dwt_dim >> (wavelet_depth - level);
+            for (int o = int(level > 0); o < 4; o++)
+            {
+                const int left = band_dim.x * slice_coord.x / num_slices.x;
+                const int right = band_dim.x * (slice_coord.x+1) / num_slices.x;
+                const int top = band_dim.y * slice_coord.y / num_slices.y;
+                const int bottom = band_dim.y * (slice_coord.y+1) / num_slices.y;
+
+                const int q_idx = quants[level][o];
+                const int qfactor = luts.ff_dirac_qscale_tab[q_idx];
+
+                const int yh = o >> 1;
+                const int xh = o & 1;
+
+                for (int y = top; y < bottom; y++)
+                {
+                    for (int x = left; x < right; x++)
+                    {
+                        int sx = subband_coord(x, xh, level);
+                        int sy = subband_coord(y, yh, level);
+                        int coef = imageLoad(planes0[p], ivec2(sx, sy)).x;
+                        uint c_abs = uint(abs(coef));
+                        c_abs = (c_abs << 2) / qfactor;
+                        bits += count_vc2_ue_uint(c_abs);
+                        bits += int(c_abs > 0);
+                    }
+                }
+            }
+        }
+        bits += align(bits, 8) - bits;
+        int bytes_len = (bits >> 3) - bytes_start - 1;
+        int pad_s = align(bytes_len, size_scaler) / size_scaler;
+        int pad_c = (pad_s * size_scaler) - bytes_len;
+        bits += pad_c * 8;
+    }
+
+    cache[quant_index] = bits;
+    return bits;
+}
+
+int ssize_round(int b)
+{
+    return align(b, size_scaler) + 4 + prefix_bytes;
+}
+
+void main()
+{
+    int slice_index = int(gl_GlobalInvocationID.x);
+    int max_index = num_slices.x * num_slices.y;
+    if (slice_index >= max_index)
+        return;
+
+    for (int i = 0; i < DIRAC_MAX_QUANT_INDEX; i++)
+        cache[i] = 0;
+
+    const int q_ceil = DIRAC_MAX_QUANT_INDEX;
+    const int top = bits_ceil;
+    const int bottom = bits_floor;
+    int quant_buf[2] = int[2](-1, -1);
+    int quant = slice.args[slice_index].quant_idx;
+    int step = 1;
+    int bits_last = 0;
+    int bits = count_hq_slice(quant);
+    while ((bits > top) || (bits < bottom))
+    {
+        const int signed_step = bits > top ? +step : -step;
+        quant = clamp(quant + signed_step, 0, q_ceil-1);
+        bits = count_hq_slice(quant);
+        if (quant_buf[1] == quant)
+        {
+            quant = max(quant_buf[0], quant);
+            bits = quant == quant_buf[0] ? bits_last : bits;
+            break;
+        }
+        step = clamp(step / 2, 1, (q_ceil - 1) / 2);
+        quant_buf[1] = quant_buf[0];
+        quant_buf[0] = quant;
+        bits_last = bits;
+    }
+    int bytes = ssize_round(bits >> 3);
+    slice.args[slice_index].quant_idx = clamp(quant, 0, q_ceil-1);
+    slice.args[slice_index].bytes = bytes;
+    slice_sizes[gl_LocalInvocationIndex] = bytes;
+    barrier();
+
+    /* Prefix sum for all slices in current workgroup */
+    int total_bytes = 0;
+    for (int i = 0; i < gl_LocalInvocationIndex; i++)
+        total_bytes += slice_sizes[i];
+    slice.args[slice_index].pb_start = total_bytes;
+}
-- 
2.48.1



More information about the ffmpeg-devel mailing list