Skip to content

Commit

Permalink
test: radix sort order
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Dec 11, 2023
1 parent 92aa330 commit 0db4d38
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 26 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ jobs:

- name: gaussian render test
run: cargo run --bin test_gaussian

- name: radix sort test
run: cargo run --bin test_radix --features="debug_gpu"
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ path = "tests/gpu/gaussian.rs"
[[bin]]
name = "test_radix"
path = "tests/gpu/radix.rs"
required-features = ["debug_gpu"]


[[bench]]
Expand Down
4 changes: 4 additions & 0 deletions src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ pub struct GpuGaussianCloud {
pub draw_indirect_buffer: Buffer,

pub radix_sort_buffers: GpuRadixBuffers,

#[cfg(feature = "debug_gpu")]
pub debug_gpu: GaussianCloud,
}
impl RenderAsset for GaussianCloud {
type ExtractedAsset = GaussianCloud;
Expand Down Expand Up @@ -195,6 +198,7 @@ impl RenderAsset for GaussianCloud {
count,
draw_indirect_buffer,
radix_sort_buffers: GpuRadixBuffers::new(count, render_device),
debug_gpu: gaussian_cloud,
})
}
}
Expand Down
28 changes: 15 additions & 13 deletions src/render/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use bevy::{
renderer::{
RenderContext,
RenderDevice,
RenderQueue,
},
render_graph::{
Node,
Expand Down Expand Up @@ -514,6 +513,7 @@ impl Node for RadixSortNode {
);
}

// TODO: add options to only complete a fraction of the sorting process
{
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

Expand Down Expand Up @@ -594,27 +594,29 @@ impl Node for RadixSortNode {
// TODO: move to test_radix
// #[cfg(feature = "debug_gpu")]
{
wgpu::util::DownloadBuffer::read_buffer(
render_context.render_device().wgpu_device(),
world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
&cloud.radix_sort_buffers.entry_buffer_a.slice(
0..cloud.radix_sort_buffers.entry_buffer_a.size()
),
|buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
// println!("{:X?}", transmute_slice::<u8, u32>(&*buffer.unwrap()));
}
);
// wgpu::util::DownloadBuffer::read_buffer(
// render_context.render_device().wgpu_device(),
// queue,
// world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
// &cloud.radix_sort_buffers.entry_buffer_a.slice(
// 0..cloud.radix_sort_buffers.entry_buffer_a.size()
// ),
// |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
// let binding = buffer.unwrap();
// let u32_muck = bytemuck::cast_slice::<u8, u32>(&*binding);
// println!("{:X?}", &u32_muck);
// }
// );
// wgpu::util::DownloadBuffer::read_buffer(
// render_context.render_device().wgpu_device(),
// world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
// &self.sorting_buffer.slice(0..self.sorting_buffer_size as u64 - 4 * 5),
// |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
// println!("{:X?}", transmute_slice::<u8, [u32; 256]>(&*buffer.unwrap()));
// }
// );
// wgpu::util::DownloadBuffer::read_buffer(
// render_context.render_device().wgpu_device(),
// queue,
// world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
// &self.entry_buffer_a.slice(..),
// |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
// println!("{:X?}", transmute_slice::<u8, [(u32, u32); 2048]>(&*buffer.unwrap()));
Expand Down
7 changes: 7 additions & 0 deletions src/render/transform.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ fn world_to_clip(world_pos: vec3<f32>) -> vec4<f32> {
return homogenous_pos / (homogenous_pos.w + 0.000000001);
}


// fn world_to_clip(world_pos: vec3<f32>) -> vec4<f32> {
// let homogenous_pos = view.unjittered_view_proj * vec4<f32>(world_pos, 1.0);
// return homogenous_pos / (homogenous_pos.w + 0.000000001);
// }


fn in_frustum(clip_space_pos: vec3<f32>) -> bool {
return abs(clip_space_pos.x) < 1.1
&& abs(clip_space_pos.y) < 1.1
Expand Down
2 changes: 2 additions & 0 deletions tests/gpu/_harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ pub fn test_harness_app(


pub struct TestState {
pub test_loaded: bool,
pub test_completed: bool,
}

impl Default for TestState {
fn default() -> Self {
TestState {
test_loaded: false,
test_completed: false,
}
}
Expand Down
5 changes: 4 additions & 1 deletion tests/gpu/gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ fn capture_ready(
let state_clone = Arc::clone(&state);
let buffer_clone = Arc::clone(&buffer);

let state = state.lock().unwrap();
let mut state = state.lock().unwrap();
state.test_loaded = true;

if state.test_completed {
{
let captures = buffer.lock().unwrap();
Expand All @@ -141,6 +143,7 @@ fn capture_ready(

save_captures(buffer.clone());
test_stability(buffer);
// TODO: add correctness test (use CPU gaussian pipeline?)

// TODO: add post-frame test registration

Expand Down
161 changes: 149 additions & 12 deletions tests/gpu/radix.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
use std::sync::{
Arc,
Mutex,
use std::{
process::exit,
sync::{
Arc,
Mutex,
},
};

use bevy::{
prelude::*,
app::AppExit,
core::FrameCount,
render::view::screenshot::ScreenshotManager,
window::PrimaryWindow,
core_pipeline::core_3d::CORE_3D,
render::{
RenderApp,
renderer::{
RenderContext,
RenderQueue,
},
render_graph::{
Node,
NodeRunError,
RenderGraphApp,
RenderGraphContext,
},
render_asset::RenderAssets, view::ExtractedView,
},
};

use bevy_gaussian_splatting::{
Expand All @@ -20,19 +35,37 @@ use bevy_gaussian_splatting::{
use _harness::{
TestHarness,
test_harness_app,
TestState,
TestStateArc,
};

mod _harness;


pub mod node {
pub const RADIX_SORT_TEST: &str = "radix_sort_test";
}


fn main() {
let mut app = test_harness_app(TestHarness {
resolution: (512.0, 512.0),
});

app.add_systems(Startup, setup);
app.add_systems(Update, verify_radix_stages);

if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
render_app
.add_render_graph_node::<RadixTestNode>(
CORE_3D,
node::RADIX_SORT_TEST,
)
.add_render_graph_edge(
CORE_3D,
node::RADIX_SORT_TEST,
bevy::core_pipeline::core_3d::graph::node::END_MAIN_PASS,
);
}

app.run();
}
Expand Down Expand Up @@ -60,10 +93,114 @@ fn setup(
}


fn verify_radix_stages(
mut exit: EventWriter<AppExit>,
frame_count: Res<FrameCount>,
state: Local<TestStateArc>,
) {
pub struct RadixTestNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
)>,
state: TestStateArc,
views: QueryState<(
&'static ExtractedView,
)>,
start_frame: u32,
}

impl FromWorld for RadixTestNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
state: Arc::new(Mutex::new(TestState::default())),
views: world.query(),
start_frame: 0,
}
}
}


impl Node for RadixTestNode {
fn update(
&mut self,
world: &mut World,
) {
let mut state = self.state.lock().unwrap();
if state.test_completed {
exit(0);
}

if state.test_loaded && self.start_frame == 0 {
self.start_frame = world.get_resource::<FrameCount>().unwrap().0;
}

let frame_count = world.get_resource::<FrameCount>().unwrap().0;
const FRAME_LIMIT: u32 = 10;
if state.test_loaded && frame_count >= self.start_frame + FRAME_LIMIT {
state.test_completed = true;
}

self.gaussian_clouds.update_archetypes(world);
self.views.update_archetypes(world);
}

fn run(
&self,
_graph: &mut RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), NodeRunError> {
println!("radix sort test: running frame {}...", world.get_resource::<FrameCount>().unwrap().0);

for (view,) in self.views.iter_manual(world) {
let camera_position = view.transform.translation();
println!("radix sort test: camera position: {:?}", camera_position);

for (cloud_handle,) in self.gaussian_clouds.iter_manual(world) {
let gaussian_cloud_res = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap();

let mut state = self.state.lock().unwrap();
if gaussian_cloud_res.get(cloud_handle).is_none() {
continue;
} else if !state.test_loaded {
state.test_loaded = true;
}

let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
let gaussians = cloud.debug_gpu.gaussians.clone();

println!("radix sort test: {} gaussians", gaussians.len());

wgpu::util::DownloadBuffer::read_buffer(
render_context.render_device().wgpu_device(),
world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
&cloud.radix_sort_buffers.entry_buffer_a.slice(
0..cloud.radix_sort_buffers.entry_buffer_a.size()
),
move |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
let binding = buffer.unwrap();
let u32_muck = bytemuck::cast_slice::<u8, u32>(&*binding);

let mut radix_sorted_indices = Vec::new();
for i in (1..u32_muck.len()).step_by(2) {
radix_sorted_indices.push(u32_muck[i]);
}

let max_depth = radix_sorted_indices.iter()
.fold(0.0, |depth_acc, &idx| {
let position = gaussians[idx as usize].position;
let position_vec3 = Vec3::new(position[0], position[1], position[2]);
let depth = (position_vec3 - camera_position).length();

assert!(depth_acc <= depth, "radix sort, non-decreasing check failed: {} > {}", depth_acc, depth);

depth
});

assert!(max_depth > 0.0, "radix sort, max depth check failed: {}", max_depth);

// TODO: analyze incorrectly sorted gaussian positions or upstream buffers (e.g. histogram sort error vs. position of gaussian distance from correctly sorted index)
}
);
}
}

Ok(())
}
}

0 comments on commit 0db4d38

Please sign in to comment.