# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Portions of this file (HexPlane construction) are adapted from the
# Nerfstudio K-Planes / HexPlane reference implementation; see
# https://github.com/sarafridov/K-Planes for the upstream.
"""HexPlane spatio-temporal feature field (experimental).
Multi-resolution 6-plane decomposition of a 4D ``(x, y, z, t)`` feature
field, ported from G-SHARP v0.2 (`training/scene/hexplane.py`), itself
derived from the K-Planes / 4DGaussians formulation.
Six 2D feature planes — one for each pair of input axes ``(xy, xz, xt,
yz, yt, zt)`` — are sampled with bilinear interpolation, multiplied
element-wise across planes (giving a per-scale feature vector), and
concatenated across multi-resolution scales to produce the final feature.
"""
from __future__ import annotations
import itertools
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
__all__ = ["HexPlaneField"]
# ---------------------------------------------------------------------------
# Internal helpers (private; faithful port of the G-SHARP implementations)
# ---------------------------------------------------------------------------
def _normalize_aabb(pts: Tensor, aabb: Tensor) -> Tensor:
"""Linear remap of *pts* into ``[-1, 1]`` using the axis-aligned bounding box.
Matches the ``normalize_aabb`` helper in the G-SHARP source. With
``aabb = [[+b, +b, +b], [-b, -b, -b]]`` the remap is
``-pts / b`` (sign-flipped, kept for parity with upstream).
"""
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def _grid_sample_wrapper(
grid: Tensor, coords: Tensor, align_corners: bool = True
) -> Tensor:
grid_dim = coords.shape[-1]
if grid.dim() == grid_dim + 1:
grid = grid.unsqueeze(0)
if coords.dim() == 2:
coords = coords.unsqueeze(0)
if grid_dim not in (2, 3):
raise NotImplementedError(
f"_grid_sample_wrapper supports 2D / 3D coords only; got {grid_dim}D."
)
coords = coords.view(
[coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])
)
b, feature_dim = grid.shape[:2]
n = coords.shape[-2]
interp = F.grid_sample(
grid,
coords,
align_corners=align_corners,
mode="bilinear",
padding_mode="border",
)
return interp.view(b, feature_dim, n).transpose(-1, -2).squeeze(0)
def _init_grid_param(
grid_nd: int,
in_dim: int,
out_dim: int,
reso: Sequence[int],
a: float = 0.1,
b: float = 0.5,
) -> nn.ParameterList:
if in_dim != len(reso):
raise ValueError(f"_init_grid_param: in_dim={in_dim} != len(reso)={len(reso)}.")
if grid_nd > in_dim:
raise ValueError(f"_init_grid_param: grid_nd={grid_nd} > in_dim={in_dim}.")
has_time_planes = in_dim == 4
coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
grid_coefs = nn.ParameterList()
for coo_comb in coo_combs:
new_grid_coef = nn.Parameter(
torch.empty([1, out_dim] + [reso[cc] for cc in coo_comb[::-1]])
)
if has_time_planes and 3 in coo_comb:
# Spatio-temporal planes initialise to ones so deformation
# starts identity-like; matches G-SHARP convention.
nn.init.ones_(new_grid_coef)
else:
nn.init.uniform_(new_grid_coef, a=a, b=b)
grid_coefs.append(new_grid_coef)
return grid_coefs
def _interpolate_ms_features(
pts: Tensor,
ms_grids: nn.ModuleList,
grid_dimensions: int,
concat_features: bool,
) -> Tensor:
coo_combs = list(itertools.combinations(range(pts.shape[-1]), grid_dimensions))
multi_scale_interp: list[Tensor] = [] if concat_features else None
summed: Tensor | float = 0.0
for grids in ms_grids:
interp_space: Tensor | float = 1.0
for ci, coo_comb in enumerate(coo_combs):
feature_dim = grids[ci].shape[1]
interp_plane = _grid_sample_wrapper(grids[ci], pts[..., coo_comb]).view(
-1, feature_dim
)
interp_space = interp_space * interp_plane
if concat_features:
assert multi_scale_interp is not None
multi_scale_interp.append(interp_space)
else:
summed = summed + interp_space
if concat_features:
assert multi_scale_interp is not None
return torch.cat(multi_scale_interp, dim=-1)
return summed if isinstance(summed, Tensor) else torch.zeros(0)
# ---------------------------------------------------------------------------
# Public class
# ---------------------------------------------------------------------------
_DEFAULT_PLANE_CONFIG: dict = {
"grid_dimensions": 2,
"input_coordinate_dim": 4,
"output_coordinate_dim": 32,
"resolution": [64, 64, 64, 25],
}
_DEFAULT_MULTIRES: tuple[int, ...] = (1, 2)
[docs]
class HexPlaneField(nn.Module):
"""Multi-resolution 6-plane decomposition of a 4D feature field.
For each scale, six 2D feature planes covering every pair of the four
input axes ``(x, y, z, t)`` are sampled bilinearly and multiplied
element-wise to produce a per-scale feature vector. With
``concat_features=True`` (the default and only mode currently supported)
feature vectors from all scales are concatenated; the resulting feature
dimensionality is ``output_coordinate_dim * len(multires)``.
Spatial coordinates are normalized into ``[-1, 1]`` using the
axis-aligned bounding box ``[-bounds, bounds]`` along x/y/z; the time
coordinate is passed through unchanged (callers should pre-normalize it
into a sensible range — values outside the grid are clamped via
``padding_mode="border"`` in :func:`torch.nn.functional.grid_sample`).
Args:
bounds: Half-extent of the spatial AABB along each axis (default
``1.6``, matching G-SHARP).
planes_config: Optional dict with keys ``grid_dimensions``,
``input_coordinate_dim``, ``output_coordinate_dim``,
``resolution``. The default is suitable for surgical-scene
4D fields (32 features per plane, ``[64, 64, 64, 25]``
resolution).
multires: Iterable of integer multipliers applied to the *spatial*
grid resolution for each scale; the temporal resolution is
kept fixed across scales. Default ``(1, 2)``.
"""
def __init__(
self,
bounds: float = 1.6,
planes_config: Optional[dict] = None,
multires: Optional[Sequence[int]] = None,
) -> None:
super().__init__()
config = (
dict(planes_config)
if planes_config is not None
else dict(_DEFAULT_PLANE_CONFIG)
)
multires_seq = (
list(multires) if multires is not None else list(_DEFAULT_MULTIRES)
)
self.bounds = float(bounds)
aabb = torch.tensor(
[[bounds, bounds, bounds], [-bounds, -bounds, -bounds]],
dtype=torch.float32,
)
self.aabb = nn.Parameter(aabb, requires_grad=False)
self.grid_config = config
self.multires = multires_seq
self.concat_features = True
self.grids = nn.ModuleList()
self.feat_dim = 0
for res in self.multires:
scale_config = dict(config)
base_reso = list(scale_config["resolution"])
# Multi-res only on spatial axes; time grid stays fixed across scales.
scale_config["resolution"] = [r * res for r in base_reso[:3]] + base_reso[
3:
]
gp = _init_grid_param(
grid_nd=scale_config["grid_dimensions"],
in_dim=scale_config["input_coordinate_dim"],
out_dim=scale_config["output_coordinate_dim"],
reso=scale_config["resolution"],
)
if self.concat_features:
self.feat_dim += gp[-1].shape[1]
else:
self.feat_dim = gp[-1].shape[1]
self.grids.append(gp)
[docs]
def forward(self, xyzt: Tensor) -> Tensor:
"""Sample the multi-resolution HexPlane at the given 4D points.
Args:
xyzt: ``(N, 4)`` (or any shape ending in 4) with ``[..., :3]``
being world-space spatial coordinates inside the AABB and
``[..., 3]`` being the temporal coordinate.
Returns:
``(N, feat_dim)`` feature tensor where ``feat_dim ==
output_coordinate_dim * len(multires)`` under
``concat_features=True``.
"""
if xyzt.shape[-1] != 4:
raise ValueError(
f"HexPlaneField.forward: xyzt last dim must be 4, "
f"got shape {tuple(xyzt.shape)}."
)
xyz = xyzt[..., :3]
t = xyzt[..., 3:]
xyz_norm = _normalize_aabb(xyz, self.aabb)
pts = torch.cat([xyz_norm, t], dim=-1).reshape(-1, 4)
return _interpolate_ms_features(
pts,
ms_grids=self.grids,
grid_dimensions=self.grid_config["grid_dimensions"],
concat_features=self.concat_features,
)
# ------------------------------------------------------------------
# Plane-partition accessors: keep the spatial / temporal grouping
# next to the grid construction code so regularizers don't have to
# hardcode the (0,1,3) / (2,4,5) indices.
# ------------------------------------------------------------------
_SPATIAL_PLANE_IDXS: Tuple[int, ...] = (0, 1, 3) # xy, xz, yz
_TEMPORAL_PLANE_IDXS: Tuple[int, ...] = (2, 4, 5) # xt, yt, zt
[docs]
def spatial_planes(self) -> list[Tensor]:
"""Return the flat list of spatial planes across all multi-res scales."""
out: list[Tensor] = []
for scale_grids in self.grids:
for i in self._SPATIAL_PLANE_IDXS:
out.append(scale_grids[i])
return out
[docs]
def temporal_planes(self) -> list[Tensor]:
"""Return the flat list of spatio-temporal planes across all scales."""
out: list[Tensor] = []
for scale_grids in self.grids:
for i in self._TEMPORAL_PLANE_IDXS:
out.append(scale_grids[i])
return out