# SPDX-FileCopyrightText: Copyright 2024-2025 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from enum import IntEnum
from abc import ABC
from typing import Any, Callable, Optional, Tuple
import torch
from torch import Tensor
from typing_extensions import Literal
from gsplat._helper import assert_shape
from gsplat.cuda._lidar import (
SpinningDirection,
LidarModelParameters,
RowOffsetStructuredSpinningLidarModelParameters,
RowOffsetStructuredSpinningLidarModelParametersExt as RowOffsetStructuredSpinningLidarModelParametersExtBase,
FOV as FOVBase,
)
ExternalDistortionModelMeta = Literal["bivariate-windshield"]
CameraModel = Literal["pinhole", "ortho", "fisheye", "ftheta", "lidar"]
def _make_lazy_cuda_func(name: str) -> Callable:
def call_cuda(*args, **kwargs):
# The following import statement is required to ensure that C++ module
# gsplat/csrc.so is loaded (and JIT-compiled if necessary). Upon module
# load, the gsplat PyTorch operators are imported into the
# torch.ops.gsplat submodule.
# pylint: disable=import-outside-toplevel
from ._backend import _C
return getattr(torch.ops.gsplat, name)(*args, **kwargs)
return call_cuda
def _make_lazy_cuda_cls(name: str) -> Any:
# The following import statement is required to ensure that C++ module
# gsplat/csrc.so is loaded (and JIT-compiled if necessary). Upon module
# load, the gsplat PyTorch custom classes are imported into the
# torch.classes.gsplat submodule.
# pylint: disable=import-outside-toplevel
from ._backend import _C
if _C is None:
return _unavailable_cuda_cls(name)
try:
return getattr(torch.classes.gsplat, name)
except RuntimeError as e:
# Class not registered (e.g. extension built without it or partial load).
if "does not exist" in str(e) or "torch::class_" in str(e):
return _unavailable_cuda_cls(name)
raise
def _unavailable_cuda_cls(name: str) -> Any:
"""Placeholder class when the CUDA extension is not available."""
class _UnavailableCudaCls:
__name__ = name
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError(
"gsplat CUDA extension is not available (not built or failed to load). "
f"Cannot instantiate '{name}'."
)
return _UnavailableCudaCls
def _make_lazy_cuda_obj(name: str) -> Any:
# pylint: disable=import-outside-toplevel
from ._backend import _C
if _C is None:
raise RuntimeError(
"gsplat CUDA extension is not available (not built or failed to load). "
f"Cannot access '{name}'."
)
obj = _C
for name_split in name.split("."):
obj = getattr(obj, name_split)
return obj
class RollingShutterType(IntEnum):
ROLLING_TOP_TO_BOTTOM = 0
ROLLING_LEFT_TO_RIGHT = 1
ROLLING_BOTTOM_TO_TOP = 2
ROLLING_RIGHT_TO_LEFT = 3
GLOBAL = 4
class FThetaPolynomialType(IntEnum):
PIXELDIST_TO_ANGLE = 0
ANGLE_TO_PIXELDIST = 1
UnscentedTransformParameters = _make_lazy_cuda_cls("UnscentedTransformParameters")
FThetaCameraDistortionParameters = _make_lazy_cuda_cls(
"FThetaCameraDistortionParameters"
)
class ExternalDistortionModelParameters(ABC):
"""Base class for external distortion model parameters.
All concrete external distortion models (e.g. BivariateWindshieldModelParameters)
should inherit from this class so that the rendering API can accept any
distortion model through a single type-erased parameter.
"""
class ExternalDistortionReferencePolynomial(IntEnum):
FORWARD = 1
BACKWARD = 2
class BivariateWindshieldModelParameters(ExternalDistortionModelParameters):
"""Thin wrapper around the CUDA BivariateWindshieldModelParameters class.
torch::Library bindings does not allow standalone constants. This
wrapper fetches MAX_ORDER and MAX_COEFFS from the C++ static getters
and exposes them as class-level attributes, preserving the existing
attribute-access calling convention.
"""
_cuda_cls = None
MAX_ORDER: int = 5 # default, overriden by C++ value
MAX_COEFFS: int = 21 # default, overriden by C++ value
@classmethod
def _ensure_cuda_cls(cls):
if cls._cuda_cls is None:
cls._cuda_cls = _make_lazy_cuda_cls("BivariateWindshieldModelParameters")
cls.MAX_ORDER = cls._cuda_cls.get_max_order()
cls.MAX_COEFFS = cls._cuda_cls.get_max_coeffs()
def __new__(cls):
cls._ensure_cuda_cls()
return cls._cuda_cls()
def has_camera_wrappers():
from ._backend import _C
# PyTorch will throw a RuntimeError if the class is not registered
# but that's okay in this case because we're just checking if it exists
try:
return hasattr(torch.classes.gsplat, "BaseCameraModel")
except RuntimeError:
return False
def has_2dgs():
from ._backend import _C
return hasattr(torch.ops.gsplat, "projection_2dgs_fused_fwd")
def has_3dgs():
from ._backend import _C
return hasattr(torch.ops.gsplat, "projection_ewa_simple_fwd")
def has_3dgut():
from ._backend import _C
return hasattr(torch.ops.gsplat, "projection_ut_3dgs_fused")
def has_adam():
from ._backend import _C
return hasattr(torch.ops.gsplat, "adam")
def has_reloc():
from ._backend import _C
return hasattr(torch.ops.gsplat, "relocation")
def create_camera_model(
camera_model: str,
width: Optional[int] = None,
height: Optional[int] = None,
principal_points: Optional[Tensor] = None,
focal_lengths: Optional[Tensor] = None,
radial_coeffs: Optional[Tensor] = None,
tangential_coeffs: Optional[Tensor] = None,
thin_prism_coeffs: Optional[Tensor] = None,
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
rs_type: RollingShutterType = RollingShutterType.GLOBAL,
lidar_coeffs: Optional["RowOffsetStructuredSpinningLidarModelParametersExt"] = None,
):
if camera_model == "lidar":
assert (
lidar_coeffs is not None
), "lidar_coeffs is required for lidar camera model"
RowOffsetStructuredSpinningLidarModelCUDA = _make_lazy_cuda_cls(
"RowOffsetStructuredSpinningLidarModel"
)
return RowOffsetStructuredSpinningLidarModelCUDA(lidar_coeffs.to_cpp())
else:
assert width is not None, "width is required for non-lidar camera models"
assert height is not None, "height is required for non-lidar camera models"
assert (
principal_points is not None
), "principal_points is required for non-lidar camera models"
BaseCameraModelCUDA = _make_lazy_cuda_cls("BaseCameraModel")
return BaseCameraModelCUDA.create(
width,
height,
camera_model,
principal_points,
focal_lengths,
radial_coeffs,
tangential_coeffs,
thin_prism_coeffs,
ftheta_coeffs,
external_distortion_coeffs,
rs_type,
)
class FOV(FOVBase):
@classmethod
def from_base(cls, base: FOVBase) -> "FOV":
return cls(start=base.start, span=base.span, direction=base.direction)
def to_cpp(self):
FOVCUDA = _make_lazy_cuda_cls("FOV")
return FOVCUDA(start=self.start, span=self.span)
class RowOffsetStructuredSpinningLidarModelParametersExt(
RowOffsetStructuredSpinningLidarModelParametersExtBase
):
"""Lidar camera parameters extended with acceleration structures"""
def to_cpp(self) -> Any:
"""Convert to C++ custom class instance."""
LidarParamsCUDA = _make_lazy_cuda_cls(
"RowOffsetStructuredSpinningLidarModelParametersExt"
)
return LidarParamsCUDA(
row_elevations_rad=self.row_elevations_rad.contiguous(),
column_azimuths_rad=self.column_azimuths_rad.contiguous(),
row_azimuth_offsets_rad=self.row_azimuth_offsets_rad.contiguous(),
spinning_direction=self.spinning_direction.value,
spinning_frequency_hz=self.spinning_frequency_hz,
fov_vert_rad=FOV.from_base(self.fov_vert_rad).to_cpp(),
fov_horiz_rad=FOV.from_base(self.fov_horiz_rad).to_cpp(),
fov_eps_rad=self.fov_eps_rad,
angles_to_columns_map=self.angles_to_columns_map,
n_bins_azimuth=self.tiling.n_bins_azimuth,
n_bins_elevation=self.tiling.n_bins_elevation,
cdf_elevation=self.tiling.cdf_elevation.contiguous(),
cdf_dense_ray_mask=self.tiling.cdf_dense_ray_mask.contiguous(),
tiles_to_elements_map=self.tiling.tiles_to_elements_map.contiguous(),
tiles_pack_info=self.tiling.tiles_pack_info.contiguous(),
)
[docs]
def world_to_cam(
means: Tensor, # [..., N, 3]
covars: Tensor, # [..., N, 3, 3]
viewmats: Tensor, # [..., C, 4, 4]
) -> Tuple[Tensor, Tensor]:
"""Transforms Gaussians from world to camera coordinate system.
Args:
means: Gaussian means. [..., N, 3]
covars: Gaussian covariances. [..., N, 3, 3]
viewmats: World-to-camera transformation matrices. [..., C, 4, 4]
Returns:
A tuple:
- **Gaussian means in camera coordinate system**. [..., C, N, 3]
- **Gaussian covariances in camera coordinate system**. [..., C, N, 3, 3]
"""
from ._torch_impl import _world_to_cam
warnings.warn(
"world_to_cam() is removed from the CUDA backend as it's relatively easy to "
"implement in PyTorch. Currently use the PyTorch implementation instead. "
"This function will be completely removed in a future release.",
DeprecationWarning,
)
batch_dims = means.shape[:-2]
N = means.shape[-2]
C = viewmats.shape[-3]
assert means.shape == batch_dims + (N, 3), means.shape
assert covars.shape == batch_dims + (N, 3, 3), covars.shape
assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape
means = means.contiguous()
covars = covars.contiguous()
viewmats = viewmats.contiguous()
return _world_to_cam(means, covars, viewmats)
def adam(
param: Tensor,
param_grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
valid: Tensor,
lr: float,
b1: float,
b2: float,
eps: float,
) -> None:
_make_lazy_cuda_func("adam")(
param, param_grad, exp_avg, exp_avg_sq, valid, lr, b1, b2, eps
)
[docs]
def spherical_harmonics(
degrees_to_use: int,
dirs: Tensor, # [..., 3]
coeffs: Tensor, # [..., K, 3]
masks: Optional[Tensor] = None, # [...,]
) -> Tensor:
"""Computes spherical harmonics.
Args:
degrees_to_use: The degree to be used.
dirs: Directions. [..., 3]
coeffs: Coefficients. [..., K, 3]
masks: Optional boolen masks to skip some computation. [...,] Default: None.
Returns:
Spherical harmonics. [..., 3]
"""
assert (degrees_to_use + 1) ** 2 <= coeffs.shape[-2], coeffs.shape
batch_dims = dirs.shape[:-1]
assert dirs.shape == batch_dims + (3,), dirs.shape
assert (
(len(coeffs.shape) == len(batch_dims) + 2)
and coeffs.shape[:-2] == batch_dims
and coeffs.shape[-1] == 3
), coeffs.shape
if masks is not None:
assert masks.shape == batch_dims, masks.shape
masks = masks.contiguous()
return _SphericalHarmonics.apply(
degrees_to_use, dirs.contiguous(), coeffs.contiguous(), masks
)
[docs]
def quat_scale_to_covar_preci(
quats: Tensor, # [..., 4],
scales: Tensor, # [..., 3],
compute_covar: bool = True,
compute_preci: bool = True,
triu: bool = False,
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
"""Converts quaternions and scales to covariance and precision matrices.
Args:
quats: Quaternions (No need to be normalized). [..., 4]
scales: Scales. [..., 3]
compute_covar: Whether to compute covariance matrices. Default: True. If False,
the returned covariance matrices will be None.
compute_preci: Whether to compute precision matrices. Default: True. If False,
the returned precision matrices will be None.
triu: If True, the return matrices will be upper triangular. Default: False.
Returns:
A tuple:
- **Covariance matrices**. If `triu` is True the returned shape is [..., 6], otherwise [..., 3, 3].
- **Precision matrices**. If `triu` is True the returned shape is [..., 6], otherwise [..., 3, 3].
"""
batch_dims = quats.shape[:-1]
assert quats.shape == batch_dims + (4,), quats.shape
assert scales.shape == batch_dims + (3,), scales.shape
quats = quats.contiguous()
scales = scales.contiguous()
covars, precis = _QuatScaleToCovarPreci.apply(
quats, scales, compute_covar, compute_preci, triu
)
return covars if compute_covar else None, precis if compute_preci else None
def persp_proj(
means: Tensor, # [..., C, N, 3]
covars: Tensor, # [..., C, N, 3, 3]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""Perspective projection on Gaussians.
DEPRECATED: please use `proj` with `ortho=False` instead.
Args:
means: Gaussian means. [..., C, N, 3]
covars: Gaussian covariances. [..., C, N, 3, 3]
Ks: Camera intrinsics. [..., C, 3, 3]
width: Image width.
height: Image height.
Returns:
A tuple:
- **Projected means**. [..., C, N, 2]
- **Projected covariances**. [..., C, N, 2, 2]
"""
warnings.warn(
"persp_proj is deprecated and will be removed in a future release. "
"Use proj with ortho=False instead.",
DeprecationWarning,
)
return proj(means, covars, Ks, width, height, ortho=False)
[docs]
def proj(
means: Tensor, # [..., C, N, 3]
covars: Tensor, # [..., C, N, 3, 3]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
camera_model: CameraModel = "pinhole",
) -> Tuple[Tensor, Tensor]:
"""Projection of Gaussians (perspective or orthographic).
Args:
means: Gaussian means. [..., C, N, 3]
covars: Gaussian covariances. [..., C, N, 3, 3]
Ks: Camera intrinsics. [..., C, 3, 3]
width: Image width.
height: Image height.
Returns:
A tuple:
- **Projected means**. [..., C, N, 2]
- **Projected covariances**. [..., C, N, 2, 2]
"""
assert (
camera_model != "ftheta"
), "ftheta camera is only supported via UT, please set with_ut=True in the rasterization()"
batch_dims = means.shape[:-3]
C, N = means.shape[-3:-1]
assert means.shape == batch_dims + (C, N, 3), means.shape
assert covars.shape == batch_dims + (C, N, 3, 3), covars.shape
assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape
means = means.contiguous()
covars = covars.contiguous()
Ks = Ks.contiguous()
return _Proj.apply(means, covars, Ks, width, height, camera_model)
[docs]
def fully_fused_projection(
means: Tensor, # [..., N, 3]
covars: Optional[Tensor], # [..., N, 6] or None
quats: Optional[Tensor], # [..., N, 4] or None
scales: Optional[Tensor], # [..., N, 3] or None
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
eps2d: float = 0.3,
near_plane: float = 0.01,
far_plane: float = 1e10,
radius_clip: float = 0.0,
packed: bool = False,
sparse_grad: bool = False,
calc_compensations: bool = False,
camera_model: CameraModel = "pinhole",
opacities: Optional[Tensor] = None, # [..., N] or None
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Projects Gaussians to 2D.
This function fuse the process of computing covariances
(:func:`quat_scale_to_covar_preci()`), transforming to camera space (:func:`world_to_cam()`),
and projection (:func:`proj()`).
.. note::
During projection, we ignore the Gaussians that are outside of the camera frustum.
So not all the elements in the output tensors are valid. The output `radii` could serve as
an indicator, in which zero radii means the corresponding elements are invalid in
the output tensors and will be ignored in the next rasterization process. If `packed=True`,
the output tensors will be packed into a flattened tensor, in which all elements are valid.
In this case, a `batch_ids` tensor and `camera_ids` tensor will be returned to indicate the
batch, camera and gaussian indices of the packed flattened tensor, which is essentially following the
COO sparse tensor format.
.. note::
This functions supports projecting Gaussians with either covariances or {quaternions, scales},
which will be converted to covariances internally in a fused CUDA kernel. Either `covars` or
{`quats`, `scales`} should be provided.
Args:
means: Gaussian means. [..., N, 3]
covars: Gaussian covariances (flattened upper triangle). [..., N, 6] Optional.
quats: Quaternions (No need to be normalized). [..., N, 4] Optional.
scales: Scales. [..., N, 3] Optional.
viewmats: World-to-camera matrices. [..., C, 4, 4]
Ks: Camera intrinsics. [..., C, 3, 3]
width: Image width.
height: Image height.
eps2d: A epsilon added to the 2D covariance for numerical stability. Default: 0.3.
near_plane: Near plane distance. Default: 0.01.
far_plane: Far plane distance. Default: 1e10.
radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0.
packed: If True, the output tensors will be packed into a flattened tensor. Default: False.
sparse_grad: This is only effective when `packed` is True. If True, during backward the gradients
of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False.
calc_compensations: If True, a view-dependent opacity compensation factor will be computed, which
is useful for anti-aliasing. Default: False.
opacities: Gaussian opacities in range [0, 1]. If provided, will use it to compute a tighter bounds.
[..., N] or None. Default: None.
Returns:
A tuple:
If `packed` is True:
- **batch_ids**. The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **camera_ids**. The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **indptr**. CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz, 2].
- **means**. Projected Gaussian means in 2D. [nnz, 2]
- **depths**. The z-depth of the projected Gaussians. [nnz]
- **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [nnz, 3]
- **compensations**. The view-dependent opacity compensation factor. [nnz]
If `packed` is False:
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [..., C, N, 2].
- **means**. Projected Gaussian means in 2D. [..., C, N, 2]
- **depths**. The z-depth of the projected Gaussians. [..., C, N]
- **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [..., C, N, 3]
- **compensations**. The view-dependent opacity compensation factor. [..., C, N]
"""
batch_dims = means.shape[:-2]
N = means.shape[-2]
C = viewmats.shape[-3]
assert means.shape == batch_dims + (N, 3), means.shape
assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape
assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape
means = means.contiguous()
if covars is not None:
assert covars.shape == batch_dims + (N, 6), covars.shape
covars = covars.contiguous()
else:
assert quats is not None, "covars or quats is required"
assert scales is not None, "covars or scales is required"
assert quats.shape == batch_dims + (N, 4), quats.shape
assert scales.shape == batch_dims + (N, 3), scales.shape
quats = quats.contiguous()
scales = scales.contiguous()
if sparse_grad:
assert packed, "sparse_grad is only supported when packed is True"
assert batch_dims == (), "sparse_grad does not support batch dimensions"
if opacities is not None:
assert opacities.shape == batch_dims + (N,), opacities.shape
opacities = opacities.contiguous()
assert (
camera_model != "ftheta"
), "ftheta camera is only supported via UT, please set with_ut=True in the rasterization()"
viewmats = viewmats.contiguous()
Ks = Ks.contiguous()
if packed:
return _FullyFusedProjectionPacked.apply(
means,
covars,
quats,
scales,
viewmats,
Ks,
width,
height,
eps2d,
near_plane,
far_plane,
radius_clip,
sparse_grad,
calc_compensations,
camera_model,
opacities,
)
else:
return _FullyFusedProjection.apply(
means,
covars,
quats,
scales,
viewmats,
Ks,
width,
height,
eps2d,
near_plane,
far_plane,
radius_clip,
calc_compensations,
camera_model,
opacities,
)
[docs]
@torch.no_grad()
def isect_tiles(
means2d: Tensor, # [..., N, 2] or [nnz, 2]
radii: Tensor, # [..., N, 2] or [nnz, 2]
depths: Tensor, # [..., N] or [nnz]
tile_size: int,
tile_width: int,
tile_height: int,
sort: bool = True,
segmented: bool = False,
packed: bool = False,
n_images: Optional[int] = None,
image_ids: Optional[Tensor] = None,
gaussian_ids: Optional[Tensor] = None,
conics: Optional[
Tensor
] = None, # [..., N, 3] or [nnz, 3], enables AccuTile when provided
opacities: Optional[
Tensor
] = None, # [..., N] or [nnz], enables AccuTile when provided
) -> Tuple[Tensor, Tensor, Tensor]:
"""Maps projected Gaussians to intersecting tiles.
When `conics` and `opacities` are provided the kernel uses conservative ellipse intersection (AccuTile/SNUGBOX),
skipping tiles that the opacity-thresholded ellipse does not touch. When either is `None` the kernel falls back to the original axis-aligned bounding box.
Args:
means2d: Projected Gaussian means. [..., N, 2] if packed is False, [nnz, 2] if packed is True.
radii: Maximum radii of the projected Gaussians. [..., N, 2] if packed is False, [nnz, 2] if packed is True.
depths: Z-depth of the projected Gaussians. [..., N] if packed is False, [nnz] if packed is True.
tile_size: Tile size.
tile_width: Tile width.
tile_height: Tile height.
sort: If True, the returned intersections will be sorted by the intersection ids. Default: True.
segmented: If True, segmented radix sort will be used to sort the intersections. Default: False.
packed: If True, the input tensors are packed. Default: False.
n_images: Number of images. Required if packed is True.
image_ids: The image indices of the projected Gaussians. Required if packed is True.
gaussian_ids: The column indices of the projected Gaussians. Required if packed is True.
conics: Inverse of projected covariances (upper triangle). [..., N, 3] if packed is False, [nnz, 3] if packed is True. Enables AccuTile when provided together with opacities.
opacities: Gaussian opacities. [..., N] if packed is False, [nnz] if packed is True. Enables AccuTile when provided together with conics.
Returns:
A tuple:
- **Tiles per Gaussian**. The number of tiles intersected by each Gaussian.
Int32 [..., N] if packed is False, Int32 [nnz] if packed is True.
- **Intersection ids**. Each id is an 64-bit integer with the following
information: image_id (Xc bits) | tile_id (Xt bits) | depth (32 bits).
Xc and Xt are the maximum number of bits required to represent the image and
tile ids, respectively. Int64 [n_isects]
- **Flatten ids**. The global flatten indices in [I * N] or [nnz] (packed). [n_isects]
"""
if packed:
nnz = means2d.size(0)
assert means2d.shape == (nnz, 2), means2d.shape
assert radii.shape == (nnz, 2), radii.shape
assert depths.shape == (nnz,), depths.shape
if conics is not None:
assert conics.shape == (nnz, 3), conics.shape
if opacities is not None:
assert opacities.shape == (nnz,), opacities.shape
assert image_ids is not None, "image_ids is required if packed is True"
assert gaussian_ids is not None, "gaussian_ids is required if packed is True"
assert n_images is not None, "n_images is required if packed is True"
image_ids = image_ids.contiguous()
gaussian_ids = gaussian_ids.contiguous()
I = n_images
else:
image_dims = means2d.shape[:-2]
I = math.prod(image_dims)
N = means2d.shape[-2]
assert means2d.shape == image_dims + (N, 2), means2d.shape
assert radii.shape == image_dims + (N, 2), radii.shape
assert depths.shape == image_dims + (N,), depths.shape
if conics is not None:
assert conics.shape == image_dims + (N, 3), conics.shape
if opacities is not None:
assert opacities.shape == image_dims + (N,), opacities.shape
tiles_per_gauss, isect_ids, flatten_ids = _make_lazy_cuda_func("intersect_tile")(
means2d.contiguous(),
radii.contiguous(),
depths.contiguous(),
conics.contiguous() if conics is not None else None,
opacities.contiguous() if opacities is not None else None,
image_ids,
gaussian_ids,
I,
tile_size,
tile_width,
tile_height,
sort,
segmented,
)
return tiles_per_gauss, isect_ids, flatten_ids
@torch.no_grad()
def isect_tiles_lidar(
lidar: RowOffsetStructuredSpinningLidarModelParametersExt,
means2d: Tensor, # [..., N, 2] or [nnz, 2]
radii: Tensor, # [..., N, 2] or [nnz, 2]
depths: Tensor, # [..., N] or [nnz]
sort: bool = True,
segmented: bool = False,
packed: bool = False,
n_images: Optional[int] = None,
image_ids: Optional[Tensor] = None,
gaussian_ids: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Maps projected Gaussians to intersecting tiles.
Args:
means2d: Projected Gaussian means. [..., N, 2] if packed is False, [nnz, 2] if packed is True.
radii: Maximum radii of the projected Gaussians. [..., N, 2] if packed is False, [nnz, 2] if packed is True.
depths: Z-depth of the projected Gaussians. [..., N] if packed is False, [nnz] if packed is True.
sort: If True, the returned intersections will be sorted by the intersection ids. Default: True.
segmented: If True, segmented radix sort will be used to sort the intersections. Default: False.
packed: If True, the input tensors are packed. Default: False.
n_images: Number of images. Required if packed is True.
image_ids: The image indices of the projected Gaussians. Required if packed is True.
gaussian_ids: The column indices of the projected Gaussians. Required if packed is True.
Returns:
A tuple:
- **Tiles per Gaussian**. The number of tiles intersected by each Gaussian.
Int32 [..., N] if packed is False, Int32 [nnz] if packed is True.
- **Intersection ids**. Each id is an 64-bit integer with the following
information: image_id (Xc bits) | tile_id (Xt bits) | depth (32 bits).
Xc and Xt are the maximum number of bits required to represent the image and
tile ids, respectively. Int64 [n_isects]
- **Flatten ids**. The global flatten indices in [I * N] or [nnz] (packed). [n_isects]
"""
if packed:
nnz = means2d.size(0)
assert means2d.shape == (nnz, 2), means2d.shape
assert radii.shape == (nnz, 2), radii.shape
assert depths.shape == (nnz,), depths.shape
assert image_ids is not None, "image_ids is required if packed is True"
assert gaussian_ids is not None, "gaussian_ids is required if packed is True"
assert n_images is not None, "n_images is required if packed is True"
image_ids = image_ids.contiguous()
gaussian_ids = gaussian_ids.contiguous()
I = n_images
else:
image_dims = means2d.shape[:-2]
I = math.prod(image_dims)
N = means2d.shape[-2]
assert means2d.shape == (*image_dims, N, 2), means2d.shape
assert radii.shape == (*image_dims, N, 2), radii.shape
assert depths.shape == (*image_dims, N), depths.shape
tiles_per_gauss, isect_ids, flatten_ids = _make_lazy_cuda_func(
"intersect_tile_lidar"
)(
lidar.to_cpp(),
means2d.contiguous(),
radii.contiguous(),
depths.contiguous(),
image_ids,
gaussian_ids,
I,
sort,
segmented,
)
return tiles_per_gauss, isect_ids, flatten_ids
[docs]
@torch.no_grad()
def isect_offset_encode(
isect_ids: Tensor,
n_images: int,
tile_width: int,
tile_height: int,
) -> Tensor:
"""Encodes intersection ids to offsets.
Args:
isect_ids: Intersection ids. [n_isects]
n_images: Number of images.
tile_width: Tile width.
tile_height: Tile height.
Returns:
Offsets. [I, tile_height, tile_width]
"""
return _make_lazy_cuda_func("intersect_offset")(
isect_ids.contiguous(), n_images, tile_width, tile_height
)
[docs]
def rasterize_to_pixels(
means2d: Tensor, # [..., N, 2] or [nnz, 2]
conics: Tensor, # [..., N, 3] or [nnz, 3]
colors: Tensor, # [..., N, channels] or [nnz, channels]
opacities: Tensor, # [..., N] or [nnz]
image_width: int,
image_height: int,
tile_size: int,
isect_offsets: Tensor, # [..., tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
backgrounds: Optional[Tensor] = None, # [..., channels]
masks: Optional[Tensor] = None, # [..., tile_height, tile_width]
packed: bool = False,
absgrad: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Rasterizes Gaussians to pixels.
Args:
means2d: Projected Gaussian means. [..., N, 2] if packed is False, [nnz, 2] if packed is True.
conics: Inverse of the projected covariances with only upper triangle values. [..., N, 3] if packed is False, [nnz, 3] if packed is True.
colors: Gaussian colors or ND features. [..., N, channels] if packed is False, [nnz, channels] if packed is True.
opacities: Gaussian opacities that support per-view values. [..., N] if packed is False, [nnz] if packed is True.
image_width: Image width.
image_height: Image height.
tile_size: Tile size.
isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [..., tile_height, tile_width]
flatten_ids: The global flatten indices in [I * N] or [nnz] from `isect_tiles()`. [n_isects]
backgrounds: Background colors. [..., channels]. Default: None.
masks: Optional tile mask to skip rendering GS to masked tiles. [..., tile_height, tile_width]. Default: None.
packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False.
absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False.
Returns:
A tuple:
- **Rendered colors**. [..., image_height, image_width, channels]
- **Rendered alphas**. [..., image_height, image_width, 1]
"""
image_dims = means2d.shape[:-2]
channels = colors.shape[-1]
device = means2d.device
if packed:
nnz = means2d.size(0)
assert means2d.shape == (nnz, 2), means2d.shape
assert conics.shape == (nnz, 3), conics.shape
assert colors.shape[0] == nnz, colors.shape
assert opacities.shape == (nnz,), opacities.shape
else:
N = means2d.size(-2)
assert means2d.shape == image_dims + (N, 2), means2d.shape
assert conics.shape == image_dims + (N, 3), conics.shape
assert colors.shape == image_dims + (N, channels), colors.shape
assert opacities.shape == image_dims + (N,), opacities.shape
if backgrounds is not None:
assert backgrounds.shape == image_dims + (channels,), backgrounds.shape
backgrounds = backgrounds.contiguous()
if masks is not None:
assert masks.shape == isect_offsets.shape, masks.shape
masks = masks.contiguous()
# Pad the channels to the nearest supported number if necessary
if channels > 513 or channels == 0:
# TODO: maybe worth to support zero channels?
raise ValueError(f"Unsupported number of color channels: {channels}")
if channels not in (
1,
2,
3,
4,
5,
8,
9,
16,
17,
24,
32,
33,
64,
65,
128,
129,
256,
257,
512,
513,
):
padded_channels = (1 << (channels - 1).bit_length()) - channels
colors = torch.cat(
[
colors,
torch.zeros(*colors.shape[:-1], padded_channels, device=device),
],
dim=-1,
)
if backgrounds is not None:
backgrounds = torch.cat(
[
backgrounds,
torch.zeros(
*backgrounds.shape[:-1], padded_channels, device=device
),
],
dim=-1,
)
else:
padded_channels = 0
tile_height, tile_width = isect_offsets.shape[-2:]
assert (
tile_height * tile_size >= image_height
), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}"
assert (
tile_width * tile_size >= image_width
), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"
render_colors, render_alphas = _RasterizeToPixels.apply(
means2d.contiguous(),
conics.contiguous(),
colors.contiguous(),
opacities.contiguous(),
backgrounds,
masks,
image_width,
image_height,
tile_size,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
absgrad,
)
if padded_channels > 0:
render_colors = render_colors[..., :-padded_channels]
return render_colors, render_alphas
def rasterize_to_pixels_eval3d(
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
scales: Tensor, # [..., N, 3]
colors: Tensor, # [..., C, N, channels] or [nnz, channels]
opacities: Tensor, # [..., C, N] or [nnz]
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
image_width: int,
image_height: int,
tile_size: int,
isect_offsets: Tensor, # [..., C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
backgrounds: Optional[Tensor] = None, # [..., C, channels]
masks: Optional[Tensor] = None, # [..., C, tile_height, tile_width]
camera_model: CameraModel = "pinhole",
ut_params: Optional[UnscentedTransformParameters] = None,
rays: Optional[Tensor] = None, # [..., C, H, W, 6]
# distortion
radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4]
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
lidar_coeffs: Optional[RowOffsetStructuredSpinningLidarModelParametersExt] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
use_hit_distance: bool = False,
return_normals: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Rasterizes Gaussians to pixels.
Similar to `rasterize_to_pixels()`, but compute the Gaussian responses in the
3D world space instead of the 2D image space. Supports rolling shutter and
camera distortion.
Returns:
A tuple:
- **Rendered colors**. [..., C, image_height, image_width, channels]
- **Rendered alphas**. [..., C, image_height, image_width, 1]
"""
if ut_params is None:
ut_params = UnscentedTransformParameters()
colors, alphas, *_ = rasterize_to_pixels_eval3d_extra(
means=means,
quats=quats,
scales=scales,
colors=colors,
opacities=opacities,
viewmats=viewmats,
Ks=Ks,
rays=rays,
image_width=image_width,
image_height=image_height,
tile_size=tile_size,
isect_offsets=isect_offsets,
flatten_ids=flatten_ids,
backgrounds=backgrounds,
masks=masks,
camera_model=camera_model,
ut_params=ut_params,
radial_coeffs=radial_coeffs,
tangential_coeffs=tangential_coeffs,
thin_prism_coeffs=thin_prism_coeffs,
ftheta_coeffs=ftheta_coeffs,
lidar_coeffs=lidar_coeffs,
external_distortion_coeffs=external_distortion_coeffs,
rolling_shutter=rolling_shutter,
viewmats_rs=viewmats_rs,
return_sample_counts=False,
use_hit_distance=use_hit_distance,
return_normals=return_normals,
)
return colors, alphas
def rasterize_to_pixels_eval3d_extra(
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
scales: Tensor, # [..., N, 3]
colors: Tensor, # [..., C, N, channels] or [nnz, channels]
opacities: Tensor, # [..., C, N] or [nnz]
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
image_width: int,
image_height: int,
tile_size: int,
isect_offsets: Tensor, # [..., C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
backgrounds: Optional[Tensor] = None, # [..., C, channels]
masks: Optional[Tensor] = None, # [..., C, tile_height, tile_width]
camera_model: CameraModel = "pinhole",
ut_params: Optional[UnscentedTransformParameters] = None,
rays: Optional[Tensor] = None, # [..., C, P, 6]
# distortion
radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4]
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
lidar_coeffs: Optional[RowOffsetStructuredSpinningLidarModelParametersExt] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
return_sample_counts: bool = False,
use_hit_distance: bool = False,
return_normals: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
"""Rasterizes Gaussians to pixels, returning extra information for debugging.
Similar to `rasterize_to_pixels_eval3d()`, but returns turns the last gaussian id
accumulated in a pixel, and optionally the number of accumulated samples per pixel.
Args:
return_last_ids: If True, also return last flatten_idx per pixel. Default: False.
return_sample_counts: If True, also return number of accumulated samples per pixel. Default: False.
return_normals: If True, compute and return accumulated normals per pixel.
Normals are computed from Gaussian quaternions (canonical normal = (0,0,1)
transformed by rotation, flipped if facing away from ray). Default: False.
Returns:
A tuple (contents depend on return flags):
- **Rendered colors**. [..., C, image_height, image_width, channels]
- **Rendered alphas**. [..., C, image_height, image_width, 1]
- **Last flatten_idx**. [..., C, image_height, image_width]
- **Sample counts** (optional). [..., C, image_height, image_width]. If return_sample_counts=True.
- **Rendered normals** (optional). [..., C, image_height, image_width, 3]. If return_normals=True.
"""
if ut_params is None:
ut_params = UnscentedTransformParameters()
batch_dims = means.shape[:-2]
num_batch_dims = len(batch_dims)
N = means.size(-2)
C = viewmats.size(-3)
P = rays.shape[-2] if rays is not None else 0
channels = colors.shape[-1]
device = means.device
assert means.shape == batch_dims + (N, 3), means.shape
assert quats.shape == batch_dims + (N, 4), quats.shape
assert scales.shape == batch_dims + (N, 3), scales.shape
assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape
assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape
if rays is not None:
assert_shape("rays", rays, batch_dims + (C, P, 6))
assert (
rays.dtype == torch.float32
), f"rays must be torch.float32, got {rays.dtype}"
assert colors.ndim in (num_batch_dims + 2, num_batch_dims + 3), colors.shape
if colors.ndim == num_batch_dims + 2:
raise NotImplementedError("packed mode is not supported yet")
assert (
colors.shape[:-2] == batch_dims and colors.shape[-1] == channels
), colors.shape
else:
assert colors.shape == batch_dims + (C, N, channels), colors.shape
assert opacities.shape == colors.shape[:-1], opacities.shape
if backgrounds is not None:
assert backgrounds.shape == batch_dims + (C, channels), backgrounds.shape
backgrounds = backgrounds.contiguous()
if masks is not None:
assert masks.shape == isect_offsets.shape, masks.shape
masks = masks.contiguous()
if radial_coeffs is not None:
assert radial_coeffs.shape[:-1] == batch_dims + (C,) and radial_coeffs.shape[
-1
] in (6, 4), radial_coeffs.shape
radial_coeffs = radial_coeffs.contiguous()
if tangential_coeffs is not None:
assert tangential_coeffs.shape == batch_dims + (C, 2), tangential_coeffs.shape
tangential_coeffs = tangential_coeffs.contiguous()
if thin_prism_coeffs is not None:
assert thin_prism_coeffs.shape == batch_dims + (C, 4), thin_prism_coeffs.shape
thin_prism_coeffs = thin_prism_coeffs.contiguous()
if viewmats_rs is not None:
assert viewmats_rs.shape == batch_dims + (C, 4, 4), viewmats_rs.shape
viewmats_rs = viewmats_rs.contiguous()
# Pad the channels to the nearest supported number if necessary
channels = colors.shape[-1]
if channels > 513 or channels == 0:
# TODO: maybe worth to support zero channels?
raise ValueError(f"Unsupported number of color channels: {channels}")
if channels not in (
1,
2,
3,
4,
5,
8,
9,
16,
17,
24,
32,
33,
64,
65,
128,
129,
256,
257,
512,
513,
):
padded_channels = (1 << (channels - 1).bit_length()) - channels
# Insert padding before the last channel so that it stays at
# CDIM-1. When depth is present it is always the last channel,
# so this keeps it where the CUDA kernel writes hit_distance.
# When depth is absent the last channel is preserved
# through the round-trip.
# This matches the approach used in rasterize_to_pixels_2dgs.
colors = torch.cat(
[
colors[..., :-1],
torch.zeros(*colors.shape[:-1], padded_channels, device=device),
colors[..., -1:],
],
dim=-1,
)
if backgrounds is not None:
backgrounds = torch.cat(
[
backgrounds,
torch.zeros(
*backgrounds.shape[:-1], padded_channels, device=device
),
],
dim=-1,
)
else:
padded_channels = 0
tile_height, tile_width = isect_offsets.shape[-2:]
if camera_model == "lidar":
assert tile_width == lidar_coeffs.tiling.n_bins_azimuth
assert tile_height == lidar_coeffs.tiling.n_bins_elevation
# TODO: improve checks. Right now we don't have access to max_pts_per_tile used,
# hence this assert needs to be commented out.
# assert tile_width*tile_height*lidar_coeffs.tiling.max_pts_per_tile >= lidar_coeffs.n_rows*lidar_coeffs.n_columns
else:
assert (
tile_height * tile_size >= image_height
), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}"
assert (
tile_width * tile_size >= image_width
), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"
(
render_colors,
render_alphas,
last_ids,
sample_counts,
render_normals,
) = _RasterizeToPixelsEval3D.apply(
means.contiguous(),
quats.contiguous(),
scales.contiguous(),
colors.contiguous(),
opacities.contiguous(),
backgrounds.contiguous() if backgrounds is not None else None,
masks.contiguous() if masks is not None else None,
viewmats.contiguous(),
Ks.contiguous(),
image_width,
image_height,
tile_size,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
camera_model,
ut_params,
rays.contiguous() if rays is not None else None,
# distortion
radial_coeffs.contiguous() if radial_coeffs is not None else None,
tangential_coeffs.contiguous() if tangential_coeffs is not None else None,
thin_prism_coeffs.contiguous() if thin_prism_coeffs is not None else None,
ftheta_coeffs,
lidar_coeffs,
external_distortion_coeffs,
# rolling shutter
rolling_shutter,
viewmats_rs.contiguous() if viewmats_rs is not None else None,
# Forward is always collecting the last_ids for the backward pass,
# no need to tell it to do it.
return_sample_counts, # Pass flag to forward
use_hit_distance,
return_normals, # Pass return_normals flag to forward
)
if padded_channels > 0:
render_colors = torch.cat(
[render_colors[..., : -padded_channels - 1], render_colors[..., -1:]],
dim=-1,
)
return render_colors, render_alphas, last_ids, sample_counts, render_normals
[docs]
@torch.no_grad()
def rasterize_to_indices_in_range(
range_start: int,
range_end: int,
transmittances: Tensor, # [..., image_height, image_width]
means2d: Tensor, # [..., N, 2]
conics: Tensor, # [..., N, 3]
opacities: Tensor, # [..., N]
image_width: int,
image_height: int,
tile_size: int,
isect_offsets: Tensor, # [..., tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
) -> Tuple[Tensor, Tensor, Tensor]:
"""Rasterizes a batch of Gaussians to images but only returns the indices.
.. note::
This function supports iterative rasterization, in which each call of this function
will rasterize a batch of Gaussians from near to far, defined by `[range_start, range_end)`.
If a one-step full rasterization is desired, set `range_start` to 0 and `range_end` to a really
large number, e.g, 1e10.
Args:
range_start: The start batch of Gaussians to be rasterized (inclusive).
range_end: The end batch of Gaussians to be rasterized (exclusive).
transmittances: Currently transmittances. [..., image_height, image_width]
means2d: Projected Gaussian means. [..., N, 2]
conics: Inverse of the projected covariances with only upper triangle values. [..., N, 3]
opacities: Gaussian opacities that support per-view values. [..., N]
image_width: Image width.
image_height: Image height.
tile_size: Tile size.
isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [..., tile_height, tile_width]
flatten_ids: The global flatten indices in [I * N] from `isect_tiles()`. [n_isects]
Returns:
A tuple:
- **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M].
- **Pixel ids**. pixel indices (row-major). A flattened list of shape [M].
- **Image ids**. image indices. A flattened list of shape [M].
"""
image_dims = means2d.shape[:-2]
tile_height, tile_width = isect_offsets.shape[-2:]
N = means2d.shape[-2]
assert transmittances.shape == image_dims + (
image_height,
image_width,
), transmittances.shape
assert means2d.shape == image_dims + (N, 2), means2d.shape
assert conics.shape == image_dims + (N, 3), conics.shape
assert opacities.shape == image_dims + (N,), opacities.shape
assert isect_offsets.shape == image_dims + (
tile_height,
tile_width,
), isect_offsets.shape
assert (
tile_height * tile_size >= image_height
), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}"
assert (
tile_width * tile_size >= image_width
), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"
out_gauss_ids, out_indices = _make_lazy_cuda_func("rasterize_to_indices_3dgs")(
range_start,
range_end,
transmittances.contiguous(),
means2d.contiguous(),
conics.contiguous(),
opacities.contiguous(),
image_width,
image_height,
tile_size,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
)
out_pixel_ids = out_indices % (image_width * image_height)
out_image_ids = out_indices // (image_width * image_height)
return out_gauss_ids, out_pixel_ids, out_image_ids
class _QuatScaleToCovarPreci(torch.autograd.Function):
"""Converts quaternions and scales to covariance and precision matrices."""
@staticmethod
def forward(
ctx,
quats: Tensor, # [..., 4],
scales: Tensor, # [..., 3],
compute_covar: bool = True,
compute_preci: bool = True,
triu: bool = False,
) -> Tuple[Tensor, Tensor]:
covars, precis = _make_lazy_cuda_func("quat_scale_to_covar_preci_fwd")(
quats, scales, compute_covar, compute_preci, triu
)
ctx.save_for_backward(quats, scales)
ctx.compute_covar = compute_covar
ctx.compute_preci = compute_preci
ctx.triu = triu
return covars, precis
@staticmethod
def backward(ctx, v_covars: Tensor, v_precis: Tensor):
quats, scales = ctx.saved_tensors
compute_covar = ctx.compute_covar
compute_preci = ctx.compute_preci
triu = ctx.triu
if compute_covar and v_covars.is_sparse:
v_covars = v_covars.to_dense()
if compute_preci and v_precis.is_sparse:
v_precis = v_precis.to_dense()
v_quats, v_scales = _make_lazy_cuda_func("quat_scale_to_covar_preci_bwd")(
quats,
scales,
triu,
v_covars.contiguous() if compute_covar else None,
v_precis.contiguous() if compute_preci else None,
)
return (
v_quats,
v_scales,
None, # compute_covar
None, # compute_preci
None, # triu
)
class _Proj(torch.autograd.Function):
"""Perspective fully_fused_projection on Gaussians."""
@staticmethod
def forward(
ctx,
means: Tensor, # [..., C, N, 3]
covars: Tensor, # [..., C, N, 3, 3]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
camera_model: CameraModel = "pinhole",
) -> Tuple[Tensor, Tensor]:
assert (
camera_model != "ftheta"
), "ftheta camera is only supported via UT, please set with_ut=True in the rasterization()"
camera_model_type = _make_lazy_cuda_obj(
f"CameraModelType.{camera_model.upper()}"
)
means2d, covars2d = _make_lazy_cuda_func("projection_ewa_simple_fwd")(
means,
covars,
Ks,
width,
height,
camera_model_type,
)
ctx.save_for_backward(means, covars, Ks)
ctx.width = width
ctx.height = height
ctx.camera_model_type = camera_model_type
return means2d, covars2d
@staticmethod
def backward(ctx, v_means2d: Tensor, v_covars2d: Tensor):
means, covars, Ks = ctx.saved_tensors
width = ctx.width
height = ctx.height
camera_model_type = ctx.camera_model_type
v_means, v_covars = _make_lazy_cuda_func("projection_ewa_simple_bwd")(
means,
covars,
Ks,
width,
height,
camera_model_type,
v_means2d.contiguous(),
v_covars2d.contiguous(),
)
return (
v_means,
v_covars,
None, # Ks
None, # width
None, # height
None, # camera_model
)
class _FullyFusedProjection(torch.autograd.Function):
"""Projects Gaussians to 2D."""
@staticmethod
def forward(
ctx,
means: Tensor, # [..., N, 3]
covars: Tensor, # [..., N, 6] or None
quats: Tensor, # [..., N, 4] or None
scales: Tensor, # [..., N, 3] or None
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
eps2d: float,
near_plane: float,
far_plane: float,
radius_clip: float,
calc_compensations: bool,
camera_model: CameraModel = "pinhole",
opacities: Optional[Tensor] = None, # [..., N] or None
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
assert (
camera_model != "ftheta"
), "ftheta camera is only supported via UT, please set with_ut=True in the rasterization()"
camera_model_type = _make_lazy_cuda_obj(
f"CameraModelType.{camera_model.upper()}"
)
# "covars" and {"quats", "scales"} are mutually exclusive
radii, means2d, depths, conics, compensations = _make_lazy_cuda_func(
"projection_ewa_3dgs_fused_fwd"
)(
means,
covars,
quats,
scales,
opacities,
viewmats,
Ks,
width,
height,
eps2d,
near_plane,
far_plane,
radius_clip,
calc_compensations,
camera_model_type,
)
if not calc_compensations:
compensations = None
ctx.save_for_backward(
means, covars, quats, scales, viewmats, Ks, radii, conics, compensations
)
ctx.width = width
ctx.height = height
ctx.eps2d = eps2d
ctx.camera_model_type = camera_model_type
return radii, means2d, depths, conics, compensations
@staticmethod
def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
(
means,
covars,
quats,
scales,
viewmats,
Ks,
radii,
conics,
compensations,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
eps2d = ctx.eps2d
camera_model_type = ctx.camera_model_type
if v_compensations is not None:
v_compensations = v_compensations.contiguous()
v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
"projection_ewa_3dgs_fused_bwd"
)(
means,
covars,
quats,
scales,
viewmats,
Ks,
width,
height,
eps2d,
camera_model_type,
radii,
conics,
compensations,
v_means2d.contiguous(),
v_depths.contiguous(),
v_conics.contiguous(),
v_compensations,
ctx.needs_input_grad[4], # viewmats_requires_grad
)
if not ctx.needs_input_grad[0]:
v_means = None
if not ctx.needs_input_grad[1]:
v_covars = None
if not ctx.needs_input_grad[2]:
v_quats = None
if not ctx.needs_input_grad[3]:
v_scales = None
if not ctx.needs_input_grad[4]:
v_viewmats = None
return (
v_means,
v_covars,
v_quats,
v_scales,
v_viewmats,
None, # Ks
None, # width
None, # height
None, # eps2d
None, # near_plane
None, # far_plane
None, # radius_clip
None, # calc_compensations
None, # camera_model
None, # ut_params
None, # radial_coeffs
)
def fully_fused_projection_with_ut(
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
scales: Tensor, # [..., N, 3]
opacities: Optional[Tensor], # [..., N]
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
eps2d: float = 0.3,
near_plane: float = 0.01,
far_plane: float = 1e10,
radius_clip: float = 0.0,
calc_compensations: bool = False,
camera_model: CameraModel = "pinhole",
ut_params: Optional[UnscentedTransformParameters] = None,
# distortion
radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4]
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
lidar_coeffs: Optional[RowOffsetStructuredSpinningLidarModelParametersExt] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
global_z_order: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Projects Gaussians to 2D using Unscented Transform (UT).
similar to `fully_fused_projection()`, but supports camera distortion and
rolling shutter.
.. warning::
This function is not differentiable to any input.
Args:
global_z_order: Defines how Gaussians are sorted for depth ordering. If True (default),
Gaussians are sorted by their z-coordinate in camera space. If False, they are sorted
by their Euclidean distance from the camera origin. The z-coordinate sorting is typically
faster and sufficient for most cases, while Euclidean distance can be useful for scenes
with wide field-of-view or non-standard camera models. Default: True.
"""
if ut_params is None:
ut_params = UnscentedTransformParameters()
batch_dims = means.shape[:-2]
N = means.shape[-2]
C = viewmats.shape[-3]
assert means.shape == batch_dims + (N, 3), means.shape
assert quats.shape == batch_dims + (N, 4), quats.shape
assert scales.shape == batch_dims + (N, 3), scales.shape
if opacities is not None:
assert opacities.shape == batch_dims + (N,), opacities.shape
assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape
assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape
if radial_coeffs is not None:
assert radial_coeffs.shape[:-1] == batch_dims + (C,) and radial_coeffs.shape[
-1
] in [6, 4], radial_coeffs.shape
if tangential_coeffs is not None:
assert tangential_coeffs.shape == batch_dims + (C, 2), tangential_coeffs.shape
if thin_prism_coeffs is not None:
assert thin_prism_coeffs.shape == batch_dims + (C, 4), thin_prism_coeffs.shape
if viewmats_rs is not None:
assert viewmats_rs.shape == batch_dims + (C, 4, 4), viewmats_rs.shape
if lidar_coeffs is not None:
assert isinstance(
lidar_coeffs, RowOffsetStructuredSpinningLidarModelParametersExt
)
camera_model_type = _make_lazy_cuda_obj(f"CameraModelType.{camera_model.upper()}")
ftheta_coeffs = (
ftheta_coeffs
if ftheta_coeffs is not None
else FThetaCameraDistortionParameters()
)
radii, means2d, depths, conics, compensations = _make_lazy_cuda_func(
"projection_ut_3dgs_fused"
)(
means.contiguous(),
quats.contiguous(),
scales.contiguous(),
opacities.contiguous() if opacities is not None else None,
viewmats.contiguous(),
viewmats_rs.contiguous() if viewmats_rs is not None else None,
Ks.contiguous(),
width,
height,
eps2d,
near_plane,
far_plane,
radius_clip,
calc_compensations,
camera_model_type,
global_z_order,
ut_params,
rolling_shutter,
radial_coeffs.contiguous() if radial_coeffs is not None else None,
tangential_coeffs.contiguous() if tangential_coeffs is not None else None,
thin_prism_coeffs.contiguous() if thin_prism_coeffs is not None else None,
ftheta_coeffs,
lidar_coeffs.to_cpp() if lidar_coeffs is not None else None,
external_distortion_coeffs,
)
if not calc_compensations:
compensations = None
return radii, means2d, depths, conics, compensations
class _RasterizeToPixels(torch.autograd.Function):
"""Rasterize gaussians"""
@staticmethod
def forward(
ctx,
means2d: Tensor, # [..., N, 2] or [nnz, 2]
conics: Tensor, # [..., N, 3] or [nnz, 3]
colors: Tensor, # [..., N, channels] or [nnz, channels]
opacities: Tensor, # [..., N] or [nnz]
backgrounds: Tensor, # [..., channels], Optional
masks: Tensor, # [..., tile_height, tile_width], Optional
width: int,
height: int,
tile_size: int,
isect_offsets: Tensor, # [..., tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
absgrad: bool,
) -> Tuple[Tensor, Tensor]:
render_colors, render_alphas, last_ids = _make_lazy_cuda_func(
"rasterize_to_pixels_3dgs_fwd"
)(
means2d,
conics,
colors,
opacities,
backgrounds,
masks,
width,
height,
tile_size,
isect_offsets,
flatten_ids,
)
ctx.save_for_backward(
means2d,
conics,
colors,
opacities,
backgrounds,
masks,
isect_offsets,
flatten_ids,
render_alphas,
last_ids,
)
ctx.width = width
ctx.height = height
ctx.tile_size = tile_size
ctx.absgrad = absgrad
# double to float
render_alphas = render_alphas.float()
return render_colors, render_alphas
@staticmethod
def backward(
ctx,
v_render_colors: Tensor, # [..., H, W, 3]
v_render_alphas: Tensor, # [..., H, W, 1]
):
(
means2d,
conics,
colors,
opacities,
backgrounds,
masks,
isect_offsets,
flatten_ids,
render_alphas,
last_ids,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
tile_size = ctx.tile_size
absgrad = ctx.absgrad
(
v_means2d_abs,
v_means2d,
v_conics,
v_colors,
v_opacities,
) = _make_lazy_cuda_func("rasterize_to_pixels_3dgs_bwd")(
means2d,
conics,
colors,
opacities,
backgrounds,
masks,
width,
height,
tile_size,
isect_offsets,
flatten_ids,
render_alphas,
last_ids,
v_render_colors.contiguous(),
v_render_alphas.contiguous(),
absgrad,
)
if absgrad:
means2d.absgrad = v_means2d_abs
if ctx.needs_input_grad[4]:
v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum(
dim=(-3, -2)
)
else:
v_backgrounds = None
return (
v_means2d,
v_conics,
v_colors,
v_opacities,
v_backgrounds,
None, # masks
None, # width
None, # height
None, # tile_size
None, # isect_offsets
None, # flatten_ids
None, # absgrad
)
class _RasterizeToPixelsEval3D(torch.autograd.Function):
"""Rasterize gaussians"""
@staticmethod
def forward(
ctx,
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
scales: Tensor, # [..., N, 3]
colors: Tensor, # [..., C, N, D] or [nnz, D]
opacities: Tensor, # [..., C, N] or [nnz]
backgrounds: Tensor, # [..., C, D], Optional
masks: Tensor, # [..., C, tile_height, tile_width], Optional
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
tile_size: int,
isect_offsets: Tensor, # [..., C, tile_height, tile_width]
flatten_ids: Tensor, # [..., n_isects]
camera_model: CameraModel = "pinhole",
ut_params: Optional[UnscentedTransformParameters] = None,
rays: Optional[Tensor] = None, # [..., C, P, 6]
# distortion
radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4]
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
lidar_coeffs: Optional[
RowOffsetStructuredSpinningLidarModelParametersExt
] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
return_sample_counts: bool = False,
use_hit_distance: bool = False,
return_normals: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
if ut_params is None:
ut_params = UnscentedTransformParameters()
camera_model_type = _make_lazy_cuda_obj(
f"CameraModelType.{camera_model.upper()}"
)
ftheta_coeffs = (
ftheta_coeffs
if ftheta_coeffs is not None
else FThetaCameraDistortionParameters()
)
lidar_coeffs = lidar_coeffs.to_cpp() if lidar_coeffs is not None else None
# Extract batch_dims for sample_counts allocation
batch_dims = means.shape[:-2]
C = viewmats.size(-3)
# Conditionally allocate sample_counts based on flag
if return_sample_counts:
# Allocate with correct final shape (batch_dims, C, H, W)
sample_counts = torch.empty(
batch_dims + (C, height, width), dtype=torch.int32, device=means.device
)
else:
sample_counts = None
# Conditionally allocate normals based on flag
if return_normals:
render_normals = torch.empty(
batch_dims + (C, height, width, 3),
dtype=torch.float32,
device=means.device,
)
else:
render_normals = None
render_colors, render_alphas, last_ids = _make_lazy_cuda_func(
"rasterize_to_pixels_from_world_3dgs_fwd"
)(
means,
quats,
scales,
colors,
opacities,
backgrounds,
masks,
width,
height,
tile_size,
viewmats,
viewmats_rs,
Ks,
camera_model_type,
ut_params,
rolling_shutter,
rays,
radial_coeffs,
tangential_coeffs,
thin_prism_coeffs,
ftheta_coeffs,
lidar_coeffs,
external_distortion_coeffs,
isect_offsets,
flatten_ids,
use_hit_distance,
sample_counts,
render_normals,
)
ctx.save_for_backward(
means,
quats,
scales,
colors,
opacities,
backgrounds,
masks,
viewmats,
viewmats_rs,
Ks,
rays,
radial_coeffs,
tangential_coeffs,
thin_prism_coeffs,
isect_offsets,
flatten_ids,
render_alphas,
last_ids,
)
ctx.width = width
ctx.height = height
ctx.ut_params = ut_params
ctx.rs_type = rolling_shutter
ctx.camera_model_type = camera_model_type
ctx.tile_size = tile_size
ctx.ftheta_coeffs = ftheta_coeffs
ctx.lidar_coeffs = lidar_coeffs
ctx.external_distortion_coeffs = external_distortion_coeffs
ctx.use_hit_distance = use_hit_distance
return render_colors, render_alphas, last_ids, sample_counts, render_normals
@staticmethod
def backward(
ctx,
v_render_colors: Tensor, # [..., C, H, W, 3]
v_render_alphas: Tensor, # [..., C, H, W, 1]
v_last_ids: Optional[Tensor], # None - last_ids is integer (non-differentiable)
v_sample_counts: Optional[
Tensor
], # None - sample_counts is integer (non-differentiable)
v_render_normals: Optional[Tensor], # [..., C, H, W, 3]
):
(
means,
quats,
scales,
colors,
opacities,
backgrounds,
masks,
viewmats,
viewmats_rs,
Ks,
rays,
radial_coeffs,
tangential_coeffs,
thin_prism_coeffs,
isect_offsets,
flatten_ids,
render_alphas,
last_ids,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
ut_params = ctx.ut_params
rs_type = ctx.rs_type
camera_model_type = ctx.camera_model_type
tile_size = ctx.tile_size
ftheta_coeffs = ctx.ftheta_coeffs
lidar_coeffs = ctx.lidar_coeffs
external_distortion_coeffs = ctx.external_distortion_coeffs
use_hit_distance = ctx.use_hit_distance
(
v_means,
v_quats,
v_scales,
v_colors,
v_opacities,
v_rays,
) = _make_lazy_cuda_func("rasterize_to_pixels_from_world_3dgs_bwd")(
means,
quats,
scales,
colors,
opacities,
backgrounds,
masks,
width,
height,
tile_size,
viewmats,
viewmats_rs,
Ks,
camera_model_type,
ut_params,
rs_type,
rays,
radial_coeffs,
tangential_coeffs,
thin_prism_coeffs,
ftheta_coeffs,
lidar_coeffs, # already converted to C++ in forward
external_distortion_coeffs,
isect_offsets,
flatten_ids,
use_hit_distance,
render_alphas,
last_ids,
v_render_colors.contiguous(),
v_render_alphas.contiguous(),
v_render_normals.contiguous() if v_render_normals is not None else None,
)
if ctx.needs_input_grad[5]: # backgrounds
v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum(
dim=(-3, -2)
)
else:
v_backgrounds = None
# Check not needed anymore because we return v_rays directly
# if ctx.needs_input_grad[7]: # viewmats
# raise NotImplementedError
return (
v_means,
v_quats,
v_scales,
v_colors,
v_opacities,
v_backgrounds,
None, # masks
None, # viewmats
None, # Ks
None, # width
None, # height
None, # tile_size
None, # isect_offsets
None, # flatten_ids
None, # camera_model
None, # ut_params
v_rays, # rays
None, # radial_coeffs
None, # tangential_coeffs
None, # thin_prism_coeffs
None, # ftheta_coeffs
None, # lidar_coeffs
None, # external_distortion_coeffs
None, # rolling_shutter
None, # viewmats_rs
None, # return_sample_counts (flag, no gradient)
None, # use_hit_distance
None, # return_normals (flag, no gradient)
)
class _FullyFusedProjectionPacked(torch.autograd.Function):
"""Projects Gaussians to 2D. Return packed tensors."""
@staticmethod
def forward(
ctx,
means: Tensor, # [..., N, 3]
covars: Tensor, # [..., N, 6] or None
quats: Tensor, # [..., N, 4] or None
scales: Tensor, # [..., N, 3] or None
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
eps2d: float,
near_plane: float,
far_plane: float,
radius_clip: float,
sparse_grad: bool,
calc_compensations: bool,
camera_model: CameraModel = "pinhole",
opacities: Optional[Tensor] = None, # [..., N] or None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
assert (
camera_model != "ftheta"
), "ftheta camera is only supported via UT, please set with_ut=True in the rasterization()"
camera_model_type = _make_lazy_cuda_obj(
f"CameraModelType.{camera_model.upper()}"
)
(
indptr,
batch_ids,
camera_ids,
gaussian_ids,
radii,
means2d,
depths,
conics,
compensations,
) = _make_lazy_cuda_func("projection_ewa_3dgs_packed_fwd")(
means,
covars, # optional
quats, # optional
scales, # optional
opacities, # optional
viewmats,
Ks,
width,
height,
eps2d,
near_plane,
far_plane,
radius_clip,
calc_compensations,
camera_model_type,
)
if not calc_compensations:
compensations = None
ctx.save_for_backward(
batch_ids,
camera_ids,
gaussian_ids,
means,
covars,
quats,
scales,
viewmats,
Ks,
conics,
compensations,
)
ctx.width = width
ctx.height = height
ctx.eps2d = eps2d
ctx.sparse_grad = sparse_grad
ctx.camera_model_type = camera_model_type
return (
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
conics,
compensations,
)
@staticmethod
def backward(
ctx,
v_batch_ids,
v_camera_ids,
v_gaussian_ids,
v_indptr,
v_radii,
v_means2d,
v_depths,
v_conics,
v_compensations,
):
(
batch_ids,
camera_ids,
gaussian_ids,
means,
covars,
quats,
scales,
viewmats,
Ks,
conics,
compensations,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
eps2d = ctx.eps2d
sparse_grad = ctx.sparse_grad
camera_model_type = ctx.camera_model_type
if v_compensations is not None:
v_compensations = v_compensations.contiguous()
v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
"projection_ewa_3dgs_packed_bwd"
)(
means,
covars,
quats,
scales,
viewmats,
Ks,
width,
height,
eps2d,
camera_model_type,
batch_ids,
camera_ids,
gaussian_ids,
conics,
compensations,
v_means2d.contiguous(),
v_depths.contiguous(),
v_conics.contiguous(),
v_compensations,
ctx.needs_input_grad[4], # viewmats_requires_grad
sparse_grad,
)
if sparse_grad:
batch_dims = means.shape[:-2]
B = math.prod(batch_dims)
N = means.shape[-2]
if not ctx.needs_input_grad[0]:
v_means = None
else:
if sparse_grad:
# TODO: gaussian_ids is duplicated so not ideal.
# An idea is to directly set the attribute (e.g., .sparse_grad) of
# the tensor but this requires the tensor to be leaf node only. And
# a customized optimizer would be needed in this case.
v_means = torch.sparse_coo_tensor(
indices=gaussian_ids[None],
values=v_means, # [nnz, 3]
size=means.shape,
is_coalesced=len(viewmats) == 1,
)
if not ctx.needs_input_grad[1]:
v_covars = None
else:
if sparse_grad:
v_covars = torch.sparse_coo_tensor(
indices=gaussian_ids[None],
values=v_covars, # [nnz, 6]
size=covars.shape,
is_coalesced=len(viewmats) == 1,
)
if not ctx.needs_input_grad[2]:
v_quats = None
else:
if sparse_grad:
v_quats = torch.sparse_coo_tensor(
indices=gaussian_ids[None],
values=v_quats, # [nnz, 4]
size=quats.shape,
is_coalesced=len(viewmats) == 1,
)
if not ctx.needs_input_grad[3]:
v_scales = None
else:
if sparse_grad:
v_scales = torch.sparse_coo_tensor(
indices=gaussian_ids[None],
values=v_scales, # [nnz, 3]
size=scales.shape,
is_coalesced=len(viewmats) == 1,
)
if not ctx.needs_input_grad[4]:
v_viewmats = None
return (
v_means,
v_covars,
v_quats,
v_scales,
v_viewmats,
None, # Ks
None, # width
None, # height
None, # eps2d
None, # near_plane
None, # far_plane
None, # radius_clip
None, # calc_compensations
None, # sparse_grad
None, # camera_model
None, # ut_params
)
class _SphericalHarmonics(torch.autograd.Function):
"""Spherical Harmonics"""
@staticmethod
def forward(
ctx, sh_degree: int, dirs: Tensor, coeffs: Tensor, masks: Tensor
) -> Tensor:
colors = _make_lazy_cuda_func("spherical_harmonics_fwd")(
sh_degree, dirs, coeffs, masks
)
ctx.save_for_backward(dirs, coeffs, masks)
ctx.sh_degree = sh_degree
return colors
@staticmethod
def backward(ctx, v_colors: Tensor):
dirs, coeffs, masks = ctx.saved_tensors
sh_degree = ctx.sh_degree
compute_v_dirs = ctx.needs_input_grad[1]
v_coeffs, v_dirs = _make_lazy_cuda_func("spherical_harmonics_bwd")(
sh_degree,
dirs,
coeffs,
masks,
v_colors.contiguous(),
compute_v_dirs,
)
if not compute_v_dirs:
v_dirs = None
return (
None, # sh_degree
v_dirs,
v_coeffs,
None, # masks
)
###### 2DGS ######
[docs]
def fully_fused_projection_2dgs(
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
scales: Tensor, # [..., N, 3]
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
eps2d: float = 0.3,
near_plane: float = 0.01,
far_plane: float = 1e10,
radius_clip: float = 0.0,
packed: bool = False,
sparse_grad: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Prepare Gaussians for rasterization
This function prepares ray-splat intersection matrices, computes
per splat bounding box and 2D means in image space.
Args:
means: Gaussian means. [..., N, 3]
quats: Quaternions (No need to be normalized). [..., N, 4].
scales: Scales. [..., N, 3].
viewmats: World-to-camera matrices. [..., C, 4, 4]
Ks: Camera intrinsics. [..., C, 3, 3]
width: Image width.
height: Image height.
near_plane: Near plane distance. Default: 0.01.
far_plane: Far plane distance. Default: 200.
radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0.
packed: If True, the output tensors will be packed into a flattened tensor. Default: False.
sparse_grad (Experimental): This is only effective when `packed` is True. If True, during backward the gradients
of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False.
Returns:
A tuple:
If `packed` is True:
- **batch_ids**. The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **camera_ids**. The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz, 2].
- **means**. Projected Gaussian means in 2D. [nnz, 2]
- **depths**. The z-depth of the projected Gaussians. [nnz]
- **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates (WH)^T in equation (9) in paper [nnz, 3, 3]
- **normals**. The normals in camera spaces. [nnz, 3]
If `packed` is False:
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [..., C, N, 2].
- **means**. Projected Gaussian means in 2D. [..., C, N, 2]
- **depths**. The z-depth of the projected Gaussians. [..., C, N]
- **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates [..., C, N, 3, 3]
- **normals**. The normals in camera spaces. [..., C, N, 3]
"""
batch_dims = means.shape[:-2]
N = means.shape[-2]
C = viewmats.shape[-3]
assert means.shape == batch_dims + (N, 3), means.shape
assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape
assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape
means = means.contiguous()
assert quats is not None, "quats is required"
assert scales is not None, "scales is required"
assert quats.shape == batch_dims + (N, 4), quats.shape
assert scales.shape == batch_dims + (N, 3), scales.shape
quats = quats.contiguous()
scales = scales.contiguous()
if sparse_grad:
assert packed, "sparse_grad is only supported when packed is True"
viewmats = viewmats.contiguous()
Ks = Ks.contiguous()
if packed:
return _FullyFusedProjectionPacked2DGS.apply(
means,
quats,
scales,
viewmats,
Ks,
width,
height,
near_plane,
far_plane,
radius_clip,
sparse_grad,
)
else:
return _FullyFusedProjection2DGS.apply(
means,
quats,
scales,
viewmats,
Ks,
width,
height,
eps2d,
near_plane,
far_plane,
radius_clip,
)
class _FullyFusedProjection2DGS(torch.autograd.Function):
"""Projects Gaussians to 2D."""
@staticmethod
def forward(
ctx,
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
scales: Tensor, # [..., N, 3]
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
eps2d: float,
near_plane: float,
far_plane: float,
radius_clip: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
radii, means2d, depths, ray_transforms, normals = _make_lazy_cuda_func(
"projection_2dgs_fused_fwd"
)(
means,
quats,
scales,
viewmats,
Ks,
width,
height,
eps2d,
near_plane,
far_plane,
radius_clip,
)
ctx.save_for_backward(
means,
quats,
scales,
viewmats,
Ks,
radii,
ray_transforms,
normals,
)
ctx.width = width
ctx.height = height
ctx.eps2d = eps2d
return radii, means2d, depths, ray_transforms, normals
@staticmethod
def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals):
(
means,
quats,
scales,
viewmats,
Ks,
radii,
ray_transforms,
normals,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
eps2d = ctx.eps2d
v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
"projection_2dgs_fused_bwd"
)(
means,
quats,
scales,
viewmats,
Ks,
width,
height,
radii,
ray_transforms,
v_means2d.contiguous(),
v_depths.contiguous(),
v_normals.contiguous(),
v_ray_transforms.contiguous(),
ctx.needs_input_grad[3], # viewmats_requires_grad
)
if not ctx.needs_input_grad[0]:
v_means = None
if not ctx.needs_input_grad[1]:
v_quats = None
if not ctx.needs_input_grad[2]:
v_scales = None
if not ctx.needs_input_grad[3]:
v_viewmats = None
return (
v_means,
v_quats,
v_scales,
v_viewmats,
None, # Ks
None, # width
None, # height
None, # eps2d
None, # near_plane
None, # far_plane
None, # radius_clip
None, # camera_model
)
class _FullyFusedProjectionPacked2DGS(torch.autograd.Function):
"""Projects Gaussians to 2D. Return packed tensors."""
@staticmethod
def forward(
ctx,
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
scales: Tensor, # [..., N, 3]
viewmats: Tensor, # [..., C, 4, 4]
Ks: Tensor, # [..., C, 3, 3]
width: int,
height: int,
near_plane: float,
far_plane: float,
radius_clip: float,
sparse_grad: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
(
indptr,
batch_ids,
camera_ids,
gaussian_ids,
radii,
means2d,
depths,
ray_transforms,
normals,
) = _make_lazy_cuda_func("projection_2dgs_packed_fwd")(
means,
quats,
scales,
viewmats,
Ks,
width,
height,
near_plane,
far_plane,
radius_clip,
)
ctx.save_for_backward(
batch_ids,
camera_ids,
gaussian_ids,
means,
quats,
scales,
viewmats,
Ks,
ray_transforms,
)
ctx.width = width
ctx.height = height
ctx.sparse_grad = sparse_grad
return (
batch_ids,
camera_ids,
gaussian_ids,
radii,
means2d,
depths,
ray_transforms,
normals,
)
@staticmethod
def backward(
ctx,
v_batch_ids,
v_camera_ids,
v_gaussian_ids,
v_radii,
v_means2d,
v_depths,
v_ray_transforms,
v_normals,
):
(
batch_ids,
camera_ids,
gaussian_ids,
means,
quats,
scales,
viewmats,
Ks,
ray_transforms,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
sparse_grad = ctx.sparse_grad
v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
"projection_2dgs_packed_bwd"
)(
means,
quats,
scales,
viewmats,
Ks,
width,
height,
batch_ids,
camera_ids,
gaussian_ids,
ray_transforms,
v_means2d.contiguous(),
v_depths.contiguous(),
v_ray_transforms.contiguous(),
v_normals.contiguous(),
ctx.needs_input_grad[3], # viewmats_requires_grad
sparse_grad,
)
if sparse_grad:
batch_dims = means.shape[:-2]
B = math.prod(batch_dims)
N = means.shape[-2]
if not ctx.needs_input_grad[0]:
v_means = None
else:
if sparse_grad:
# TODO: gaussian_ids is duplicated so not ideal.
# An idea is to directly set the attribute (e.g., .sparse_grad) of
# the tensor but this requires the tensor to be leaf node only. And
# a customized optimizer would be needed in this case.
v_means = torch.sparse_coo_tensor(
indices=gaussian_ids[None],
values=v_means, # [nnz, 3]
size=means.shape,
is_coalesced=len(viewmats) == 1,
)
if not ctx.needs_input_grad[1]:
v_quats = None
else:
if sparse_grad:
v_quats = torch.sparse_coo_tensor(
indices=gaussian_ids[None],
values=v_quats, # [nnz, 4]
size=quats.shape,
is_coalesced=len(viewmats) == 1,
)
if not ctx.needs_input_grad[2]:
v_scales = None
else:
if sparse_grad:
v_scales = torch.sparse_coo_tensor(
indices=gaussian_ids[None],
values=v_scales, # [nnz, 3]
size=scales.shape,
is_coalesced=len(viewmats) == 1,
)
if not ctx.needs_input_grad[3]:
v_viewmats = None
return (
v_means,
v_quats,
v_scales,
v_viewmats,
None, # Ks
None, # width
None, # height
None, # eps2d
None, # near_plane
None, # far_plane
None, # radius_clip
None, # sparse_grad
None, # camera_model
)
[docs]
def rasterize_to_pixels_2dgs(
means2d: Tensor, # [..., N, 2]
ray_transforms: Tensor, # [..., N, 3, 3]
colors: Tensor, # [..., N, channels]
opacities: Tensor, # [..., N]
normals: Tensor, # [..., N, 3]
densify: Tensor, # [..., N, 2]
image_width: int,
image_height: int,
tile_size: int,
isect_offsets: Tensor, # [..., tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
backgrounds: Optional[Tensor] = None, # [..., channels]
masks: Optional[Tensor] = None, # [..., tile_height, tile_width]
packed: bool = False,
absgrad: bool = False,
distloss: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Rasterize Gaussians to pixels.
Args:
means2d: Projected Gaussian means. [..., N, 2] if packed is False, [nnz, 2] if packed is True.
ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [..., N, 3, 3] if packed is False, [nnz, channels] if packed is True.
colors: Gaussian colors or ND features. [..., N, channels] if packed is False, [nnz, channels] if packed is True.
opacities: Gaussian opacities that support per-view values. [..., N] if packed is False, [nnz] if packed is True.
normals: The normals in camera space. [..., N, 3] if packed is False, [nnz, 3] if packed is True.
densify: Dummy variable to keep track of gradient for densification. [..., N, 2] if packed, [nnz, 3] if packed is True.
tile_size: Tile size.
isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [..., tile_height, tile_width]
flatten_ids: The global flatten indices in [I * N] or [nnz] from `isect_tiles()`. [n_isects]
backgrounds: Background colors. [..., channels]. Default: None.
masks: Optional tile mask to skip rendering GS to masked tiles. [..., tile_height, tile_width]. Default: None.
packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False.
absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False.
Returns:
A tuple:
- **Rendered colors**. [..., image_height, image_width, channels]
- **Rendered alphas**. [..., image_height, image_width, 1]
- **Rendered normals**. [..., image_height, image_width, 3]
- **Rendered distortion**. [..., image_height, image_width, 1]
- **Rendered median depth**.[..., image_height, image_width, 1]
"""
image_dims = means2d.shape[:-2]
channels = colors.shape[-1]
device = means2d.device
if packed:
nnz = means2d.size(0)
assert means2d.shape == (nnz, 2), means2d.shape
assert ray_transforms.shape == (nnz, 3, 3), ray_transforms.shape
assert colors.shape[0] == nnz, colors.shape
assert opacities.shape == (nnz,), opacities.shape
else:
N = means2d.size(-2)
assert means2d.shape == image_dims + (N, 2), means2d.shape
assert ray_transforms.shape == image_dims + (N, 3, 3), ray_transforms.shape
assert colors.shape[:-2] == image_dims, colors.shape
assert opacities.shape == image_dims + (N,), opacities.shape
if backgrounds is not None:
assert backgrounds.shape == image_dims + (channels,), backgrounds.shape
backgrounds = backgrounds.contiguous()
# Pad the channels to the nearest supported number if necessary
if channels > 512 or channels == 0:
# TODO: maybe worth to support zero channels?
raise ValueError(f"Unsupported number of color channels: {channels}")
if channels not in (1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512):
padded_channels = (1 << (channels - 1).bit_length()) - channels
# Make sure the depth (last channel if present) remains in the last channel after padding (for depth distortion and median depth in CUDA kernel)
colors = torch.cat(
[
colors[..., :-1],
torch.empty(*colors.shape[:-1], padded_channels, device=device),
colors[..., -1:],
],
dim=-1,
)
if backgrounds is not None:
backgrounds = torch.cat(
[
backgrounds,
torch.zeros(
*backgrounds.shape[:-1], padded_channels, device=device
),
],
dim=-1,
)
else:
padded_channels = 0
tile_height, tile_width = isect_offsets.shape[-2:]
assert (
tile_height * tile_size >= image_height
), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}"
assert (
tile_width * tile_size >= image_width
), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"
(
render_colors,
render_alphas,
render_normals,
render_distort,
render_median,
) = _RasterizeToPixels2DGS.apply(
means2d.contiguous(),
ray_transforms.contiguous(),
colors.contiguous(),
opacities.contiguous(),
normals.contiguous(),
densify.contiguous(),
backgrounds,
masks,
image_width,
image_height,
tile_size,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
absgrad,
distloss,
)
if padded_channels > 0:
render_colors = torch.cat(
[render_colors[..., : -padded_channels - 1], render_colors[..., -1:]],
dim=-1,
)
return render_colors, render_alphas, render_normals, render_distort, render_median
[docs]
@torch.no_grad()
def rasterize_to_indices_in_range_2dgs(
range_start: int,
range_end: int,
transmittances: Tensor, # [..., image_height, image_width]
means2d: Tensor, # [..., N, 2]
ray_transforms: Tensor, # [..., N, 3, 3]
opacities: Tensor, # [..., N]
image_width: int,
image_height: int,
tile_size: int,
isect_offsets: Tensor,
flatten_ids: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Rasterizes a batch of Gaussians to images but only returns the indices.
.. note::
This function supports iterative rasterization, in which each call of this function
will rasterize a batch of Gaussians from near to far, defined by `[range_start, range_end)`.
If a one-step full rasterization is desired, set `range_start` to 0 and `range_end` to a really
large number, e.g, 1e10.
Args:
range_start: The start batch of Gaussians to be rasterized (inclusive).
range_end: The end batch of Gaussians to be rasterized (exclusive).
transmittances: Currently transmittances. [..., image_height, image_width]
means2d: Projected Gaussian means. [..., N, 2]
ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [..., N, 3, 3]
opacities: Gaussian opacities that support per-view values. [..., N]
image_width: Image width.
image_height: Image height.
tile_size: Tile size.
isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [..., tile_height, tile_width]
flatten_ids: The global flatten indices in [I * N] from `isect_tiles()`. [n_isects]
Returns:
A tuple:
- **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M].
- **Pixel ids**. pixel indices (row-major). A flattened list of shape [M].
- **Camera ids**. Camera indices. A flattened list of shape [M].
- **Batch ids**. Batch indices. A flattened list of shape [M].
"""
image_dims = means2d.shape[:-2]
tile_height, tile_width = isect_offsets.shape[-2:]
N = means2d.shape[-2]
assert transmittances.shape == image_dims + (
image_height,
image_width,
), transmittances.shape
assert means2d.shape == image_dims + (N, 2), means2d.shape
assert ray_transforms.shape == image_dims + (N, 3, 3), ray_transforms.shape
assert opacities.shape == image_dims + (N,), opacities.shape
assert isect_offsets.shape == image_dims + (
tile_height,
tile_width,
), isect_offsets.shape
assert (
tile_height * tile_size >= image_height
), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}"
assert (
tile_width * tile_size >= image_width
), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"
out_gauss_ids, out_indices = _make_lazy_cuda_func("rasterize_to_indices_2dgs")(
range_start,
range_end,
transmittances.contiguous(),
means2d.contiguous(),
ray_transforms.contiguous(),
opacities.contiguous(),
image_width,
image_height,
tile_size,
isect_offsets.contiguous(),
flatten_ids.contiguous(),
)
out_pixel_ids = out_indices % (image_width * image_height)
out_image_ids = out_indices // (image_width * image_height)
return out_gauss_ids, out_pixel_ids, out_image_ids
class _RasterizeToPixels2DGS(torch.autograd.Function):
"""Rasterize gaussians 2DGS"""
@staticmethod
def forward(
ctx,
means2d: Tensor,
ray_transforms: Tensor,
colors: Tensor,
opacities: Tensor,
normals: Tensor,
densify: Tensor,
backgrounds: Tensor,
masks: Tensor,
width: int,
height: int,
tile_size: int,
isect_offsets: Tensor,
flatten_ids: Tensor,
absgrad: bool,
distloss: bool,
) -> Tuple[Tensor, Tensor]:
(
render_colors,
render_alphas,
render_normals,
render_distort,
render_median,
last_ids,
median_ids,
) = _make_lazy_cuda_func("rasterize_to_pixels_2dgs_fwd")(
means2d,
ray_transforms,
colors,
opacities,
normals,
backgrounds,
masks,
width,
height,
tile_size,
isect_offsets,
flatten_ids,
)
ctx.save_for_backward(
means2d,
ray_transforms,
colors,
opacities,
normals,
densify,
backgrounds,
masks,
isect_offsets,
flatten_ids,
render_colors,
render_alphas,
last_ids,
median_ids,
)
ctx.width = width
ctx.height = height
ctx.tile_size = tile_size
ctx.absgrad = absgrad
ctx.distloss = distloss
# double to float
render_alphas = render_alphas.float()
return (
render_colors,
render_alphas,
render_normals,
render_distort,
render_median,
)
@staticmethod
def backward(
ctx,
v_render_colors: Tensor,
v_render_alphas: Tensor,
v_render_normals: Tensor,
v_render_distort: Tensor,
v_render_median: Tensor,
):
(
means2d,
ray_transforms,
colors,
opacities,
normals,
densify,
backgrounds,
masks,
isect_offsets,
flatten_ids,
render_colors,
render_alphas,
last_ids,
median_ids,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
tile_size = ctx.tile_size
absgrad = ctx.absgrad
(
v_means2d_abs,
v_means2d,
v_ray_transforms,
v_colors,
v_opacities,
v_normals,
v_densify,
) = _make_lazy_cuda_func("rasterize_to_pixels_2dgs_bwd")(
means2d,
ray_transforms,
colors,
opacities,
normals,
densify,
backgrounds,
masks,
width,
height,
tile_size,
isect_offsets,
flatten_ids,
render_colors,
render_alphas,
last_ids,
median_ids,
v_render_colors.contiguous(),
v_render_alphas.contiguous(),
v_render_normals.contiguous(),
v_render_distort.contiguous(),
v_render_median.contiguous(),
absgrad,
)
torch.cuda.synchronize()
if absgrad:
means2d.absgrad = v_means2d_abs
if ctx.needs_input_grad[6]:
v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum(
dim=(-3, -2)
)
else:
v_backgrounds = None
return (
v_means2d,
v_ray_transforms,
v_colors,
v_opacities,
v_normals,
v_densify,
v_backgrounds,
None, # masks
None, # width
None, # height
None, # tile_size
None, # isect_offsets
None, # flatten_ids
None, # absgrad
None, # distloss
)