gsplat.contrib (experimental)¶
Warning
Everything under gsplat.contrib is experimental. APIs may change
or be removed between releases.
gsplat experimental / research-grade components.
Warning
Everything under gsplat.contrib is experimental. APIs may change
or be removed between releases. Pin a version if you depend on it in
production.
Contents:
gsplat.contrib.dynamic— deformable / 4D Gaussian Splatting (HexPlane field, MLP deformation network,DynamicStrategy).
Dynamic / 4D Gaussians¶
HexPlane spatio-temporal feature field (experimental).
Multi-resolution 6-plane decomposition of a 4D (x, y, z, t) feature
field, ported from G-SHARP v0.2 (training/scene/hexplane.py), itself
derived from the K-Planes / 4DGaussians formulation.
Six 2D feature planes — one for each pair of input axes (xy, xz, xt,
yz, yt, zt) — are sampled with bilinear interpolation, multiplied
element-wise across planes (giving a per-scale feature vector), and
concatenated across multi-resolution scales to produce the final feature.
- class HexPlaneField(bounds: float = 1.6, planes_config: dict | None = None, multires: Sequence[int] | None = None)[source]¶
Multi-resolution 6-plane decomposition of a 4D feature field.
For each scale, six 2D feature planes covering every pair of the four input axes
(x, y, z, t)are sampled bilinearly and multiplied element-wise to produce a per-scale feature vector. Withconcat_features=True(the default and only mode currently supported) feature vectors from all scales are concatenated; the resulting feature dimensionality isoutput_coordinate_dim * len(multires).Spatial coordinates are normalized into
[-1, 1]using the axis-aligned bounding box[-bounds, bounds]along x/y/z; the time coordinate is passed through unchanged (callers should pre-normalize it into a sensible range — values outside the grid are clamped viapadding_mode="border"intorch.nn.functional.grid_sample()).- Parameters:
bounds – Half-extent of the spatial AABB along each axis (default
1.6, matching G-SHARP).planes_config – Optional dict with keys
grid_dimensions,input_coordinate_dim,output_coordinate_dim,resolution. The default is suitable for surgical-scene 4D fields (32 features per plane,[64, 64, 64, 25]resolution).multires – Iterable of integer multipliers applied to the spatial grid resolution for each scale; the temporal resolution is kept fixed across scales. Default
(1, 2).
- forward(xyzt: Tensor) Tensor[source]¶
Sample the multi-resolution HexPlane at the given 4D points.
- Parameters:
xyzt –
(N, 4)(or any shape ending in 4) with[..., :3]being world-space spatial coordinates inside the AABB and[..., 3]being the temporal coordinate.- Returns:
(N, feat_dim)feature tensor wherefeat_dim == output_coordinate_dim * len(multires)underconcat_features=True.
Deformation network and per-Gaussian deformation table (experimental).
Public API:
DeformNetwork— MLP that consumes HexPlane features and emits per-Gaussian deltas on(means, quats, opacities)at a given time. Heads are zero-initialised so the at-construction behaviour is the identity map on its inputs.DeformationTable— per-Gaussian boolean flag indicating whether each Gaussian is animated by the deform-net. Providesprune()/duplicate()/split()for lock-step resize withDefaultStrategy-style densification.
Port targets:
holohub/applications/surgical_scene_recon/training/scene/deformation.py_deformation_table/update_deformation_table_with_tool_masksinholohub/applications/surgical_scene_recon/training/gsplat_train.py
- class DeformNetwork(feature_dim: int, hidden_dim: int = 64, num_layers: int = 3)[source]¶
MLP head emitting per-Gaussian deltas on means / quats / opacities.
Architecture: a num_layers-deep ReLU trunk consuming
plane_features(typically the output ofgsplat.contrib.dynamic.HexPlaneField), followed by three linear heads — 3-d for the position delta, 4-d for the quaternion delta, and 1-d for the opacity delta. The three heads are zero-initialised so the at-construction forward pass returns(means, quats, opacities)unchanged (identity map). Locked bytest_deform_net_zero_init_is_identity.The t argument is reserved for future time-aware extensions; the current implementation expects time information to already be encoded into
plane_featuresviaHexPlaneField.- Parameters:
feature_dim – Dimensionality of
plane_features(must match the producingHexPlaneField’sfeat_dim).hidden_dim – Trunk width. Default
64.num_layers – Number of
Linear + ReLUblocks in the trunk (must be>= 1). Default3.
- forward(means: Tensor, quats: Tensor, opacities: Tensor, t: Tensor, plane_features: Tensor) Tuple[Tensor, Tensor, Tensor][source]¶
Apply the per-Gaussian deformation deltas.
- Parameters:
means –
(N, 3)Gaussian centres.quats –
(N, 4)rotation quaternions (any layout; deltas are added in the same layout).opacities –
(N, 1)opacity values (raw or activated; the delta is added in the same space).t –
(N, 1)(or broadcastable) time stamp. Currently ignored; kept in the signature for forward-compatibility with time-aware variants.plane_features –
(N, feature_dim)features sampled from the HexPlane field. Must share dtype with the other tensors.
- Returns:
(means_new, quats_new, opacities_new)— same shapes as the inputs.- Raises:
ValueError – on batch-dim mismatch, plane-feature-dim mismatch, or dtype mismatch among the four tensors.
- class DeformationTable(num_gaussians: int, device: device | None = None)[source]¶
Per-Gaussian boolean table marking which Gaussians are dynamic.
Used by
gsplat.contrib.dynamic.DynamicStrategyto decide which Gaussians get fed throughDeformNetworkeach step. Resize lock-step with the gsplatDefaultStrategydensification ops viaprune(),duplicate(), andsplit().The table is a plain
torch.booltensor (no autograd, no parameters) so it adds zero overhead to the optimiser state.- Parameters:
num_gaussians – Initial Gaussian count.
device – Optional device (defaults to CPU).
- duplicate(indices: Tensor) None[source]¶
Append duplicates of the given indices (DefaultStrategy duplicate op).
Originals stay; one duplicate is appended per index, inheriting the parent’s dynamic flag.
- prune(keep_mask: Tensor) None[source]¶
Drop Gaussians where keep_mask is False (DefaultStrategy prune op).
- Parameters:
keep_mask –
(N,)bool tensor withTruefor surviving Gaussians. Length must equal current table size.
- set_indices(indices: Tensor, value: bool = True) None[source]¶
Mark the given Gaussian indices as dynamic (or static if value is False).
- split(indices: Tensor, factor: int = 2) None[source]¶
Replace each index with factor children (DefaultStrategy split op).
Original indices are removed;
factorchildren are appended per split index, each inheriting the parent’s dynamic flag.- Parameters:
indices –
(S,)indices to split.factor – Children per split. Default
2(matches the gsplatDefaultStrategyconvention).
HexPlane regularizers (experimental).
Three regularizers ported from G-SHARP v0.2’s
training/scene/regulation.py. They operate on lists of feature-plane
tensors of shape (B, C, H, W) — typically the per-scale
gsplat.contrib.dynamic.HexPlaneField planes.
plane_smoothness()andtime_smoothness()share the same math (sum of squared second-difference along the H axis); the distinction is which planes the caller passes. For HexPlane outputs, the spatial planes[xy, xz, yz](combo indices[0, 1, 3]) are smoothed spatially, and the spatio-temporal planes[xt, yt, zt](combo indices[2, 4, 5]) are smoothed in time.time_l1()regularizes spatio-temporal planes toward 1.0 (their initialization), encouraging static regions to stay put — matchesL1TimePlanesin the G-SHARP source.
Caller selects which planes go where; these functions just iterate the input list and sum the result.
- hexplane_regularization(field: HexPlaneField, lambda_plane_smooth: float = 1.0, lambda_time_smooth: float = 1.0, lambda_time_l1: float = 1.0) Tensor[source]¶
Convenience wrapper: applies the three HexPlane regularizers using the field’s own spatial / temporal plane accessors.
Use this instead of calling
plane_smoothness(),time_smoothness(),time_l1()with hand-partitioned plane lists — the spatial / temporal partition is a property of the HexPlane construction and lives ongsplat.contrib.dynamic.HexPlaneField. Hand-rolled partitions drift when HexPlaneField is refactored.- Parameters:
field – A
HexPlaneFieldinstance.lambda_plane_smooth – Scalar weight for the spatial smoothness term.
lambda_time_smooth – Scalar weight for the temporal smoothness term.
lambda_time_l1 – Scalar weight for the temporal L1 deviation term.
- Returns:
Scalar tensor — weighted sum of the three regularizers.
- plane_smoothness(planes: Sequence[Tensor]) Tensor[source]¶
Spatial 2D smoothness — sum of mean squared second-difference along H.
Pass the spatial HexPlane planes (combo indices
[0, 1, 3]of a 4D HexPlane:xy,xz,yz) for the spatial-smoothness regularizer described in the G-SHARP proposal.- Parameters:
planes – Iterable of 4D
(B, C, H, W)plane tensors.- Returns:
Scalar tensor (sum across planes; mean within each plane).
- time_l1(planes: Sequence[Tensor]) Tensor[source]¶
L1 deviation from 1.0 on spatio-temporal planes.
HexPlane spatio-temporal planes are initialised to 1.0 so deformation starts identity-like.
time_l1penalises their deviation from 1.0, encouraging static tissue to keep an identity (no-time-deformation) response. Matches G-SHARP’sL1TimePlanesregularizer.- Parameters:
planes – Iterable of plane tensors (any shape with
mean()).- Returns:
Scalar L1 deviation summed across planes (mean within each plane).
- time_smoothness(planes: Sequence[Tensor]) Tensor[source]¶
Temporal smoothness — same math as
plane_smoothness().Pass the spatio-temporal HexPlane planes (combo indices
[2, 4, 5]of a 4D HexPlane:xt,yt,zt). For these planes the H axis is the temporal axis (per theHexPlaneField._init_grid_param()reversed-order layout), so the second-difference squared is the per-plane temporal smoothness.- Parameters:
planes – Iterable of 4D
(B, C, H, W)spatio-temporal plane tensors.- Returns:
Scalar tensor.
DynamicStrategy — deformable extension of DefaultStrategy.
Public API:
DynamicStrategy— subclass ofgsplat.strategy.DefaultStrategythat additionally tracks aDeformationTableand resizes it in lock-step with each densify / prune op so that subsequent forward passes can route only the dynamic Gaussians through the deform-net.
Design notes:
gsplat.rasterizationhas no time axis. The deformation pass itself (HexPlane →DeformNetwork→(means, quats, opacities)) lives in the trainer (examples/dynamic_surgical_trainer.py) and runs beforerasterization(...)is called, mirroring G-SHARP’srasterize_splats. This strategy class only owns the densification policy + deformation-table bookkeeping; it does not apply the deform-net itself.The
@dataclassdecorator is intentionally not applied here: no new fields are added with defaults, so we inherit the parent’s auto-generated__init__cleanly without dataclass-field ordering errors.
- class DynamicStrategy(prune_opa: float = 0.005, grow_grad2d: float = 0.0002, grow_scale3d: float = 0.01, grow_scale2d: float = 0.05, prune_scale3d: float = 0.1, prune_scale2d: float = 0.15, refine_scale2d_stop_iter: int = 0, refine_start_iter: int = 500, refine_stop_iter: int = 15000, reset_every: int = 3000, refine_every: int = 100, pause_refine_after_reset: int = 0, absgrad: bool = False, revised_opacity: bool = False, verbose: bool = False, key_for_gradient: Literal['means2d', 'gradient_2dgs'] = 'means2d')[source]¶
Deformable-aware densification / pruning strategy.
Extra invariants on top of
DefaultStrategy:state["dynamic_mask"]is a per-Gaussiantorch.booltensor of shape(num_gaussians,)that flags which Gaussians the trainer should route through the DeformNet. Resized in lock-step withparams["means"]by gsplat’s strategy ops (which iterate every tensor in state and apply the per-Gaussian split / duplicate / prune permutation — seegsplat/strategy/ops.py:135, 191-195, 223-225). Identity is preserved across split (children inherit the parent’s flag).
Note that the HexPlane and DeformNet trainables are not part of params. gsplat’s densification ops blindly iterate every entry in params and split/duplicate/prune them per-Gaussian; non-per-Gaussian tensors (HexPlane plane grids, DeformNet MLP weights) would be indexed with out-of-bounds per-Gaussian indices. Keep those trainables in their own optimizers, wired separately by the trainer (see
examples/dynamic_surgical_trainer.py:build_deform_modules).Historical note: an earlier version of this strategy stored a
DeformationTablewrapper atstate["deformation_table"]and resized it via a custom_resize_tablehook. That hook did not preserve survivor identity across split, and the trainer never consulted the mask anyway. Wiring it through state as a plain tensor (so gsplat’s ops do the right thing) closes both gaps. The wrapper class is still importable for back-compat; the canonical mask is the tensor in state.- check_sanity(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer]) None[source]¶
Sanity-check identical to
DefaultStrategy.check_sanity().The HexPlane / DeformNet trainables live outside params (see the class docstring), so this method has no extra requirements beyond the parent’s “params and optimizers share keys, and per-Gaussian keys means/scales/quats/opacities are present” check.
- initialize_state(scene_scale: float = 1.0, num_gaussians: int = 0, device: device | None = None, init_dynamic: bool = True) Dict[str, Any][source]¶
Extend
DefaultStrategy.initialize_state()with a per-Gaussian dynamic mask.The mask is stored under
state["dynamic_mask"]as a plain bool tensor (shape(num_gaussians,)). It is not wrapped in aDeformationTableso that gsplat’s densification ops ingsplat.strategy.ops(which iterate every tensor in state and apply the per-Gaussian split / duplicate / prune permutation automatically — see ops.py:135, 191-195, 223-225) can resize it in lock-step withparams["means"]with identity preservation across split. The wrapper class is still exposed for callers that want the helper API (seeDeformationTable), but the canonical mask now lives in state.- Parameters:
scene_scale – Forwarded to the parent.
num_gaussians – Initial Gaussian count.
device – Device for the mask tensor (defaults to CPU).
init_dynamic – Initial value for every flag.
Truematches the current trainer behaviour (every Gaussian goes throughDeformNetwork); setFalseif you have a static-by-default workflow and intend to flip dynamic indices manually.
- Returns:
The strategy state dict with the additional
dynamic_maskentry (atorch.booltensor of shape(num_gaussians,)).
- step_post_backward(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer], state: Dict[str, Any], step: int, info: Dict[str, Any], packed: bool = False) None[source]¶
Post-backward hook — defers to the parent.
state["dynamic_mask"]is resized in lock-step with the per-Gaussian parameters automatically by gsplat’s densification ops (which iterate every tensor in state and apply the same per-Gaussian permutation as the params — seegsplat/strategy/ops.py:135). No extra hook needed here.- Raises:
RuntimeError – if
state["dynamic_mask"]is missing (i.e.initialize_state()wasn’t called first).