Source code for gsplat.compression.png_compression

import json
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor

from gsplat.compression.sort import sort_splats
from gsplat.utils import inverse_log_transform, log_transform


[docs] @dataclass class PngCompression: """Uses quantization and sorting to compress splats into PNG files and uses K-means clustering to compress the spherical harmonic coefficents. .. warning:: This class requires the `imageio <https://pypi.org/project/imageio/>`_, `plas <https://github.com/fraunhoferhhi/PLAS.git>`_ and `torchpq <https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install>`_ packages to be installed. .. warning:: This class might throw away a few lowest opacities splats if the number of splats is not a square number. .. note:: The splats parameters are expected to be pre-activation values. It expects the following fields in the splats dictionary: "means", "scales", "quats", "opacities", "sh0", "shN". More fields can be added to the dictionary, but they will only be compressed using NPZ compression. References: - `Compact 3D Scene Representation via Self-Organizing Gaussian Grids <https://arxiv.org/abs/2312.13299>`_ - `Making Gaussian Splats more smaller <https://aras-p.info/blog/2023/09/27/Making-Gaussian-Splats-more-smaller/>`_ Args: use_sort (bool, optional): Whether to sort splats before compression. Defaults to True. verbose (bool, optional): Whether to print verbose information. Default to True. """ use_sort: bool = True verbose: bool = True def _get_compress_fn(self, param_name: str) -> Callable: compress_fn_map = { "means": _compress_png_16bit, "scales": _compress_png, "quats": _compress_png, "opacities": _compress_png, "sh0": _compress_png, "shN": _compress_kmeans, } if param_name in compress_fn_map: return compress_fn_map[param_name] else: return _compress_npz def _get_decompress_fn(self, param_name: str) -> Callable: decompress_fn_map = { "means": _decompress_png_16bit, "scales": _decompress_png, "quats": _decompress_png, "opacities": _decompress_png, "sh0": _decompress_png, "shN": _decompress_kmeans, } if param_name in decompress_fn_map: return decompress_fn_map[param_name] else: return _decompress_npz
[docs] def compress(self, compress_dir: str, splats: Dict[str, Tensor]) -> None: """Run compression Args: compress_dir (str): directory to save compressed files splats (Dict[str, Tensor]): Gaussian splats to compress """ # Param-specific preprocessing splats["means"] = log_transform(splats["means"]) splats["quats"] = F.normalize(splats["quats"], dim=-1) n_gs = len(splats["means"]) n_sidelen = int(n_gs**0.5) n_crop = n_gs - n_sidelen**2 if n_crop != 0: splats = _crop_n_splats(splats, n_crop) print( f"Warning: Number of Gaussians was not square. Removed {n_crop} Gaussians." ) if self.use_sort: splats = sort_splats(splats) meta = {} for param_name in splats.keys(): compress_fn = self._get_compress_fn(param_name) kwargs = { "n_sidelen": n_sidelen, "verbose": self.verbose, } meta[param_name] = compress_fn( compress_dir, param_name, splats[param_name], **kwargs ) with open(os.path.join(compress_dir, "meta.json"), "w") as f: json.dump(meta, f)
[docs] def decompress(self, compress_dir: str) -> Dict[str, Tensor]: """Run decompression Args: compress_dir (str): directory that contains compressed files Returns: Dict[str, Tensor]: decompressed Gaussian splats """ with open(os.path.join(compress_dir, "meta.json"), "r") as f: meta = json.load(f) splats = {} for param_name, param_meta in meta.items(): decompress_fn = self._get_decompress_fn(param_name) splats[param_name] = decompress_fn(compress_dir, param_name, param_meta) # Param-specific postprocessing splats["means"] = inverse_log_transform(splats["means"]) return splats
def _crop_n_splats(splats: Dict[str, Tensor], n_crop: int) -> Dict[str, Tensor]: opacities = splats["opacities"] keep_indices = torch.argsort(opacities, descending=True)[:-n_crop] for k, v in splats.items(): splats[k] = v[keep_indices] return splats def _compress_png( compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs ) -> Dict[str, Any]: """Compress parameters with 8-bit quantization and lossless PNG compression. Args: compress_dir (str): compression directory param_name (str): parameter field name params (Tensor): parameters n_sidelen (int): image side length Returns: Dict[str, Any]: metadata """ import imageio.v2 as imageio if torch.numel == 0: meta = { "shape": list(params.shape), "dtype": str(params.dtype).split(".")[1], } return meta grid = params.reshape((n_sidelen, n_sidelen, -1)) mins = torch.amin(grid, dim=(0, 1)) maxs = torch.amax(grid, dim=(0, 1)) grid_norm = (grid - mins) / (maxs - mins) img_norm = grid_norm.detach().cpu().numpy() img = (img_norm * (2**8 - 1)).round().astype(np.uint8) img = img.squeeze() imageio.imwrite(os.path.join(compress_dir, f"{param_name}.png"), img) meta = { "shape": list(params.shape), "dtype": str(params.dtype).split(".")[1], "mins": mins.tolist(), "maxs": maxs.tolist(), } return meta def _decompress_png(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor: """Decompress parameters from PNG file. Args: compress_dir (str): compression directory param_name (str): parameter field name meta (Dict[str, Any]): metadata Returns: Tensor: parameters """ import imageio.v2 as imageio if not np.all(meta["shape"]): params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) return meta img = imageio.imread(os.path.join(compress_dir, f"{param_name}.png")) img_norm = img / (2**8 - 1) grid_norm = torch.tensor(img_norm) mins = torch.tensor(meta["mins"]) maxs = torch.tensor(meta["maxs"]) grid = grid_norm * (maxs - mins) + mins params = grid.reshape(meta["shape"]) params = params.to(dtype=getattr(torch, meta["dtype"])) return params def _compress_png_16bit( compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs ) -> Dict[str, Any]: """Compress parameters with 16-bit quantization and PNG compression. Args: compress_dir (str): compression directory param_name (str): parameter field name params (Tensor): parameters n_sidelen (int): image side length Returns: Dict[str, Any]: metadata """ import imageio.v2 as imageio if torch.numel == 0: meta = { "shape": list(params.shape), "dtype": str(params.dtype).split(".")[1], } return meta grid = params.reshape((n_sidelen, n_sidelen, -1)) mins = torch.amin(grid, dim=(0, 1)) maxs = torch.amax(grid, dim=(0, 1)) grid_norm = (grid - mins) / (maxs - mins) img_norm = grid_norm.detach().cpu().numpy() img = (img_norm * (2**16 - 1)).round().astype(np.uint16) img_l = img & 0xFF img_u = (img >> 8) & 0xFF imageio.imwrite( os.path.join(compress_dir, f"{param_name}_l.png"), img_l.astype(np.uint8) ) imageio.imwrite( os.path.join(compress_dir, f"{param_name}_u.png"), img_u.astype(np.uint8) ) meta = { "shape": list(params.shape), "dtype": str(params.dtype).split(".")[1], "mins": mins.tolist(), "maxs": maxs.tolist(), } return meta def _decompress_png_16bit( compress_dir: str, param_name: str, meta: Dict[str, Any] ) -> Tensor: """Decompress parameters from PNG files. Args: compress_dir (str): compression directory param_name (str): parameter field name meta (Dict[str, Any]): metadata Returns: Tensor: parameters """ import imageio.v2 as imageio if not np.all(meta["shape"]): params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) return meta img_l = imageio.imread(os.path.join(compress_dir, f"{param_name}_l.png")) img_u = imageio.imread(os.path.join(compress_dir, f"{param_name}_u.png")) img_u = img_u.astype(np.uint16) img = (img_u << 8) + img_l img_norm = img / (2**16 - 1) grid_norm = torch.tensor(img_norm) mins = torch.tensor(meta["mins"]) maxs = torch.tensor(meta["maxs"]) grid = grid_norm * (maxs - mins) + mins params = grid.reshape(meta["shape"]) params = params.to(dtype=getattr(torch, meta["dtype"])) return params def _compress_npz( compress_dir: str, param_name: str, params: Tensor, **kwargs ) -> Dict[str, Any]: """Compress parameters with numpy's NPZ compression.""" npz_dict = {"arr": params.detach().cpu().numpy()} save_fp = os.path.join(compress_dir, f"{param_name}.npz") os.makedirs(os.path.dirname(save_fp), exist_ok=True) np.savez_compressed(save_fp, **npz_dict) meta = { "shape": params.shape, "dtype": str(params.dtype).split(".")[1], } return meta def _decompress_npz(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor: """Decompress parameters with numpy's NPZ compression.""" arr = np.load(os.path.join(compress_dir, f"{param_name}.npz"))["arr"] params = torch.tensor(arr) params = params.reshape(meta["shape"]) params = params.to(dtype=getattr(torch, meta["dtype"])) return params def _compress_kmeans( compress_dir: str, param_name: str, params: Tensor, n_clusters: int = 65536, quantization: int = 6, verbose: bool = True, **kwargs, ) -> Dict[str, Any]: """Run K-means clustering on parameters and save centroids and labels to a npz file. .. warning:: TorchPQ must installed to use K-means clustering. Args: compress_dir (str): compression directory param_name (str): parameter field name params (Tensor): parameters to compress n_clusters (int): number of K-means clusters quantization (int): number of bits in quantization verbose (bool, optional): Whether to print verbose information. Default to True. Returns: Dict[str, Any]: metadata """ try: from torchpq.clustering import KMeans except: raise ImportError( "Please install torchpq with 'pip install torchpq' to use K-means clustering" ) if torch.numel == 0: meta = { "shape": list(params.shape), "dtype": str(params.dtype).split(".")[1], } return meta kmeans = KMeans(n_clusters=n_clusters, distance="manhattan", verbose=verbose) x = params.reshape(params.shape[0], -1).permute(1, 0).contiguous() labels = kmeans.fit(x) labels = labels.detach().cpu().numpy() centroids = kmeans.centroids.permute(1, 0) mins = torch.min(centroids) maxs = torch.max(centroids) centroids_norm = (centroids - mins) / (maxs - mins) centroids_norm = centroids_norm.detach().cpu().numpy() centroids_quant = ( (centroids_norm * (2**quantization - 1)).round().astype(np.uint8) ) labels = labels.astype(np.uint16) npz_dict = { "centroids": centroids_quant, "labels": labels, } np.savez_compressed(os.path.join(compress_dir, f"{param_name}.npz"), **npz_dict) meta = { "shape": list(params.shape), "dtype": str(params.dtype).split(".")[1], "mins": mins.tolist(), "maxs": maxs.tolist(), "quantization": quantization, } return meta def _decompress_kmeans( compress_dir: str, param_name: str, meta: Dict[str, Any], **kwargs ) -> Tensor: """Decompress parameters from K-means compression. Args: compress_dir (str): compression directory param_name (str): parameter field name meta (Dict[str, Any]): metadata Returns: Tensor: parameters """ if not np.all(meta["shape"]): params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) return meta npz_dict = np.load(os.path.join(compress_dir, f"{param_name}.npz")) centroids_quant = npz_dict["centroids"] labels = npz_dict["labels"] centroids_norm = centroids_quant / (2 ** meta["quantization"] - 1) centroids_norm = torch.tensor(centroids_norm) mins = torch.tensor(meta["mins"]) maxs = torch.tensor(meta["maxs"]) centroids = centroids_norm * (maxs - mins) + mins params = centroids[labels] params = params.reshape(meta["shape"]) params = params.to(dtype=getattr(torch, meta["dtype"])) return params