Skip to content

Commit

Permalink
refactor: gaussian cloud transform component
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Dec 18, 2024
1 parent 12c6d52 commit 4a9434f
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 48 deletions.
2 changes: 1 addition & 1 deletion examples/multi_camera.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,12 @@ pub fn setup_surfel_compare(
}

commands.spawn((
Transform::from_translation(Vec3::new(spacing, spacing, 0.0)),
GaussianCloudHandle(
gaussian_assets.add(GaussianCloud::from_gaussians(red_gaussians))
),
GaussianCloudSettings {
aabb: true,
transform: Transform::from_translation(Vec3::new(spacing, spacing, 0.0)),
gaussian_mode: GaussianMode::GaussianSurfel,
..default()
},
Expand Down
16 changes: 12 additions & 4 deletions src/gaussian/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use bevy::{
prelude::*,
render::{
primitives::Aabb,
sync_world::SyncToRenderWorld,
view::visibility::{
check_visibility,
NoFrustumCulling,
Expand Down Expand Up @@ -85,7 +86,6 @@ pub fn calculate_bounds(
for (entity, cloud_handle) in &without_aabb {
if let Some(cloud) = gaussian_clouds.get(cloud_handle) {
if let Some(aabb) = cloud.compute_aabb() {
info!("Inserting aabb {:?}", aabb);
commands.entity(entity).try_insert(aabb);
}
}
Expand All @@ -102,7 +102,12 @@ pub fn calculate_bounds(
Reflect,
)]
#[reflect(Component, Default)]
#[require(GaussianCloudSettings, Transform, Visibility)]
#[require(
GaussianCloudSettings,
SyncToRenderWorld,
Transform,
Visibility,
)]
pub struct GaussianCloudHandle(pub Handle<GaussianCloud>);

impl From<Handle<GaussianCloud>> for GaussianCloudHandle {
Expand Down Expand Up @@ -225,9 +230,12 @@ impl GaussianCloud {
let mut min = Vec3::splat(f32::INFINITY);
let mut max = Vec3::splat(f32::NEG_INFINITY);

// TODO: find a more correct aabb bound derived from scalar max gaussian scale
let max_scale = 0.1;

for position in self.position_iter() {
min = min.min(Vec3::from(*position));
max = max.max(Vec3::from(*position));
min = min.min(Vec3::from(*position) - Vec3::splat(max_scale));
max = max.max(Vec3::from(*position) + Vec3::splat(max_scale));
}

Aabb::from_min_max(min, max).into()
Expand Down
3 changes: 0 additions & 3 deletions src/gaussian/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ pub struct GaussianCloudSettings {
pub aabb: bool,
pub global_opacity: f32,
pub global_scale: f32,
// TODO: move transform to GaussianCloudHandle entity
pub transform: Transform,
pub opacity_adaptive_radius: bool,
pub visualize_bounding_box: bool,
pub sort_mode: SortMode,
Expand All @@ -87,7 +85,6 @@ impl Default for GaussianCloudSettings {
aabb: false,
global_opacity: 1.0,
global_scale: 1.0,
transform: Transform::IDENTITY,
opacity_adaptive_radius: true,
visualize_bounding_box: false,
sort_mode: SortMode::default(),
Expand Down
57 changes: 19 additions & 38 deletions src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use bevy::{
Render,
RenderApp,
RenderSet,
sync_world::RenderEntity,
},
};

Expand Down Expand Up @@ -184,8 +185,8 @@ impl Plugin for RenderPipelinePlugin {
.add_systems(
Render,
(
queue_gaussian_bind_group.in_set(RenderSet::Queue),
queue_gaussian_view_bind_groups.in_set(RenderSet::Queue),
queue_gaussian_bind_group.in_set(RenderSet::PrepareBindGroups),
queue_gaussian_view_bind_groups.in_set(RenderSet::PrepareBindGroups),
queue_gaussians.in_set(RenderSet::Queue),
),
);
Expand Down Expand Up @@ -305,21 +306,16 @@ fn queue_gaussians(
>,
gaussian_splatting_bundles: Query<GpuGaussianBundleQuery>,
) {
info!("\nqueueing gaussians");
let warmup = views.iter().any(|(_, _, camera, _, _)| camera.warmup);
if warmup {
return;
}

info!("warmup complete");

// TODO: condition this system based on GaussianCloudBindGroup attachment
if gaussian_cloud_uniform.buffer().is_none() {
return;
};

info!("uniform buffer found");

let draw_custom = transparent_3d_draw_functions.read().id::<DrawGaussians>();

for (
Expand All @@ -329,35 +325,25 @@ fn queue_gaussians(
visible_entities,
msaa,
) in &mut views {
info!("view entity found");

let Some(transparent_phase) = transparent_render_phases.get_mut(&view_entity) else {
continue;
};

info!("transparent phase found");

for (
_entity,
cloud_handle,
sorted_entries_handle,
settings,
_,
) in &gaussian_splatting_bundles {
info!("gaussian entity found");

if gaussian_clouds.get(cloud_handle).is_none() {
return;
}

info!("queue - gaussian cloud asset found");

if sorted_entries.get(sorted_entries_handle).is_none() {
return;
}

info!("sorted entries asset found");

let msaa = msaa.cloned().unwrap_or_default();

let key = GaussianCloudPipelineKey {
Expand All @@ -377,8 +363,6 @@ fn queue_gaussians(
// let rangefinder = view.rangefinder3d();

for (render_entity, visible_entity) in visible_entities.iter::<With<GaussianCloudHandle>>() {
info!("adding transparent phase item");

transparent_phase.add(Transparent3d {
entity: (*render_entity, *visible_entity),
draw_function: draw_custom,
Expand Down Expand Up @@ -738,7 +722,7 @@ type DrawGaussians = (
);


#[derive(Component, ShaderType, Clone)]
#[derive(Component, ShaderType, Clone, Copy)]
pub struct GaussianCloudUniform {
pub transform: Mat4,
pub global_opacity: f32,
Expand All @@ -755,15 +739,15 @@ pub fn extract_gaussians(
gaussian_cloud_res: Res<RenderAssets<GpuGaussianCloud>>,
gaussians_query: Extract<
Query<(
Entity,
RenderEntity,
&ViewVisibility,
&GaussianCloudHandle,
&SortedEntriesHandle,
&GaussianCloudSettings,
&GlobalTransform,
)>,
>,
) {
info!("\nextracting gaussians");
let mut commands_list = Vec::with_capacity(*prev_commands_len);
// let visible_gaussians = gaussians_query.iter().filter(|(_, vis, ..)| vis.is_visible());

Expand All @@ -773,38 +757,32 @@ pub fn extract_gaussians(
cloud_handle,
sorted_entries,
settings,
transform,
) in gaussians_query.iter() {
info!("gaussian entity found");

// if !visibility.get() {
// continue;
// }

info!("gaussian entity visible");
if !visibility.get() {
continue;
}

if let Some(load_state) = asset_server.get_load_state(&cloud_handle.0) {
if load_state.is_loading() {
continue;
}
}

info!("gaussian cloud asset loaded");

if gaussian_cloud_res.get(cloud_handle).is_none() {
continue;
}

info!("gaussian cloud asset found");

let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();

let settings_uniform = GaussianCloudUniform {
transform: settings.transform.compute_matrix(),
transform: transform.compute_matrix(),
global_opacity: settings.global_opacity,
global_scale: settings.global_scale,
count: cloud.count as u32,
count_root_ceil: (cloud.count as f32).sqrt().ceil() as u32,
};

commands_list.push((
entity,
GpuGaussianSplattingBundle {
Expand Down Expand Up @@ -961,20 +939,17 @@ pub fn queue_gaussian_view_bind_groups(
>,
globals_buffer: Res<GlobalsBuffer>,
) {
info!("\nqueueing gaussian view bind groups");
if let (
Some(view_binding),
Some(globals),
) = (
view_uniforms.uniforms.binding(),
globals_buffer.buffer.binding(),
) {
info!("view binding and globals found");
for (
entity,
_extracted_view,
) in &views {
info!("view entity found");
let layout = &gaussian_cloud_pipeline.view_layout;

let entries = vec![
Expand Down Expand Up @@ -1045,14 +1020,20 @@ impl<P: PhaseItem, const I: usize> RenderCommand<P> for SetGaussianUniformBindGr
fn render<'w>(
_item: &P,
_view: (),
gaussian_cloud_index: Option<ROQueryItem<Self::ItemQuery>>,
gaussian_cloud_index: Option<ROQueryItem<'w, Self::ItemQuery>>,
bind_groups: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let bind_groups = bind_groups.into_inner();
let bind_group = bind_groups.base_bind_group.as_ref().expect("bind group not initialized");

let mut set_bind_group = |indices: &[u32]| pass.set_bind_group(I, bind_group, indices);

if gaussian_cloud_index.is_none() {
info!("skipping gaussian uniform bind group\n");
return RenderCommandResult::Skip;
}

let gaussian_cloud_index = gaussian_cloud_index.unwrap().index();
set_bind_group(&[gaussian_cloud_index]);

Expand Down
4 changes: 3 additions & 1 deletion src/sort/rayon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub fn rayon_sort(
&GaussianCloudHandle,
&SortedEntriesHandle,
&GaussianCloudSettings,
&GlobalTransform,
)>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
mut cameras: Query<
Expand All @@ -59,6 +60,7 @@ pub fn rayon_sort(
gaussian_cloud_handle,
sorted_entries_handle,
settings,
transform,
) in gaussian_clouds.iter() {
if settings.sort_mode != SortMode::Rayon {
continue;
Expand Down Expand Up @@ -90,7 +92,7 @@ pub fn rayon_sort(
.enumerate()
.for_each(|(idx, (position, sort_entry))| {
let position = Vec3A::from_slice(position.as_ref());
let position = settings.transform.compute_affine().transform_point3a(position);
let position = transform.affine().transform_point3a(position);

let delta = trigger.last_camera_position - position;

Expand Down
4 changes: 3 additions & 1 deletion src/sort/std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub fn std_sort(
&GaussianCloudHandle,
&SortedEntriesHandle,
&GaussianCloudSettings,
&GlobalTransform,
)>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
mut cameras: Query<
Expand All @@ -59,6 +60,7 @@ pub fn std_sort(
gaussian_cloud_handle,
sorted_entries_handle,
settings,
transform,
) in gaussian_clouds.iter() {
if settings.sort_mode != SortMode::Std {
continue;
Expand Down Expand Up @@ -90,7 +92,7 @@ pub fn std_sort(
.enumerate()
.for_each(|(idx, (position, sort_entry))| {
let position = Vec3A::from_slice(position.as_ref());
let position = settings.transform.compute_affine().transform_point3a(position);
let position = transform.affine().transform_point3a(position);

let delta = trigger.last_camera_position - position;

Expand Down

0 comments on commit 4a9434f

Please sign in to comment.