diff --git a/assets/shaders/inc/meshlet_primitive_cull.h b/assets/shaders/inc/meshlet_primitive_cull.h index 5a50d9bd..f931d62e 100644 --- a/assets/shaders/inc/meshlet_primitive_cull.h +++ b/assets/shaders/inc/meshlet_primitive_cull.h @@ -165,7 +165,15 @@ void meshlet_init_shared() uint meshlet_get_meshlet_index() { - return gl_WorkGroupID.y; + return gl_WorkGroupSize.y == 8 ? gl_WorkGroupID.x : gl_WorkGroupID.y; +} + +uint meshlet_get_sublet_index(uint meshlet_index, uint sublet_index) +{ + if (gl_WorkGroupSize.y == 8) + return 8u * meshlet_index + sublet_index; + else + return 8u * meshlet_index + gl_WorkGroupSize.y * gl_WorkGroupID.x + sublet_index; } void meshlet_emit_primitive(uvec3 prim, vec4 clip_pos, vec4 viewport) diff --git a/tests/assets/shaders/meshlet_debug_plain.mesh b/tests/assets/shaders/meshlet_debug_plain.mesh index ad561b5a..c5c26c42 100644 --- a/tests/assets/shaders/meshlet_debug_plain.mesh +++ b/tests/assets/shaders/meshlet_debug_plain.mesh @@ -120,12 +120,16 @@ void main() #endif #if defined(MESHLET_PRIMITIVE_CULL_WAVE32) && MESHLET_PRIMITIVE_CULL_WAVE32 - uint linear_index = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID; + uint linear_index = gl_SubgroupInvocationID; + uint sublet_index = gl_SubgroupID; #else - uint linear_index = gl_LocalInvocationIndex; + uint linear_index = gl_LocalInvocationID.x; + uint sublet_index = gl_LocalInvocationID.y; #endif - IndirectDrawMesh meshlet = indirect_commands_mesh.draws[8u * task.meshlet_index + gl_WorkGroupID.x]; + sublet_index = meshlet_get_sublet_index(task.meshlet_index, sublet_index); + IndirectDrawMesh meshlet = indirect_commands_mesh.draws[sublet_index]; + mat4 M = transforms.data[task.node_offset]; // Transform positions.