Source code for gsplat.contrib.dynamic.deformation
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Portions of this file (DeformNetwork architecture, DeformationTable
# bookkeeping) are ported from the G-SHARP v0.2 surgical reconstruction
# application; see holohub/applications/surgical_scene_recon/training.
"""Deformation network and per-Gaussian deformation table (experimental).
Public API:
- :class:`DeformNetwork` — MLP that consumes HexPlane features and emits
per-Gaussian deltas on ``(means, quats, opacities)`` at a given time. Heads
are zero-initialised so the at-construction behaviour is the identity map
on its inputs.
- :class:`DeformationTable` — per-Gaussian boolean flag indicating whether
each Gaussian is animated by the deform-net. Provides
:meth:`prune` / :meth:`duplicate` / :meth:`split` for lock-step resize
with ``DefaultStrategy``-style densification.
Port targets:
- ``holohub/applications/surgical_scene_recon/training/scene/deformation.py``
- ``_deformation_table`` / ``update_deformation_table_with_tool_masks`` in
``holohub/applications/surgical_scene_recon/training/gsplat_train.py``
"""
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
__all__ = ["DeformNetwork", "DeformationTable"]
[docs]
class DeformNetwork(nn.Module):
"""MLP head emitting per-Gaussian deltas on means / quats / opacities.
Architecture: a *num_layers*-deep ReLU trunk consuming
``plane_features`` (typically the output of
:class:`gsplat.contrib.dynamic.HexPlaneField`), followed by three linear
heads — 3-d for the position delta, 4-d for the quaternion delta, and
1-d for the opacity delta. The three heads are zero-initialised so the
at-construction forward pass returns ``(means, quats, opacities)``
unchanged (identity map). Locked by
``test_deform_net_zero_init_is_identity``.
The *t* argument is reserved for future time-aware extensions; the
current implementation expects time information to already be encoded
into ``plane_features`` via :class:`HexPlaneField`.
Args:
feature_dim: Dimensionality of ``plane_features`` (must match
the producing :class:`HexPlaneField`'s ``feat_dim``).
hidden_dim: Trunk width. Default ``64``.
num_layers: Number of ``Linear + ReLU`` blocks in the trunk
(must be ``>= 1``). Default ``3``.
"""
def __init__(
self,
feature_dim: int,
hidden_dim: int = 64,
num_layers: int = 3,
) -> None:
super().__init__()
if num_layers < 1:
raise ValueError(f"num_layers must be >= 1, got {num_layers}.")
if feature_dim < 1:
raise ValueError(f"feature_dim must be >= 1, got {feature_dim}.")
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
layers: list[nn.Module] = [nn.Linear(feature_dim, hidden_dim), nn.ReLU()]
for _ in range(num_layers - 1):
layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
self.trunk = nn.Sequential(*layers)
self.pos_head = nn.Linear(hidden_dim, 3)
self.quat_head = nn.Linear(hidden_dim, 4)
self.opacity_head = nn.Linear(hidden_dim, 1)
# Zero-init the heads so the initial deformation is identity.
# Gradients still flow through the heads (so the trunk learns).
for head in (self.pos_head, self.quat_head, self.opacity_head):
nn.init.zeros_(head.weight)
nn.init.zeros_(head.bias)
[docs]
def forward(
self,
means: Tensor,
quats: Tensor,
opacities: Tensor,
t: Tensor, # noqa: ARG002 — reserved for future time-aware extensions
plane_features: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Apply the per-Gaussian deformation deltas.
Args:
means: ``(N, 3)`` Gaussian centres.
quats: ``(N, 4)`` rotation quaternions (any layout; deltas are
added in the same layout).
opacities: ``(N, 1)`` opacity values (raw or activated; the
delta is added in the same space).
t: ``(N, 1)`` (or broadcastable) time stamp. Currently ignored;
kept in the signature for forward-compatibility with
time-aware variants.
plane_features: ``(N, feature_dim)`` features sampled from the
HexPlane field. Must share dtype with the other tensors.
Returns:
``(means_new, quats_new, opacities_new)`` — same shapes as the
inputs.
Raises:
ValueError: on batch-dim mismatch, plane-feature-dim mismatch,
or dtype mismatch among the four tensors.
"""
n = means.shape[0]
if (
quats.shape[0] != n
or opacities.shape[0] != n
or plane_features.shape[0] != n
):
raise ValueError(
f"DeformNetwork: batch dim mismatch — means {means.shape[0]}, "
f"quats {quats.shape[0]}, opacities {opacities.shape[0]}, "
f"plane_features {plane_features.shape[0]}."
)
if plane_features.shape[-1] != self.feature_dim:
raise ValueError(
f"DeformNetwork: plane_features last dim "
f"{plane_features.shape[-1]} != feature_dim {self.feature_dim}."
)
if not (means.dtype == quats.dtype == opacities.dtype == plane_features.dtype):
raise ValueError(
f"DeformNetwork: dtype mismatch — means {means.dtype}, "
f"quats {quats.dtype}, opacities {opacities.dtype}, "
f"plane_features {plane_features.dtype}."
)
h = self.trunk(plane_features)
d_means = self.pos_head(h)
d_quats = self.quat_head(h)
d_opacities = self.opacity_head(h)
return means + d_means, quats + d_quats, opacities + d_opacities
[docs]
class DeformationTable:
"""Per-Gaussian boolean table marking which Gaussians are dynamic.
Used by :class:`gsplat.contrib.dynamic.DynamicStrategy` to decide which
Gaussians get fed through :class:`DeformNetwork` each step. Resize
lock-step with the gsplat ``DefaultStrategy`` densification ops via
:meth:`prune`, :meth:`duplicate`, and :meth:`split`.
The table is a plain ``torch.bool`` tensor (no autograd, no parameters)
so it adds zero overhead to the optimiser state.
Args:
num_gaussians: Initial Gaussian count.
device: Optional device (defaults to CPU).
"""
def __init__(
self, num_gaussians: int, device: Optional[torch.device] = None
) -> None:
if num_gaussians < 0:
raise ValueError(
f"DeformationTable: num_gaussians must be >= 0, got {num_gaussians}."
)
self.mask = torch.zeros(num_gaussians, dtype=torch.bool, device=device)
def __len__(self) -> int:
return int(self.mask.shape[0])
[docs]
def set_indices(self, indices: Tensor, value: bool = True) -> None:
"""Mark the given Gaussian indices as dynamic (or static if *value* is False)."""
self.mask[indices] = value
[docs]
def prune(self, keep_mask: Tensor) -> None:
"""Drop Gaussians where *keep_mask* is False (DefaultStrategy prune op).
Args:
keep_mask: ``(N,)`` bool tensor with ``True`` for surviving
Gaussians. Length must equal current table size.
"""
if keep_mask.shape != self.mask.shape:
raise ValueError(
f"DeformationTable.prune: keep_mask shape {tuple(keep_mask.shape)} "
f"!= table shape {tuple(self.mask.shape)}."
)
self.mask = self.mask[keep_mask]
[docs]
def duplicate(self, indices: Tensor) -> None:
"""Append duplicates of the given indices (DefaultStrategy duplicate op).
Originals stay; one duplicate is appended per index, inheriting
the parent's dynamic flag.
"""
self.mask = torch.cat([self.mask, self.mask[indices]], dim=0)
[docs]
def split(self, indices: Tensor, factor: int = 2) -> None:
"""Replace each index with *factor* children (DefaultStrategy split op).
Original indices are removed; ``factor`` children are appended per
split index, each inheriting the parent's dynamic flag.
Args:
indices: ``(S,)`` indices to split.
factor: Children per split. Default ``2`` (matches the gsplat
``DefaultStrategy`` convention).
"""
if factor < 1:
raise ValueError(
f"DeformationTable.split: factor must be >= 1, got {factor}."
)
keep = torch.ones(self.mask.shape[0], dtype=torch.bool, device=self.mask.device)
keep[indices] = False
children = self.mask[indices].repeat_interleave(factor)
self.mask = torch.cat([self.mask[keep], children], dim=0)