diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f720a3dd..693d14a8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,3 +44,9 @@ jobs: - name: test run: cargo test + + # - name: gaussian render test + # run: cargo run --bin test_gaussian + + # - name: radix sort test + # run: cargo run --bin test_radix --features="debug_gpu" diff --git a/.vscode/bevy_gaussian_splatting.code-workspace b/.vscode/bevy_gaussian_splatting.code-workspace new file mode 100644 index 00000000..b355fe5c --- /dev/null +++ b/.vscode/bevy_gaussian_splatting.code-workspace @@ -0,0 +1,11 @@ +{ + "folders": [ + { + "path": "../" + }, + { + "path": "../../bevy" + } + ], + "settings": {} +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 8ff516f6..99e1d334 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,44 +1,90 @@ [package] name = "bevy_gaussian_splatting" description = "bevy gaussian splatting render pipeline plugin" -version = "0.4.0" +version = "0.5.0" edition = "2021" authors = ["mosure "] license = "MIT" -keywords = ["bevy", "gaussian-splatting", "render-pipeline", "ply"] -categories = ["computer-vision", "graphics", "rendering", "rendering::data-formats"] +keywords = [ + "bevy", + "gaussian-splatting", + "render-pipeline", + "ply", +] +categories = [ + "computer-vision", + "graphics", + "rendering", + "rendering::data-formats", +] homepage = "https://github.com/mosure/bevy_gaussian_splatting" repository = "https://github.com/mosure/bevy_gaussian_splatting" readme = "README.md" -exclude = [".devcontainer", ".github", "docs", "dist", "build", "assets", "credits"] +exclude = [ + ".devcontainer", + ".github", + "docs", + "dist", + "build", + "assets", + "credits", +] default-run = "viewer" [features] -default = ["io_flexbuffers", "io_ply"] +default = [ + "io_flexbuffers", + "io_ply", + "sort_radix", + "sort_rayon", + "tooling", + "viewer", +] + +debug_gpu = [] io_bincode2 = ["bincode2", "flate2"] io_flexbuffers = ["flexbuffers"] io_ply = ["ply-rs"] +sort_radix = [] +sort_rayon = [ + "rayon", + "wasm-bindgen-rayon", +] + +tooling = [ + "byte-unit", +] + +viewer = [ + "bevy-inspector-egui", + "bevy_panorbit_camera", +] + + [dependencies] -bevy-inspector-egui = "0.21" -bevy_panorbit_camera = "0.9" +bevy-inspector-egui = { version = "0.21", optional = true } +bevy_panorbit_camera = { version = "0.9", optional = true } bincode2 = { version = "2.0", optional = true } -byte-unit = "5.0.3" +byte-unit = { version = "5.0", optional = true } bytemuck = "1.14" flate2 = { version = "1.0", optional = true } flexbuffers = { version = "2.0", optional = true } ply-rs = { version = "0.1", optional = true } rand = "0.8" +rayon = { version = "1.8", optional = true } serde = "1.0" +static_assertions = "1.1" wgpu = "0.17.1" [target.'cfg(target_arch = "wasm32")'.dependencies] console_error_panic_hook = "0.1" wasm-bindgen = "0.2" +wasm-bindgen-rayon = { version = "1.0", optional = true } # TODO: use minimal bevy features @@ -50,7 +96,6 @@ default-features = true version = "0.12" default-features = true features = [ - 'bevy_ci_testing', 'debug_glam_assert', ] @@ -71,10 +116,6 @@ features = [ criterion = { version = "0.5", features = ["html_reports"] } -# remove after close: https://github.com/jakobhellermann/bevy-inspector-egui/issues/163 -[profile.dev.package."bevy-inspector-egui"] -opt-level = 1 - [profile.dev.package."*"] opt-level = 3 @@ -99,17 +140,27 @@ path = "src/lib.rs" [[bin]] name = "viewer" path = "viewer/viewer.rs" +required-features = ["viewer"] [[bin]] name = "ply_to_gcloud" path = "tools/ply_to_gcloud.rs" -required-features = ["io_ply"] +required-features = ["io_ply", "tooling"] [[bin]] name = "compare_aabb_obb" path = "tools/compare_aabb_obb.rs" +[[bin]] +name = "test_gaussian" +path = "tests/gpu/gaussian.rs" + +[[bin]] +name = "test_radix" +path = "tests/gpu/radix.rs" +required-features = ["debug_gpu", "sort_radix"] + [[bench]] name = "io" diff --git a/src/gaussian.rs b/src/gaussian.rs index fa39718d..c83b6c72 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -21,6 +21,8 @@ use serde::{ ser::SerializeTuple, }; +use crate::sort::SortMode; + const fn num_sh_coefficients(degree: usize) -> usize { if degree == 0 { @@ -214,6 +216,7 @@ pub struct GaussianCloudSettings { pub global_scale: f32, pub global_transform: GlobalTransform, pub visualize_bounding_box: bool, + pub sort_mode: SortMode, } impl Default for GaussianCloudSettings { @@ -223,6 +226,7 @@ impl Default for GaussianCloudSettings { global_scale: 2.0, global_transform: Transform::IDENTITY.into(), visualize_bounding_box: false, + sort_mode: SortMode::default(), } } } diff --git a/src/io/gcloud/mod.rs b/src/io/gcloud/mod.rs index 50c5e235..08e4e9ce 100644 --- a/src/io/gcloud/mod.rs +++ b/src/io/gcloud/mod.rs @@ -1,5 +1,17 @@ +use static_assertions::assert_cfg; + + #[cfg(feature = "io_bincode2")] pub mod bincode2; #[cfg(feature = "io_flexbuffers")] pub mod flexbuffers; + + +assert_cfg!( + any( + feature = "io_bincode2", + feature = "io_flexbuffers", + ), + "no gcloud io enabled", +); diff --git a/src/lib.rs b/src/lib.rs index 51bb1a66..6ff57acd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,9 @@ use render::RenderPipelinePlugin; pub mod gaussian; pub mod io; +pub mod morph; pub mod render; +pub mod sort; pub mod utils; diff --git a/src/render/morph/mod.rs b/src/morph/mod.rs similarity index 100% rename from src/render/morph/mod.rs rename to src/morph/mod.rs diff --git a/src/render/morph/particle.rs b/src/morph/particle.rs similarity index 100% rename from src/render/morph/particle.rs rename to src/morph/particle.rs diff --git a/src/render/morph/particle.wgsl b/src/morph/particle.wgsl similarity index 100% rename from src/render/morph/particle.wgsl rename to src/morph/particle.wgsl diff --git a/src/render/morph/wavelet.rs b/src/morph/wavelet.rs similarity index 100% rename from src/render/morph/wavelet.rs rename to src/morph/wavelet.rs diff --git a/src/render/morph/wavelet.wgsl b/src/morph/wavelet.wgsl similarity index 100% rename from src/render/morph/wavelet.wgsl rename to src/morph/wavelet.wgsl diff --git a/src/render/mod.rs b/src/render/mod.rs index 52f249f4..5f5974cc 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -62,17 +62,13 @@ use crate::{ GaussianCloudSettings, MAX_SH_COEFF_COUNT, }, - render::{ - morph::MorphPlugin, - sort::RadixSortPlugin, + morph::MorphPlugin, + sort::{ + SortPlugin, + SortedEntries, }, }; -use self::sort::GpuRadixBuffers; - -pub mod morph; -pub mod sort; - const BINDINGS_SHADER_HANDLE: Handle = Handle::weak_from_u128(675257236); const GAUSSIAN_SHADER_HANDLE: Handle = Handle::weak_from_u128(68294581); @@ -117,7 +113,7 @@ impl Plugin for RenderPipelinePlugin { app.add_plugins(UniformComponentPlugin::::default()); app.add_plugins(( MorphPlugin, - RadixSortPlugin, + SortPlugin, )); if let Ok(render_app) = app.get_sub_app_mut(RenderApp) { @@ -150,6 +146,7 @@ impl Plugin for RenderPipelinePlugin { pub struct GpuGaussianSplattingBundle { pub settings: GaussianCloudSettings, pub settings_uniform: GaussianCloudUniform, + pub sorted_entries: Handle, pub verticies: Handle, } @@ -160,7 +157,8 @@ pub struct GpuGaussianCloud { pub draw_indirect_buffer: Buffer, - pub radix_sort_buffers: GpuRadixBuffers, + #[cfg(feature = "debug_gpu")] + pub debug_gpu: GaussianCloud, } impl RenderAsset for GaussianCloud { type ExtractedAsset = GaussianCloud; @@ -183,18 +181,25 @@ impl RenderAsset for GaussianCloud { let count = gaussian_cloud.gaussians.len(); - let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor { + // let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor { + // label: Some("draw indirect buffer"), + // size: std::mem::size_of::() as u64, + // usage: BufferUsages::INDIRECT | BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::COPY_SRC, + // mapped_at_creation: false, + // }); + + let draw_indirect_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor { label: Some("draw indirect buffer"), - size: std::mem::size_of::() as u64, + contents: bytemuck::cast_ref::<[u32; 4], [u8; 16]>(&[4, count as u32, 0, 0]), usage: BufferUsages::INDIRECT | BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::COPY_SRC, - mapped_at_creation: false, }); Ok(GpuGaussianCloud { gaussian_buffer, count, draw_indirect_buffer, - radix_sort_buffers: GpuRadixBuffers::new(count, render_device), + #[cfg(feature = "debug_gpu")] + debug_gpu: gaussian_cloud, }) } } @@ -208,9 +213,11 @@ fn queue_gaussians( mut pipelines: ResMut>, pipeline_cache: Res, gaussian_clouds: Res>, + sorted_entries: Res>, gaussian_splatting_bundles: Query<( Entity, &Handle, + &Handle, &GaussianCloudSettings, )>, mut views: Query<( @@ -225,30 +232,44 @@ fn queue_gaussians( let draw_custom = transparent_3d_draw_functions.read().id::(); - for (_view, mut transparent_phase) in &mut views { - for (entity, cloud, settings) in &gaussian_splatting_bundles { - if let Some(_cloud) = gaussian_clouds.get(cloud) { - let key = GaussianCloudPipelineKey { - aabb: settings.aabb, - visualize_bounding_box: settings.visualize_bounding_box, - }; - - let pipeline = pipelines.specialize(&pipeline_cache, &custom_pipeline, key); - - // // TODO: distance to gaussian cloud centroid - // let rangefinder = view.rangefinder3d(); - - transparent_phase.add(Transparent3d { - entity, - draw_function: draw_custom, - distance: 0.0, - // distance: rangefinder - // .distance_translation(&mesh_instance.transforms.transform.translation), - pipeline, - batch_range: 0..1, - dynamic_offset: None, - }); + for ( + _view, + mut transparent_phase, + ) in &mut views { + for ( + entity, + cloud_handle, + sorted_entries_handle, + settings, + ) in &gaussian_splatting_bundles { + if gaussian_clouds.get(cloud_handle).is_none() { + return; + } + + if sorted_entries.get(sorted_entries_handle).is_none() { + return; } + + let key = GaussianCloudPipelineKey { + aabb: settings.aabb, + visualize_bounding_box: settings.visualize_bounding_box, + }; + + let pipeline = pipelines.specialize(&pipeline_cache, &custom_pipeline, key); + + // // TODO: distance to gaussian cloud centroid + // let rangefinder = view.rangefinder3d(); + + transparent_phase.add(Transparent3d { + entity, + draw_function: draw_custom, + distance: 0.0, + // distance: rangefinder + // .distance_translation(&mesh_instance.transforms.transform.translation), + pipeline, + batch_range: 0..1, + dynamic_offset: None, + }); } } } @@ -357,27 +378,27 @@ impl FromWorld for GaussianCloudPipeline { // TODO: allow setting shader defines via API // TODO: separate shader defines for each pipeline -struct ShaderDefines { - radix_bits_per_digit: u32, - radix_digit_places: u32, - radix_base: u32, - entries_per_invocation_a: u32, - entries_per_invocation_c: u32, - workgroup_invocations_a: u32, - workgroup_invocations_c: u32, - workgroup_entries_a: u32, - workgroup_entries_c: u32, - sorting_buffer_size: u32, - - temporal_sort_window_size: u32, +pub struct ShaderDefines { + pub radix_bits_per_digit: u32, + pub radix_digit_places: u32, + pub radix_base: u32, + pub entries_per_invocation_a: u32, + pub entries_per_invocation_c: u32, + pub workgroup_invocations_a: u32, + pub workgroup_invocations_c: u32, + pub workgroup_entries_a: u32, + pub workgroup_entries_c: u32, + pub sorting_buffer_size: u32, + + pub temporal_sort_window_size: u32, } impl ShaderDefines { - fn max_tile_count(&self, count: usize) -> u32 { + pub fn max_tile_count(&self, count: usize) -> u32 { (count as u32 + self.workgroup_entries_c - 1) / self.workgroup_entries_c } - fn sorting_status_counters_buffer_size(&self, count: usize) -> usize { + pub fn sorting_status_counters_buffer_size(&self, count: usize) -> usize { self.radix_base as usize * self.max_tile_count(count) as usize * std::mem::size_of::() } } @@ -413,7 +434,7 @@ impl Default for ShaderDefines { } } -fn shader_defs( +pub fn shader_defs( aabb: bool, visualize_bounding_box: bool, ) -> Vec { @@ -544,6 +565,7 @@ pub fn extract_gaussians( // &ComputedVisibility, &Visibility, &Handle, + &Handle, &GaussianCloudSettings, )>, >, @@ -555,6 +577,7 @@ pub fn extract_gaussians( entity, visibility, verticies, + sorted_entries, settings, ) in gaussians_query.iter() { if visibility == Visibility::Hidden { @@ -570,6 +593,7 @@ pub fn extract_gaussians( GpuGaussianSplattingBundle { settings: settings.clone(), settings_uniform, + sorted_entries: sorted_entries.clone(), verticies: verticies.clone(), }, )); @@ -581,7 +605,7 @@ pub fn extract_gaussians( #[derive(Resource, Default)] pub struct GaussianUniformBindGroups { - base_bind_group: Option, + pub base_bind_group: Option, } #[derive(Component)] @@ -598,9 +622,11 @@ fn queue_gaussian_bind_group( gaussian_uniforms: Res>, asset_server: Res, gaussian_cloud_res: Res>, + sorted_entries_res: Res>, gaussian_clouds: Query<( Entity, &Handle, + &Handle, )>, ) { let Some(model) = gaussian_uniforms.buffer() else { @@ -622,9 +648,13 @@ fn queue_gaussian_bind_group( ], )); - for (entity, cloud_handle) in gaussian_clouds.iter() { + for ( + entity, + cloud_handle, + sorted_entries_handle, + ) in gaussian_clouds.iter() { // TODO: add asset loading indicator (and maybe streamed loading) - if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle) { + if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle){ continue; } @@ -632,7 +662,16 @@ fn queue_gaussian_bind_group( continue; } - let cloud = gaussian_cloud_res.get(cloud_handle).unwrap(); + if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) { + continue; + } + + if sorted_entries_res.get(sorted_entries_handle).is_none() { + continue; + } + + let cloud: &GpuGaussianCloud = gaussian_cloud_res.get(cloud_handle).unwrap(); + let sorted_entries = sorted_entries_res.get(sorted_entries_handle).unwrap(); commands.entity(entity).insert(GaussianCloudBindGroup { cloud_bind_group: render_device.create_bind_group( @@ -656,7 +695,7 @@ fn queue_gaussian_bind_group( BindGroupEntry { binding: 0, resource: BindingResource::Buffer(BufferBinding { - buffer: &cloud.radix_sort_buffers.entry_buffer_a, + buffer: &sorted_entries.sorted_entry_buffer, offset: 0, size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64), }), diff --git a/src/render/transform.wgsl b/src/render/transform.wgsl index f7513bc1..50d13ebf 100644 --- a/src/render/transform.wgsl +++ b/src/render/transform.wgsl @@ -8,6 +8,13 @@ fn world_to_clip(world_pos: vec3) -> vec4 { return homogenous_pos / (homogenous_pos.w + 0.000000001); } + +// fn world_to_clip(world_pos: vec3) -> vec4 { +// let homogenous_pos = view.unjittered_view_proj * vec4(world_pos, 1.0); +// return homogenous_pos / (homogenous_pos.w + 0.000000001); +// } + + fn in_frustum(clip_space_pos: vec3) -> bool { return abs(clip_space_pos.x) < 1.1 && abs(clip_space_pos.y) < 1.1 diff --git a/src/sort/mod.rs b/src/sort/mod.rs new file mode 100644 index 00000000..f24662b0 --- /dev/null +++ b/src/sort/mod.rs @@ -0,0 +1,220 @@ +use bevy::{ + prelude::*, + asset::LoadState, + ecs::system::{ + lifetimeless::SRes, + SystemParamItem, + }, + reflect::TypeUuid, + render::{ + render_asset::{ + RenderAsset, + RenderAssetPlugin, + PrepareAssetError, + }, + render_resource::*, + renderer::RenderDevice, + }, +}; +use bytemuck::{ + Pod, + Zeroable, +}; +use static_assertions::assert_cfg; + +use crate::{ + GaussianCloud, + GaussianCloudSettings, +}; + + +#[cfg(feature = "sort_radix")] +pub mod radix; + +#[cfg(feature = "sort_rayon")] +pub mod rayon; + + +assert_cfg!( + any( + feature = "sort_radix", + feature = "sort_rayon", + ), + "no sort mode enabled", +); + + +#[derive( + Component, + Debug, + Clone, + PartialEq, + Reflect, +)] +pub enum SortMode { + None, + + #[cfg(feature = "sort_radix")] + Radix, + + #[cfg(feature = "sort_rayon")] + Rayon, +} + +impl Default for SortMode { + #[allow(unreachable_code)] + fn default() -> Self { + #[cfg(feature = "sort_rayon")] + return Self::Rayon; + + #[cfg(feature = "sort_radix")] + return Self::Radix; + + Self::None + } +} + + +#[derive(Default)] +pub struct SortPlugin; + +impl Plugin for SortPlugin { + fn build(&self, app: &mut App) { + #[cfg(feature = "sort_radix")] + app.add_plugins(radix::RadixSortPlugin); + + #[cfg(feature = "sort_rayon")] + app.add_plugins(rayon::RayonSortPlugin); + + + app.register_type::(); + app.init_asset::(); + app.register_asset_reflect::(); + + app.add_plugins(RenderAssetPlugin::::default()); + + app.add_systems(Update, auto_insert_sorted_entries); + } +} + + +fn auto_insert_sorted_entries( + mut commands: Commands, + asset_server: Res, + gaussian_clouds_res: Res>, + mut sorted_entries_res: ResMut>, + gaussian_clouds: Query<( + Entity, + &Handle, + &GaussianCloudSettings, + Without>, + )>, +) { + for ( + entity, + gaussian_cloud_handle, + _settings, + _, + ) in gaussian_clouds.iter() { + // // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection) + // if settings.sort_mode == SortMode::None { + // continue; + // } + + if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) { + continue; + } + + let cloud = gaussian_clouds_res.get(gaussian_cloud_handle); + if cloud.is_none() { + continue; + } + let cloud = cloud.unwrap(); + + // TODO: move gaussian_cloud and sorted_entry assets into an asset bundle + let sorted_entries = sorted_entries_res.add(SortedEntries { + sorted: (0..cloud.gaussians.len()) + .map(|idx| { + SortEntry { + key: 1, + index: idx as u32, + } + }) + .collect(), + }); + + commands.entity(entity) + .insert(sorted_entries); + } +} + + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Reflect, + ShaderType, + Pod, + Zeroable, +)] +#[repr(C)] +pub struct SortEntry { + pub key: u32, + pub index: u32, +} + +// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None +// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer +#[derive( + Clone, + Asset, + Debug, + Default, + PartialEq, + Reflect, + TypeUuid, +)] +#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"] +pub struct SortedEntries { + pub sorted: Vec, +} + +impl RenderAsset for SortedEntries { + type ExtractedAsset = SortedEntries; + type PreparedAsset = GpuSortedEntry; + type Param = SRes; + + fn extract_asset(&self) -> Self::ExtractedAsset { + self.clone() + } + + fn prepare_asset( + sorted_entries: Self::ExtractedAsset, + render_device: &mut SystemParamItem, + ) -> Result> { + let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor { + label: Some("sorted_entry_buffer"), + contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()), + usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE, + }); + + let count = sorted_entries.sorted.len(); + + Ok(GpuSortedEntry { + sorted_entry_buffer, + count, + }) + } +} + + +// TODO: support instancing and multiple cameras +// separate entry_buffer_a binding into unique a bind group to optimize buffer updates +#[derive(Debug, Clone)] +pub struct GpuSortedEntry { + pub sorted_entry_buffer: Buffer, + pub count: usize, +} diff --git a/src/render/sort/mod.rs b/src/sort/radix.rs similarity index 67% rename from src/render/sort/mod.rs rename to src/sort/radix.rs index b0667714..b9a1f92f 100644 --- a/src/render/sort/mod.rs +++ b/src/sort/radix.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use bevy::{ prelude::*, asset::{ @@ -5,10 +7,6 @@ use bevy::{ LoadState, }, core_pipeline::core_3d::CORE_3D, - ecs::system::{ - lifetimeless::SRes, - SystemParamItem, - }, render::{ render_asset::RenderAssets, render_resource::*, @@ -31,6 +29,7 @@ use bevy::{ use crate::{ gaussian::GaussianCloud, + GaussianCloudSettings, render::{ GaussianCloudBindGroup, GaussianCloudPipeline, @@ -39,6 +38,11 @@ use crate::{ ShaderDefines, shader_defs, }, + sort::{ + SortEntry, + SortedEntries, + SortMode, + }, }; @@ -49,7 +53,6 @@ pub mod node { pub const RADIX_SORT: &str = "radix_sort"; } - #[derive(Default)] pub struct RadixSortPlugin; @@ -88,6 +91,9 @@ impl Plugin for RadixSortPlugin { queue_radix_bind_group.in_set(RenderSet::QueueMeshes), ), ); + + render_app.init_resource::(); + render_app.add_systems(ExtractSchedule, update_sort_buffers); } } @@ -99,21 +105,25 @@ impl Plugin for RadixSortPlugin { } } +#[derive(Resource, Default)] +pub struct RadixSortBuffers { + pub asset_map: HashMap< + AssetId, + GpuRadixBuffers, + >, +} -// TODO: allow swapping of sort backends -// requires GaussianCloud RenderAsset dependency #[derive(Debug, Clone)] pub struct GpuRadixBuffers { pub sorting_global_buffer: Buffer, pub sorting_status_counter_buffer: Buffer, pub sorting_pass_buffers: [Buffer; 4], - pub entry_buffer_a: Buffer, pub entry_buffer_b: Buffer, } impl GpuRadixBuffers { pub fn new( count: usize, - render_device: &mut SystemParamItem>, + render_device: &RenderDevice, ) -> Self { let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor { label: Some("sorting global buffer"), @@ -141,16 +151,9 @@ impl GpuRadixBuffers { .try_into() .unwrap(); - let entry_buffer_a = render_device.create_buffer(&BufferDescriptor { - label: Some("entry buffer a"), - size: (count * std::mem::size_of::<(u32, u32)>()) as u64, - usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, - mapped_at_creation: false, - }); - let entry_buffer_b = render_device.create_buffer(&BufferDescriptor { label: Some("entry buffer b"), - size: (count * std::mem::size_of::<(u32, u32)>()) as u64, + size: (count * std::mem::size_of::()) as u64, usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, mapped_at_creation: false, }); @@ -159,13 +162,29 @@ impl GpuRadixBuffers { sorting_global_buffer, sorting_status_counter_buffer, sorting_pass_buffers, - entry_buffer_a, entry_buffer_b, } } } +fn update_sort_buffers( + gpu_gaussian_clouds: Res>, + mut sort_buffers: ResMut, + render_device: Res, +) { + for (asset_id, cloud,) in gpu_gaussian_clouds.iter() { + // TODO: handle cloud resize operations and resolve leaked stale buffers + if sort_buffers.asset_map.contains_key(&asset_id) { + continue; + } + + let gpu_radix_buffers = GpuRadixBuffers::new(cloud.count, &render_device); + sort_buffers.asset_map.insert(asset_id, gpu_radix_buffers); + } +} + + #[derive(Resource)] pub struct RadixSortPipeline { pub radix_sort_layout: BindGroupLayout, @@ -232,7 +251,7 @@ impl FromWorld for RadixSortPipeline { ty: BindingType::Buffer { ty: BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, - min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64), + min_binding_size: BufferSize::new(std::mem::size_of::() as u64), }, count: None, }, @@ -242,7 +261,7 @@ impl FromWorld for RadixSortPipeline { ty: BindingType::Buffer { ty: BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, - min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64), + min_binding_size: BufferSize::new(std::mem::size_of::() as u64), }, count: None, }, @@ -309,12 +328,26 @@ pub fn queue_radix_bind_group( render_device: Res, asset_server: Res, gaussian_cloud_res: Res>, + sorted_entries_res: Res>, gaussian_clouds: Query<( Entity, &Handle, + &Handle, + &GaussianCloudSettings, )>, + sort_buffers: Res, ) { - for (entity, cloud_handle) in gaussian_clouds.iter() { + for ( + entity, + cloud_handle, + sorted_entries_handle, + settings, + ) in gaussian_clouds.iter() { + if settings.sort_mode != SortMode::Radix { + continue; + } + + // TODO: deduplicate asset load checks if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle) { continue; } @@ -323,23 +356,37 @@ pub fn queue_radix_bind_group( continue; } + if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) { + continue; + } + + if sorted_entries_res.get(sorted_entries_handle).is_none() { + continue; + } + + if !sort_buffers.asset_map.contains_key(&cloud_handle.id()) { + continue; + } + let cloud = gaussian_cloud_res.get(cloud_handle).unwrap(); + let sorted_entries = sorted_entries_res.get(sorted_entries_handle).unwrap(); + let sorting_assets = &sort_buffers.asset_map[&cloud_handle.id()]; let sorting_global_entry = BindGroupEntry { binding: 1, resource: BindingResource::Buffer(BufferBinding { - buffer: &cloud.radix_sort_buffers.sorting_global_buffer, + buffer: &sorting_assets.sorting_global_buffer, offset: 0, - size: BufferSize::new(cloud.radix_sort_buffers.sorting_global_buffer.size()), + size: BufferSize::new(sorting_assets.sorting_global_buffer.size()), }), }; let sorting_status_counters_entry = BindGroupEntry { binding: 2, resource: BindingResource::Buffer(BufferBinding { - buffer: &cloud.radix_sort_buffers.sorting_status_counter_buffer, + buffer: &sorting_assets.sorting_status_counter_buffer, offset: 0, - size: BufferSize::new(cloud.radix_sort_buffers.sorting_status_counter_buffer.size()), + size: BufferSize::new(sorting_assets.sorting_status_counter_buffer.size()), }), }; @@ -361,7 +408,7 @@ pub fn queue_radix_bind_group( BindGroupEntry { binding: 0, resource: BindingResource::Buffer(BufferBinding { - buffer: &cloud.radix_sort_buffers.sorting_pass_buffers[idx], + buffer: &sorting_assets.sorting_pass_buffers[idx], offset: 0, size: BufferSize::new(std::mem::size_of::() as u64), }), @@ -373,24 +420,24 @@ pub fn queue_radix_bind_group( binding: 4, resource: BindingResource::Buffer(BufferBinding { buffer: if idx % 2 == 0 { - &cloud.radix_sort_buffers.entry_buffer_a + &sorted_entries.sorted_entry_buffer } else { - &cloud.radix_sort_buffers.entry_buffer_b + &sorting_assets.entry_buffer_b }, offset: 0, - size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64), + size: BufferSize::new((cloud.count as usize * std::mem::size_of::()) as u64), }), }, BindGroupEntry { binding: 5, resource: BindingResource::Buffer(BufferBinding { buffer: if idx % 2 == 0 { - &cloud.radix_sort_buffers.entry_buffer_b + &sorting_assets.entry_buffer_b } else { - &cloud.radix_sort_buffers.entry_buffer_a + &sorted_entries.sorted_entry_buffer }, offset: 0, - size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64), + size: BufferSize::new((cloud.count as usize * std::mem::size_of::()) as u64), }), }, ], @@ -407,10 +454,6 @@ pub fn queue_radix_bind_group( } - - - - pub struct RadixSortNode { gaussian_clouds: QueryState<( &'static Handle, @@ -475,8 +518,7 @@ impl Node for RadixSortNode { let pipeline_cache = world.resource::(); let pipeline = world.resource::(); let gaussian_uniforms = world.resource::(); - - let command_encoder = render_context.command_encoder(); + let sort_buffers = world.resource::(); for ( view_bind_group, @@ -489,102 +531,109 @@ impl Node for RadixSortNode { ) in self.gaussian_clouds.iter_manual(world) { let cloud = world.get_resource::>().unwrap().get(cloud_handle).unwrap(); - let radix_digit_places = ShaderDefines::default().radix_digit_places; + assert!(sort_buffers.asset_map.contains_key(&cloud_handle.id())); + let sorting_assets = &sort_buffers.asset_map[&cloud_handle.id()]; { - command_encoder.clear_buffer( - &cloud.radix_sort_buffers.sorting_global_buffer, - 0, - None, - ); - - command_encoder.clear_buffer( - &cloud.radix_sort_buffers.sorting_status_counter_buffer, - 0, - None, - ); - - command_encoder.clear_buffer( - &cloud.draw_indirect_buffer, - 0, - None, - ); - } + let command_encoder = render_context.command_encoder(); + let radix_digit_places = ShaderDefines::default().radix_digit_places; - { - let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); - - pass.set_bind_group( - 0, - &view_bind_group.value, - &[view_uniform_offset.offset], - ); - pass.set_bind_group( - 1, - gaussian_uniforms.base_bind_group.as_ref().unwrap(), - &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex - ); - pass.set_bind_group( - 2, - &cloud_bind_group.cloud_bind_group, - &[] - ); - pass.set_bind_group( - 3, - &radix_bind_group.radix_sort_bind_groups[1], - &[], - ); - - let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap(); - pass.set_pipeline(radix_sort_a); - - let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a; - pass.dispatch_workgroups((cloud.count as u32 + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1); - - - let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap(); - pass.set_pipeline(radix_sort_b); - - pass.dispatch_workgroups(1, radix_digit_places, 1); - } + { + command_encoder.clear_buffer( + &sorting_assets.sorting_global_buffer, + 0, + None, + ); + + command_encoder.clear_buffer( + &sorting_assets.sorting_status_counter_buffer, + 0, + None, + ); - for pass_idx in 0..radix_digit_places { - if pass_idx > 0 { command_encoder.clear_buffer( - &cloud.radix_sort_buffers.sorting_status_counter_buffer, + &cloud.draw_indirect_buffer, 0, None, ); } - let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); - - let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap(); - pass.set_pipeline(&radix_sort_c); - - pass.set_bind_group( - 0, - &view_bind_group.value, - &[view_uniform_offset.offset], - ); - pass.set_bind_group( - 1, - gaussian_uniforms.base_bind_group.as_ref().unwrap(), - &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex - ); - pass.set_bind_group( - 2, - &cloud_bind_group.cloud_bind_group, - &[] - ); - pass.set_bind_group( - 3, - &radix_bind_group.radix_sort_bind_groups[pass_idx as usize], - &[], - ); - - let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c; - pass.dispatch_workgroups(1, (cloud.count as u32 + workgroup_entries_c - 1) / workgroup_entries_c, 1); + // TODO: add options to only complete a fraction of the sorting process + { + let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); + + pass.set_bind_group( + 0, + &view_bind_group.value, + &[view_uniform_offset.offset], + ); + pass.set_bind_group( + 1, + gaussian_uniforms.base_bind_group.as_ref().unwrap(), + &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex + ); + pass.set_bind_group( + 2, + &cloud_bind_group.cloud_bind_group, + &[] + ); + pass.set_bind_group( + 3, + &radix_bind_group.radix_sort_bind_groups[1], + &[], + ); + + let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap(); + pass.set_pipeline(radix_sort_a); + + let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a; + pass.dispatch_workgroups((cloud.count as u32 + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1); + + + let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap(); + pass.set_pipeline(radix_sort_b); + + pass.dispatch_workgroups(1, radix_digit_places, 1); + } + + for pass_idx in 0..radix_digit_places { + if pass_idx > 0 { + command_encoder.clear_buffer( + &sorting_assets.sorting_status_counter_buffer, + 0, + None, + ); + } + + let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); + + let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap(); + pass.set_pipeline(&radix_sort_c); + + pass.set_bind_group( + 0, + &view_bind_group.value, + &[view_uniform_offset.offset], + ); + pass.set_bind_group( + 1, + gaussian_uniforms.base_bind_group.as_ref().unwrap(), + &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex + ); + pass.set_bind_group( + 2, + &cloud_bind_group.cloud_bind_group, + &[] + ); + pass.set_bind_group( + 3, + &radix_bind_group.radix_sort_bind_groups[pass_idx as usize], + &[], + ); + + let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c; + pass.dispatch_workgroups(1, (cloud.count as u32 + workgroup_entries_c - 1) / workgroup_entries_c, 1); + } } } } diff --git a/src/render/sort/radix.wgsl b/src/sort/radix.wgsl similarity index 100% rename from src/render/sort/radix.wgsl rename to src/sort/radix.wgsl diff --git a/src/sort/rayon.rs b/src/sort/rayon.rs new file mode 100644 index 00000000..25f1a301 --- /dev/null +++ b/src/sort/rayon.rs @@ -0,0 +1,106 @@ +use bevy::{ + prelude::*, + asset::LoadState, + utils::Instant, +}; + +use rayon::prelude::*; + +use crate::{ + GaussianCloud, + GaussianCloudSettings, + sort::{ + SortedEntries, + SortMode, + }, +}; + + +#[derive(Default)] +pub struct RayonSortPlugin; + +impl Plugin for RayonSortPlugin { + fn build(&self, app: &mut App) { + app.add_systems(Update, rayon_sort); + } +} + +pub fn rayon_sort( + asset_server: Res, + gaussian_clouds_res: Res>, + mut sorted_entries_res: ResMut>, + gaussian_clouds: Query<( + &Handle, + &Handle, + &GaussianCloudSettings, + )>, + cameras: Query<( + &GlobalTransform, + &Camera3d, + )>, + mut last_camera_position: Local, + mut last_sort_time: Local>, +) { + let period = std::time::Duration::from_millis(100); + if let Some(last_sort_time) = last_sort_time.as_ref() { + if last_sort_time.elapsed() < period { + return; + } + } + + // TODO: move sort to render world, use extracted views and update the existing buffer instead of creating new + + for ( + camera_transform, + _camera, + ) in cameras.iter() { + let camera_position = camera_transform.compute_transform().translation; + if *last_camera_position == camera_position { + return; + } + + for ( + gaussian_cloud_handle, + sorted_entries_handle, + settings, + ) in gaussian_clouds.iter() { + if settings.sort_mode != SortMode::Rayon { + continue; + } + + if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) { + continue; + } + + if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) { + continue; + } + + if let Some(gaussian_cloud) = gaussian_clouds_res.get(gaussian_cloud_handle) { + if let Some(sorted_entries) = sorted_entries_res.get_mut(sorted_entries_handle) { + assert_eq!(gaussian_cloud.gaussians.len(), sorted_entries.sorted.len()); + + *last_camera_position = camera_position; + *last_sort_time = Some(Instant::now()); + + gaussian_cloud.gaussians.par_iter() + .zip(sorted_entries.sorted.par_iter_mut()) + .enumerate() + .for_each(|(idx, (gaussian, sort_entry))| { + let position = Vec3::from_slice(gaussian.position.as_ref()); + let delta = camera_position - position; + + sort_entry.key = bytemuck::cast(delta.length_squared()); + sort_entry.index = idx as u32; + }); + + sorted_entries.sorted.par_sort_unstable_by(|a, b| { + bytemuck::cast::(b.key).partial_cmp(&bytemuck::cast::(a.key)).unwrap() + }); + + // TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect) + } + } + } + } +} diff --git a/src/render/sort/temporal.wgsl b/src/sort/temporal.wgsl similarity index 100% rename from src/render/sort/temporal.wgsl rename to src/sort/temporal.wgsl diff --git a/src/utils.rs b/src/utils.rs index 6af5b927..e6781970 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,6 @@ +#[cfg(target_arch = "wasm32")] +pub use wasm_bindgen_rayon::init_thread_pool; + pub fn setup_hooks() { #[cfg(debug_assertions)] diff --git a/tests/gaussian.rs b/tests/gaussian.rs index e69de29b..5cb2a9fd 100644 --- a/tests/gaussian.rs +++ b/tests/gaussian.rs @@ -0,0 +1,17 @@ +use bevy_gaussian_splatting::{ + GaussianCloud, + io::codec::GaussianCloudCodec, + random_gaussians, +}; + + +#[test] +fn test_codec() { + let count = 100; + + let gaussians = random_gaussians(count); + let encoded = gaussians.encode(); + let decoded = GaussianCloud::decode(encoded.as_slice()); + + assert_eq!(gaussians, decoded); +} diff --git a/tests/gpu/_harness.rs b/tests/gpu/_harness.rs new file mode 100644 index 00000000..81821b43 --- /dev/null +++ b/tests/gpu/_harness.rs @@ -0,0 +1,62 @@ +use std::sync::{Arc, Mutex}; + +use bevy::prelude::*; + +use bevy_gaussian_splatting::GaussianSplattingPlugin; + + +// scraping this in CI for now until bevy ci testing is more stable +// #[test] + main thread, and windowless screenshot limitations exist +// see: https://github.com/anchpop/endless-sea/blob/3b8481f1152293907794d60e920d4cc5a7ca8f40/src/tests/helpers.rs#L69-L83 + + +#[derive(Resource)] +pub struct TestHarness { + pub resolution: (f32, f32), +} + +pub fn test_harness_app( + harness: TestHarness, +) -> App { + let mut app = App::new(); + + app.insert_resource(ClearColor(Color::rgb_u8(0, 0, 0))); + app.add_plugins( + DefaultPlugins + .set(WindowPlugin { + primary_window: Some(Window { + fit_canvas_to_parent: false, + mode: bevy::window::WindowMode::Windowed, + present_mode: bevy::window::PresentMode::AutoVsync, + prevent_default_event_handling: false, + resolution: harness.resolution.into(), + title: "bevy_gaussian_splatting pipeline test".to_string(), + ..default() + }), + ..default() + }), + ); + + app.add_plugins(GaussianSplattingPlugin); + + app.insert_resource(harness); + + app +} + + +pub struct TestState { + pub test_loaded: bool, + pub test_completed: bool, +} + +impl Default for TestState { + fn default() -> Self { + TestState { + test_loaded: false, + test_completed: false, + } + } +} + +pub type TestStateArc = Arc>; diff --git a/tests/gpu/gaussian.rs b/tests/gpu/gaussian.rs new file mode 100644 index 00000000..506e3299 --- /dev/null +++ b/tests/gpu/gaussian.rs @@ -0,0 +1,166 @@ +use std::sync::{ + Arc, + Mutex, +}; + +use bevy::{ + prelude::*, + app::AppExit, + core::FrameCount, + render::view::screenshot::ScreenshotManager, + window::PrimaryWindow, +}; + +use bevy_gaussian_splatting::{ + GaussianCloud, + GaussianSplattingBundle, + random_gaussians, +}; + +use _harness::{ + TestHarness, + test_harness_app, + TestStateArc, +}; + +mod _harness; + + +fn main() { + let mut app = test_harness_app(TestHarness { + resolution: (512.0, 512.0), + }); + + app.add_systems(Startup, setup); + app.add_systems(Update, capture_ready); + + app.run(); +} + +fn setup( + mut commands: Commands, + mut gaussian_assets: ResMut>, +) { + let cloud = gaussian_assets.add(random_gaussians(10000)); + + commands.spawn(( + GaussianSplattingBundle { + cloud, + ..default() + }, + Name::new("gaussian_cloud"), + )); + + commands.spawn(( + Camera3dBundle { + transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)), + ..default() + }, + )); +} + +fn check_image_equality(image: &Image, other: &Image) -> bool { + if image.width() != other.width() || image.height() != other.height() { + return false; + } + + for (word, other_word) in image.data.iter().zip(other.data.iter()) { + if word != other_word { + return false; + } + } + + true +} + +fn test_stability(captures: Arc>>) { + let all_frames_similar = captures.lock().unwrap().iter() + .fold(Some(None), |acc, image| { + match acc { + Some(acc_image) => { + if let Some(acc_image) = acc_image { + if check_image_equality(acc_image, image) { + Some(Some(acc_image)) + } else { + None + } + } else { + Some(Some(image)) + } + }, + None => None, + } + }).is_some(); + assert!(all_frames_similar, "all frames are not the same"); +} + +fn save_captures(captures: Arc>>) { + captures.lock().unwrap().iter() + .enumerate() + .for_each(|(i, image)| { + let path = format!("target/tmp/test_gaussian_frame_{}.png", i); + + let dyn_img = image.clone().try_into_dynamic().unwrap(); + let img = dyn_img.to_rgba8(); + img.save(path).unwrap(); + }); +} + +fn capture_ready( + // gaussian_cloud_assets: Res>, + // asset_server: Res, + // gaussian_clouds: Query< + // Entity, + // &Handle, + // >, + main_window: Query>, + mut screenshot_manager: ResMut, + mut exit: EventWriter, + frame_count: Res, + state: Local, + buffer: Local>>>, +) { + let buffer = buffer.to_owned(); + + let buffer_frames = 10; + let wait_frames = 10; // wait for gaussian cloud to load + if frame_count.0 < wait_frames { + return; + } + + let state_clone = Arc::clone(&state); + let buffer_clone = Arc::clone(&buffer); + + let mut state = state.lock().unwrap(); + state.test_loaded = true; + + if state.test_completed { + { + let captures = buffer.lock().unwrap(); + let frame_count = captures.len(); + assert_eq!(frame_count, buffer_frames, "captured {} frames, expected {}", frame_count, buffer_frames); + } + + save_captures(buffer.clone()); + test_stability(buffer); + // TODO: add correctness test (use CPU gaussian pipeline to compare results) + + exit.send(AppExit); + return; + } + + if let Ok(window_entity) = main_window.get_single() { + screenshot_manager.take_screenshot(window_entity, move |image: Image| { + let has_non_zero_data = image.data.iter().fold(false, |non_zero, &x| non_zero || x != 0); + assert!(has_non_zero_data, "screenshot is all zeros"); + + let mut buffer = buffer_clone.lock().unwrap(); + buffer.push(image); + + if buffer.len() >= buffer_frames { + let mut state = state_clone.lock().unwrap(); + state.test_completed = true; + } + }).unwrap(); + } +} diff --git a/tests/gpu/radix.rs b/tests/gpu/radix.rs new file mode 100644 index 00000000..d072b8ae --- /dev/null +++ b/tests/gpu/radix.rs @@ -0,0 +1,230 @@ +use std::{ + process::exit, + sync::{ + Arc, + Mutex, + }, +}; + +use bevy::{ + prelude::*, + core::FrameCount, + core_pipeline::core_3d::{ + CORE_3D, + Transparent3d, + }, + render::{ + RenderApp, + renderer::{ + RenderContext, + RenderQueue, + }, + render_asset::RenderAssets, + render_graph::{ + Node, + NodeRunError, + RenderGraphApp, + RenderGraphContext, + }, + render_phase::RenderPhase, + view::ExtractedView, + }, +}; + +use bevy_gaussian_splatting::{ + GaussianCloud, + GaussianSplattingBundle, + random_gaussians, + sort::SortedEntries, +}; + +use _harness::{ + TestHarness, + test_harness_app, + TestState, + TestStateArc, +}; + +mod _harness; + + +pub mod node { + pub const RADIX_SORT_TEST: &str = "radix_sort_test"; +} + + +fn main() { + let mut app = test_harness_app(TestHarness { + resolution: (512.0, 512.0), + }); + + app.add_systems(Startup, setup); + + if let Ok(render_app) = app.get_sub_app_mut(RenderApp) { + render_app + .add_render_graph_node::( + CORE_3D, + node::RADIX_SORT_TEST, + ) + .add_render_graph_edge( + CORE_3D, + node::RADIX_SORT_TEST, + bevy::core_pipeline::core_3d::graph::node::END_MAIN_PASS, + ); + } + + app.run(); +} + +fn setup( + mut commands: Commands, + mut gaussian_assets: ResMut>, +) { + let cloud = gaussian_assets.add(random_gaussians(10000)); + + commands.spawn(( + GaussianSplattingBundle { + cloud, + settings: GaussianCloudSettings { + sort_mode: SortMode::Radix, + ..default() + }, + ..default() + }, + Name::new("gaussian_cloud"), + )); + + commands.spawn(( + Camera3dBundle { + transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)), + ..default() + }, + )); +} + + +pub struct RadixTestNode { + gaussian_clouds: QueryState<( + &'static Handle, + &'static Handle, + )>, + state: TestStateArc, + views: QueryState<( + &'static ExtractedView, + &'static RenderPhase, + )>, + start_frame: u32, +} + +impl FromWorld for RadixTestNode { + fn from_world(world: &mut World) -> Self { + Self { + gaussian_clouds: world.query(), + state: Arc::new(Mutex::new(TestState::default())), + views: world.query(), + start_frame: 0, + } + } +} + + +impl Node for RadixTestNode { + fn update( + &mut self, + world: &mut World, + ) { + let mut state = self.state.lock().unwrap(); + if state.test_completed { + exit(0); + } + + if state.test_loaded && self.start_frame == 0 { + self.start_frame = world.get_resource::().unwrap().0; + } + + let frame_count = world.get_resource::().unwrap().0; + const FRAME_LIMIT: u32 = 10; + if state.test_loaded && frame_count >= self.start_frame + FRAME_LIMIT { + state.test_completed = true; + } + + self.gaussian_clouds.update_archetypes(world); + self.views.update_archetypes(world); + } + + fn run( + &self, + _graph: &mut RenderGraphContext, + render_context: &mut RenderContext, + world: &World, + ) -> Result<(), NodeRunError> { + for (view, _phase,) in self.views.iter_manual(world) { + let camera_position = view.transform.translation(); + + for ( + cloud_handle, + sorted_entries_handle, + ) in self.gaussian_clouds.iter_manual(world) { + let gaussian_cloud_res = world.get_resource::>().unwrap(); + let sorted_entries_res = world.get_resource::>().unwrap(); + + let mut state = self.state.lock().unwrap(); + if gaussian_cloud_res.get(cloud_handle).is_none() || sorted_entries_res.get(sorted_entries_handle).is_none() { + continue; + } else if !state.test_loaded { + state.test_loaded = true; + } + + let cloud = gaussian_cloud_res.get(cloud_handle).unwrap(); + let sorted_entries = sorted_entries_res.get(sorted_entries_handle).unwrap(); + let gaussians = cloud.debug_gpu.gaussians.clone(); + + wgpu::util::DownloadBuffer::read_buffer( + render_context.render_device().wgpu_device(), + world.get_resource::().unwrap().0.as_ref(), + &sorted_entries.sorted_entry_buffer.slice( + 0..sorted_entries.sorted_entry_buffer.size() + ), + move |buffer: Result| { + let binding = buffer.unwrap(); + let u32_muck = bytemuck::cast_slice::(&*binding); + + let mut radix_sorted_indices = Vec::new(); + for i in (1..u32_muck.len()).step_by(2) { + radix_sorted_indices.push((i, u32_muck[i] as usize)); + } + + // TODO: depth order validation over ndc cells + + radix_sorted_indices.iter() + .fold(0.0, |depth_acc, &(entry_idx, idx)| { + if idx == 0 || u32_muck[entry_idx - 1] == 0xffffffff || u32_muck[entry_idx - 1] == 0x0 { + return depth_acc; + } + + let position = gaussians[idx].position; + let position_vec3 = Vec3::new(position[0], position[1], position[2]); + let depth = (position_vec3 - camera_position).length(); + + let depth_is_non_decreasing = depth_acc <= depth; + if !depth_is_non_decreasing { + println!( + "radix keys: [..., {:#010x}, {:#010x}, {:#010x}, ...]", + u32_muck[entry_idx - 1 - 2], + u32_muck[entry_idx - 1], + u32_muck[entry_idx - 1 + 2], + ); + } + + assert!(depth_is_non_decreasing, "radix sort, non-decreasing check failed: {} > {}", depth_acc, depth); + + depth_acc.max(depth) + }); + } + ); + } + } + + Ok(()) + } +} diff --git a/viewer/viewer.rs b/viewer/viewer.rs index 0431ea0f..2854ffe8 100644 --- a/viewer/viewer.rs +++ b/viewer/viewer.rs @@ -17,7 +17,7 @@ use bevy_gaussian_splatting::{ GaussianCloud, GaussianSplattingBundle, GaussianSplattingPlugin, - render::morph::{ + morph::{ ParticleBehaviors, random_particle_behaviors, },