Densification

In gsplat, we abstract out the densification and pruning process of the Gaussian training into a strategy. A strategy is a class that defines how the Gaussian parameters (along with their optimizers) should be updated (splitting, pruning, etc.) during the training.

An example of the training workflow using DefaultStrategy is like:

from gsplat import DefaultStrategy, rasterization

# Define Gaussian parameters and optimizers
params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
optimizers: Dict[str, torch.optim.Optimizer] = ...

# Initialize the strategy
strategy = DefaultStrategy()

# Check the sanity of the parameters and optimizers
strategy.check_sanity(params, optimizers)

# Initialize the strategy state
strategy_state = strategy.initialize_state()

# Training loop
for step in range(1000):
    # Forward pass
    render_image, render_alpha, info = rasterization(...)

    # Pre-backward step
    strategy.step_pre_backward(params, optimizers, strategy_state, step, info)

    # Compute the loss
    loss = ...

    # Backward pass
    loss.backward()

    # Post-backward step
    strategy.step_post_backward(params, optimizers, strategy_state, step, info)

A strategy will inplacely update the Gaussian parameters as well as the optimizers, so it has a specific expectation on the format of the parameters and the optimizers. It is designed to work with the Guassians defined as either a Dict of torch.nn.Parameter or a torch.nn.ParameterDict with at least the following keys: {“means”, “scales”, “quats”, “opacities”}. On top of these attributes, an arbitrary number of extra attributes are supported. Besides the parameters, it also expects a Dict of torch.optim.Optimizer with the same keys as the parameters, and each optimizer should correspond to only one learnable parameter.

For example, the following is a valid format for the parameters and the optimizers that can be used with our strategies:

N = 100
params = torch.nn.ParameterDict{
    "means": Tensor(N, 3), "scales": Tensor(N), "quats": Tensor(N, 4), "opacities": Tensor(N),
    "colors": Tensor(N, 25, 3), "features1": Tensor(N, 128), "features2": Tensor(N, 64),
}
optimizers = {k: torch.optim.Adam([p], lr=1e-3) for k, p in params.keys()}

Below are the strategies that are currently implemented in gsplat:

class DefaultStrategy(prune_opa: float = 0.005, grow_grad2d: float = 0.0002, grow_scale3d: float = 0.01, grow_scale2d: float = 0.05, prune_scale3d: float = 0.1, prune_scale2d: float = 0.15, refine_scale2d_stop_iter: int = 0, refine_start_iter: int = 500, refine_stop_iter: int = 15000, reset_every: int = 3000, refine_every: int = 100, pause_refine_after_reset: int = 0, absgrad: bool = False, revised_opacity: bool = False, verbose: bool = False, key_for_gradient: typing_extensions.Literal[means2d, gradient_2dgs] = 'means2d')[source]

A default strategy that follows the original 3DGS paper:

3D Gaussian Splatting for Real-Time Radiance Field Rendering

The strategy will:

  • Periodically duplicate GSs with high image plane gradients and small scales.

  • Periodically split GSs with high image plane gradients and large scales.

  • Periodically prune GSs with low opacity.

  • Periodically reset GSs to a lower opacity.

If absgrad=True, it will use the absolute gradients instead of average gradients for GS duplicating & splitting, following the AbsGS paper:

AbsGS: Recovering Fine Details for 3D Gaussian Splatting

Which typically leads to better results but requires to set the grow_grad2d to a higher value, e.g., 0.0008. Also, the rasterization() function should be called with absgrad=True as well so that the absolute gradients are computed.

Parameters:
  • prune_opa (float) – GSs with opacity below this value will be pruned. Default is 0.005.

  • grow_grad2d (float) – GSs with image plane gradient above this value will be split/duplicated. Default is 0.0002.

  • grow_scale3d (float) – GSs with 3d scale (normalized by scene_scale) below this value will be duplicated. Above will be split. Default is 0.01.

  • grow_scale2d (float) – GSs with 2d scale (normalized by image resolution) above this value will be split. Default is 0.05.

  • prune_scale3d (float) – GSs with 3d scale (normalized by scene_scale) above this value will be pruned. Default is 0.1.

  • prune_scale2d (float) – GSs with 2d scale (normalized by image resolution) above this value will be pruned. Default is 0.15.

  • refine_scale2d_stop_iter (int) – Stop refining GSs based on 2d scale after this iteration. Default is 0. Set to a positive value to enable this feature.

  • refine_start_iter (int) – Start refining GSs after this iteration. Default is 500.

  • refine_stop_iter (int) – Stop refining GSs after this iteration. Default is 15_000.

  • reset_every (int) – Reset opacities every this steps. Default is 3000.

  • refine_every (int) – Refine GSs every this steps. Default is 100.

  • pause_refine_after_reset (int) – Pause refining GSs until this number of steps after reset, Default is 0 (no pause at all) and one might want to set this number to the number of images in training set.

  • absgrad (bool) – Use absolute gradients for GS splitting. Default is False.

  • revised_opacity (bool) – Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental). Default is False.

  • verbose (bool) – Whether to print verbose information. Default is False.

  • key_for_gradient (str) – Which variable uses for densification strategy. 3DGS uses “means2d” gradient and 2DGS uses a similar gradient which stores in variable “gradient_2dgs”.

Examples

>>> from gsplat import DefaultStrategy, rasterization
>>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
>>> optimizers: Dict[str, torch.optim.Optimizer] = ...
>>> strategy = DefaultStrategy()
>>> strategy.check_sanity(params, optimizers)
>>> strategy_state = strategy.initialize_state()
>>> for step in range(1000):
...     render_image, render_alpha, info = rasterization(...)
...     strategy.step_pre_backward(params, optimizers, strategy_state, step, info)
...     loss = ...
...     loss.backward()
...     strategy.step_post_backward(params, optimizers, strategy_state, step, info)
check_sanity(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer])[source]

Sanity check for the parameters and optimizers.

Check if:
  • params and optimizers have the same keys.

  • Each optimizer has exactly one param_group, corresponding to each parameter.

  • The following keys are present: {“means”, “scales”, “quats”, “opacities”}.

Raises:

AssertionError – If any of the above conditions is not met.

Note

It is not required but highly recommended for the user to call this function after initializing the strategy to ensure the convention of the parameters and optimizers is as expected.

initialize_state(scene_scale: float = 1.0) Dict[str, Any][source]

Initialize and return the running state for this strategy.

The returned state should be passed to the step_pre_backward() and step_post_backward() functions.

step_post_backward(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer], state: Dict[str, Any], step: int, info: Dict[str, Any], packed: bool = False)[source]

Callback function to be executed after the loss.backward() call.

step_pre_backward(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer], state: Dict[str, Any], step: int, info: Dict[str, Any])[source]

Callback function to be executed before the loss.backward() call.

class MCMCStrategy(cap_max: int = 1000000, noise_lr: float = 500000.0, refine_start_iter: int = 500, refine_stop_iter: int = 25000, refine_every: int = 100, min_opacity: float = 0.005, verbose: bool = False)[source]

Strategy that follows the paper:

3D Gaussian Splatting as Markov Chain Monte Carlo

This strategy will:

  • Periodically teleport GSs with low opacity to a place that has high opacity.

  • Periodically introduce new GSs sampled based on the opacity distribution.

  • Periodically perturb the GSs locations.

Parameters:
  • cap_max (int) – Maximum number of GSs. Default to 1_000_000.

  • noise_lr (float) – MCMC samping noise learning rate. Default to 5e5.

  • refine_start_iter (int) – Start refining GSs after this iteration. Default to 500.

  • refine_stop_iter (int) – Stop refining GSs after this iteration. Default to 25_000.

  • refine_every (int) – Refine GSs every this steps. Default to 100.

  • min_opacity (float) – GSs with opacity below this value will be pruned. Default to 0.005.

  • verbose (bool) – Whether to print verbose information. Default to False.

Examples

>>> from gsplat import MCMCStrategy, rasterization
>>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
>>> optimizers: Dict[str, torch.optim.Optimizer] = ...
>>> strategy = MCMCStrategy()
>>> strategy.check_sanity(params, optimizers)
>>> strategy_state = strategy.initialize_state()
>>> for step in range(1000):
...     render_image, render_alpha, info = rasterization(...)
...     loss = ...
...     loss.backward()
...     strategy.step_post_backward(params, optimizers, strategy_state, step, info, lr=1e-3)
check_sanity(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer])[source]

Sanity check for the parameters and optimizers.

Check if:
  • params and optimizers have the same keys.

  • Each optimizer has exactly one param_group, corresponding to each parameter.

  • The following keys are present: {“means”, “scales”, “quats”, “opacities”}.

Raises:

AssertionError – If any of the above conditions is not met.

Note

It is not required but highly recommended for the user to call this function after initializing the strategy to ensure the convention of the parameters and optimizers is as expected.

initialize_state() Dict[str, Any][source]

Initialize and return the running state for this strategy.

step_post_backward(params: Dict[str, Parameter] | ParameterDict, optimizers: Dict[str, Optimizer], state: Dict[str, Any], step: int, info: Dict[str, Any], lr: float)[source]

Callback function to be executed after the loss.backward() call.

Parameters:

lr (float) – Learning rate for “means” attribute of the GS.