Source code for gsplat.cuda._wrapper

from typing import Callable, Optional, Tuple, Any
import warnings
from typing_extensions import Literal

import torch
from torch import Tensor


def _make_lazy_cuda_func(name: str) -> Callable:
    def call_cuda(*args, **kwargs):
        # pylint: disable=import-outside-toplevel
        from ._backend import _C

        return getattr(_C, name)(*args, **kwargs)

    return call_cuda


def selective_adam_update(
    param: Tensor,
    param_grad: Tensor,
    exp_avg: Tensor,
    exp_avg_sq: Tensor,
    tiles_touched: Tensor,
    lr: float,
    b1: float,
    b2: float,
    eps: float,
    N: int,
    M: int,
) -> None:
    _make_lazy_cuda_func("selective_adam_update")(
        param, param_grad, exp_avg, exp_avg_sq, tiles_touched, lr, b1, b2, eps, N, M
    )


def _make_lazy_cuda_obj(name: str) -> Any:
    # pylint: disable=import-outside-toplevel
    from ._backend import _C

    obj = _C
    for name_split in name.split("."):
        obj = getattr(_C, name_split)
    return obj


[docs] def spherical_harmonics( degrees_to_use: int, dirs: Tensor, # [..., 3] coeffs: Tensor, # [..., K, 3] masks: Optional[Tensor] = None, ) -> Tensor: """Computes spherical harmonics. Args: degrees_to_use: The degree to be used. dirs: Directions. [..., 3] coeffs: Coefficients. [..., K, 3] masks: Optional boolen masks to skip some computation. [...,] Default: None. Returns: Spherical harmonics. [..., 3] """ assert (degrees_to_use + 1) ** 2 <= coeffs.shape[-2], coeffs.shape assert dirs.shape[:-1] == coeffs.shape[:-2], (dirs.shape, coeffs.shape) assert dirs.shape[-1] == 3, dirs.shape assert coeffs.shape[-1] == 3, coeffs.shape if masks is not None: assert masks.shape == dirs.shape[:-1], masks.shape masks = masks.contiguous() return _SphericalHarmonics.apply( degrees_to_use, dirs.contiguous(), coeffs.contiguous(), masks )
[docs] def quat_scale_to_covar_preci( quats: Tensor, # [N, 4], scales: Tensor, # [N, 3], compute_covar: bool = True, compute_preci: bool = True, triu: bool = False, ) -> Tuple[Optional[Tensor], Optional[Tensor]]: """Converts quaternions and scales to covariance and precision matrices. Args: quats: Quaternions (No need to be normalized). [N, 4] scales: Scales. [N, 3] compute_covar: Whether to compute covariance matrices. Default: True. If False, the returned covariance matrices will be None. compute_preci: Whether to compute precision matrices. Default: True. If False, the returned precision matrices will be None. triu: If True, the return matrices will be upper triangular. Default: False. Returns: A tuple: - **Covariance matrices**. If `triu` is True the returned shape is [N, 6], otherwise [N, 3, 3]. - **Precision matrices**. If `triu` is True the returned shape is [N, 6], otherwise [N, 3, 3]. """ assert quats.dim() == 2 and quats.size(1) == 4, quats.size() assert scales.dim() == 2 and scales.size(1) == 3, scales.size() quats = quats.contiguous() scales = scales.contiguous() covars, precis = _QuatScaleToCovarPreci.apply( quats, scales, compute_covar, compute_preci, triu ) return covars if compute_covar else None, precis if compute_preci else None
def persp_proj( means: Tensor, # [C, N, 3] covars: Tensor, # [C, N, 3, 3] Ks: Tensor, # [C, 3, 3] width: int, height: int, ) -> Tuple[Tensor, Tensor]: """Perspective projection on Gaussians. DEPRECATED: please use `proj` with `ortho=False` instead. Args: means: Gaussian means. [C, N, 3] covars: Gaussian covariances. [C, N, 3, 3] Ks: Camera intrinsics. [C, 3, 3] width: Image width. height: Image height. Returns: A tuple: - **Projected means**. [C, N, 2] - **Projected covariances**. [C, N, 2, 2] """ warnings.warn( "persp_proj is deprecated and will be removed in a future release. " "Use proj with ortho=False instead.", DeprecationWarning, ) return proj(means, covars, Ks, width, height, ortho=False)
[docs] def proj( means: Tensor, # [C, N, 3] covars: Tensor, # [C, N, 3, 3] Ks: Tensor, # [C, 3, 3] width: int, height: int, camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", ) -> Tuple[Tensor, Tensor]: """Projection of Gaussians (perspective or orthographic). Args: means: Gaussian means. [C, N, 3] covars: Gaussian covariances. [C, N, 3, 3] Ks: Camera intrinsics. [C, 3, 3] width: Image width. height: Image height. Returns: A tuple: - **Projected means**. [C, N, 2] - **Projected covariances**. [C, N, 2, 2] """ C, N, _ = means.shape assert means.shape == (C, N, 3), means.size() assert covars.shape == (C, N, 3, 3), covars.size() assert Ks.shape == (C, 3, 3), Ks.size() means = means.contiguous() covars = covars.contiguous() Ks = Ks.contiguous() return _Proj.apply(means, covars, Ks, width, height, camera_model)
[docs] def world_to_cam( means: Tensor, # [N, 3] covars: Tensor, # [N, 3, 3] viewmats: Tensor, # [C, 4, 4] ) -> Tuple[Tensor, Tensor]: """Transforms Gaussians from world to camera coordinate system. Args: means: Gaussian means. [N, 3] covars: Gaussian covariances. [N, 3, 3] viewmats: World-to-camera transformation matrices. [C, 4, 4] Returns: A tuple: - **Gaussian means in camera coordinate system**. [C, N, 3] - **Gaussian covariances in camera coordinate system**. [C, N, 3, 3] """ C = viewmats.size(0) N = means.size(0) assert means.size() == (N, 3), means.size() assert covars.size() == (N, 3, 3), covars.size() assert viewmats.size() == (C, 4, 4), viewmats.size() means = means.contiguous() covars = covars.contiguous() viewmats = viewmats.contiguous() return _WorldToCam.apply(means, covars, viewmats)
[docs] def fully_fused_projection( means: Tensor, # [N, 3] covars: Optional[Tensor], # [N, 6] or None quats: Optional[Tensor], # [N, 4] or None scales: Optional[Tensor], # [N, 3] or None viewmats: Tensor, # [C, 4, 4] Ks: Tensor, # [C, 3, 3] width: int, height: int, eps2d: float = 0.3, near_plane: float = 0.01, far_plane: float = 1e10, radius_clip: float = 0.0, packed: bool = False, sparse_grad: bool = False, calc_compensations: bool = False, camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Projects Gaussians to 2D. This function fuse the process of computing covariances (:func:`quat_scale_to_covar_preci()`), transforming to camera space (:func:`world_to_cam()`), and projection (:func:`proj()`). .. note:: During projection, we ignore the Gaussians that are outside of the camera frustum. So not all the elements in the output tensors are valid. The output `radii` could serve as an indicator, in which zero radii means the corresponding elements are invalid in the output tensors and will be ignored in the next rasterization process. If `packed=True`, the output tensors will be packed into a flattened tensor, in which all elements are valid. In this case, a `camera_ids` tensor and `gaussian_ids` tensor will be returned to indicate the row (camera) and column (Gaussian) indices of the packed flattened tensor, which is essentially following the COO sparse tensor format. .. note:: This functions supports projecting Gaussians with either covariances or {quaternions, scales}, which will be converted to covariances internally in a fused CUDA kernel. Either `covars` or {`quats`, `scales`} should be provided. Args: means: Gaussian means. [N, 3] covars: Gaussian covariances (flattened upper triangle). [N, 6] Optional. quats: Quaternions (No need to be normalized). [N, 4] Optional. scales: Scales. [N, 3] Optional. viewmats: Camera-to-world matrices. [C, 4, 4] Ks: Camera intrinsics. [C, 3, 3] width: Image width. height: Image height. eps2d: A epsilon added to the 2D covariance for numerical stability. Default: 0.3. near_plane: Near plane distance. Default: 0.01. far_plane: Far plane distance. Default: 1e10. radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0. packed: If True, the output tensors will be packed into a flattened tensor. Default: False. sparse_grad: This is only effective when `packed` is True. If True, during backward the gradients of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False. calc_compensations: If True, a view-dependent opacity compensation factor will be computed, which is useful for anti-aliasing. Default: False. Returns: A tuple: If `packed` is True: - **camera_ids**. The row indices of the projected Gaussians. Int32 tensor of shape [nnz]. - **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz]. - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz]. - **means**. Projected Gaussian means in 2D. [nnz, 2] - **depths**. The z-depth of the projected Gaussians. [nnz] - **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [nnz, 3] - **compensations**. The view-dependent opacity compensation factor. [nnz] If `packed` is False: - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [C, N]. - **means**. Projected Gaussian means in 2D. [C, N, 2] - **depths**. The z-depth of the projected Gaussians. [C, N] - **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [C, N, 3] - **compensations**. The view-dependent opacity compensation factor. [C, N] """ C = viewmats.size(0) N = means.size(0) assert means.size() == (N, 3), means.size() assert viewmats.size() == (C, 4, 4), viewmats.size() assert Ks.size() == (C, 3, 3), Ks.size() means = means.contiguous() if covars is not None: assert covars.size() == (N, 6), covars.size() covars = covars.contiguous() else: assert quats is not None, "covars or quats is required" assert scales is not None, "covars or scales is required" assert quats.size() == (N, 4), quats.size() assert scales.size() == (N, 3), scales.size() quats = quats.contiguous() scales = scales.contiguous() if sparse_grad: assert packed, "sparse_grad is only supported when packed is True" viewmats = viewmats.contiguous() Ks = Ks.contiguous() if packed: return _FullyFusedProjectionPacked.apply( means, covars, quats, scales, viewmats, Ks, width, height, eps2d, near_plane, far_plane, radius_clip, sparse_grad, calc_compensations, camera_model, ) else: return _FullyFusedProjection.apply( means, covars, quats, scales, viewmats, Ks, width, height, eps2d, near_plane, far_plane, radius_clip, calc_compensations, camera_model, )
[docs] @torch.no_grad() def isect_tiles( means2d: Tensor, # [C, N, 2] or [nnz, 2] radii: Tensor, # [C, N] or [nnz] depths: Tensor, # [C, N] or [nnz] tile_size: int, tile_width: int, tile_height: int, sort: bool = True, packed: bool = False, n_cameras: Optional[int] = None, camera_ids: Optional[Tensor] = None, gaussian_ids: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor]: """Maps projected Gaussians to intersecting tiles. Args: means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. radii: Maximum radii of the projected Gaussians. [C, N] if packed is False, [nnz] if packed is True. depths: Z-depth of the projected Gaussians. [C, N] if packed is False, [nnz] if packed is True. tile_size: Tile size. tile_width: Tile width. tile_height: Tile height. sort: If True, the returned intersections will be sorted by the intersection ids. Default: True. packed: If True, the input tensors are packed. Default: False. n_cameras: Number of cameras. Required if packed is True. camera_ids: The row indices of the projected Gaussians. Required if packed is True. gaussian_ids: The column indices of the projected Gaussians. Required if packed is True. Returns: A tuple: - **Tiles per Gaussian**. The number of tiles intersected by each Gaussian. Int32 [C, N] if packed is False, Int32 [nnz] if packed is True. - **Intersection ids**. Each id is an 64-bit integer with the following information: camera_id (Xc bits) | tile_id (Xt bits) | depth (32 bits). Xc and Xt are the maximum number of bits required to represent the camera and tile ids, respectively. Int64 [n_isects] - **Flatten ids**. The global flatten indices in [C * N] or [nnz] (packed). [n_isects] """ if packed: nnz = means2d.size(0) assert means2d.shape == (nnz, 2), means2d.size() assert radii.shape == (nnz,), radii.size() assert depths.shape == (nnz,), depths.size() assert camera_ids is not None, "camera_ids is required if packed is True" assert gaussian_ids is not None, "gaussian_ids is required if packed is True" assert n_cameras is not None, "n_cameras is required if packed is True" camera_ids = camera_ids.contiguous() gaussian_ids = gaussian_ids.contiguous() C = n_cameras else: C, N, _ = means2d.shape assert means2d.shape == (C, N, 2), means2d.size() assert radii.shape == (C, N), radii.size() assert depths.shape == (C, N), depths.size() tiles_per_gauss, isect_ids, flatten_ids = _make_lazy_cuda_func("isect_tiles")( means2d.contiguous(), radii.contiguous(), depths.contiguous(), camera_ids, gaussian_ids, C, tile_size, tile_width, tile_height, sort, True, # DoubleBuffer: memory efficient radixsort ) return tiles_per_gauss, isect_ids, flatten_ids
[docs] @torch.no_grad() def isect_offset_encode( isect_ids: Tensor, n_cameras: int, tile_width: int, tile_height: int ) -> Tensor: """Encodes intersection ids to offsets. Args: isect_ids: Intersection ids. [n_isects] n_cameras: Number of cameras. tile_width: Tile width. tile_height: Tile height. Returns: Offsets. [C, tile_height, tile_width] """ return _make_lazy_cuda_func("isect_offset_encode")( isect_ids.contiguous(), n_cameras, tile_width, tile_height )
[docs] def rasterize_to_pixels( means2d: Tensor, # [C, N, 2] or [nnz, 2] conics: Tensor, # [C, N, 3] or [nnz, 3] colors: Tensor, # [C, N, channels] or [nnz, channels] opacities: Tensor, # [C, N] or [nnz] image_width: int, image_height: int, tile_size: int, isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] backgrounds: Optional[Tensor] = None, # [C, channels] masks: Optional[Tensor] = None, # [C, tile_height, tile_width] packed: bool = False, absgrad: bool = False, ) -> Tuple[Tensor, Tensor]: """Rasterizes Gaussians to pixels. Args: means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. conics: Inverse of the projected covariances with only upper triangle values. [C, N, 3] if packed is False, [nnz, 3] if packed is True. colors: Gaussian colors or ND features. [C, N, channels] if packed is False, [nnz, channels] if packed is True. opacities: Gaussian opacities that support per-view values. [C, N] if packed is False, [nnz] if packed is True. image_width: Image width. image_height: Image height. tile_size: Tile size. isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects] backgrounds: Background colors. [C, channels]. Default: None. masks: Optional tile mask to skip rendering GS to masked tiles. [C, tile_height, tile_width]. Default: None. packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. Returns: A tuple: - **Rendered colors**. [C, image_height, image_width, channels] - **Rendered alphas**. [C, image_height, image_width, 1] """ C = isect_offsets.size(0) device = means2d.device if packed: nnz = means2d.size(0) assert means2d.shape == (nnz, 2), means2d.shape assert conics.shape == (nnz, 3), conics.shape assert colors.shape[0] == nnz, colors.shape assert opacities.shape == (nnz,), opacities.shape else: N = means2d.size(1) assert means2d.shape == (C, N, 2), means2d.shape assert conics.shape == (C, N, 3), conics.shape assert colors.shape[:2] == (C, N), colors.shape assert opacities.shape == (C, N), opacities.shape if backgrounds is not None: assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape backgrounds = backgrounds.contiguous() if masks is not None: assert masks.shape == isect_offsets.shape, masks.shape masks = masks.contiguous() # Pad the channels to the nearest supported number if necessary channels = colors.shape[-1] if channels > 513 or channels == 0: # TODO: maybe worth to support zero channels? raise ValueError(f"Unsupported number of color channels: {channels}") if channels not in ( 1, 2, 3, 4, 5, 8, 9, 16, 17, 32, 33, 64, 65, 128, 129, 256, 257, 512, 513, ): padded_channels = (1 << (channels - 1).bit_length()) - channels colors = torch.cat( [ colors, torch.zeros(*colors.shape[:-1], padded_channels, device=device), ], dim=-1, ) if backgrounds is not None: backgrounds = torch.cat( [ backgrounds, torch.zeros( *backgrounds.shape[:-1], padded_channels, device=device ), ], dim=-1, ) else: padded_channels = 0 tile_height, tile_width = isect_offsets.shape[1:3] assert ( tile_height * tile_size >= image_height ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" assert ( tile_width * tile_size >= image_width ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" render_colors, render_alphas = _RasterizeToPixels.apply( means2d.contiguous(), conics.contiguous(), colors.contiguous(), opacities.contiguous(), backgrounds, masks, image_width, image_height, tile_size, isect_offsets.contiguous(), flatten_ids.contiguous(), absgrad, ) if padded_channels > 0: render_colors = render_colors[..., :-padded_channels] return render_colors, render_alphas
[docs] @torch.no_grad() def rasterize_to_indices_in_range( range_start: int, range_end: int, transmittances: Tensor, # [C, image_height, image_width] means2d: Tensor, # [C, N, 2] conics: Tensor, # [C, N, 3] opacities: Tensor, # [C, N] image_width: int, image_height: int, tile_size: int, isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] ) -> Tuple[Tensor, Tensor, Tensor]: """Rasterizes a batch of Gaussians to images but only returns the indices. .. note:: This function supports iterative rasterization, in which each call of this function will rasterize a batch of Gaussians from near to far, defined by `[range_start, range_end)`. If a one-step full rasterization is desired, set `range_start` to 0 and `range_end` to a really large number, e.g, 1e10. Args: range_start: The start batch of Gaussians to be rasterized (inclusive). range_end: The end batch of Gaussians to be rasterized (exclusive). transmittances: Currently transmittances. [C, image_height, image_width] means2d: Projected Gaussian means. [C, N, 2] conics: Inverse of the projected covariances with only upper triangle values. [C, N, 3] opacities: Gaussian opacities that support per-view values. [C, N] image_width: Image width. image_height: Image height. tile_size: Tile size. isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] flatten_ids: The global flatten indices in [C * N] from `isect_tiles()`. [n_isects] Returns: A tuple: - **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M]. - **Pixel ids**. pixel indices (row-major). A flattened list of shape [M]. - **Camera ids**. Camera indices. A flattened list of shape [M]. """ C, N, _ = means2d.shape assert conics.shape == (C, N, 3), conics.shape assert opacities.shape == (C, N), opacities.shape assert isect_offsets.shape[0] == C, isect_offsets.shape tile_height, tile_width = isect_offsets.shape[1:3] assert ( tile_height * tile_size >= image_height ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" assert ( tile_width * tile_size >= image_width ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" out_gauss_ids, out_indices = _make_lazy_cuda_func("rasterize_to_indices_in_range")( range_start, range_end, transmittances.contiguous(), means2d.contiguous(), conics.contiguous(), opacities.contiguous(), image_width, image_height, tile_size, isect_offsets.contiguous(), flatten_ids.contiguous(), ) out_pixel_ids = out_indices % (image_width * image_height) out_camera_ids = out_indices // (image_width * image_height) return out_gauss_ids, out_pixel_ids, out_camera_ids
class _QuatScaleToCovarPreci(torch.autograd.Function): """Converts quaternions and scales to covariance and precision matrices.""" @staticmethod def forward( ctx, quats: Tensor, # [N, 4], scales: Tensor, # [N, 3], compute_covar: bool = True, compute_preci: bool = True, triu: bool = False, ) -> Tuple[Tensor, Tensor]: covars, precis = _make_lazy_cuda_func("quat_scale_to_covar_preci_fwd")( quats, scales, compute_covar, compute_preci, triu ) ctx.save_for_backward(quats, scales) ctx.compute_covar = compute_covar ctx.compute_preci = compute_preci ctx.triu = triu return covars, precis @staticmethod def backward(ctx, v_covars: Tensor, v_precis: Tensor): quats, scales = ctx.saved_tensors compute_covar = ctx.compute_covar compute_preci = ctx.compute_preci triu = ctx.triu if compute_covar and v_covars.is_sparse: v_covars = v_covars.to_dense() if compute_preci and v_precis.is_sparse: v_precis = v_precis.to_dense() v_quats, v_scales = _make_lazy_cuda_func("quat_scale_to_covar_preci_bwd")( quats, scales, v_covars.contiguous() if compute_covar else None, v_precis.contiguous() if compute_preci else None, triu, ) return v_quats, v_scales, None, None, None class _Proj(torch.autograd.Function): """Perspective fully_fused_projection on Gaussians.""" @staticmethod def forward( ctx, means: Tensor, # [C, N, 3] covars: Tensor, # [C, N, 3, 3] Ks: Tensor, # [C, 3, 3] width: int, height: int, camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", ) -> Tuple[Tensor, Tensor]: camera_model_type = _make_lazy_cuda_obj( f"CameraModelType.{camera_model.upper()}" ) means2d, covars2d = _make_lazy_cuda_func("proj_fwd")( means, covars, Ks, width, height, camera_model_type, ) ctx.save_for_backward(means, covars, Ks) ctx.width = width ctx.height = height ctx.camera_model_type = camera_model_type return means2d, covars2d @staticmethod def backward(ctx, v_means2d: Tensor, v_covars2d: Tensor): means, covars, Ks = ctx.saved_tensors width = ctx.width height = ctx.height camera_model_type = ctx.camera_model_type v_means, v_covars = _make_lazy_cuda_func("proj_bwd")( means, covars, Ks, width, height, camera_model_type, v_means2d.contiguous(), v_covars2d.contiguous(), ) return v_means, v_covars, None, None, None, None class _WorldToCam(torch.autograd.Function): """Transforms Gaussians from world to camera space.""" @staticmethod def forward( ctx, means: Tensor, # [N, 3] covars: Tensor, # [N, 3, 3] viewmats: Tensor, # [C, 4, 4] ) -> Tuple[Tensor, Tensor]: means_c, covars_c = _make_lazy_cuda_func("world_to_cam_fwd")( means, covars, viewmats ) ctx.save_for_backward(means, covars, viewmats) return means_c, covars_c @staticmethod def backward(ctx, v_means_c: Tensor, v_covars_c: Tensor): means, covars, viewmats = ctx.saved_tensors v_means, v_covars, v_viewmats = _make_lazy_cuda_func("world_to_cam_bwd")( means, covars, viewmats, v_means_c.contiguous(), v_covars_c.contiguous(), ctx.needs_input_grad[0], ctx.needs_input_grad[1], ctx.needs_input_grad[2], ) if not ctx.needs_input_grad[0]: v_means = None if not ctx.needs_input_grad[1]: v_covars = None if not ctx.needs_input_grad[2]: v_viewmats = None return v_means, v_covars, v_viewmats class _FullyFusedProjection(torch.autograd.Function): """Projects Gaussians to 2D.""" @staticmethod def forward( ctx, means: Tensor, # [N, 3] covars: Tensor, # [N, 6] or None quats: Tensor, # [N, 4] or None scales: Tensor, # [N, 3] or None viewmats: Tensor, # [C, 4, 4] Ks: Tensor, # [C, 3, 3] width: int, height: int, eps2d: float, near_plane: float, far_plane: float, radius_clip: float, calc_compensations: bool, camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: camera_model_type = _make_lazy_cuda_obj( f"CameraModelType.{camera_model.upper()}" ) # "covars" and {"quats", "scales"} are mutually exclusive radii, means2d, depths, conics, compensations = _make_lazy_cuda_func( "fully_fused_projection_fwd" )( means, covars, quats, scales, viewmats, Ks, width, height, eps2d, near_plane, far_plane, radius_clip, calc_compensations, camera_model_type, ) if not calc_compensations: compensations = None ctx.save_for_backward( means, covars, quats, scales, viewmats, Ks, radii, conics, compensations ) ctx.width = width ctx.height = height ctx.eps2d = eps2d ctx.camera_model_type = camera_model_type return radii, means2d, depths, conics, compensations @staticmethod def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): ( means, covars, quats, scales, viewmats, Ks, radii, conics, compensations, ) = ctx.saved_tensors width = ctx.width height = ctx.height eps2d = ctx.eps2d camera_model_type = ctx.camera_model_type if v_compensations is not None: v_compensations = v_compensations.contiguous() v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( "fully_fused_projection_bwd" )( means, covars, quats, scales, viewmats, Ks, width, height, eps2d, camera_model_type, radii, conics, compensations, v_means2d.contiguous(), v_depths.contiguous(), v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad ) if not ctx.needs_input_grad[0]: v_means = None if not ctx.needs_input_grad[1]: v_covars = None if not ctx.needs_input_grad[2]: v_quats = None if not ctx.needs_input_grad[3]: v_scales = None if not ctx.needs_input_grad[4]: v_viewmats = None return ( v_means, v_covars, v_quats, v_scales, v_viewmats, None, None, None, None, None, None, None, None, None, None, ) class _RasterizeToPixels(torch.autograd.Function): """Rasterize gaussians""" @staticmethod def forward( ctx, means2d: Tensor, # [C, N, 2] conics: Tensor, # [C, N, 3] colors: Tensor, # [C, N, D] opacities: Tensor, # [C, N] backgrounds: Tensor, # [C, D], Optional masks: Tensor, # [C, tile_height, tile_width], Optional width: int, height: int, tile_size: int, isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] absgrad: bool, ) -> Tuple[Tensor, Tensor]: render_colors, render_alphas, last_ids = _make_lazy_cuda_func( "rasterize_to_pixels_fwd" )( means2d, conics, colors, opacities, backgrounds, masks, width, height, tile_size, isect_offsets, flatten_ids, ) ctx.save_for_backward( means2d, conics, colors, opacities, backgrounds, masks, isect_offsets, flatten_ids, render_alphas, last_ids, ) ctx.width = width ctx.height = height ctx.tile_size = tile_size ctx.absgrad = absgrad # double to float render_alphas = render_alphas.float() return render_colors, render_alphas @staticmethod def backward( ctx, v_render_colors: Tensor, # [C, H, W, 3] v_render_alphas: Tensor, # [C, H, W, 1] ): ( means2d, conics, colors, opacities, backgrounds, masks, isect_offsets, flatten_ids, render_alphas, last_ids, ) = ctx.saved_tensors width = ctx.width height = ctx.height tile_size = ctx.tile_size absgrad = ctx.absgrad ( v_means2d_abs, v_means2d, v_conics, v_colors, v_opacities, ) = _make_lazy_cuda_func("rasterize_to_pixels_bwd")( means2d, conics, colors, opacities, backgrounds, masks, width, height, tile_size, isect_offsets, flatten_ids, render_alphas, last_ids, v_render_colors.contiguous(), v_render_alphas.contiguous(), absgrad, ) if absgrad: means2d.absgrad = v_means2d_abs if ctx.needs_input_grad[4]: v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum( dim=(1, 2) ) else: v_backgrounds = None return ( v_means2d, v_conics, v_colors, v_opacities, v_backgrounds, None, None, None, None, None, None, None, ) class _FullyFusedProjectionPacked(torch.autograd.Function): """Projects Gaussians to 2D. Return packed tensors.""" @staticmethod def forward( ctx, means: Tensor, # [N, 3] covars: Tensor, # [N, 6] or None quats: Tensor, # [N, 4] or None scales: Tensor, # [N, 3] or None viewmats: Tensor, # [C, 4, 4] Ks: Tensor, # [C, 3, 3] width: int, height: int, eps2d: float, near_plane: float, far_plane: float, radius_clip: float, sparse_grad: bool, calc_compensations: bool, camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: camera_model_type = _make_lazy_cuda_obj( f"CameraModelType.{camera_model.upper()}" ) ( indptr, camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations, ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd")( means, covars, # optional quats, # optional scales, # optional viewmats, Ks, width, height, eps2d, near_plane, far_plane, radius_clip, calc_compensations, camera_model_type, ) if not calc_compensations: compensations = None ctx.save_for_backward( camera_ids, gaussian_ids, means, covars, quats, scales, viewmats, Ks, conics, compensations, ) ctx.width = width ctx.height = height ctx.eps2d = eps2d ctx.sparse_grad = sparse_grad ctx.camera_model_type = camera_model_type return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations @staticmethod def backward( ctx, v_camera_ids, v_gaussian_ids, v_radii, v_means2d, v_depths, v_conics, v_compensations, ): ( camera_ids, gaussian_ids, means, covars, quats, scales, viewmats, Ks, conics, compensations, ) = ctx.saved_tensors width = ctx.width height = ctx.height eps2d = ctx.eps2d sparse_grad = ctx.sparse_grad camera_model_type = ctx.camera_model_type if v_compensations is not None: v_compensations = v_compensations.contiguous() v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( "fully_fused_projection_packed_bwd" )( means, covars, quats, scales, viewmats, Ks, width, height, eps2d, camera_model_type, camera_ids, gaussian_ids, conics, compensations, v_means2d.contiguous(), v_depths.contiguous(), v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad sparse_grad, ) if not ctx.needs_input_grad[0]: v_means = None else: if sparse_grad: # TODO: gaussian_ids is duplicated so not ideal. # An idea is to directly set the attribute (e.g., .sparse_grad) of # the tensor but this requires the tensor to be leaf node only. And # a customized optimizer would be needed in this case. v_means = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=v_means, # [nnz, 3] size=means.size(), # [N, 3] is_coalesced=len(viewmats) == 1, ) if not ctx.needs_input_grad[1]: v_covars = None else: if sparse_grad: v_covars = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=v_covars, # [nnz, 6] size=covars.size(), # [N, 6] is_coalesced=len(viewmats) == 1, ) if not ctx.needs_input_grad[2]: v_quats = None else: if sparse_grad: v_quats = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=v_quats, # [nnz, 4] size=quats.size(), # [N, 4] is_coalesced=len(viewmats) == 1, ) if not ctx.needs_input_grad[3]: v_scales = None else: if sparse_grad: v_scales = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=v_scales, # [nnz, 3] size=scales.size(), # [N, 3] is_coalesced=len(viewmats) == 1, ) if not ctx.needs_input_grad[4]: v_viewmats = None return ( v_means, v_covars, v_quats, v_scales, v_viewmats, None, None, None, None, None, None, None, None, None, None, ) class _SphericalHarmonics(torch.autograd.Function): """Spherical Harmonics""" @staticmethod def forward( ctx, sh_degree: int, dirs: Tensor, coeffs: Tensor, masks: Tensor ) -> Tensor: colors = _make_lazy_cuda_func("compute_sh_fwd")(sh_degree, dirs, coeffs, masks) ctx.save_for_backward(dirs, coeffs, masks) ctx.sh_degree = sh_degree ctx.num_bases = coeffs.shape[-2] return colors @staticmethod def backward(ctx, v_colors: Tensor): dirs, coeffs, masks = ctx.saved_tensors sh_degree = ctx.sh_degree num_bases = ctx.num_bases compute_v_dirs = ctx.needs_input_grad[1] v_coeffs, v_dirs = _make_lazy_cuda_func("compute_sh_bwd")( num_bases, sh_degree, dirs, coeffs, masks, v_colors.contiguous(), compute_v_dirs, ) if not compute_v_dirs: v_dirs = None return None, v_dirs, v_coeffs, None ###### 2DGS ######
[docs] def fully_fused_projection_2dgs( means: Tensor, # [N, 3] quats: Tensor, # [N, 4] scales: Tensor, viewmats: Tensor, Ks: Tensor, width: int, height: int, eps2d: float = 0.3, near_plane: float = 0.01, far_plane: float = 1e10, radius_clip: float = 0.0, packed: bool = False, sparse_grad: bool = False, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Prepare Gaussians for rasterization This function prepares ray-splat intersection matrices, computes per splat bounding box and 2D means in image space. Args: means: Gaussian means. [N, 3] quats: Quaternions (No need to be normalized). [N, 4]. scales: Scales. [N, 3]. viewmats: Camera-to-world matrices. [C, 4, 4] Ks: Camera intrinsics. [C, 3, 3] width: Image width. height: Image height. near_plane: Near plane distance. Default: 0.01. far_plane: Far plane distance. Default: 200. radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0. packed: If True, the output tensors will be packed into a flattened tensor. Default: False. sparse_grad (Experimental): This is only effective when `packed` is True. If True, during backward the gradients of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False. Returns: A tuple: If `packed` is True: - **camera_ids**. The row indices of the projected Gaussians. Int32 tensor of shape [nnz]. - **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz]. - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz]. - **means**. Projected Gaussian means in 2D. [nnz, 2] - **depths**. The z-depth of the projected Gaussians. [nnz] - **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates (WH)^T in equation (9) in paper [nnz, 3, 3] - **normals**. The normals in camera spaces. [nnz, 3] If `packed` is False: - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [C, N]. - **means**. Projected Gaussian means in 2D. [C, N, 2] - **depths**. The z-depth of the projected Gaussians. [C, N] - **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. - **normals**. The normals in camera spaces. [C, N, 3] """ C = viewmats.size(0) N = means.size(0) assert means.size() == (N, 3), means.size() assert viewmats.size() == (C, 4, 4), viewmats.size() assert Ks.size() == (C, 3, 3), Ks.size() means = means.contiguous() assert quats is not None, "quats is required" assert scales is not None, "scales is required" assert quats.size() == (N, 4), quats.size() assert scales.size() == (N, 3), scales.size() quats = quats.contiguous() scales = scales.contiguous() if sparse_grad: assert packed, "sparse_grad is only supported when packed is True" viewmats = viewmats.contiguous() Ks = Ks.contiguous() if packed: return _FullyFusedProjectionPacked2DGS.apply( means, quats, scales, viewmats, Ks, width, height, near_plane, far_plane, radius_clip, sparse_grad, ) else: return _FullyFusedProjection2DGS.apply( means, quats, scales, viewmats, Ks, width, height, eps2d, near_plane, far_plane, radius_clip, )
class _FullyFusedProjection2DGS(torch.autograd.Function): """Projects Gaussians to 2D.""" @staticmethod def forward( ctx, means: Tensor, quats: Tensor, scales: Tensor, viewmats: Tensor, Ks: Tensor, width: int, height: int, eps2d: float, near_plane: float, far_plane: float, radius_clip: float, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: radii, means2d, depths, ray_transforms, normals = _make_lazy_cuda_func( "fully_fused_projection_fwd_2dgs" )( means, quats, scales, viewmats, Ks, width, height, eps2d, near_plane, far_plane, radius_clip, ) ctx.save_for_backward( means, quats, scales, viewmats, Ks, radii, ray_transforms, normals, ) ctx.width = width ctx.height = height ctx.eps2d = eps2d return radii, means2d, depths, ray_transforms, normals @staticmethod def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals): ( means, quats, scales, viewmats, Ks, radii, ray_transforms, normals, ) = ctx.saved_tensors width = ctx.width height = ctx.height eps2d = ctx.eps2d v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( "fully_fused_projection_bwd_2dgs" )( means, quats, scales, viewmats, Ks, width, height, radii, ray_transforms, v_means2d.contiguous(), v_depths.contiguous(), v_normals.contiguous(), v_ray_transforms.contiguous(), ctx.needs_input_grad[3], # viewmats_requires_grad ) if not ctx.needs_input_grad[0]: v_means = None if not ctx.needs_input_grad[1]: v_quats = None if not ctx.needs_input_grad[2]: v_scales = None if not ctx.needs_input_grad[3]: v_viewmats = None return ( v_means, v_quats, v_scales, v_viewmats, None, None, None, None, None, None, None, None, ) class _FullyFusedProjectionPacked2DGS(torch.autograd.Function): """Projects Gaussians to 2D. Return packed tensors.""" @staticmethod def forward( ctx, means: Tensor, # [N, 3] quats: Tensor, # [N, 4] scales: Tensor, # [N, 3] viewmats: Tensor, # [C, 4, 4] Ks: Tensor, # [C, 3, 3] width: int, height: int, near_plane: float, far_plane: float, radius_clip: float, sparse_grad: bool, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ( indptr, camera_ids, gaussian_ids, radii, means2d, depths, ray_transforms, normals, ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd_2dgs")( means, quats, scales, viewmats, Ks, width, height, near_plane, far_plane, radius_clip, ) ctx.save_for_backward( camera_ids, gaussian_ids, means, quats, scales, viewmats, Ks, ray_transforms, ) ctx.width = width ctx.height = height ctx.sparse_grad = sparse_grad return camera_ids, gaussian_ids, radii, means2d, depths, ray_transforms, normals @staticmethod def backward( ctx, v_camera_ids, v_gaussian_ids, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals, ): ( camera_ids, gaussian_ids, means, quats, scales, viewmats, Ks, ray_transforms, ) = ctx.saved_tensors width = ctx.width height = ctx.height sparse_grad = ctx.sparse_grad v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( "fully_fused_projection_packed_bwd_2dgs" )( means, quats, scales, viewmats, Ks, width, height, camera_ids, gaussian_ids, ray_transforms, v_means2d.contiguous(), v_depths.contiguous(), v_ray_transforms.contiguous(), v_normals.contiguous(), ctx.needs_input_grad[4], # viewmats_requires_grad sparse_grad, ) if not ctx.needs_input_grad[0]: v_means = None else: if sparse_grad: # TODO: gaussian_ids is duplicated so not ideal. # An idea is to directly set the attribute (e.g., .sparse_grad) of # the tensor but this requires the tensor to be leaf node only. And # a customized optimizer would be needed in this case. v_means = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=v_means, # [nnz, 3] size=means.size(), # [N, 3] is_coalesced=len(viewmats) == 1, ) if not ctx.needs_input_grad[1]: v_quats = None else: if sparse_grad: v_quats = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=v_quats, # [nnz, 4] size=quats.size(), # [N, 4] is_coalesced=len(viewmats) == 1, ) if not ctx.needs_input_grad[2]: v_scales = None else: if sparse_grad: v_scales = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=v_scales, # [nnz, 3] size=scales.size(), # [N, 3] is_coalesced=len(viewmats) == 1, ) if not ctx.needs_input_grad[4]: v_viewmats = None return ( v_means, v_quats, v_scales, v_viewmats, None, None, None, None, None, None, None, None, None, )
[docs] def rasterize_to_pixels_2dgs( means2d: Tensor, ray_transforms: Tensor, colors: Tensor, opacities: Tensor, normals: Tensor, densify: Tensor, image_width: int, image_height: int, tile_size: int, isect_offsets: Tensor, flatten_ids: Tensor, backgrounds: Optional[Tensor] = None, masks: Optional[Tensor] = None, packed: bool = False, absgrad: bool = False, distloss: bool = False, ) -> Tuple[Tensor, Tensor]: """Rasterize Gaussians to pixels. Args: means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3] if packed is False, [nnz, channels] if packed is True. colors: Gaussian colors or ND features. [C, N, channels] if packed is False, [nnz, channels] if packed is True. opacities: Gaussian opacities that support per-view values. [C, N] if packed is False, [nnz] if packed is True. normals: The normals in camera space. [C, N, 3] if packed is False, [nnz, 3] if packed is True. densify: Dummy variable to keep track of gradient for densification. [C, N, 2] if packed, [nnz, 3] if packed is True. tile_size: Tile size. isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects] backgrounds: Background colors. [C, channels]. Default: None. masks: Optional tile mask to skip rendering GS to masked tiles. [C, tile_height, tile_width]. Default: None. packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. Returns: A tuple: - **Rendered colors**. [C, image_height, image_width, channels] - **Rendered alphas**. [C, image_height, image_width, 1] - **Rendered normals**. [C, image_height, image_width, 3] - **Rendered distortion**. [C, image_height, image_width, 1] - **Rendered median depth**.[C, image_height, image_width, 1] """ C = isect_offsets.size(0) device = means2d.device if packed: nnz = means2d.size(0) assert means2d.shape == (nnz, 2), means2d.shape assert ray_transforms.shape == (nnz, 3, 3), ray_transforms.shape assert colors.shape[0] == nnz, colors.shape assert opacities.shape == (nnz,), opacities.shape else: N = means2d.size(1) assert means2d.shape == (C, N, 2), means2d.shape assert ray_transforms.shape == (C, N, 3, 3), ray_transforms.shape assert colors.shape[:2] == (C, N), colors.shape assert opacities.shape == (C, N), opacities.shape if backgrounds is not None: assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape backgrounds = backgrounds.contiguous() # Pad the channels to the nearest supported number if necessary channels = colors.shape[-1] if channels > 512 or channels == 0: # TODO: maybe worth to support zero channels? raise ValueError(f"Unsupported number of color channels: {channels}") if channels not in (1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512): padded_channels = (1 << (channels - 1).bit_length()) - channels colors = torch.cat( [colors, torch.empty(*colors.shape[:-1], padded_channels, device=device)], dim=-1, ) if backgrounds is not None: backgrounds = torch.cat( [ backgrounds, torch.empty( *backgrounds.shape[:-1], padded_channels, device=device ), ], dim=-1, ) else: padded_channels = 0 tile_height, tile_width = isect_offsets.shape[1:3] assert ( tile_height * tile_size >= image_height ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" assert ( tile_width * tile_size >= image_width ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" ( render_colors, render_alphas, render_normals, render_distort, render_median, ) = _RasterizeToPixels2DGS.apply( means2d.contiguous(), ray_transforms.contiguous(), colors.contiguous(), opacities.contiguous(), normals.contiguous(), densify.contiguous(), backgrounds, masks, image_width, image_height, tile_size, isect_offsets.contiguous(), flatten_ids.contiguous(), absgrad, distloss, ) if padded_channels > 0: render_colors = render_colors[..., :-padded_channels] return render_colors, render_alphas, render_normals, render_distort, render_median
[docs] @torch.no_grad() def rasterize_to_indices_in_range_2dgs( range_start: int, range_end: int, transmittances: Tensor, means2d: Tensor, ray_transforms: Tensor, opacities: Tensor, image_width: int, image_height: int, tile_size: int, isect_offsets: Tensor, flatten_ids: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """Rasterizes a batch of Gaussians to images but only returns the indices. .. note:: This function supports iterative rasterization, in which each call of this function will rasterize a batch of Gaussians from near to far, defined by `[range_start, range_end)`. If a one-step full rasterization is desired, set `range_start` to 0 and `range_end` to a really large number, e.g, 1e10. Args: range_start: The start batch of Gaussians to be rasterized (inclusive). range_end: The end batch of Gaussians to be rasterized (exclusive). transmittances: Currently transmittances. [C, image_height, image_width] means2d: Projected Gaussian means. [C, N, 2] ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3] opacities: Gaussian opacities that support per-view values. [C, N] image_width: Image width. image_height: Image height. tile_size: Tile size. isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] flatten_ids: The global flatten indices in [C * N] from `isect_tiles()`. [n_isects] Returns: A tuple: - **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M]. - **Pixel ids**. pixel indices (row-major). A flattened list of shape [M]. - **Camera ids**. Camera indices. A flattened list of shape [M]. """ C, N, _ = means2d.shape assert ray_transforms.shape == (C, N, 3, 3), ray_transforms.shape assert opacities.shape == (C, N), opacities.shape assert isect_offsets.shape[0] == C, isect_offsets.shape tile_height, tile_width = isect_offsets.shape[1:3] assert ( tile_height * tile_size >= image_height ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" assert ( tile_width * tile_size >= image_width ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" out_gauss_ids, out_indices = _make_lazy_cuda_func( "rasterize_to_indices_in_range_2dgs" )( range_start, range_end, transmittances.contiguous(), means2d.contiguous(), ray_transforms.contiguous(), opacities.contiguous(), image_width, image_height, tile_size, isect_offsets.contiguous(), flatten_ids.contiguous(), ) out_pixel_ids = out_indices % (image_width * image_height) out_camera_ids = out_indices // (image_width * image_height) return out_gauss_ids, out_pixel_ids, out_camera_ids
class _RasterizeToPixels2DGS(torch.autograd.Function): """Rasterize gaussians 2DGS""" @staticmethod def forward( ctx, means2d: Tensor, ray_transforms: Tensor, colors: Tensor, opacities: Tensor, normals: Tensor, densify: Tensor, backgrounds: Tensor, masks: Tensor, width: int, height: int, tile_size: int, isect_offsets: Tensor, flatten_ids: Tensor, absgrad: bool, distloss: bool, ) -> Tuple[Tensor, Tensor]: ( render_colors, render_alphas, render_normals, render_distort, render_median, last_ids, median_ids, ) = _make_lazy_cuda_func("rasterize_to_pixels_fwd_2dgs")( means2d, ray_transforms, colors, opacities, normals, backgrounds, masks, width, height, tile_size, isect_offsets, flatten_ids, ) ctx.save_for_backward( means2d, ray_transforms, colors, opacities, normals, densify, backgrounds, masks, isect_offsets, flatten_ids, render_colors, render_alphas, last_ids, median_ids, ) ctx.width = width ctx.height = height ctx.tile_size = tile_size ctx.absgrad = absgrad ctx.distloss = distloss # doubel to float render_alphas = render_alphas.float() return ( render_colors, render_alphas, render_normals, render_distort, render_median, ) @staticmethod def backward( ctx, v_render_colors: Tensor, v_render_alphas: Tensor, v_render_normals: Tensor, v_render_distort: Tensor, v_render_median: Tensor, ): ( means2d, ray_transforms, colors, opacities, normals, densify, backgrounds, masks, isect_offsets, flatten_ids, render_colors, render_alphas, last_ids, median_ids, ) = ctx.saved_tensors width = ctx.width height = ctx.height tile_size = ctx.tile_size absgrad = ctx.absgrad ( v_means2d_abs, v_means2d, v_ray_transforms, v_colors, v_opacities, v_normals, v_densify, ) = _make_lazy_cuda_func("rasterize_to_pixels_bwd_2dgs")( means2d, ray_transforms, colors, opacities, normals, densify, backgrounds, masks, width, height, tile_size, isect_offsets, flatten_ids, render_colors, render_alphas, last_ids, median_ids, v_render_colors.contiguous(), v_render_alphas.contiguous(), v_render_normals.contiguous(), v_render_distort.contiguous(), v_render_median.contiguous(), absgrad, ) torch.cuda.synchronize() if absgrad: means2d.absgrad = v_means2d_abs if ctx.needs_input_grad[6]: v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum( dim=(1, 2) ) else: v_backgrounds = None return ( v_means2d, v_ray_transforms, v_colors, v_opacities, v_normals, v_densify, v_backgrounds, None, None, None, None, None, None, None, None, )