Densification¶
In gsplat, we abstract out the densification and pruning process of the Gaussian training into a strategy. A strategy is a class that defines how the Gaussian parameters (along with their optimizers) should be updated (splitting, pruning, etc.) during the training.
An example of the training workflow using DefaultStrategy is like:
from gsplat import DefaultStrategy, rasterization
# Define Gaussian parameters and optimizers
params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
optimizers: Dict[str, torch.optim.Optimizer] = ...
# Initialize the strategy
strategy = DefaultStrategy()
# Check the sanity of the parameters and optimizers
strategy.check_sanity(params, optimizers)
# Initialize the strategy state
strategy_state = strategy.initialize_state()
# Training loop
for step in range(1000):
# Forward pass
render_image, render_alpha, info = rasterization(...)
# Pre-backward step
strategy.step_pre_backward(params, optimizers, strategy_state, step, info)
# Compute the loss
loss = ...
# Backward pass
loss.backward()
# Post-backward step
strategy.step_post_backward(params, optimizers, strategy_state, step, info)
A strategy will inplacely update the Gaussian parameters as well as the optimizers, so it has a specific expectation on the format of the parameters and the optimizers. It is designed to work with the Guassians defined as either a Dict of torch.nn.Parameter or a torch.nn.ParameterDict with at least the following keys: {“means”, “scales”, “quats”, “opacities”}. On top of these attributes, an arbitrary number of extra attributes are supported. Besides the parameters, it also expects a Dict of torch.optim.Optimizer with the same keys as the parameters, and each optimizer should correspond to only one learnable parameter.
For example, the following is a valid format for the parameters and the optimizers that can be used with our strategies:
N = 100
params = torch.nn.ParameterDict({
"means": Tensor(N, 3), "scales": Tensor(N), "quats": Tensor(N, 4), "opacities": Tensor(N),
"colors": Tensor(N, 25, 3), "features1": Tensor(N, 128), "features2": Tensor(N, 64),
})
optimizers = {k: torch.optim.Adam([p], lr=1e-3) for k, p in params.keys()}
Below are the strategies that are currently implemented in gsplat: