Source code for gsplat.cuda._torch_impl_2dgs

from typing import Optional, Tuple

import torch
from torch import Tensor

from gsplat.cuda._torch_impl import _quat_scale_to_matrix


def _fully_fused_projection_2dgs(
    means: Tensor,  # [N, 3]
    quats: Tensor,  # [N, 4]
    scales: Tensor,  # [N, 3]
    viewmats: Tensor,  # [C, 4, 4]
    Ks: Tensor,  # [C, 3, 3]
    width: int,
    height: int,
    near_plane: float = 0.01,
    far_plane: float = 1e10,
    eps: float = 1e-6,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection_2dgs()`

    .. note::

        This is a minimal implementation of fully fused version, which has more
        arguments. Not all arguments are supported.
    """
    R_cw = viewmats[:, :3, :3]  # [C, 3, 3]
    t_cw = viewmats[:, :3, 3]  # [C, 3]
    means_c = torch.einsum("cij,nj->cni", R_cw, means) + t_cw[:, None, :]  # (C, N, 3)
    RS_wl = _quat_scale_to_matrix(quats, scales)
    RS_cl = torch.einsum("cij,njk->cnik", R_cw, RS_wl)  # [C, N, 3, 3]

    # compute normals
    normals = RS_cl[..., 2]  # [C, N, 3]
    C, N, _ = normals.shape
    cos = -normals.reshape((C * N, 1, 3)) @ means_c.reshape((C * N, 3, 1))
    cos = cos.reshape((C, N, 1))
    multiplier = torch.where(cos > 0, torch.tensor(1.0), torch.tensor(-1.0))
    normals *= multiplier

    # ray transform matrix, omitting the z rotation
    T_cl = torch.cat([RS_cl[..., :2], means_c[..., None]], dim=-1)  # [C, N, 3, 3]
    T_sl = torch.einsum("cij,cnjk->cnik", Ks[:, :3, :3], T_cl)  # [C, N, 3, 3]
    # in paper notation M = (WH)^T
    # later h_u = M @ h_x, h_v = M @ h_y
    M = torch.transpose(T_sl, -1, -2)  # [C, N, 3, 3]

    # compute the AABB of gaussian
    test = torch.tensor([1.0, 1.0, -1.0], device=means.device).reshape(1, 1, 3)
    d = (M[..., 2] * M[..., 2] * test).sum(dim=-1, keepdim=True)  # [C, N, 1]
    valid = torch.abs(d) > eps
    f = torch.where(valid, test / d, torch.zeros_like(test)).unsqueeze(
        -1
    )  # (C, N, 3, 1)
    means2d = (M[..., :2] * M[..., 2:3] * f).sum(dim=-2)  # [C, N, 2]
    extents = torch.sqrt(
        means2d**2 - (M[..., :2] * M[..., :2] * f).sum(dim=-2)
    )  # [C, N, 2]

    depths = means_c[..., 2]  # [C, N]
    radius = torch.ceil(3.0 * torch.max(extents, dim=-1).values)  # (C, N)

    valid = valid.squeeze(-1) & (depths > near_plane) & (depths < far_plane)
    radius[~valid] = 0.0

    inside = (
        (means2d[..., 0] + radius > 0)
        & (means2d[..., 0] - radius < width)
        & (means2d[..., 1] + radius > 0)
        & (means2d[..., 1] - radius < height)
    )
    radius[~inside] = 0.0
    radii = radius.int()
    return radii, means2d, depths, M, normals


[docs] def accumulate_2dgs( means2d: Tensor, # [C, N, 2] ray_transforms: Tensor, # [C, N, 3, 3] opacities: Tensor, # [C, N] colors: Tensor, # [C, N, channels] normals: Tensor, # [C, N, 3] gaussian_ids: Tensor, # [M] pixel_ids: Tensor, # [M] camera_ids: Tensor, # [M] image_width: int, image_height: int, ) -> Tuple[Tensor, Tensor, Tensor]: """Alpha compositing for 2DGS. .. warning:: This function requires the nerfacc package to be installed. Please install it using the following command pip install nerfacc. Args: means2d: Gaussian means in 2D. [C, N, 2] ray_transforms: transformation matrices that transform rays in pixel space into splat's local frame. [C, N, 3, 3] opacities: Per-view Gaussian opacities (for example, when antialiasing is enabled, Gaussian in each view would efficiently have different opacity). [C, N] colors: Per-view Gaussian colors. Supports N-D features. [C, N, channels] normals: Per-view Gaussian normals. [C, N, 3] gaussian_ids: Collection of Gaussian indices to be rasterized. A flattened list of shape [M]. pixel_ids: Collection of pixel indices (row-major) to be rasterized. A flattened list of shape [M]. camera_ids: Collection of camera indices to be rasterized. A flattened list of shape [M]. image_width: Image width. image_height: Image height. Returns: A tuple: **renders**: Accumulated colors. [C, image_height, image_width, channels] **alphas**: Accumulated opacities. [C, image_height, image_width, 1] **normals**: Accumulated opacities. [C, image_height, image_width, 3] """ try: from nerfacc import accumulate_along_rays, render_weight_from_alpha except ImportError: raise ImportError("Please install nerfacc package: pip install nerfacc") C, N = means2d.shape[:2] channels = colors.shape[-1] pixel_ids_x = pixel_ids % image_width + 0.5 pixel_ids_y = pixel_ids // image_width + 0.5 pixel_coords = torch.stack([pixel_ids_x, pixel_ids_y], dim=-1) # [M, 2] deltas = pixel_coords - means2d[camera_ids, gaussian_ids] # [M, 2] M = ray_transforms[camera_ids, gaussian_ids] # [M, 3, 3] h_u = -M[..., 0, :3] + M[..., 2, :3] * pixel_ids_x[..., None] # [M, 3] h_v = -M[..., 1, :3] + M[..., 2, :3] * pixel_ids_y[..., None] # [M, 3] tmp = torch.cross(h_u, h_v, dim=-1) us = tmp[..., 0] / tmp[..., 2] vs = tmp[..., 1] / tmp[..., 2] sigmas_3d = us**2 + vs**2 # [M] sigmas_2d = 2 * (deltas[..., 0] ** 2 + deltas[..., 1] ** 2) sigmas = 0.5 * torch.minimum(sigmas_3d, sigmas_2d) # [M] alphas = torch.clamp_max( opacities[camera_ids, gaussian_ids] * torch.exp(-sigmas), 0.999 ) indices = camera_ids * image_height * image_width + pixel_ids total_pixels = C * image_height * image_width weights, trans = render_weight_from_alpha( alphas, ray_indices=indices, n_rays=total_pixels ) renders = accumulate_along_rays( weights, colors[camera_ids, gaussian_ids], ray_indices=indices, n_rays=total_pixels, ).reshape(C, image_height, image_width, channels) alphas = accumulate_along_rays( weights, None, ray_indices=indices, n_rays=total_pixels ).reshape(C, image_height, image_width, 1) renders_normal = accumulate_along_rays( weights, normals[camera_ids, gaussian_ids], ray_indices=indices, n_rays=total_pixels, ).reshape(C, image_height, image_width, 3) return renders, alphas, renders_normal
def _rasterize_to_pixels_2dgs( means2d: Tensor, # [C, N, 2] ray_transforms: Tensor, # [C, N, 3, 3] colors: Tensor, # [C, N, channels] normals: Tensor, # [C, N, 3] opacities: Tensor, # [C, N] image_width: int, image_height: int, tile_size: int, isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] backgrounds: Optional[Tensor] = None, # [C, channels] batch_per_iter: int = 100, ): """Pytorch implementation of `gsplat.cuda._wrapper.rasterize_to_pixels_2dgs()`. This function rasterizes 2D Gaussians to pixels in a Pytorch-friendly way. It iteratively accumulates the renderings within each batch of Gaussians. The interations are controlled by `batch_per_iter`. .. note:: This is a minimal implementation of the fully fused version, which has more arguments. Not all arguments are supported. .. note:: This function relies on Pytorch's autograd for the backpropagation. It is much slower than our fully fused rasterization implementation and comsumes much more GPU memory. But it could serve as a playground for new ideas or debugging, as no backward implementation is needed. .. warning:: This function requires the `nerfacc` package to be installed. Please install it using the following command `pip install nerfacc`. """ from ._wrapper import rasterize_to_indices_in_range_2dgs C, N = means2d.shape[:2] n_isects = len(flatten_ids) device = means2d.device render_colors = torch.zeros( (C, image_height, image_width, colors.shape[-1]), device=device ) render_alphas = torch.zeros((C, image_height, image_width, 1), device=device) render_normals = torch.zeros((C, image_height, image_width, 3), device=device) # Split Gaussians into batches and iteratively accumulate the renderings block_size = tile_size * tile_size isect_offsets_fl = torch.cat( [isect_offsets.flatten(), torch.tensor([n_isects], device=device)] ) max_range = (isect_offsets_fl[1:] - isect_offsets_fl[:-1]).max().item() num_batches = (max_range + block_size - 1) // block_size for step in range(0, num_batches, batch_per_iter): transmittances = 1.0 - render_alphas[..., 0] # Find the M intersections between pixels and gaussians. # Each intersection corresponds to a tuple (gs_id, pixel_id, camera_id) gs_ids, pixel_ids, camera_ids = rasterize_to_indices_in_range_2dgs( step, step + batch_per_iter, transmittances, means2d, ray_transforms, opacities, image_width, image_height, tile_size, isect_offsets, flatten_ids, ) # [M], [M] if len(gs_ids) == 0: break # Accumulate the renderings within this batch of Gaussians. renders_step, accs_step, renders_normal_step = accumulate_2dgs( means2d, ray_transforms, opacities, colors, normals, gs_ids, pixel_ids, camera_ids, image_width, image_height, ) render_colors = render_colors + renders_step * transmittances[..., None] render_alphas = render_alphas + accs_step * transmittances[..., None] render_normals = ( render_normals + renders_normal_step * transmittances[..., None] ) render_alphas = render_alphas if backgrounds is not None: render_colors = render_colors + backgrounds[:, None, None, :] * ( 1.0 - render_alphas ) return render_colors, render_alphas, render_normals