Source code for gsplat.rendering

# SPDX-FileCopyrightText: Copyright 2024-2025 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Dict, Optional, Tuple, cast

import torch
import torch.distributed
import torch.nn.functional as F
from torch import Tensor
from typing_extensions import Literal
from ._helper import assert_shape

from .cuda._wrapper import (
    RollingShutterType,
    CameraModel,
    FThetaCameraDistortionParameters,
    FThetaPolynomialType,
    RowOffsetStructuredSpinningLidarModelParametersExt,
    UnscentedTransformParameters,
    ExternalDistortionModelMeta,
    ExternalDistortionModelParameters,
    ExternalDistortionReferencePolynomial,
    BivariateWindshieldModelParameters,
    fully_fused_projection,
    fully_fused_projection_2dgs,
    fully_fused_projection_with_ut,
    isect_offset_encode,
    isect_tiles,
    isect_tiles_lidar,
    rasterize_to_pixels,
    rasterize_to_pixels_2dgs,
    rasterize_to_pixels_eval3d,
    rasterize_to_pixels_eval3d_extra,
    spherical_harmonics,
)
from .distributed import (
    all_gather_int32,
    all_gather_tensor_list,
    all_to_all_int32,
    all_to_all_tensor_list,
)
from .utils import depth_to_normal, get_projection_matrix

# Gaussian depth modes (D/ED): use projection depth (controlled by global_z_order)
# Hit distance modes (d/Ed): compute along-ray distance in rasterization
RenderMode = Literal["RGB", "d", "Ed", "D", "ED", "RGB-d", "RGB-Ed", "RGB+D", "RGB+ED"]

RasterizeMode = Literal["classic", "antialiased"]


# TODO: RenderMode should be an enum so that we can add these query methods to it.
# The problem is that it'd break backward compatibllity due to some symbols used, e.g. RGB+D or RGB-d.
def render_mode_has_color(mode: RenderMode) -> bool:
    return mode in {"RGB", "RGB-d", "RGB-Ed", "RGB+D", "RGB+ED"}


def render_mode_has_hit_distance(mode: RenderMode) -> bool:
    return mode in {"d", "Ed", "RGB-d", "RGB-Ed"}


def render_mode_has_depth(mode: RenderMode) -> bool:
    return mode in {"D", "ED", "RGB+D", "RGB+ED"}


def render_mode_has_expected_depth(mode: RenderMode) -> bool:
    return mode in {"Ed", "ED", "RGB-Ed", "RGB+ED"}


def render_mode_has_depth_channel(mode: RenderMode) -> bool:
    return render_mode_has_depth(mode) or render_mode_has_hit_distance(mode)


def render_mode_has_only_depth_channel(mode: RenderMode) -> bool:
    return render_mode_has_depth_channel(mode) and not render_mode_has_color(mode)


def render_mode_has_only_color(mode: RenderMode) -> bool:
    return not render_mode_has_depth_channel(mode) and render_mode_has_color(mode)


def _compute_view_dirs_packed(
    means: Tensor,  # [..., N, 3]
    campos: Tensor,  # [..., C, 3]
    batch_ids: Tensor,  # [nnz]
    camera_ids: Tensor,  # [nnz]
    gaussian_ids: Tensor,  # [nnz]
    indptr: Tensor,  # [B*C+1]
    B: int,
    C: int,
) -> Tensor:
    """Compute view directions for packed Gaussian-camera pairs.

    This function computes the view directions (means - campos) for each
    Gaussian-camera pair in the packed format. It automatically selects between
    a simple vectorized approach or an optimized loop-based approach based on
    the data size and whether campos requires gradients.

    Args:
        means: The 3D centers of the Gaussians. [..., N, 3]
        campos: Camera positions in world coordinates [..., C, 3]
        batch_ids: The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
        camera_ids: The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
        gaussian_ids: The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
        indptr: CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
        B: Number of batches
        C: Number of cameras

    Returns:
        dirs: View directions [nnz, 3]
    """
    N = means.shape[-2]
    nnz = batch_ids.shape[0]
    device = means.device
    means_flat = means.view(B, N, 3)
    campos_flat = campos.view(B, C, 3)

    if B * C == 1:
        # Single batch-camera pair. No indexed lookup for campos is needed.
        dirs = means_flat[0, gaussian_ids] - campos_flat[0, 0]  # [nnz, 3]
    else:
        avg_means_per_camera = nnz / (B * C)
        split_batch_camera_ops = (
            avg_means_per_camera > 10000
            and campos_flat.is_cuda
            and campos_flat.requires_grad
        )

        if not split_batch_camera_ops:
            # Simple vectorized indexing for campos.
            dirs = (
                means_flat[batch_ids, gaussian_ids] - campos_flat[batch_ids, camera_ids]
            )  # [nnz, 3]
        else:
            # For large N with pose optimization: split into B*C separate operations
            # to avoid many-to-one indexing of campos in backward pass. This speeds up the
            # backwards pass and is more impactful when GPU occupancy is high.
            dirs = torch.empty((nnz, 3), dtype=means_flat.dtype, device=device)
            indptr_cpu = indptr.cpu()
            for b_idx in range(B):
                for c_idx in range(C):
                    bc_idx = b_idx * C + c_idx
                    start_idx = indptr_cpu[bc_idx].item()
                    end_idx = indptr_cpu[bc_idx + 1].item()
                    if start_idx == end_idx:
                        continue

                    # Get the gaussian indices for this batch-camera pair and compute dirs
                    gids = gaussian_ids[start_idx:end_idx]
                    dirs[start_idx:end_idx] = (
                        means_flat[b_idx, gids] - campos_flat[b_idx, c_idx]
                    )

    return dirs


def normalize_features_layout(
    features: Tensor,
    batch_dims: tuple,
    C: int,
    trailing_dims: tuple,
    batch_ids: Optional[Tensor] = None,
    camera_ids: Optional[Tensor] = None,
    feature_ids: Optional[Tensor] = None,
) -> Tensor:
    """Normalize per-view or per-gaussian feature tensor layout to (nnz, *trailing) or (*batch_dims, C, *trailing)."""
    B = math.prod(batch_dims)
    N = features.shape[-(len(trailing_dims) + 1)]

    # per-view features?
    if (
        features.shape
        == batch_dims
        + (
            C,
            N,
        )
        + trailing_dims
    ):
        # packed?
        if feature_ids is not None:
            # [..., C, N, *trailing] -> [nnz, *trailing]
            return features.view(B, C, N, *trailing_dims)[
                batch_ids, camera_ids, feature_ids
            ]
        else:
            # already (..., C, N, *trailing)
            return features
    # per-gaussian features?
    else:
        assert features.shape == (*batch_dims, N, *trailing_dims)
        # packed?
        if feature_ids is not None:
            # [..., N, *trailing] -> [nnz, *trailing]
            return features.view(B, N, *trailing_dims)[batch_ids, feature_ids]
        else:
            # (..., N, *trailing) -> (..., C, N, *trailing)
            return torch.broadcast_to(
                features.unsqueeze(len(batch_dims)), batch_dims + (C, N, *trailing_dims)
            )


def viewmat_to_camera_position(viewmats: Tensor) -> Tensor:
    """Camera position in world from world-to-camera 4x4 matrix without full inverse.

    For V = [R | t; 0 1], inv(V) has translation -R^T t, so camera position is -R^T t.
    This avoids torch.inverse and does not fail on singular 4x4 (e.g. degenerate poses).
    """
    R = viewmats[..., :3, :3]
    t = viewmats[..., :3, 3]
    return -(R.mT @ t.unsqueeze(-1)).squeeze(-1)


def compute_directions(
    batch_dims: tuple,
    means: Tensor,
    viewmats: Tensor,
    batch_ids: Optional[Tensor] = None,
    camera_ids: Optional[Tensor] = None,
    gaussian_ids: Optional[Tensor] = None,
    indptr: Optional[Tensor] = None,  # [B*C+1]
    *,
    viewmats_rs: Optional[Tensor] = None,
) -> Tensor:
    # Compute cameras' absolute positions (no 4x4 inverse; robust to singular viewmats)
    campos = viewmat_to_camera_position(viewmats)
    if viewmats_rs is not None:
        campos_rs = viewmat_to_camera_position(viewmats_rs)
        campos = 0.5 * (campos + campos_rs)

    # Compute the direction of each gaussian wrt. its camera
    if gaussian_ids is None:
        dirs = means[..., None, :, :] - campos[..., None, :]
    else:
        B = math.prod(batch_dims)
        C = campos.shape[-2]
        dirs = _compute_view_dirs_packed(
            means, campos, batch_ids, camera_ids, gaussian_ids, indptr, B, C
        )  # [nnz, 3]
    return F.normalize(dirs, p=2, dim=-1)


[docs] def rasterization( means: Tensor, # [..., N, 3] quats: Tensor, # [..., N, 4] scales: Tensor, # [..., N, 3] opacities: Tensor, # [..., N] colors: Optional[ Tensor ], # [..., (C,) N, D] or [..., (C,) N, K, 3]; None for depth-only render_modes viewmats: Tensor, # [..., C, 4, 4] Ks: Tensor, # [..., C, 3, 3] width: int, height: int, near_plane: float = 0.01, far_plane: float = 1e10, radius_clip: float = 0.0, eps2d: float = 0.3, sh_degree: Optional[int] = None, packed: bool = True, tile_size: int = 16, backgrounds: Optional[Tensor] = None, render_mode: RenderMode = "RGB", sparse_grad: bool = False, absgrad: bool = False, rasterize_mode: RasterizeMode = "classic", channel_chunk: int = 32, distributed: bool = False, camera_model: CameraModel = "pinhole", segmented: bool = False, covars: Optional[Tensor] = None, with_ut: bool = False, with_eval3d: bool = False, return_normals: bool = False, global_z_order: bool = True, rays: Optional[ Tensor ] = None, # [..., C, H, W, 6] -> ox, oy, oz, dx*spread, dy*spread, dz*spread # distortion radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4] tangential_coeffs: Optional[Tensor] = None, # [..., C, 2] thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4] ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None, lidar_coeffs: Optional[RowOffsetStructuredSpinningLidarModelParametersExt] = None, external_distortion_coeffs: Optional[ExternalDistortionModelParameters] = None, # rolling shutter rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL, viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4] # unscented transform (for 3DGUT) ut_params: Optional[UnscentedTransformParameters] = None, # extra signal channels (order in output: RGB, depth, extra) extra_signals: Optional[ Tensor ] = None, # [..., (C,) N, E] or [..., (C,) N, K, 3] when extra_signals_sh_degree set extra_signals_sh_degree: Optional[ int ] = None, # Currently only None or 3 is accepted. ) -> Tuple[Tensor, Tensor, Dict]: """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). This function provides a handful features for 3D Gaussian rasterization, which we detail in the following notes. A complete profiling of the these features can be found in the :ref:`profiling` page. .. note:: **Multi-GPU Distributed Rasterization**: This function can be used in a multi-GPU distributed scenario by setting `distributed` to True. When `distributed` is True, a subset of total Gaussians could be passed into this function in each rank, and the function will collaboratively render a set of images using Gaussians from all ranks. Note to achieve balanced computation, it is recommended (not enforced) to have similar number of Gaussians in each rank. But we do enforce that the number of cameras to be rendered in each rank is the same. The function will return the rendered images corresponds to the input cameras in each rank, and allows for gradients to flow back to the Gaussians living in other ranks. For the details, please refer to the paper `On Scaling Up 3D Gaussian Splatting Training <https://arxiv.org/abs/2406.18533>`_. .. note:: **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. .. note:: **Support N-D Features**: If `sh_degree` is None, the `colors` is expected to be with shape [..., N, D] or [..., C, N, D], in which D is the channel of the features to be rendered. The computation is slow when D > 32 at the moment. If `sh_degree` is set, the `colors` is expected to be the SH coefficients with shape [..., N, K, 3] or [..., C, N, K, 3], where K is the number of SH bases. In this case, it is expected that :math:`(\\textit{sh_degree} + 1) ^ 2 \\leq K`, where `sh_degree` controls the activated bases in the SH coefficients. .. note:: **Depth Rendering**: This function supports colors or/and depths via `render_mode`. **Gaussian Depth Modes** (use projection depth, controlled by `global_z_order`): - "D": Accumulated Gaussian depth :math:`\\sum_i w_i z_i` - "ED": Expected Gaussian depth :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}` - "RGB+D": RGB + accumulated Gaussian depth - "RGB+ED": RGB + expected Gaussian depth **Hit Distance Modes** (compute along-ray distance in rasterization): - "d": Accumulated hit distance :math:`\\sum_i w_i d_i` - "Ed": Expected hit distance :math:`\\frac{\\sum_i w_i d_i}{\\sum_i w_i}` - "RGB-d": RGB + accumulated hit distance - "RGB-Ed": RGB + expected hit distance "RGB" renders only the colored image. For combined modes, depth is the last channel. When extra_signals are present, render_colors is RGB + depth only (4 channels); extra channels are returned in ``meta["render_extra_signals"]``. .. note:: **Extra signals**: Optional `extra_signals` are rendered and returned in ``meta["render_extra_signals"]`` (shape [..., C, height, width, E]). If `extra_signals_sh_degree` is set, extra_signals are SH coefficients [..., N, K, 3] evaluated per view. .. note:: **Memory-Speed Trade-off**: The `packed` argument provides a trade-off between memory footprint and runtime. If `packed` is True, the intermediate results are packed into sparse tensors, which is more memory efficient but might be slightly slower. This is especially helpful when the scene is large and each camera sees only a small portion of the scene. If `packed` is False, the intermediate results are with shape [..., C, N, ...], which is faster but might consume more memory. .. note:: **Sparse Gradients**: If `sparse_grad` is True, the gradients for {means, quats, scales} will be stored in a `COO sparse layout <https://pytorch.org/docs/stable/generated/torch.sparse_coo_tensor.html>`_. This can be helpful for saving memory for training when the scene is large and each iteration only activates a small portion of the Gaussians. Usually a sparse optimizer is required to work with sparse gradients, such as `torch.optim.SparseAdam <https://pytorch.org/docs/stable/generated/torch.optim.SparseAdam.html#sparseadam>`_. This argument is only effective when `packed` is True. .. note:: **Speed-up for Large Scenes**: The `radius_clip` argument is extremely helpful for speeding up large scale scenes or scenes with large depth of fields. Gaussians with 2D radius smaller or equal than this value (in pixel unit) will be skipped during rasterization. This will skip all the far-away Gaussians that are too small to be seen in the image. But be warned that if there are close-up Gaussians that are also below this threshold, they will also get skipped (which is rarely happened in practice). This is by default disabled by setting `radius_clip` to 0.0. .. note:: **Antialiased Rendering**: If `rasterize_mode` is "antialiased", the function will apply a view-dependent compensation factor :math:`\\rho=\\sqrt{\\frac{Det(\\Sigma)}{Det(\\Sigma+ \\epsilon I)}}` to Gaussian opacities, where :math:`\\Sigma` is the projected 2D covariance matrix and :math:`\\epsilon` is the `eps2d`. This will make the rendered image more antialiased, as proposed in the paper `Mip-Splatting: Alias-free 3D Gaussian Splatting <https://arxiv.org/pdf/2311.16493>`_. .. note:: **AbsGrad**: If `absgrad` is True, the absolute gradients of the projected 2D means will be computed during the backward pass, which could be accessed by `meta["means2d"].absgrad`. This is an implementation of the paper `AbsGS: Recovering Fine Details for 3D Gaussian Splatting <https://arxiv.org/abs/2404.10484>`_, which is shown to be more effective for splitting Gaussians during training. .. note:: **Camera Distortion and Rolling Shutter**: The function supports rendering with opencv distortion formula for pinhole and fisheye cameras (`radial_coeffs`, `tangential_coeffs`, `thin_prism_coeffs`). It also supports rolling shutter rendering with the `rolling_shutter` argument. We take reference from the paper `3DGUT: Enabling Distorted Cameras and Secondary Rays in Gaussian Splatting <https://arxiv.org/abs/2412.12507>`_. .. warning:: This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. Args: means: The 3D centers of the Gaussians. [..., N, 3] quats: The quaternions of the Gaussians (wxyz convension). It's not required to be normalized. [..., N, 4] scales: The scales of the Gaussians. [..., N, 3] opacities: The opacities of the Gaussians. [..., N] colors: The colors of the Gaussians. [..., (C,) N, D] or [..., (C,) N, K, 3] for SH coefficients. viewmats: The world-to-cam transformation of the cameras. [..., C, 4, 4] Ks: The camera intrinsics. [..., C, 3, 3] width: The width of the image. For lidar sensors, this is ignored. The width is taken from lidar_coeffs.n_columns. height: The height of the image. For lidar sensors, this is ignored. The height is taken from lidar_coeffs.n_rows. near_plane: The near plane for clipping. Default is 0.01. far_plane: The far plane for clipping. Default is 1e10. radius_clip: Gaussians with 2D radius smaller or equal than this value will be skipped. This is extremely helpful for speeding up large scale scenes. Default is 0.0. eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. This will prevents the projected GS to be too small. For example eps2d=0.3 leads to minimal 3 pixel unit. Default is 0.3. sh_degree: The SH degree to use, which can be smaller than the total number of bands. If set, the `colors` should be [..., (C,) N, K, 3] SH coefficients, else the `colors` should be [..., (C,) N, D] post-activation color values. Default is None. packed: Whether to use packed mode which is more memory efficient but might or might not be as fast. Default is True. tile_size: The size of the tiles for rasterization. Default is 16. (Note: other values are not tested) backgrounds: The background colors. [..., C, D]. Default is None. render_mode: The rendering mode. Supported modes are "RGB", "d", "Ed", "D", "ED", "RGB-d", "RGB-Ed", "RGB+D", and "RGB+ED". "RGB" renders the colored image. Gaussian depth modes (D, ED, RGB+D, RGB+ED) use projection depth. Hit distance modes (d, Ed, RGB-d, RGB-Ed) compute along-ray distance. Expected modes (Ed, ED) are normalized by opacity. Default is "RGB". sparse_grad: If true, the gradients for {means, quats, scales} will be stored in a COO sparse layout. This can be helpful for saving memory. Default is False. absgrad: If true, the absolute gradients of the projected 2D means will be computed during the backward pass, which could be accessed by `meta["means2d"].absgrad`. Default is False. rasterize_mode: The rasterization mode. Supported modes are "classic" and "antialiased". Default is "classic". channel_chunk: The number of channels to render in one go. Default is 32. If the required rendering channels are larger than this value, the rendering will be done looply in chunks. distributed: Whether to use distributed rendering. Default is False. If True, The input Gaussians are expected to be a subset of scene in each rank, and the function will collaboratively render the images for all ranks. camera_model: The camera model to use. Supported models are "pinhole", "ortho", "fisheye", and "ftheta". Default is "pinhole". segmented: Whether to use segmented radix sort. Default is False. Segmented radix sort performs sorting in segments, which is more efficient for the sorting operation itself. However, since it requires offset indices as input, additional global memory access is needed, which results in slower overall performance in most use cases. covars: Optional covariance matrices of the Gaussians. If provided, the `quats` and `scales` will be ignored. [..., N, 3, 3], Default is None. with_ut: Whether to use Unscented Transform (UT) for projection. Default is False. with_eval3d: Whether to calculate Gaussian response in 3D world space, instead of 2D image space. Default is False. return_normals: Whether to compute and return accumulated normals per pixel. Normals are computed from Gaussian quaternions (canonical normal = (0,0,1) transformed by rotation, flipped if facing away from ray). Requires with_eval3d=True. Default is False. global_z_order: Whether to use z-depth (True) or Euclidean distance (False) for sorting Gaussians during rasterization. When True, Gaussians are sorted by their z-coordinate in camera space. When False, they are sorted by their Euclidean distance from the camera origin. Default is True. radial_coeffs: Opencv pinhole/fisheye radial distortion coefficients. Default is None. For pinhole camera, the shape should be [..., C, 6]. For fisheye camera, the shape should be [..., C, 4]. tangential_coeffs: Opencv pinhole tangential distortion coefficients. Default is None. The shape should be [..., C, 2] if provided. thin_prism_coeffs: Opencv pinhole thin prism distortion coefficients. Default is None. The shape should be [..., C, 4] if provided. ftheta_coeffs: F-Theta camera distortion coefficients shared for all cameras. Default is None. See `FThetaCameraDistortionParameters` for details. rolling_shutter: The rolling shutter type. Default `RollingShutterType.GLOBAL` means global shutter. viewmats_rs: The second viewmat when rolling shutter is used. Default is None. Returns: A tuple: **render_colors**: The rendered colors. [..., C, height, width, X]. X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or "RGB+ED", X is D+1. **render_alphas**: The rendered alphas. [..., C, height, width, 1]. **meta**: A dictionary of intermediate results of the rasterization. Examples: .. code-block:: python >>> # define Gaussians >>> means = torch.randn((100, 3), device=device) >>> quats = torch.randn((100, 4), device=device) >>> scales = torch.rand((100, 3), device=device) * 0.1 >>> colors = torch.rand((100, 3), device=device) >>> opacities = torch.rand((100,), device=device) >>> # define cameras >>> viewmats = torch.eye(4, device=device)[None, :, :] >>> Ks = torch.tensor([ >>> [300., 0., 150.], [0., 300., 100.], [0., 0., 1.]], device=device)[None, :, :] >>> width, height = 300, 200 >>> # render >>> colors, alphas, meta = rasterization( >>> means, quats, scales, opacities, colors, viewmats, Ks, width, height >>> ) >>> print (colors.shape, alphas.shape) torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 1]) >>> print (meta.keys()) dict_keys(['camera_ids', 'gaussian_ids', 'radii', 'means2d', 'depths', 'conics', 'opacities', 'tile_width', 'tile_height', 'tiles_per_gauss', 'isect_ids', 'flatten_ids', 'isect_offsets', 'width', 'height', 'tile_size']) """ meta = {} has_color = render_mode_has_color(render_mode) if colors is None and has_color: raise ValueError( f"colors must be provided when render_mode='{render_mode}' includes RGB. " f"Pass colors=None only for depth-only render modes (D, d, Ed, ED)." ) if colors is None and sh_degree is not None: raise ValueError( f"sh_degree must be None when colors is None, got sh_degree={sh_degree}." ) external_distortion_coeffs = cast( Optional[BivariateWindshieldModelParameters], external_distortion_coeffs ) if lidar_coeffs is not None: width = lidar_coeffs.n_columns height = lidar_coeffs.n_rows batch_dims = means.shape[:-2] num_batch_dims = len(batch_dims) B = math.prod(batch_dims) N = means.shape[-2] C = viewmats.shape[-3] D = ( colors.shape[-1] if has_color else 0 ) # number of primary color channels; 0 for depth-only I = B * C H = height W = width device = means.device assert means.shape == batch_dims + (N, 3), means.shape if covars is None: assert quats.shape == batch_dims + (N, 4), quats.shape assert scales.shape == batch_dims + (N, 3), scales.shape else: assert covars.shape == batch_dims + (N, 3, 3), covars.shape quats, scales = None, None # convert covars from 3x3 matrix to upper-triangular 6D vector tri_indices = ([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]) covars = covars[..., tri_indices[0], tri_indices[1]] assert opacities.shape == batch_dims + (N,), opacities.shape assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape if rays is not None: assert_shape("rays", rays, batch_dims + (C, H, W, 6)) assert global_z_order or with_ut, "global_z_order can be false only if with_ut=True" assert (camera_model == "lidar") == ( lidar_coeffs is not None ), "Lidar coefficients must be given if and only if camera model is lidar" def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tensor: view_list = list( map( lambda x: x.split(int(x.shape[0] / C), dim=0), world_view.split([C * N_i for N_i in N_world], dim=0), ) ) return torch.stack([torch.cat(l, dim=0) for l in zip(*view_list)], dim=0) def check_features(features: Tensor, sh_degree: Optional[int], name: str) -> bool: channels = features.shape[-1] if sh_degree is None: # treat colors as post-activation values, should be in shape [..., N, D] or [..., C, N, D] assert ( features.dim() == num_batch_dims + 2 and features.shape[:-1] == (*batch_dims, N) ) or ( features.dim() == num_batch_dims + 3 and features.shape[:-1] == (*batch_dims, C, N) ), f"{name}'s shape {features.shape=} must be either {(*batch_dims, N, channels)} or {(*batch_dims, C, N, channels)}" if distributed: assert ( features.dim() == num_batch_dims + 2 ), f"Distributed mode only supports per-Gaussian {name}." else: # treat features as SH coefficients, should be in shape [..., N, K, 3] or [..., C, N, K, 3] # Allowing for activating partial SH bands assert ( features.dim() == num_batch_dims + 3 and features.shape[:-2] == batch_dims + (N,) and channels == 3 ) or ( features.dim() == num_batch_dims + 4 and features.shape[:-2] == batch_dims + (C, N) and channels == 3 ), f"{name}'s shape {features.shape=} must be either {(*batch_dims, N, 3)} or {(*batch_dims, C, N, 3)}" assert (sh_degree + 1) ** 2 <= features.shape[-2], features.shape if distributed: assert ( features.dim() == num_batch_dims + 3 ), f"Distributed mode only supports per-Gaussian {name}." # Skip colors validation for depth-only modes (colors are ignored/overwritten) if has_color: check_features(colors, sh_degree, "colors") if extra_signals is not None: check_features(extra_signals, extra_signals_sh_degree, "extra signals") if absgrad: assert not distributed, "AbsGrad is not supported in distributed mode." if ( radial_coeffs is not None or tangential_coeffs is not None or thin_prism_coeffs is not None or ftheta_coeffs is not None or rolling_shutter != RollingShutterType.GLOBAL ): assert ( with_ut ), "Distortion and rolling shutter are only supported with `with_ut=True`." if rolling_shutter != RollingShutterType.GLOBAL: assert ( viewmats_rs is not None ), "Rolling shutter requires to provide viewmats_rs." else: assert ( viewmats_rs is None ), "viewmats_rs should be None for global rolling shutter." if with_ut or with_eval3d: assert (quats is not None) and ( scales is not None ), "UT and eval3d requires to provide quats and scales." assert packed is False, "Packed mode is not supported with UT." assert sparse_grad is False, "Sparse grad is not supported with UT." if return_normals and not with_eval3d: raise ValueError( "return_normals=True requires with_eval3d=True. " "Normal computation is only supported in eval3d mode." ) # Validate hit distance modes require eval3d if render_mode_has_hit_distance(render_mode) and not with_eval3d: raise ValueError( f"Hit distance mode '{render_mode}' requires with_eval3d=True. " f"Classic mode only supports Gaussian depth modes ('D', 'ED', 'RGB+D', 'RGB+ED'). " f"Either set with_eval3d=True or use a Gaussian depth render_mode." ) # Implement the multi-GPU strategy proposed in # `On Scaling Up 3D Gaussian Splatting Training <https://arxiv.org/abs/2406.18533>`. # # If in distributed mode, we distribute the projection computation over Gaussians # and the rasterize computation over cameras. So first we gather the cameras # from all ranks for projection. if distributed: assert batch_dims == (), "Distributed mode does not support batch dimensions" world_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() # Gather the number of Gaussians in each rank. N_world = all_gather_int32(world_size, N, device=device) # Enforce that the number of cameras is the same across all ranks. C_world = [C] * world_size viewmats, Ks = all_gather_tensor_list(world_size, [viewmats, Ks]) if viewmats_rs is not None: (viewmats_rs,) = all_gather_tensor_list(world_size, [viewmats_rs]) if rays is not None: (rays,) = all_gather_tensor_list(world_size, [rays]) # Silently change C from local #Cameras to global #Cameras. C = len(viewmats) if with_ut: # Use provided UT parameters or create default if ut_params is None: ut_params = UnscentedTransformParameters() proj_results = fully_fused_projection_with_ut( means=means, quats=quats, scales=scales, opacities=opacities, # use opacities to compute a tigher bound for radii. viewmats=viewmats, Ks=Ks, width=width, height=height, eps2d=eps2d, near_plane=near_plane, far_plane=far_plane, radius_clip=radius_clip, calc_compensations=(rasterize_mode == "antialiased"), camera_model=camera_model, ut_params=ut_params, radial_coeffs=radial_coeffs, tangential_coeffs=tangential_coeffs, thin_prism_coeffs=thin_prism_coeffs, ftheta_coeffs=ftheta_coeffs, lidar_coeffs=lidar_coeffs, external_distortion_coeffs=external_distortion_coeffs, rolling_shutter=rolling_shutter, viewmats_rs=viewmats_rs, global_z_order=global_z_order, ) else: if lidar_coeffs is not None: raise ValueError( "Lidar coefficients given but with_ut=False. Lidar camera model requires with_ut=True." ) # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. proj_results = fully_fused_projection( means, covars, quats, scales, viewmats, Ks, width, height, eps2d=eps2d, packed=packed, near_plane=near_plane, far_plane=far_plane, radius_clip=radius_clip, sparse_grad=sparse_grad, calc_compensations=(rasterize_mode == "antialiased"), camera_model=camera_model, opacities=opacities, # use opacities to compute a tigher bound for radii. ) if packed: # The results are packed into shape [nnz, ...]. All elements are valid. ( batch_ids, camera_ids, gaussian_ids, indptr, radii, means2d, depths, conics, compensations, ) = proj_results opacities = opacities.view(B, N)[batch_ids, gaussian_ids] # [nnz] image_ids = batch_ids * C + camera_ids else: # The results are with shape [..., C, N, ...]. Only the elements with radii > 0 are valid. radii, means2d, depths, conics, compensations = proj_results opacities = torch.broadcast_to( opacities[..., None, :], batch_dims + (C, N) ) # [..., C, N] indptr, batch_ids, camera_ids, gaussian_ids = None, None, None, None image_ids = None if compensations is not None: opacities = opacities * compensations valid_gaussians = (radii > 0).all(dim=-1) meta.update( { # global batch and camera ids "batch_ids": batch_ids, "camera_ids": camera_ids, # local gaussian_ids "gaussian_ids": gaussian_ids, "radii": radii, "means2d": means2d, "depths": depths, "conics": conics, "opacities": opacities, } ) # Assemble proj_features: evaluate SH (if needed) for colors and extra_signals, # then concatenate them into the feature tensor that will be passed through # distributed communication and finally to rasterization. # In depth-only modes, skip color processing entirely (colors_sh_degree=None). colors_sh_degree = sh_degree if has_color else None proj_features = None if colors_sh_degree is None and extra_signals_sh_degree is None: # No SH evaluation needed — concatenate and normalize in one pass. # Note: we avoid torch.cat for single-element lists because it always # allocates a new tensor even when there is nothing to concatenate. feature_list = [] if has_color: feature_list.append(colors) if extra_signals is not None: feature_list.append(extra_signals) if feature_list: proj_features = ( torch.cat(feature_list, dim=-1) if len(feature_list) > 1 else feature_list[0] ) proj_features = normalize_features_layout( proj_features, batch_dims, C, proj_features.shape[-1:], batch_ids, camera_ids, gaussian_ids, ) else: # At least one signal needs SH evaluation. Normalize and evaluate each # independently, then concatenate. dirs = compute_directions( batch_dims, means, viewmats, batch_ids, camera_ids, gaussian_ids, indptr, viewmats_rs=viewmats_rs, ) feature_list = [] if has_color: colors_tail = ( colors.shape[-2:] if colors_sh_degree is not None else colors.shape[-1:] ) colors = normalize_features_layout( colors, batch_dims, C, colors_tail, batch_ids, camera_ids, gaussian_ids ) if colors_sh_degree is not None: colors = spherical_harmonics( colors_sh_degree, dirs, colors, masks=valid_gaussians ) # Make sure colors >= 0 so that it's apples-to-apples with Inria CUDA backend colors = torch.clamp_min(colors + 0.5, 0.0) feature_list.append(colors) if extra_signals is not None: es_tail = ( extra_signals.shape[-2:] if extra_signals_sh_degree is not None else extra_signals.shape[-1:] ) extra_signals = normalize_features_layout( extra_signals, batch_dims, C, es_tail, batch_ids, camera_ids, gaussian_ids, ) if extra_signals_sh_degree is not None: extra_signals = spherical_harmonics( extra_signals_sh_degree, dirs, extra_signals, masks=valid_gaussians ) extra_signals = extra_signals + 0.5 feature_list.append(extra_signals) if feature_list: proj_features = ( torch.cat(feature_list, dim=-1) if len(feature_list) > 1 else feature_list[0] ) # If in distributed mode, we need to scatter the GSs to the destination ranks, based # on which cameras they are visible to, which we already figured out in the projection # stage. if distributed: if packed: # count how many elements need to be sent to each rank cnts = torch.bincount(camera_ids, minlength=C) # all cameras cnts = cnts.split(C_world, dim=0) cnts = [cuts.sum() for cuts in cnts] # all to all communication across all ranks. After this step, each rank # would have all the necessary GSs to render its own images. collected_splits = all_to_all_int32(world_size, cnts, device=device) (radii,) = all_to_all_tensor_list( world_size, [radii], cnts, output_splits=collected_splits ) if proj_features is not None: ( means2d, depths, conics, opacities, proj_features, ) = all_to_all_tensor_list( world_size, [means2d, depths, conics, opacities, proj_features], cnts, output_splits=collected_splits, ) else: (means2d, depths, conics, opacities) = all_to_all_tensor_list( world_size, [means2d, depths, conics, opacities], cnts, output_splits=collected_splits, ) # before sending the data, we should turn the camera_ids from global to local. # i.e. the camera_ids produced by the projection stage are over all cameras world-wide, # so we need to turn them into camera_ids that are local to each rank. offsets = torch.tensor( [0] + C_world[:-1], device=camera_ids.device, dtype=camera_ids.dtype ) offsets = torch.cumsum(offsets, dim=0) offsets = offsets.repeat_interleave(torch.stack(cnts)) camera_ids = camera_ids - offsets # and turn gaussian ids from local to global. offsets = torch.tensor( [0] + N_world[:-1], device=gaussian_ids.device, dtype=gaussian_ids.dtype, ) offsets = torch.cumsum(offsets, dim=0) offsets = offsets.repeat_interleave(torch.stack(cnts)) gaussian_ids = gaussian_ids + offsets # all to all communication across all ranks. camera_ids, gaussian_ids = all_to_all_tensor_list( world_size, [camera_ids, gaussian_ids], cnts, output_splits=collected_splits, ) # Silently change C from global #Cameras to local #Cameras. C = C_world[world_rank] else: # Silently change C from global #Cameras to local #Cameras. C = C_world[world_rank] # all to all communication across all ranks. After this step, each rank # would have all the necessary GSs to render its own images. (radii,) = all_to_all_tensor_list( world_size, [radii.flatten(0, 1)], splits=[C_i * N for C_i in C_world], output_splits=[C * N_i for N_i in N_world], ) radii = reshape_view(C, radii, N_world) if proj_features is not None: ( means2d, depths, conics, opacities, proj_features, ) = all_to_all_tensor_list( world_size, [ means2d.flatten(0, 1), depths.flatten(0, 1), conics.flatten(0, 1), opacities.flatten(0, 1), proj_features.flatten(0, 1), ], splits=[C_i * N for C_i in C_world], output_splits=[C * N_i for N_i in N_world], ) proj_features = reshape_view(C, proj_features, N_world) else: (means2d, depths, conics, opacities) = all_to_all_tensor_list( world_size, [ means2d.flatten(0, 1), depths.flatten(0, 1), conics.flatten(0, 1), opacities.flatten(0, 1), ], splits=[C_i * N for C_i in C_world], output_splits=[C * N_i for N_i in N_world], ) means2d = reshape_view(C, means2d, N_world) depths = reshape_view(C, depths, N_world) conics = reshape_view(C, conics, N_world) opacities = reshape_view(C, opacities, N_world) # Rasterize to pixels. # Append depth channel to proj_features if needed. # Layout is [proj_features(D+E) | depth(1)], with depth always last. # In depth-only modes proj_features may not be set yet (no colors, no extra_signals). if render_mode_has_depth_channel(render_mode): if render_mode_has_hit_distance(render_mode): depth_channel = torch.zeros_like( depths[..., None] ) # kernel overwrites with hit distance else: depth_channel = depths[..., None] # projection depth if proj_features is not None: proj_features = torch.cat((proj_features, depth_channel), dim=-1) if backgrounds is not None: backgrounds = torch.cat( [ backgrounds, torch.zeros((*batch_dims, C, 1), device=backgrounds.device), ], dim=-1, ) else: proj_features = depth_channel if backgrounds is not None: backgrounds = torch.zeros( (*batch_dims, C, 1), device=backgrounds.device ) else: assert render_mode_has_only_color(render_mode) assert proj_features is not None if lidar_coeffs is not None: tile_width = lidar_coeffs.tiling.n_bins_azimuth tile_height = lidar_coeffs.tiling.n_bins_elevation tiles_per_gauss, isect_ids, flatten_ids = isect_tiles_lidar( lidar_coeffs, means2d, radii, depths, segmented=segmented, packed=packed, n_images=I, image_ids=image_ids, gaussian_ids=gaussian_ids, ) else: # Identify intersecting tiles tile_width = math.ceil(width / float(tile_size)) tile_height = math.ceil(height / float(tile_size)) tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( means2d, radii, depths, tile_size, tile_width, tile_height, segmented=segmented, packed=packed, n_images=I, image_ids=image_ids, gaussian_ids=gaussian_ids, conics=None if with_ut else conics, opacities=None if with_ut else opacities, ) # print("rank", world_rank, "Before isect_offset_encode") isect_offsets = isect_offset_encode(isect_ids, I, tile_width, tile_height) isect_offsets = isect_offsets.reshape(batch_dims + (C, tile_height, tile_width)) meta.update( { "tile_width": tile_width, "tile_height": tile_height, "tiles_per_gauss": tiles_per_gauss, "isect_ids": isect_ids, "flatten_ids": flatten_ids, "isect_offsets": isect_offsets, "width": width, "height": height, "tile_size": tile_size, "n_batches": B, "n_cameras": C, } ) # print("rank", world_rank, "Before rasterize_to_pixels") if proj_features.shape[-1] > channel_chunk: # slice into chunks n_chunks = (proj_features.shape[-1] + channel_chunk - 1) // channel_chunk render_colors, render_alphas = [], [] render_normals = None # Only compute normals in first chunk for i in range(n_chunks): features_chunk = proj_features[ ..., i * channel_chunk : (i + 1) * channel_chunk ] backgrounds_chunk = ( backgrounds[..., i * channel_chunk : (i + 1) * channel_chunk] if backgrounds is not None else None ) if with_eval3d: # Only compute normals in first chunk (normals don't depend on colors) return_normals_chunk = return_normals if i == 0 else False ( render_colors_, render_alphas_, _, _, render_normals_, ) = rasterize_to_pixels_eval3d_extra( means=means, quats=quats, scales=scales, colors=features_chunk, opacities=opacities, viewmats=viewmats, Ks=Ks, rays=rays, image_width=width, image_height=height, tile_size=tile_size, isect_offsets=isect_offsets, flatten_ids=flatten_ids, backgrounds=backgrounds_chunk, camera_model=camera_model, radial_coeffs=radial_coeffs, tangential_coeffs=tangential_coeffs, thin_prism_coeffs=thin_prism_coeffs, ftheta_coeffs=ftheta_coeffs, lidar_coeffs=lidar_coeffs, external_distortion_coeffs=external_distortion_coeffs, rolling_shutter=rolling_shutter, viewmats_rs=viewmats_rs, use_hit_distance=render_mode_has_hit_distance(render_mode), return_normals=return_normals_chunk, ) if i == 0 and render_normals_ is not None: render_normals = render_normals_ else: if rays is not None: raise ValueError( "Rays input is only supported with with_eval3d=True" ) render_colors_, render_alphas_ = rasterize_to_pixels( means2d, conics, features_chunk, opacities, width, height, tile_size, isect_offsets, flatten_ids, backgrounds=backgrounds_chunk, packed=packed, absgrad=absgrad, ) render_colors.append(render_colors_) render_alphas.append(render_alphas_) render_colors = torch.cat(render_colors, dim=-1) render_alphas = render_alphas[0] # discard the rest else: render_normals = None if with_eval3d: ( render_colors, render_alphas, _, _, render_normals, ) = rasterize_to_pixels_eval3d_extra( means=means, quats=quats, scales=scales, colors=proj_features, opacities=opacities, viewmats=viewmats, Ks=Ks, rays=rays, image_width=width, image_height=height, tile_size=tile_size, isect_offsets=isect_offsets, flatten_ids=flatten_ids, backgrounds=backgrounds, camera_model=camera_model, radial_coeffs=radial_coeffs, tangential_coeffs=tangential_coeffs, thin_prism_coeffs=thin_prism_coeffs, ftheta_coeffs=ftheta_coeffs, lidar_coeffs=lidar_coeffs, external_distortion_coeffs=external_distortion_coeffs, rolling_shutter=rolling_shutter, viewmats_rs=viewmats_rs, use_hit_distance=render_mode_has_hit_distance(render_mode), return_normals=return_normals, ) else: if rays is not None: raise ValueError("Rays input is only supported with with_eval3d=True") render_colors, render_alphas = rasterize_to_pixels( means2d, conics, proj_features, opacities, width, height, tile_size, isect_offsets, flatten_ids, backgrounds=backgrounds, packed=packed, absgrad=absgrad, ) if extra_signals is not None: # Extract the extra signals (per ray) from render_colors E = extra_signals.shape[-1] meta["render_extra_signals"] = render_colors[..., D : D + E] # Leave only colors (and possibly depth) if render_mode_has_depth_channel(render_mode): render_depth = render_colors[..., -1:] # Normalize depth for expected modes (Ed, ED, RGB-Ed, RGB+ED) if render_mode_has_expected_depth(render_mode): render_depth = render_depth / render_alphas.clamp(min=1e-10) render_colors = torch.cat([render_colors[..., 0:D], render_depth], dim=-1) else: render_colors = render_colors[..., 0:D] else: # Normalize depth for expected modes (Ed, ED, RGB-Ed, RGB+ED) if render_mode_has_expected_depth(render_mode): # normalize the accumulated depth to get the expected depth render_colors = torch.cat( [ render_colors[..., :-1], render_colors[..., -1:] / render_alphas.clamp(min=1e-10), ], dim=-1, ) # Add normals to meta if computed if return_normals: meta["normals"] = render_normals return render_colors, render_alphas, meta
def _maybe_evaluate_sh( sh_degree, features, means, radii, viewmats, batch_dims, C, N, clamp ): num_batch_dims = len(batch_dims) # Turn features into [..., C, N, D] or [..., nnz, D] to pass into rasterize_to_pixels() if sh_degree is None: # Colors are post-activation values, with shape [..., N, D] or [..., C, N, D] if features.dim() == num_batch_dims + 2: # Turn [..., N, D] into [..., C, N, D] features = torch.broadcast_to( features[..., None, :, :], batch_dims + (C, N, -1) ) else: # features is already [..., C, N, D] pass else: # Colors are SH coefficients, with shape [..., N, K, 3] or [..., C, N, K, 3] camtoworlds = torch.inverse(viewmats) # [..., C, 4, 4] dirs = means[..., None, :, :] - camtoworlds[..., None, :3, 3] # [..., C, N, 3] masks = (radii > 0).all(dim=-1) # [..., C, N] if features.dim() == num_batch_dims + 3: # Turn [..., N, K, 3] into [..., C, N, K, 3] shs = torch.broadcast_to( features[..., None, :, :, :], batch_dims + (C, N, -1, 3) ) # [..., C, N, K, 3] else: # features is already [..., C, N, K, 3] shs = features features = spherical_harmonics( sh_degree, dirs, shs, masks=masks ) # [..., C, N, 3] if clamp: # make it apple-to-apple with Inria's CUDA Backend. features = torch.clamp_min(features + 0.5, 0.0) else: features = features + 0.5 return features def _rasterization( means: Tensor, # [..., N, 3] quats: Tensor, # [..., N, 4] scales: Tensor, # [..., N, 3] opacities: Tensor, # [..., N] colors: Tensor, # [..., (C,) N, D] or [..., (C,) N, K, 3] viewmats: Tensor, # [..., C, 4, 4] Ks: Tensor, # [..., C, 3, 3] width: int, height: int, near_plane: float = 0.01, far_plane: float = 1e10, eps2d: float = 0.3, sh_degree: Optional[int] = None, tile_size: int = 16, rays: Optional[ Tensor ] = None, # [..., C, H, W, 6] -> ox, oy, oz, dx*spread, dy*spread, dz*spread backgrounds: Optional[Tensor] = None, render_mode: RenderMode = "RGB", rasterize_mode: RasterizeMode = "classic", channel_chunk: int = 32, batch_per_iter: int = 100, with_eval3d: bool = False, with_ut: bool = False, camera_model: CameraModel = "pinhole", lidar_coeffs: Optional[RowOffsetStructuredSpinningLidarModelParametersExt] = None, extra_signals: Optional[ Tensor ] = None, # [..., (C,) N, 3] or [..., (C,) N, K, 3] when extra_signals_sh_degree set extra_signals_sh_degree: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Dict]: """A version of rasterization() that utilies on PyTorch's autograd. .. note:: This function still relies on gsplat's CUDA backend for some computation, but the entire differentiable graph is on of PyTorch (and nerfacc) so could use Pytorch's autograd for backpropagation. .. note:: This function relies on installing latest nerfacc, via: pip install git+https://github.com/nerfstudio-project/nerfacc .. note:: Compared to rasterization(), this function does not support some arguments such as `packed`, `sparse_grad` and `absgrad`. """ from gsplat.cuda._torch_impl import ( _fully_fused_projection, _rasterize_to_pixels, ) from gsplat.cuda._torch_impl_eval3d import _rasterize_to_pixels_eval3d from gsplat.cuda._torch_impl_ut import _fully_fused_projection_with_ut from gsplat.cuda._math import _quat_scale_to_covar_preci if lidar_coeffs is not None: width = lidar_coeffs.n_columns height = lidar_coeffs.n_rows has_color = render_mode_has_color(render_mode) batch_dims = means.shape[:-2] num_batch_dims = len(batch_dims) B = math.prod(batch_dims) N = means.shape[-2] C = viewmats.shape[-3] D = ( colors.shape[-1] if has_color else 0 ) # number of primary color channels; 0 for depth-only I = B * C H = height W = width device = means.device assert means.shape == batch_dims + (N, 3), means.shape assert quats.shape == batch_dims + (N, 4), quats.shape assert scales.shape == batch_dims + (N, 3), scales.shape assert opacities.shape == batch_dims + (N,), opacities.shape assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape assert rays is None or rays.shape == batch_dims + (C, H, W, 6), rays.shape if has_color: if sh_degree is None: # treat colors as post-activation values, should be in shape [..., N, D] or [..., C, N, D] assert ( colors.dim() == num_batch_dims + 2 and colors.shape[:-1] == batch_dims + (N,) ) or ( colors.dim() == num_batch_dims + 3 and colors.shape[:-1] == batch_dims + (C, N) ), colors.shape else: # treat colors as SH coefficients, should be in shape [..., N, K, 3] or [..., C, N, K, 3] # Allowing for activating partial SH bands assert ( colors.dim() == num_batch_dims + 3 and colors.shape[:-2] == batch_dims + (N,) and colors.shape[-1] == 3 ) or ( colors.dim() == num_batch_dims + 4 and colors.shape[:-2] == batch_dims + (C, N) and colors.shape[-1] == 3 ), colors.shape assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape if with_ut: radii, means2d, depths, conics, compensations = _fully_fused_projection_with_ut( means=means, quats=quats, scales=scales, opacities=opacities, viewmats=viewmats, Ks=Ks, width=width, height=height, eps2d=eps2d, near_plane=near_plane, far_plane=far_plane, calc_compensations=(rasterize_mode == "antialiased"), camera_model=camera_model, lidar_coeffs=lidar_coeffs, ) else: if rays is not None: raise ValueError("Rays input is only supported with with_eval3d=True") assert camera_model == "pinhole", camera_model # Project Gaussians to 2D. # The results are with shape [..., C, N, ...]. Only the elements with radii > 0 are valid. covars, _ = _quat_scale_to_covar_preci(quats, scales, True, False, triu=False) radii, means2d, depths, conics, compensations = _fully_fused_projection( means, covars, viewmats, Ks, width, height, eps2d=eps2d, near_plane=near_plane, far_plane=far_plane, calc_compensations=(rasterize_mode == "antialiased"), ) opacities = torch.broadcast_to( opacities[..., None, :], batch_dims + (C, N) ) # [..., C, N] batch_ids, camera_ids, gaussian_ids = None, None, None image_ids = None if compensations is not None: opacities = opacities * compensations # Identify intersecting tiles if lidar_coeffs is not None: tile_width = lidar_coeffs.tiling.n_bins_azimuth tile_height = lidar_coeffs.tiling.n_bins_elevation tiles_per_gauss, isect_ids, flatten_ids = isect_tiles_lidar( lidar_coeffs, means2d, radii, depths, packed=False, n_images=I, image_ids=image_ids, gaussian_ids=gaussian_ids, ) else: tile_width = math.ceil(width / float(tile_size)) tile_height = math.ceil(height / float(tile_size)) tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( means2d, radii, depths, tile_size, tile_width, tile_height, packed=False, n_images=I, image_ids=image_ids, gaussian_ids=gaussian_ids, conics=None if with_ut else conics, opacities=None if with_ut else opacities, ) isect_offsets = isect_offset_encode(isect_ids, I, tile_width, tile_height) isect_offsets = isect_offsets.reshape(batch_dims + (C, tile_height, tile_width)) # Turn colors into [..., C, N, D] or [..., nnz, D] to pass into rasterize_to_pixels() # Make sure they're clamped if evaluating SH. if has_color: colors = _maybe_evaluate_sh( sh_degree, colors, means, radii, viewmats, batch_dims, C, N, True ) # Now do the same to the extra signals. if extra_signals is not None: # Do not clamp it. extra_signals = _maybe_evaluate_sh( extra_signals_sh_degree, extra_signals, means, radii, viewmats, batch_dims, C, N, False, ) if has_color: # Concatenate colors and extra_signals for joint rasterization. assert colors.shape[:-1] == extra_signals.shape[:-1], ( colors.shape, extra_signals.shape, ) colors = torch.cat([colors, extra_signals], dim=-1) # Rasterize to pixels if render_mode_has_depth_channel(render_mode) and render_mode_has_color( render_mode ): colors = torch.cat((colors, depths[..., None]), dim=-1) if backgrounds is not None: backgrounds = torch.cat( [ backgrounds, torch.zeros(batch_dims + (C, 1), device=backgrounds.device), ], dim=-1, ) elif render_mode_has_only_depth_channel(render_mode): # In depth-only mode, extra_signals were not concatenated into colors # above. Place them before depth so depth stays last. if extra_signals is not None and not has_color: colors = torch.cat([extra_signals, depths[..., None]], dim=-1) else: colors = depths[..., None] if backgrounds is not None: backgrounds = torch.zeros( batch_dims + (C, colors.shape[-1]), device=backgrounds.device ) else: # RGB pass # Chunking logic for both eval3d and standard paths if colors.shape[-1] > channel_chunk: # slice into chunks n_chunks = (colors.shape[-1] + channel_chunk - 1) // channel_chunk render_colors, render_alphas = [], [] for i in range(n_chunks): colors_chunk = colors[..., i * channel_chunk : (i + 1) * channel_chunk] backgrounds_chunk = ( backgrounds[..., i * channel_chunk : (i + 1) * channel_chunk] if backgrounds is not None else None ) if with_eval3d: # Using CUDA code due to its speed. This function is already # being thoroughtly tested in test_basic.py render_colors_, render_alphas_ = rasterize_to_pixels_eval3d( means=means, quats=quats, scales=scales, colors_chunk=colors_chunk, opacities=opacities, viewmats=viewmats, camera_model=camera_model, Ks=Ks, image_width=width, image_height=height, rays=rays, lidar_coeffs=lidar_coeffs, tile_size=tile_size, isect_offsets=isect_offsets, flatten_ids=flatten_ids, backgrounds=backgrounds_chunk, ) else: if rays is not None: raise ValueError( "Rays input is only supported with with_eval3d=True" ) assert camera_model == "pinhole", camera_model render_colors_, render_alphas_ = _rasterize_to_pixels( means2d, conics, colors_chunk, opacities, width, height, tile_size, isect_offsets, flatten_ids, backgrounds=backgrounds_chunk, batch_per_iter=batch_per_iter, ) render_colors.append(render_colors_) render_alphas.append(render_alphas_) render_colors = torch.cat(render_colors, dim=-1) render_alphas = render_alphas[0] # discard the rest else: # No chunking needed if with_eval3d: # Using CUDA code due to its speed. This function is already # being thoroughtly tested in test_basic.py render_colors, render_alphas = rasterize_to_pixels_eval3d( means=means, quats=quats, scales=scales, colors=colors, opacities=opacities, viewmats=viewmats, Ks=Ks, image_width=width, image_height=height, rays=rays, camera_model=camera_model, lidar_coeffs=lidar_coeffs, tile_size=tile_size, isect_offsets=isect_offsets, flatten_ids=flatten_ids, backgrounds=backgrounds, ) else: if rays is not None: raise ValueError("Rays input is only supported with with_eval3d=True") assert camera_model == "pinhole", camera_model render_colors, render_alphas = _rasterize_to_pixels( means2d, conics, colors, opacities, width, height, tile_size, isect_offsets, flatten_ids, backgrounds=backgrounds, batch_per_iter=batch_per_iter, ) if extra_signals is not None: # Extract the extra signals (per ray) from render_colors E = extra_signals.shape[-1] render_extra_signals = render_colors[..., D : D + E] # Leave only colors (and possibly depth) if render_mode_has_depth_channel(render_mode): render_depth = render_colors[..., -1:] # Normalize depth for expected modes (Ed, ED, RGB-Ed, RGB+ED) if render_mode_has_expected_depth(render_mode): render_depth = render_depth / render_alphas.clamp(min=1e-10) render_colors = torch.cat([render_colors[..., 0:D], render_depth], dim=-1) else: render_colors = render_colors[..., 0:D] else: render_extra_signals = None # Normalize depth for expected modes (Ed, ED, RGB-Ed, RGB+ED) if render_mode_has_expected_depth(render_mode): # normalize the accumulated depth to get the expected depth render_depth = render_colors[..., -1:] / render_alphas.clamp(min=1e-10) render_colors = torch.cat([render_colors[..., :D], render_depth], dim=-1) meta = { "batch_ids": batch_ids, "camera_ids": camera_ids, "gaussian_ids": gaussian_ids, "radii": radii, "means2d": means2d, "depths": depths, "conics": conics, "opacities": opacities, "tile_width": tile_width, "tile_height": tile_height, "tiles_per_gauss": tiles_per_gauss, "isect_ids": isect_ids, "flatten_ids": flatten_ids, "isect_offsets": isect_offsets, "width": width, "height": height, "tile_size": tile_size, "n_batches": B, "n_cameras": C, } if render_extra_signals is not None: meta["render_extra_signals"] = render_extra_signals return render_colors, render_alphas, meta # def rasterization_legacy_wrapper( # means: Tensor, # [N, 3] # quats: Tensor, # [N, 4] # scales: Tensor, # [N, 3] # opacities: Tensor, # [N] # colors: Tensor, # [N, D] or [N, K, 3] # viewmats: Tensor, # [C, 4, 4] # Ks: Tensor, # [C, 3, 3] # width: int, # height: int, # near_plane: float = 0.01, # eps2d: float = 0.3, # sh_degree: Optional[int] = None, # tile_size: int = 16, # backgrounds: Optional[Tensor] = None, # **kwargs, # ) -> Tuple[Tensor, Tensor, Dict]: # """Wrapper for old version gsplat. # .. warning:: # This function exists for comparison purpose only. So we skip collecting # the intermidiate variables, and only return an empty dict. # """ # from gsplat.cuda_legacy._wrapper import ( # project_gaussians, # rasterize_gaussians, # spherical_harmonics, # ) # assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" # C = len(viewmats) # render_colors, render_alphas = [], [] # for cid in range(C): # fx, fy = Ks[cid, 0, 0], Ks[cid, 1, 1] # cx, cy = Ks[cid, 0, 2], Ks[cid, 1, 2] # viewmat = viewmats[cid] # means2d, depths, radii, conics, _, num_tiles_hit, _ = project_gaussians( # means3d=means, # scales=scales, # glob_scale=1.0, # quats=quats, # viewmat=viewmat, # fx=fx, # fy=fy, # cx=cx, # cy=cy, # img_height=height, # img_width=width, # block_width=tile_size, # clip_thresh=near_plane, # ) # if colors.dim() == 3: # c2w = viewmat.inverse() # viewdirs = means - c2w[:3, 3] # # viewdirs = F.normalize(viewdirs, dim=-1).detach() # if sh_degree is None: # sh_degree = int(math.sqrt(colors.shape[1]) - 1) # colors = spherical_harmonics(sh_degree, viewdirs, colors) # [N, 3] # background = ( # backgrounds[cid] # if backgrounds is not None # else torch.zeros(colors.shape[-1], device=means.device) # ) # render_colors_, render_alphas_ = rasterize_gaussians( # xys=means2d, # depths=depths, # radii=radii, # conics=conics, # num_tiles_hit=num_tiles_hit, # colors=colors, # opacity=opacities[..., None], # img_height=height, # img_width=width, # block_width=tile_size, # background=background, # return_alpha=True, # ) # render_colors.append(render_colors_) # render_alphas.append(render_alphas_[..., None]) # render_colors = torch.stack(render_colors, dim=0) # render_alphas = torch.stack(render_alphas, dim=0) # return render_colors, render_alphas, {}
[docs] def rasterization_inria_wrapper( means: Tensor, # [..., N, 3] quats: Tensor, # [..., N, 4] scales: Tensor, # [..., N, 3] opacities: Tensor, # [..., N] colors: Tensor, # [..., N, D] or [..., N, K, 3] viewmats: Tensor, # [..., C, 4, 4] Ks: Tensor, # [..., C, 3, 3] width: int, height: int, near_plane: float = 0.01, far_plane: float = 100.0, eps2d: float = 0.3, sh_degree: Optional[int] = None, backgrounds: Optional[Tensor] = None, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: """Wrapper for Inria's rasterization backend. .. warning:: This function exists for comparison purpose only. Only rendered image is returned. .. warning:: Inria's CUDA backend has its own LICENSE, so this function should be used with the respect to the original LICENSE at: https://github.com/graphdeco-inria/diff-gaussian-rasterization """ from diff_gaussian_rasterization import ( GaussianRasterizationSettings, GaussianRasterizer, ) assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" batch_dims = means.shape[:-2] num_batch_dims = len(batch_dims) N = means.shape[-2] B = math.prod(batch_dims) C = viewmats.shape[-3] I = B * C device = means.device channels = colors.shape[-1] assert means.shape == batch_dims + (N, 3), means.shape assert quats.shape == batch_dims + (N, 4), quats.shape assert scales.shape == batch_dims + (N, 3), scales.shape assert opacities.shape == batch_dims + (N,), opacities.shape assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape if sh_degree is None: # treat colors as post-activation values, should be in shape [..., N, D] or [..., C, N, D] assert ( colors.dim() == num_batch_dims + 2 and colors.shape[:-1] == batch_dims + (N,) ) or ( colors.dim() == num_batch_dims + 3 and colors.shape[:-1] == batch_dims + (C, N) ), colors.shape else: # treat colors as SH coefficients, should be in shape [..., N, K, 3] or [..., C, N, K, 3] # Allowing for activating partial SH bands assert ( colors.dim() == num_batch_dims + 3 and colors.shape[:-2] == batch_dims + (N,) and colors.shape[-1] == 3 ) or ( colors.dim() == num_batch_dims + 4 and colors.shape[:-2] == batch_dims + (C, N) and colors.shape[-1] == 3 ), colors.shape assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape # flatten all batch dimensions means = means.reshape(B, N, 3) quats = quats.reshape(B, N, 4) scales = scales.reshape(B, N, 3) opacities = opacities.reshape(B, N) viewmats = viewmats.reshape(B, C, 4, 4) Ks = Ks.reshape(B, C, 3, 3) if colors.dim() == num_batch_dims + 2: colors = colors.reshape(B, N, -1) elif colors.dim() == num_batch_dims + 3: colors = colors.reshape(B, C, N, -1) # rasterization from inria does not do normalization internally quats = F.normalize(quats, dim=-1) # [N, 4] render_colors = [] for bid in range(B): for cid in range(C): FoVx = 2 * math.atan(width / (2 * Ks[bid, cid, 0, 0].item())) FoVy = 2 * math.atan(height / (2 * Ks[bid, cid, 1, 1].item())) tanfovx = math.tan(FoVx * 0.5) tanfovy = math.tan(FoVy * 0.5) world_view_transform = viewmats[bid, cid].transpose(0, 1) projection_matrix = get_projection_matrix( znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device ).transpose(0, 1) full_proj_transform = ( world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) ).squeeze(0) camera_center = world_view_transform.inverse()[3, :3] background = ( backgrounds[bid, cid] if backgrounds is not None else torch.zeros(3, device=device) ) raster_settings = GaussianRasterizationSettings( image_height=height, image_width=width, tanfovx=tanfovx, tanfovy=tanfovy, bg=background, scale_modifier=1.0, viewmatrix=world_view_transform, projmatrix=full_proj_transform, sh_degree=0 if sh_degree is None else sh_degree, campos=camera_center, prefiltered=False, debug=False, ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) means2D = torch.zeros_like(means, requires_grad=True, device=device) render_colors_ = [] for i in range(0, channels, 3): _colors = colors[bid, ..., i : i + 3] if _colors.shape[-1] < 3: pad = torch.zeros( _colors.shape[:-1], 3 - _colors.shape[-1], device=device ) _colors = torch.cat([_colors, pad], dim=-1) _render_colors_, radii = rasterizer( means3D=means[bid], means2D=means2D[bid], shs=_colors if colors.dim() == 4 else None, colors_precomp=_colors if colors.dim() == 3 else None, opacities=opacities[..., None], scales=scales[bid], rotations=quats[bid], cov3D_precomp=None, ) if _colors.shape[-1] < 3: _render_colors_ = _render_colors_[..., : _colors.shape[-1]] render_colors_.append(_render_colors_) render_colors_ = torch.cat(render_colors_, dim=-1) render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) render_colors = render_colors.reshape(batch_dims + (height, width, channels)) return render_colors, None, {}
###### 2DGS ######
[docs] def rasterization_2dgs( means: Tensor, # [..., N, 3] quats: Tensor, # [..., N, 4] scales: Tensor, # [..., N, 3] opacities: Tensor, # [..., N] colors: Tensor, # [..., (C,) N, D] or [..., (C,) N, K, 3] viewmats: Tensor, # [..., C, 4, 4] Ks: Tensor, # [..., C, 3, 3] width: int, height: int, near_plane: float = 0.01, far_plane: float = 1e10, radius_clip: float = 0.0, eps2d: float = 0.3, sh_degree: Optional[int] = None, packed: bool = False, tile_size: int = 16, backgrounds: Optional[Tensor] = None, render_mode: RenderMode = "RGB", sparse_grad: bool = False, absgrad: bool = False, distloss: bool = False, depth_mode: Literal["expected", "median"] = "expected", ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Dict]: """Rasterize a set of 2D Gaussians (N) to a batch of image planes (C). This function supports a handful of features, similar to the :func:`rasterization` function. .. warning:: This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. Args: means: The 3D centers of the Gaussians. [..., N, 3] quats: The quaternions of the Gaussians (wxyz convension). It's not required to be normalized. [..., N, 4] scales: The scales of the Gaussians. [..., N, 3] opacities: The opacities of the Gaussians. [..., N] colors: The colors of the Gaussians. [..., (C,) N, D] or [..., (C,) N, K, 3] for SH coefficients. viewmats: The world-to-cam transformation of the cameras. [..., C, 4, 4] Ks: The camera intrinsics. [..., C, 3, 3] width: The width of the image. height: The height of the image. near_plane: The near plane for clipping. Default is 0.01. far_plane: The far plane for clipping. Default is 1e10. radius_clip: Gaussians with 2D radius smaller or equal than this value will be skipped. This is extremely helpful for speeding up large scale scenes. Default is 0.0. eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. This will prevents the projected GS to be too small. For example eps2d=0.3 leads to minimal 3 pixel unit. Default is 0.3. sh_degree: The SH degree to use, which can be smaller than the total number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, else the `colors` should [(C,) N, D] post-activation color values. Default is None. packed: Whether to use packed mode which is more memory efficient but might or might not be as fast. Default is True. tile_size: The size of the tiles for rasterization. Default is 16. (Note: other values are not tested) backgrounds: The background colors. [C, D]. Default is None. render_mode: The rendering mode. Supported modes are "RGB", "d", "Ed", "D", "ED", "RGB-d", "RGB-Ed", "RGB+D", and "RGB+ED". "RGB" renders the colored image. Gaussian depth modes (D, ED, RGB+D, RGB+ED) use projection depth. Hit distance modes (d, Ed, RGB-d, RGB-Ed) compute along-ray distance. Expected modes (Ed, ED) are normalized by opacity. Default is "RGB". sparse_grad (Experimental): If true, the gradients for {means, quats, scales} will be stored in a COO sparse layout. This can be helpful for saving memory. Default is False. absgrad: If true, the absolute gradients of the projected 2D means will be computed during the backward pass, which could be accessed by `meta["means2d"].absgrad`. Default is False. channel_chunk: The number of channels to render in one go. Default is 32. If the required rendering channels are larger than this value, the rendering will be done looply in chunks. distloss: If true, use distortion regularization to get better geometry detail. depth_mode: render depth mode. Choose from expected depth and median depth. Returns: A tuple: **render_colors**: The rendered colors. [..., C, height, width, X]. X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or "RGB+ED", X is D+1. **render_alphas**: The rendered alphas. [..., C, height, width, 1]. **render_normals**: The rendered normals. [..., C, height, width, 3]. **surf_normals**: surface normal from depth. [..., C, height, width, 3] **render_distort**: The rendered distortions. [..., C, height, width, 1]. L1 version, different from L2 version in 2DGS paper. **render_median**: The rendered median depth. [..., C, height, width, 1]. **meta**: A dictionary of intermediate results of the rasterization. Examples: .. code-block:: python >>> # define Gaussians >>> means = torch.randn((100, 3), device=device) >>> quats = torch.randn((100, 4), device=device) >>> scales = torch.rand((100, 3), device=device) * 0.1 >>> colors = torch.rand((100, 3), device=device) >>> opacities = torch.rand((100,), device=device) >>> # define cameras >>> viewmats = torch.eye(4, device=device)[None, :, :] >>> Ks = torch.tensor([ >>> [300., 0., 150.], [0., 300., 100.], [0., 0., 1.]], device=device)[None, :, :] >>> width, height = 300, 200 >>> # render >>> colors, alphas, normals, surf_normals, distort, median_depth, meta = rasterization_2dgs( >>> means, quats, scales, opacities, colors, viewmats, Ks, width, height >>> ) >>> print (colors.shape, alphas.shape) torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 1]) >>> print (normals.shape, surf_normals.shape) torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 3]) >>> print (distort.shape, median_depth.shape) torch.Size([1, 200, 300, 1]) torch.Size([1, 200, 300, 1]) >>> print (meta.keys()) dict_keys(['camera_ids', 'gaussian_ids', 'radii', 'means2d', 'depths', 'ray_transforms', 'opacities', 'normals', 'tile_width', 'tile_height', 'tiles_per_gauss', 'isect_ids', 'flatten_ids', 'isect_offsets', 'width', 'height', 'tile_size', 'n_cameras', 'render_distort', 'gradient_2dgs']) """ batch_dims = means.shape[:-2] num_batch_dims = len(batch_dims) B = math.prod(batch_dims) N = means.shape[-2] C = viewmats.shape[-3] I = B * C device = means.device channels = colors.shape[-1] assert means.shape == batch_dims + (N, 3), means.shape assert quats.shape == batch_dims + (N, 4), quats.shape assert scales.shape == batch_dims + (N, 3), scales.shape assert opacities.shape == batch_dims + (N,), opacities.shape assert viewmats.shape == batch_dims + (C, 4, 4), viewmats.shape assert Ks.shape == batch_dims + (C, 3, 3), Ks.shape if distloss: assert render_mode_has_depth( render_mode ), f"distloss requires depth rendering, but render mode is {render_mode}" if sh_degree is None: # treat colors as post-activation values, should be in shape [..., N, D] or [..., C, N, D] assert ( colors.dim() == num_batch_dims + 2 and colors.shape[:-1] == batch_dims + (N,) ) or ( colors.dim() == num_batch_dims + 3 and colors.shape[:-1] == batch_dims + (C, N) ), colors.shape else: # treat colors as SH coefficients, should be in shape [..., N, K, 3] or [..., C, N, K, 3] # Allowing for activating partial SH bands assert ( colors.dim() == num_batch_dims + 3 and colors.shape[:-2] == batch_dims + (N,) and colors.shape[-1] == 3 ) or ( colors.dim() == num_batch_dims + 4 and colors.shape[:-2] == batch_dims + (C, N) and colors.shape[-1] == 3 ), colors.shape assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape # Compute Ray-Splat intersection transformation. proj_results = fully_fused_projection_2dgs( means, quats, scales, viewmats, Ks, width, height, eps2d, near_plane, far_plane, radius_clip, packed, sparse_grad, ) if packed: ( batch_ids, camera_ids, gaussian_ids, radii, means2d, depths, ray_transforms, normals, ) = proj_results opacities = opacities.view(B, N)[batch_ids, gaussian_ids] image_ids = batch_ids * C + camera_ids else: radii, means2d, depths, ray_transforms, normals = proj_results opacities = torch.broadcast_to( opacities[..., None, :], batch_dims + (C, N) ) # [..., C, N] camera_ids, gaussian_ids = None, None image_ids = None densify = torch.zeros_like( means2d, dtype=means.dtype, requires_grad=True, device="cuda" ) # Identify intersecting tiles tile_width = math.ceil(width / float(tile_size)) tile_height = math.ceil(height / float(tile_size)) tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( means2d, radii, depths, tile_size, tile_width, tile_height, packed=packed, n_images=I, image_ids=image_ids, gaussian_ids=gaussian_ids, ) isect_offsets = isect_offset_encode(isect_ids, I, tile_width, tile_height) isect_offsets = isect_offsets.reshape(batch_dims + (C, tile_height, tile_width)) # TODO: SH also suport N-D. # Compute the per-view colors # if not ( # colors.dim() == num_batch_dims + 3 and sh_degree is None # ): # silently support [..., C, N, D] color. # colors = ( # colors.view(B, N, -1)[batch_ids, gaussian_ids] # if packed # else colors[..., None, :, :].expand((-1,) * num_batch_dims + (C, -1, -1)) # ) # [nnz, D] or [..., C, N, 3] # else: # if packed: # colors = colors.view(B, C, N, -1)[batch_ids, camera_ids, gaussian_ids, :] if sh_degree is not None: # SH coefficients camtoworlds = torch.inverse(viewmats) if packed: dirs = means[..., gaussian_ids, :] - camtoworlds[..., camera_ids, :3, 3] else: dirs = means[..., None, :, :] - camtoworlds[..., None, :3, 3] if colors.dim() == num_batch_dims + 3: # Turn [..., N, K, 3] into [..., C, N, K, 3] shs = torch.broadcast_to( colors[..., None, :, :, :], batch_dims + (C, N, -1, 3) ) # [..., C, N, K, 3] else: # colors is already [..., C, N, K, 3] shs = colors colors = spherical_harmonics( sh_degree, dirs, shs, masks=(radii > 0).all(dim=-1) ) # [nnz, D] or [..., C, N, 3] # make it apple-to-apple with Inria's CUDA Backend. colors = torch.clamp_min(colors + 0.5, 0.0) # Rasterize to pixels if render_mode_has_depth_channel(render_mode) and render_mode_has_color( render_mode ): colors = torch.cat((colors, depths[..., None]), dim=-1) if backgrounds is not None: backgrounds = torch.cat( (backgrounds, torch.zeros_like(backgrounds[..., :1])), dim=-1 ) elif render_mode_has_only_depth_channel(render_mode): colors = depths[..., None] else: # RGB pass ( render_colors, render_alphas, render_normals, render_distort, render_median, ) = rasterize_to_pixels_2dgs( means2d, ray_transforms, colors, opacities, normals, densify, width, height, tile_size, isect_offsets, flatten_ids, backgrounds=backgrounds, packed=packed, absgrad=absgrad, distloss=distloss, ) render_normals_from_depth = None if render_mode_has_expected_depth(render_mode): # normalize the accumulated depth to get the expected depth render_colors = torch.cat( [ render_colors[..., :-1], render_colors[..., -1:] / render_alphas.clamp(min=1e-10), ], dim=-1, ) if render_mode_has_depth(render_mode) and render_mode_has_color(render_mode): # render_depths = render_colors[..., -1:] if depth_mode == "expected": depth_for_normal = render_colors[..., -1:] elif depth_mode == "median": depth_for_normal = render_median render_normals_from_depth = depth_to_normal( depth_for_normal, torch.linalg.inv(viewmats), Ks ).squeeze(0) meta = { "camera_ids": camera_ids, "gaussian_ids": gaussian_ids, "radii": radii, "means2d": means2d, "depths": depths, "ray_transforms": ray_transforms, "opacities": opacities, "normals": normals, "tile_width": tile_width, "tile_height": tile_height, "tiles_per_gauss": tiles_per_gauss, "isect_ids": isect_ids, "flatten_ids": flatten_ids, "isect_offsets": isect_offsets, "width": width, "height": height, "tile_size": tile_size, "n_cameras": C, "render_distort": render_distort, "gradient_2dgs": densify, # This holds the gradient used for densification for 2dgs } render_normals = torch.einsum( "...ij,...hwj->...hwi", torch.linalg.inv(viewmats)[..., :3, :3], render_normals ) return ( render_colors, render_alphas, render_normals, render_normals_from_depth, render_distort, render_median, meta, )
[docs] def rasterization_2dgs_inria_wrapper( means: Tensor, # [N, 3] quats: Tensor, # [N, 4] scales: Tensor, # [N, 3] opacities: Tensor, # [N] colors: Tensor, # [N, D] or [N, K, 3] viewmats: Tensor, # [C, 4, 4] Ks: Tensor, # [C, 3, 3] width: int, height: int, near_plane: float = 0.01, far_plane: float = 100.0, eps2d: float = 0.3, sh_degree: Optional[int] = None, backgrounds: Optional[Tensor] = None, depth_ratio: int = 0, **kwargs, ) -> Tuple[Tuple, Dict]: """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. Install the 2DGS rasterization backend from https://github.com/hbb1/diff-surfel-rasterization Credit to Jeffrey Hu https://github.com/jefequien """ from diff_surfel_rasterization import ( GaussianRasterizationSettings, GaussianRasterizer, ) assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" C = len(viewmats) device = means.device channels = colors.shape[-1] # rasterization from inria does not do normalization internally quats = F.normalize(quats, dim=-1) # [N, 4] scales = scales[:, :2] # [N, 2] render_colors = [] for cid in range(C): FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) tanfovx = math.tan(FoVx * 0.5) tanfovy = math.tan(FoVy * 0.5) world_view_transform = viewmats[cid].transpose(0, 1) projection_matrix = get_projection_matrix( znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device ).transpose(0, 1) full_proj_transform = ( world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) ).squeeze(0) camera_center = world_view_transform.inverse()[3, :3] background = ( backgrounds[cid] if backgrounds is not None else torch.zeros(3, device=device) ) raster_settings = GaussianRasterizationSettings( image_height=height, image_width=width, tanfovx=tanfovx, tanfovy=tanfovy, bg=background, scale_modifier=1.0, viewmatrix=world_view_transform, projmatrix=full_proj_transform, sh_degree=0 if sh_degree is None else sh_degree, campos=camera_center, prefiltered=False, debug=False, ) rasterizer = GaussianRasterizer(raster_settings=raster_settings) means2D = torch.zeros_like(means, requires_grad=True, device=device) render_colors_ = [] for i in range(0, channels, 3): _colors = colors[..., i : i + 3] if _colors.shape[-1] < 3: pad = torch.zeros( _colors.shape[0], 3 - _colors.shape[-1], device=device ) _colors = torch.cat([_colors, pad], dim=-1) _render_colors_, radii, allmap = rasterizer( means3D=means, means2D=means2D, shs=_colors if colors.dim() == 3 else None, colors_precomp=_colors if colors.dim() == 2 else None, opacities=opacities[:, None], scales=scales, rotations=quats, cov3D_precomp=None, ) if _colors.shape[-1] < 3: _render_colors_ = _render_colors_[:, :, : _colors.shape[-1]] render_colors_.append(_render_colors_) render_colors_ = torch.cat(render_colors_, dim=-1) render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) # additional maps allmap = allmap.permute(1, 2, 0).unsqueeze(0) # [1, H, W, C] render_depth_expected = allmap[..., 0:1] render_alphas = allmap[..., 1:2] render_normal = allmap[..., 2:5] render_depth_median = allmap[..., 5:6] render_dist = allmap[..., 6:7] render_normal = render_normal @ (world_view_transform[:3, :3].T) render_depth_expected = render_depth_expected / render_alphas render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) # render_depth is either median or expected by setting depth_ratio to 1 or 0 # for bounded scene, use median depth, i.e., depth_ratio = 1; # for unbounded scene, use expected depth, i.e., depth_ratio = 0, to reduce disk aliasing. render_depth = ( render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) normals_surf = depth_to_normal(render_depth, torch.linalg.inv(viewmats), Ks) normals_surf = normals_surf * (render_alphas).detach() render_colors = torch.cat([render_colors, render_depth], dim=-1) meta = { "normals_rend": render_normal, "normals_surf": normals_surf, "render_distloss": render_dist, "means2d": means2D, "width": width, "height": height, "radii": radii.unsqueeze(0), "n_cameras": C, "gaussian_ids": None, } return (render_colors, render_alphas), meta