import math
from dataclasses import dataclass
from typing import Any, Dict, Union
import torch
from torch import Tensor
from .base import Strategy
from .ops import inject_noise_to_position, relocate, sample_add
[docs]
@dataclass
class MCMCStrategy(Strategy):
"""Strategy that follows the paper:
`3D Gaussian Splatting as Markov Chain Monte Carlo <https://arxiv.org/abs/2404.09591>`_
This strategy will:
- Periodically teleport GSs with low opacity to a place that has high opacity.
- Periodically introduce new GSs sampled based on the opacity distribution.
- Periodically perturb the GSs locations.
Args:
cap_max (int): Maximum number of GSs. Default to 1_000_000.
noise_lr (float): MCMC samping noise learning rate. Default to 5e5.
refine_start_iter (int): Start refining GSs after this iteration. Default to 500.
refine_stop_iter (int): Stop refining GSs after this iteration. Default to 25_000.
refine_every (int): Refine GSs every this steps. Default to 100.
min_opacity (float): GSs with opacity below this value will be pruned. Default to 0.005.
verbose (bool): Whether to print verbose information. Default to False.
Examples:
>>> from gsplat import MCMCStrategy, rasterization
>>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
>>> optimizers: Dict[str, torch.optim.Optimizer] = ...
>>> strategy = MCMCStrategy()
>>> strategy.check_sanity(params, optimizers)
>>> strategy_state = strategy.initialize_state()
>>> for step in range(1000):
... render_image, render_alpha, info = rasterization(...)
... loss = ...
... loss.backward()
... strategy.step_post_backward(params, optimizers, strategy_state, step, info, lr=1e-3)
"""
cap_max: int = 1_000_000
noise_lr: float = 5e5
refine_start_iter: int = 500
refine_stop_iter: int = 25_000
refine_every: int = 100
min_opacity: float = 0.005
verbose: bool = False
[docs]
def initialize_state(self) -> Dict[str, Any]:
"""Initialize and return the running state for this strategy."""
n_max = 51
binoms = torch.zeros((n_max, n_max))
for n in range(n_max):
for k in range(n + 1):
binoms[n, k] = math.comb(n, k)
return {"binoms": binoms}
[docs]
def check_sanity(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
):
"""Sanity check for the parameters and optimizers.
Check if:
* `params` and `optimizers` have the same keys.
* Each optimizer has exactly one param_group, corresponding to each parameter.
* The following keys are present: {"means", "scales", "quats", "opacities"}.
Raises:
AssertionError: If any of the above conditions is not met.
.. note::
It is not required but highly recommended for the user to call this function
after initializing the strategy to ensure the convention of the parameters
and optimizers is as expected.
"""
super().check_sanity(params, optimizers)
# The following keys are required for this strategy.
for key in ["means", "scales", "quats", "opacities"]:
assert key in params, f"{key} is required in params but missing."
# def step_pre_backward(
# self,
# params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
# optimizers: Dict[str, torch.optim.Optimizer],
# # state: Dict[str, Any],
# step: int,
# info: Dict[str, Any],
# ):
# """Callback function to be executed before the `loss.backward()` call."""
# pass
[docs]
def step_post_backward(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
state: Dict[str, Any],
step: int,
info: Dict[str, Any],
lr: float,
):
"""Callback function to be executed after the `loss.backward()` call.
Args:
lr (float): Learning rate for "means" attribute of the GS.
"""
# move to the correct device
state["binoms"] = state["binoms"].to(params["means"].device)
binoms = state["binoms"]
if (
step < self.refine_stop_iter
and step > self.refine_start_iter
and step % self.refine_every == 0
):
# teleport GSs
n_relocated_gs = self._relocate_gs(params, optimizers, binoms)
if self.verbose:
print(f"Step {step}: Relocated {n_relocated_gs} GSs.")
# add new GSs
n_new_gs = self._add_new_gs(params, optimizers, binoms)
if self.verbose:
print(
f"Step {step}: Added {n_new_gs} GSs. "
f"Now having {len(params['means'])} GSs."
)
torch.cuda.empty_cache()
# add noise to GSs
inject_noise_to_position(
params=params, optimizers=optimizers, state={}, scaler=lr * self.noise_lr
)
@torch.no_grad()
def _relocate_gs(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
binoms: Tensor,
) -> int:
opacities = torch.sigmoid(params["opacities"].flatten())
dead_mask = opacities <= self.min_opacity
n_gs = dead_mask.sum().item()
if n_gs > 0:
relocate(
params=params,
optimizers=optimizers,
state={},
mask=dead_mask,
binoms=binoms,
min_opacity=self.min_opacity,
)
return n_gs
@torch.no_grad()
def _add_new_gs(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
binoms: Tensor,
) -> int:
current_n_points = len(params["means"])
n_target = min(self.cap_max, int(1.05 * current_n_points))
n_gs = max(0, n_target - current_n_points)
if n_gs > 0:
sample_add(
params=params,
optimizers=optimizers,
state={},
n=n_gs,
binoms=binoms,
min_opacity=self.min_opacity,
)
return n_gs