gsplat.contrib (experimental)

Warning

Everything under gsplat.contrib is experimental. APIs may change or be removed between releases.

gsplat experimental / research-grade components.

Warning

Everything under gsplat.contrib is experimental. APIs may change or be removed between releases. Pin a version if you depend on it in production.

Contents:

  • gsplat.contrib.dynamic — deformable / 4D Gaussian Splatting (HexPlane field, MLP deformation network, DynamicStrategy).

Dynamic / 4D Gaussians

HexPlane spatio-temporal feature field (experimental).

Multi-resolution 6-plane decomposition of a 4D (x, y, z, t) feature field, ported from G-SHARP v0.2 (training/scene/hexplane.py), itself derived from the K-Planes / 4DGaussians formulation.

Six 2D feature planes — one for each pair of input axes (xy, xz, xt, yz, yt, zt) — are sampled with bilinear interpolation, multiplied element-wise across planes (giving a per-scale feature vector), and concatenated across multi-resolution scales to produce the final feature.

class HexPlaneField(bounds: float = 1.6, planes_config: dict | None = None, multires: Sequence[int] | None = None)[source]

Multi-resolution 6-plane decomposition of a 4D feature field.

For each scale, six 2D feature planes covering every pair of the four input axes (x, y, z, t) are sampled bilinearly and multiplied element-wise to produce a per-scale feature vector. With concat_features=True (the default and only mode currently supported) feature vectors from all scales are concatenated; the resulting feature dimensionality is output_coordinate_dim * len(multires).

Spatial coordinates are normalized into [-1, 1] using the axis-aligned bounding box [-bounds, bounds] along x/y/z; the time coordinate is passed through unchanged (callers should pre-normalize it into a sensible range — values outside the grid are clamped via padding_mode="border" in torch.nn.functional.grid_sample()).

Parameters:
  • bounds – Half-extent of the spatial AABB along each axis (default 1.6, matching G-SHARP).

  • planes_config – Optional dict with keys grid_dimensions, input_coordinate_dim, output_coordinate_dim, resolution. The default is suitable for surgical-scene 4D fields (32 features per plane, [64, 64, 64, 25] resolution).

  • multires – Iterable of integer multipliers applied to the spatial grid resolution for each scale; the temporal resolution is kept fixed across scales. Default (1, 2).

forward(xyzt: Tensor) Tensor[source]

Sample the multi-resolution HexPlane at the given 4D points.

Parameters:

xyzt(N, 4) (or any shape ending in 4) with [..., :3] being world-space spatial coordinates inside the AABB and [..., 3] being the temporal coordinate.

Returns:

(N, feat_dim) feature tensor where feat_dim == output_coordinate_dim * len(multires) under concat_features=True.

spatial_planes() list[Tensor][source]

Return the flat list of spatial planes across all multi-res scales.

temporal_planes() list[Tensor][source]

Return the flat list of spatio-temporal planes across all scales.

Deformation network and per-Gaussian deformation table (experimental).

Public API:

  • DeformNetwork — MLP that consumes HexPlane features and emits per-Gaussian deltas on (means, quats, opacities) at a given time. Heads are zero-initialised so the at-construction behaviour is the identity map on its inputs.

  • DeformationTable — per-Gaussian boolean flag indicating whether each Gaussian is animated by the deform-net. Provides prune() / duplicate() / split() for lock-step resize with DefaultStrategy-style densification.

Port targets:

  • holohub/applications/surgical_scene_recon/training/scene/deformation.py

  • _deformation_table / update_deformation_table_with_tool_masks in holohub/applications/surgical_scene_recon/training/gsplat_train.py

class DeformNetwork(feature_dim: int, hidden_dim: int = 64, num_layers: int = 3)[source]

MLP head emitting per-Gaussian deltas on means / quats / opacities.

Architecture: a num_layers-deep ReLU trunk consuming plane_features (typically the output of gsplat.contrib.dynamic.HexPlaneField), followed by three linear heads — 3-d for the position delta, 4-d for the quaternion delta, and 1-d for the opacity delta. The three heads are zero-initialised so the at-construction forward pass returns (means, quats, opacities) unchanged (identity map). Locked by test_deform_net_zero_init_is_identity.

The t argument is reserved for future time-aware extensions; the current implementation expects time information to already be encoded into plane_features via HexPlaneField.

Parameters:
  • feature_dim – Dimensionality of plane_features (must match the producing HexPlaneField’s feat_dim).

  • hidden_dim – Trunk width. Default 64.

  • num_layers – Number of Linear + ReLU blocks in the trunk (must be >= 1). Default 3.

forward(means: Tensor, quats: Tensor, opacities: Tensor, t: Tensor, plane_features: Tensor) Tuple[Tensor, Tensor, Tensor][source]

Apply the per-Gaussian deformation deltas.

Parameters:
  • means(N, 3) Gaussian centres.

  • quats(N, 4) rotation quaternions (any layout; deltas are added in the same layout).

  • opacities(N, 1) opacity values (raw or activated; the delta is added in the same space).

  • t(N, 1) (or broadcastable) time stamp. Currently ignored; kept in the signature for forward-compatibility with time-aware variants.

  • plane_features(N, feature_dim) features sampled from the HexPlane field. Must share dtype with the other tensors.

Returns:

(means_new, quats_new, opacities_new) — same shapes as the inputs.

Raises:

ValueError – on batch-dim mismatch, plane-feature-dim mismatch, or dtype mismatch among the four tensors.

class DeformationTable(num_gaussians: int, device: device | None = None)[source]

Per-Gaussian boolean table marking which Gaussians are dynamic.

Used by gsplat.contrib.dynamic.DynamicStrategy to decide which Gaussians get fed through DeformNetwork each step. Resize lock-step with the gsplat DefaultStrategy densification ops via prune(), duplicate(), and split().

The table is a plain torch.bool tensor (no autograd, no parameters) so it adds zero overhead to the optimiser state.

Parameters:
  • num_gaussians – Initial Gaussian count.

  • device – Optional device (defaults to CPU).

duplicate(indices: Tensor) None[source]

Append duplicates of the given indices (DefaultStrategy duplicate op).

Originals stay; one duplicate is appended per index, inheriting the parent’s dynamic flag.

prune(keep_mask: Tensor) None[source]

Drop Gaussians where keep_mask is False (DefaultStrategy prune op).

Parameters:

keep_mask(N,) bool tensor with True for surviving Gaussians. Length must equal current table size.

set_indices(indices: Tensor, value: bool = True) None[source]

Mark the given Gaussian indices as dynamic (or static if value is False).

split(indices: Tensor, factor: int = 2) None[source]

Replace each index with factor children (DefaultStrategy split op).

Original indices are removed; factor children are appended per split index, each inheriting the parent’s dynamic flag.

Parameters:
  • indices(S,) indices to split.

  • factor – Children per split. Default 2 (matches the gsplat DefaultStrategy convention).

HexPlane regularizers (experimental).

Three regularizers ported from G-SHARP v0.2’s training/scene/regulation.py. They operate on lists of feature-plane tensors of shape (B, C, H, W) — typically the per-scale gsplat.contrib.dynamic.HexPlaneField planes.

  • plane_smoothness() and time_smoothness() share the same math (sum of squared second-difference along the H axis); the distinction is which planes the caller passes. For HexPlane outputs, the spatial planes [xy, xz, yz] (combo indices [0, 1, 3]) are smoothed spatially, and the spatio-temporal planes [xt, yt, zt] (combo indices [2, 4, 5]) are smoothed in time.

  • time_l1() regularizes spatio-temporal planes toward 1.0 (their initialization), encouraging static regions to stay put — matches L1TimePlanes in the G-SHARP source.

Caller selects which planes go where; these functions just iterate the input list and sum the result.

hexplane_regularization(field: HexPlaneField, lambda_plane_smooth: float = 1.0, lambda_time_smooth: float = 1.0, lambda_time_l1: float = 1.0) Tensor[source]

Convenience wrapper: applies the three HexPlane regularizers using the field’s own spatial / temporal plane accessors.

Use this instead of calling plane_smoothness(), time_smoothness(), time_l1() with hand-partitioned plane lists — the spatial / temporal partition is a property of the HexPlane construction and lives on gsplat.contrib.dynamic.HexPlaneField. Hand-rolled partitions drift when HexPlaneField is refactored.

Parameters:
  • field – A HexPlaneField instance.

  • lambda_plane_smooth – Scalar weight for the spatial smoothness term.

  • lambda_time_smooth – Scalar weight for the temporal smoothness term.

  • lambda_time_l1 – Scalar weight for the temporal L1 deviation term.

Returns:

Scalar tensor — weighted sum of the three regularizers.

plane_smoothness(planes: Sequence[Tensor]) Tensor[source]

Spatial 2D smoothness — sum of mean squared second-difference along H.

Pass the spatial HexPlane planes (combo indices [0, 1, 3] of a 4D HexPlane: xy, xz, yz) for the spatial-smoothness regularizer described in the G-SHARP proposal.

Parameters:

planes – Iterable of 4D (B, C, H, W) plane tensors.

Returns:

Scalar tensor (sum across planes; mean within each plane).

time_l1(planes: Sequence[Tensor]) Tensor[source]

L1 deviation from 1.0 on spatio-temporal planes.

HexPlane spatio-temporal planes are initialised to 1.0 so deformation starts identity-like. time_l1 penalises their deviation from 1.0, encouraging static tissue to keep an identity (no-time-deformation) response. Matches G-SHARP’s L1TimePlanes regularizer.

Parameters:

planes – Iterable of plane tensors (any shape with mean()).

Returns:

Scalar L1 deviation summed across planes (mean within each plane).

time_smoothness(planes: Sequence[Tensor]) Tensor[source]

Temporal smoothness — same math as plane_smoothness().

Pass the spatio-temporal HexPlane planes (combo indices [2, 4, 5] of a 4D HexPlane: xt, yt, zt). For these planes the H axis is the temporal axis (per the HexPlaneField._init_grid_param() reversed-order layout), so the second-difference squared is the per-plane temporal smoothness.

Parameters:

planes – Iterable of 4D (B, C, H, W) spatio-temporal plane tensors.

Returns:

Scalar tensor.

DynamicStrategy — deformable extension of DefaultStrategy.

Public API:

  • DynamicStrategy — subclass of gsplat.strategy.DefaultStrategy that additionally tracks a DeformationTable and resizes it in lock-step with each densify / prune op so that subsequent forward passes can route only the dynamic Gaussians through the deform-net.

Design notes:

  • gsplat.rasterization has no time axis. The deformation pass itself (HexPlane → DeformNetwork(means, quats, opacities)) lives in the trainer (examples/dynamic_surgical_trainer.py) and runs before rasterization(...) is called, mirroring G-SHARP’s rasterize_splats. This strategy class only owns the densification policy + deformation-table bookkeeping; it does not apply the deform-net itself.

  • The @dataclass decorator is intentionally not applied here: no new fields are added with defaults, so we inherit the parent’s auto-generated __init__ cleanly without dataclass-field ordering errors.

class DynamicStrategy(prune_opa: float = 0.005, grow_grad2d: float = 0.0002, grow_scale3d: float = 0.01, grow_scale2d: float = 0.05, prune_scale3d: float = 0.1, prune_scale2d: float = 0.15, refine_scale2d_stop_iter: int = 0, refine_start_iter: int = 500, refine_stop_iter: int = 15000, reset_every: int = 3000, refine_every: int = 100, pause_refine_after_reset: int = 0, absgrad: bool = False, revised_opacity: bool = False, verbose: bool = False, key_for_gradient: Literal['means2d', 'gradient_2dgs'] = 'means2d')[source]

Deformable-aware densification / pruning strategy.

Extra invariants on top of DefaultStrategy:

  • state["dynamic_mask"] is a per-Gaussian torch.bool tensor of shape (num_gaussians,) that flags which Gaussians the trainer should route through the DeformNet. Resized in lock-step with params["means"] by gsplat’s strategy ops (which iterate every tensor in state and apply the per-Gaussian split / duplicate / prune permutation — see gsplat/strategy/ops.py:135, 191-195, 223-225). Identity is preserved across split (children inherit the parent’s flag).

Note that the HexPlane and DeformNet trainables are not part of params. gsplat’s densification ops blindly iterate every entry in params and split/duplicate/prune them per-Gaussian; non-per-Gaussian tensors (HexPlane plane grids, DeformNet MLP weights) would be indexed with out-of-bounds per-Gaussian indices. Keep those trainables in their own optimizers, wired separately by the trainer (see examples/dynamic_surgical_trainer.py:build_deform_modules).

Historical note: an earlier version of this strategy stored a DeformationTable wrapper at state["deformation_table"] and resized it via a custom _resize_table hook. That hook did not preserve survivor identity across split, and the trainer never consulted the mask anyway. Wiring it through state as a plain tensor (so gsplat’s ops do the right thing) closes both gaps. The wrapper class is still importable for back-compat; the canonical mask is the tensor in state.

check_sanity(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer]) None[source]

Sanity-check identical to DefaultStrategy.check_sanity().

The HexPlane / DeformNet trainables live outside params (see the class docstring), so this method has no extra requirements beyond the parent’s “params and optimizers share keys, and per-Gaussian keys means/scales/quats/opacities are present” check.

initialize_state(scene_scale: float = 1.0, num_gaussians: int = 0, device: device | None = None, init_dynamic: bool = True) Dict[str, Any][source]

Extend DefaultStrategy.initialize_state() with a per-Gaussian dynamic mask.

The mask is stored under state["dynamic_mask"] as a plain bool tensor (shape (num_gaussians,)). It is not wrapped in a DeformationTable so that gsplat’s densification ops in gsplat.strategy.ops (which iterate every tensor in state and apply the per-Gaussian split / duplicate / prune permutation automatically — see ops.py:135, 191-195, 223-225) can resize it in lock-step with params["means"] with identity preservation across split. The wrapper class is still exposed for callers that want the helper API (see DeformationTable), but the canonical mask now lives in state.

Parameters:
  • scene_scale – Forwarded to the parent.

  • num_gaussians – Initial Gaussian count.

  • device – Device for the mask tensor (defaults to CPU).

  • init_dynamic – Initial value for every flag. True matches the current trainer behaviour (every Gaussian goes through DeformNetwork); set False if you have a static-by-default workflow and intend to flip dynamic indices manually.

Returns:

The strategy state dict with the additional dynamic_mask entry (a torch.bool tensor of shape (num_gaussians,)).

step_post_backward(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer], state: Dict[str, Any], step: int, info: Dict[str, Any], packed: bool = False) None[source]

Post-backward hook — defers to the parent.

state["dynamic_mask"] is resized in lock-step with the per-Gaussian parameters automatically by gsplat’s densification ops (which iterate every tensor in state and apply the same per-Gaussian permutation as the params — see gsplat/strategy/ops.py:135). No extra hook needed here.

Raises:

RuntimeError – if state["dynamic_mask"] is missing (i.e. initialize_state() wasn’t called first).

step_pre_backward(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer], state: Dict[str, Any], step: int, info: Dict[str, Any]) None[source]

Pre-backward hook — passthrough to the parent.