Module diffq.diffq

Differentiable quantizer based on scaled noise injection.

Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Differentiable quantizer based on scaled noise injection.
"""
from dataclasses import dataclass
import math
import typing as tp

import torch

from .base import BaseQuantizer
from .uniform import uniform_quantize, uniform_unquantize
from .utils import capture_init, simple_repr


class DiffQuantizer(BaseQuantizer):
    @dataclass
    class _QuantizedParam(BaseQuantizer._QuantizedParam):
        logit: torch.nn.Parameter

    @capture_init
    def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
                 group_size: int = 1, min_bits: float = 2, max_bits: float = 15,
                 param="bits", noise="gaussian",
                 init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq",
                 exclude: tp.List[str] = [], detect_bound: bool = True):
        """
        Differentiable quantizer based on scaled noise injection.
        For every parameter `p` in the model, this introduces a number of bits parameter
        `b` with the same dimensions (when group_size = 1).
        Before each forward, `p` is replaced by `p + U`
        with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization
        step for `b` bits.
        This noise approximates the quantization noise in a differentiable manner, both
        with respect to the unquantized parameter `p` and the number of bits `b`.

        At eveluation (as detected with `model.eval()`), the model is replaced
        by its true quantized version, and restored when going back to training.

        When doing actual quantization (for serialization, or evaluation),
        the number of bits is rounded to the nearest integer, and needs to be stored along.
        This will cost a few bits per dimension. To reduce this cost, one can use `group_size`,
        which will use a single noise level for multiple weight entries.

        You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the
        model size in MB. You can then use this estimate as a penalty in your training loss.

        Args:
            model (torch.nn.Module): model to quantize
            min_size (float): minimum size in MB of a parameter to be quantized.
            float16 (bool): if a layer is smaller than min_size, should we still do float16?
            group_size (int): weight entries are groupped together to reduce the number
                of noise scales to store. This should divide the size of all parameters
                bigger than min_size.
            min_bits (float): minimal number of bits.
            max_bits (float): maximal number of bits.
            init_bits (float): initial number of bits.
            extra_bits (float): extra bits to add for actual quantization (before roundoff).
            suffix (str): suffix used for the name of the extra noise scale parameters.
            exclude (list[str]): list of patterns used to match parameters to exclude.
                For instance `['bias']` to exclude all bias terms.
            detect_bound (bool): if True, will detect bound parameters and reuse
                the same quantized tensor for both, as well as the same number of bits.

        ..Warning::
            You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly.

        """
        self.group_size = group_size
        self.min_bits = min_bits
        self.max_bits = max_bits
        self.init_bits = init_bits
        self.extra_bits = extra_bits
        self.suffix = suffix
        self.param = param
        self.noise = noise
        assert noise in ["gaussian", "uniform"]
        self._optimizer_setup = False

        self._min_noise = 1 / (2 ** self.max_bits - 1)
        self._max_noise = 1 / (2 ** self.min_bits - 1)

        assert group_size >= 0
        assert min_bits < init_bits < max_bits, \
               "init_bits must be between min_bits and max_bits excluded3"

        for name, _ in model.named_parameters():
            if name.endswith(suffix):
                raise RuntimeError("The model already has some noise scales parameters, "
                                   "maybe you used twice a DiffQuantizer on the same model?.")

        super().__init__(model, min_size, float16, exclude, detect_bound)

    def _get_bits(self, logit: torch.Tensor):
        if self.param == "noise":
            return torch.log2(1 + 1 / self._get_noise_scale(logit))
        else:
            t = torch.sigmoid(logit)
            return self.max_bits * t + (1 - t) * self.min_bits

    def _get_noise_scale(self, logit: torch.Tensor):
        if self.param == "noise":
            t = torch.sigmoid(logit)
            return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise))
        else:
            return 1 / (2 ** self._get_bits(logit) - 1)

    def _register_param(self, name, param, module, other):
        if other is not None:
            return self.__class__._QuantizedParam(
               name=name, param=param, module=module, logit=other.logit, other=other)
        assert self.group_size == 0 or param.numel() % self.group_size == 0
        # we want the initial number of bits to be init_bits.
        if self.param == "noise":
            noise_scale = 1 / (2 ** self.init_bits - 1)
            t = (math.log(noise_scale) - math.log(self._max_noise)) / (
                math.log(self._min_noise) - math.log(self._max_noise))
        else:
            t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
        assert 0 < t < 1
        logit = torch.logit(torch.tensor(float(t)))
        assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
        if self.group_size > 0:
            nparam = param.numel() // self.group_size
        else:
            nparam = 1
        logit = torch.nn.Parameter(
            torch.full(
                (nparam,),
                logit,
                device=param.device))
        module.register_parameter(name + self.suffix, logit)
        return self.__class__._QuantizedParam(
           name=name, param=param, module=module, logit=logit, other=None)

    def clear_optimizer(self, optimizer: torch.optim.Optimizer):
        params = [qp.logit for qp in self._qparams]

        for group in optimizer.param_groups:
            new_params = []
            for q in list(group["params"]):
                matched = False
                for p in params:
                    if p is q:
                        matched = True
                if not matched:
                    new_params.append(q)
            group["params"][:] = new_params

    def setup_optimizer(self, optimizer: torch.optim.Optimizer,
                        lr: float = 1e-3, **kwargs):
        """
        Setup the optimizer to tune the number of bits. In particular, this will deactivate
        weight decay for the bits parameters.

        Args:
            optimizer (torch.Optimizer): optimizer to use.
            lr (float): specific learning rate for the bits parameters. 1e-3
                is perfect for Adam.,w
            kwargs (dict): overrides for other optimization parameters for the bits.
        """
        assert not self._optimizer_setup
        self._optimizer_setup = True

        params = [qp.logit for qp in self._qparams]

        for group in optimizer.param_groups:
            for q in list(group["params"]):
                for p in params:
                    if p is q:
                        raise RuntimeError("You should create the optimizer "
                                           "before the quantizer!")

        group = {"params": params, "lr": lr, "weight_decay": 0}
        group.update(kwargs)
        optimizer.add_param_group(group)

    def no_optimizer(self):
        """
        Call this if you do not want to use an optimizer.
        """
        self._optimizer_setup = True

    def check_unused(self):
        for qparam in self._qparams:
            if qparam.other is not None:
                continue
            grad = qparam.param.grad
            if grad is None or (grad == 0).all():
                if qparam.logit.grad is not None:
                    qparam.logit.grad.data.zero_()

    def model_size(self, exact=False):
        """
        Differentiable estimate of the model size.
        The size is returned in MB.

        If `exact` is True, then the output is no longer differentiable but
        reflect exactly an achievable size, even without compression,
        i.e.same as returned by `naive_model_size()`.
        """
        total = super().model_size()
        subtotal = 0
        for qparam in self._qparams:
            # only count the first appearance of a Parameter
            if qparam.other is not None:
                continue
            bits = self.extra_bits + self._get_bits(qparam.logit)
            if exact:
                bits = bits.round().clamp(1, 15)
            if self.group_size == 0:
                group_size = qparam.param.numel()
            else:
                group_size = self.group_size
            subtotal += group_size * bits.sum()
            subtotal += 2 * 32  # param scale

            # Number of bits to represent each number of bits
            bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
            subtotal += 8  # 8 bits for bits_bits
            subtotal += bits_bits * bits.numel()

        subtotal /= 2 ** 20 * 8  # bits -> MegaBytes
        return total + subtotal

    def true_model_size(self):
        """
        Naive model size without zlib compression.
        """
        return self.model_size(exact=True).item()

    def _pre_forward_train(self):
        if not self._optimizer_setup:
            raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
                               "before starting training.")
        for qparam in self._qparams:
            if qparam.other is not None:
                noisy = qparam.other.module._parameters[qparam.other.name]
            else:
                bits = self._get_bits(qparam.logit)[:, None]
                if self.group_size == 0:
                    p_flat = qparam.param.view(-1)
                else:
                    p_flat = qparam.param.view(-1, self.group_size)
                scale = p_flat.max() - p_flat.min()
                unit = 1 / (2**bits - 1)
                if self.noise == "uniform":
                    noise_source = (torch.rand_like(p_flat) - 0.5)
                elif self.noise == "gaussian":
                    noise_source = torch.randn_like(p_flat) / 2
                noise = scale * unit * noise_source
                noisy = p_flat + noise
            # We bypass the checks by PyTorch on parameters being leafs
            qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param)
        return True

    def _post_forward_train(self):
        for qparam in self._qparams:
            qparam.module._parameters[qparam.name] = qparam.param
        return True

    def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
        bits = self.extra_bits + self._get_bits(qparam.logit)
        bits = bits.round().clamp(1, 15)[:, None].byte()
        if self.group_size == 0:
            p = qparam.param.data.view(1, -1)
        else:
            p = qparam.param.data.view(-1, self.group_size)
        levels, scales = uniform_quantize(p, bits)
        return levels, scales, bits[:, 0]

    def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
        levels, param_scale, bits = quantized
        if bits.dim() == 1:
            bits = bits[:, None]
        return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data)

    def _bit_pack_param(self, qparam, quantized, pack_fn):
        levels, scales, bits = quantized
        all_packed = []
        for bit in range(1, 15):
            sub_levels = levels[bits == bit]
            if not sub_levels.numel():
                all_packed.append(None)
            else:
                packed = pack_fn(sub_levels, bit)
                all_packed.append(packed)
        packed_bits = pack_fn(bits - self.min_bits)
        return (all_packed, scales, packed_bits)

    def _bit_unpack_param(self, qparam, packed, unpack_fn):
        """Unpack bitpacked representation. Should be overriden.
        """
        packed_all_levels, scales, packed_bits = packed
        bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits
        bits = bits.to(qparam.param.device)
        levels = torch.empty(qparam.logit.numel(), self.group_size,
                             dtype=torch.short, device=qparam.param.device)
        for idx, packed_levels in enumerate(packed_all_levels):
            bit = idx + 1
            if packed_levels is None:
                continue
            sub_levels = levels[bits == bit]
            levels[bits == bit] = unpack_fn(
                packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels)
        return (levels, scales, bits)

    def detach(self):
        super().detach()
        for qparam in self._qparams:
            delattr(qparam.module, qparam.name + self.suffix)

    def __repr__(self):
        return simple_repr(self)

Classes

class DiffQuantizer (model: torch.nn.modules.module.Module, min_size: float = 0.01, float16: bool = False, group_size: int = 1, min_bits: float = 2, max_bits: float = 15, param='bits', noise='gaussian', init_bits: float = 8, extra_bits: float = 0, suffix: str = '_diffq', exclude: List[str] = [], detect_bound: bool = True)

Differentiable quantizer based on scaled noise injection. For every parameter p in the model, this introduces a number of bits parameter b with the same dimensions (when group_size = 1). Before each forward, p is replaced by p + U with U uniform iid noise with range [-d/2, d/2], with d the uniform quantization step for b bits. This noise approximates the quantization noise in a differentiable manner, both with respect to the unquantized parameter p and the number of bits b.

At eveluation (as detected with model.eval()), the model is replaced by its true quantized version, and restored when going back to training.

When doing actual quantization (for serialization, or evaluation), the number of bits is rounded to the nearest integer, and needs to be stored along. This will cost a few bits per dimension. To reduce this cost, one can use group_size, which will use a single noise level for multiple weight entries.

You can use the DiffQuantizer.model_size() method to get a differentiable estimate of the model size in MB. You can then use this estimate as a penalty in your training loss.

Args

model : torch.nn.Module
model to quantize
min_size : float
minimum size in MB of a parameter to be quantized.
float16 : bool
if a layer is smaller than min_size, should we still do float16?
group_size : int
weight entries are groupped together to reduce the number of noise scales to store. This should divide the size of all parameters bigger than min_size.
min_bits : float
minimal number of bits.
max_bits : float
maximal number of bits.
init_bits : float
initial number of bits.
extra_bits : float
extra bits to add for actual quantization (before roundoff).
suffix : str
suffix used for the name of the extra noise scale parameters.
exclude : list[str]
list of patterns used to match parameters to exclude. For instance ['bias'] to exclude all bias terms.
detect_bound : bool
if True, will detect bound parameters and reuse the same quantized tensor for both, as well as the same number of bits.

Warning

You must call model.training() and model.eval() for DiffQuantizer work properly.

Expand source code
class DiffQuantizer(BaseQuantizer):
    @dataclass
    class _QuantizedParam(BaseQuantizer._QuantizedParam):
        logit: torch.nn.Parameter

    @capture_init
    def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False,
                 group_size: int = 1, min_bits: float = 2, max_bits: float = 15,
                 param="bits", noise="gaussian",
                 init_bits: float = 8, extra_bits: float = 0, suffix: str = "_diffq",
                 exclude: tp.List[str] = [], detect_bound: bool = True):
        """
        Differentiable quantizer based on scaled noise injection.
        For every parameter `p` in the model, this introduces a number of bits parameter
        `b` with the same dimensions (when group_size = 1).
        Before each forward, `p` is replaced by `p + U`
        with U uniform iid noise with range [-d/2, d/2], with `d` the uniform quantization
        step for `b` bits.
        This noise approximates the quantization noise in a differentiable manner, both
        with respect to the unquantized parameter `p` and the number of bits `b`.

        At eveluation (as detected with `model.eval()`), the model is replaced
        by its true quantized version, and restored when going back to training.

        When doing actual quantization (for serialization, or evaluation),
        the number of bits is rounded to the nearest integer, and needs to be stored along.
        This will cost a few bits per dimension. To reduce this cost, one can use `group_size`,
        which will use a single noise level for multiple weight entries.

        You can use the `DiffQuantizer.model_size` method to get a differentiable estimate of the
        model size in MB. You can then use this estimate as a penalty in your training loss.

        Args:
            model (torch.nn.Module): model to quantize
            min_size (float): minimum size in MB of a parameter to be quantized.
            float16 (bool): if a layer is smaller than min_size, should we still do float16?
            group_size (int): weight entries are groupped together to reduce the number
                of noise scales to store. This should divide the size of all parameters
                bigger than min_size.
            min_bits (float): minimal number of bits.
            max_bits (float): maximal number of bits.
            init_bits (float): initial number of bits.
            extra_bits (float): extra bits to add for actual quantization (before roundoff).
            suffix (str): suffix used for the name of the extra noise scale parameters.
            exclude (list[str]): list of patterns used to match parameters to exclude.
                For instance `['bias']` to exclude all bias terms.
            detect_bound (bool): if True, will detect bound parameters and reuse
                the same quantized tensor for both, as well as the same number of bits.

        ..Warning::
            You must call `model.training()` and `model.eval()` for `DiffQuantizer` work properly.

        """
        self.group_size = group_size
        self.min_bits = min_bits
        self.max_bits = max_bits
        self.init_bits = init_bits
        self.extra_bits = extra_bits
        self.suffix = suffix
        self.param = param
        self.noise = noise
        assert noise in ["gaussian", "uniform"]
        self._optimizer_setup = False

        self._min_noise = 1 / (2 ** self.max_bits - 1)
        self._max_noise = 1 / (2 ** self.min_bits - 1)

        assert group_size >= 0
        assert min_bits < init_bits < max_bits, \
               "init_bits must be between min_bits and max_bits excluded3"

        for name, _ in model.named_parameters():
            if name.endswith(suffix):
                raise RuntimeError("The model already has some noise scales parameters, "
                                   "maybe you used twice a DiffQuantizer on the same model?.")

        super().__init__(model, min_size, float16, exclude, detect_bound)

    def _get_bits(self, logit: torch.Tensor):
        if self.param == "noise":
            return torch.log2(1 + 1 / self._get_noise_scale(logit))
        else:
            t = torch.sigmoid(logit)
            return self.max_bits * t + (1 - t) * self.min_bits

    def _get_noise_scale(self, logit: torch.Tensor):
        if self.param == "noise":
            t = torch.sigmoid(logit)
            return torch.exp(t * math.log(self._min_noise) + (1 - t) * math.log(self._max_noise))
        else:
            return 1 / (2 ** self._get_bits(logit) - 1)

    def _register_param(self, name, param, module, other):
        if other is not None:
            return self.__class__._QuantizedParam(
               name=name, param=param, module=module, logit=other.logit, other=other)
        assert self.group_size == 0 or param.numel() % self.group_size == 0
        # we want the initial number of bits to be init_bits.
        if self.param == "noise":
            noise_scale = 1 / (2 ** self.init_bits - 1)
            t = (math.log(noise_scale) - math.log(self._max_noise)) / (
                math.log(self._min_noise) - math.log(self._max_noise))
        else:
            t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
        assert 0 < t < 1
        logit = torch.logit(torch.tensor(float(t)))
        assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
        if self.group_size > 0:
            nparam = param.numel() // self.group_size
        else:
            nparam = 1
        logit = torch.nn.Parameter(
            torch.full(
                (nparam,),
                logit,
                device=param.device))
        module.register_parameter(name + self.suffix, logit)
        return self.__class__._QuantizedParam(
           name=name, param=param, module=module, logit=logit, other=None)

    def clear_optimizer(self, optimizer: torch.optim.Optimizer):
        params = [qp.logit for qp in self._qparams]

        for group in optimizer.param_groups:
            new_params = []
            for q in list(group["params"]):
                matched = False
                for p in params:
                    if p is q:
                        matched = True
                if not matched:
                    new_params.append(q)
            group["params"][:] = new_params

    def setup_optimizer(self, optimizer: torch.optim.Optimizer,
                        lr: float = 1e-3, **kwargs):
        """
        Setup the optimizer to tune the number of bits. In particular, this will deactivate
        weight decay for the bits parameters.

        Args:
            optimizer (torch.Optimizer): optimizer to use.
            lr (float): specific learning rate for the bits parameters. 1e-3
                is perfect for Adam.,w
            kwargs (dict): overrides for other optimization parameters for the bits.
        """
        assert not self._optimizer_setup
        self._optimizer_setup = True

        params = [qp.logit for qp in self._qparams]

        for group in optimizer.param_groups:
            for q in list(group["params"]):
                for p in params:
                    if p is q:
                        raise RuntimeError("You should create the optimizer "
                                           "before the quantizer!")

        group = {"params": params, "lr": lr, "weight_decay": 0}
        group.update(kwargs)
        optimizer.add_param_group(group)

    def no_optimizer(self):
        """
        Call this if you do not want to use an optimizer.
        """
        self._optimizer_setup = True

    def check_unused(self):
        for qparam in self._qparams:
            if qparam.other is not None:
                continue
            grad = qparam.param.grad
            if grad is None or (grad == 0).all():
                if qparam.logit.grad is not None:
                    qparam.logit.grad.data.zero_()

    def model_size(self, exact=False):
        """
        Differentiable estimate of the model size.
        The size is returned in MB.

        If `exact` is True, then the output is no longer differentiable but
        reflect exactly an achievable size, even without compression,
        i.e.same as returned by `naive_model_size()`.
        """
        total = super().model_size()
        subtotal = 0
        for qparam in self._qparams:
            # only count the first appearance of a Parameter
            if qparam.other is not None:
                continue
            bits = self.extra_bits + self._get_bits(qparam.logit)
            if exact:
                bits = bits.round().clamp(1, 15)
            if self.group_size == 0:
                group_size = qparam.param.numel()
            else:
                group_size = self.group_size
            subtotal += group_size * bits.sum()
            subtotal += 2 * 32  # param scale

            # Number of bits to represent each number of bits
            bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
            subtotal += 8  # 8 bits for bits_bits
            subtotal += bits_bits * bits.numel()

        subtotal /= 2 ** 20 * 8  # bits -> MegaBytes
        return total + subtotal

    def true_model_size(self):
        """
        Naive model size without zlib compression.
        """
        return self.model_size(exact=True).item()

    def _pre_forward_train(self):
        if not self._optimizer_setup:
            raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
                               "before starting training.")
        for qparam in self._qparams:
            if qparam.other is not None:
                noisy = qparam.other.module._parameters[qparam.other.name]
            else:
                bits = self._get_bits(qparam.logit)[:, None]
                if self.group_size == 0:
                    p_flat = qparam.param.view(-1)
                else:
                    p_flat = qparam.param.view(-1, self.group_size)
                scale = p_flat.max() - p_flat.min()
                unit = 1 / (2**bits - 1)
                if self.noise == "uniform":
                    noise_source = (torch.rand_like(p_flat) - 0.5)
                elif self.noise == "gaussian":
                    noise_source = torch.randn_like(p_flat) / 2
                noise = scale * unit * noise_source
                noisy = p_flat + noise
            # We bypass the checks by PyTorch on parameters being leafs
            qparam.module._parameters[qparam.name] = noisy.view_as(qparam.param)
        return True

    def _post_forward_train(self):
        for qparam in self._qparams:
            qparam.module._parameters[qparam.name] = qparam.param
        return True

    def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
        bits = self.extra_bits + self._get_bits(qparam.logit)
        bits = bits.round().clamp(1, 15)[:, None].byte()
        if self.group_size == 0:
            p = qparam.param.data.view(1, -1)
        else:
            p = qparam.param.data.view(-1, self.group_size)
        levels, scales = uniform_quantize(p, bits)
        return levels, scales, bits[:, 0]

    def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
        levels, param_scale, bits = quantized
        if bits.dim() == 1:
            bits = bits[:, None]
        return uniform_unquantize(levels, param_scale, bits).view_as(qparam.param.data)

    def _bit_pack_param(self, qparam, quantized, pack_fn):
        levels, scales, bits = quantized
        all_packed = []
        for bit in range(1, 15):
            sub_levels = levels[bits == bit]
            if not sub_levels.numel():
                all_packed.append(None)
            else:
                packed = pack_fn(sub_levels, bit)
                all_packed.append(packed)
        packed_bits = pack_fn(bits - self.min_bits)
        return (all_packed, scales, packed_bits)

    def _bit_unpack_param(self, qparam, packed, unpack_fn):
        """Unpack bitpacked representation. Should be overriden.
        """
        packed_all_levels, scales, packed_bits = packed
        bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits
        bits = bits.to(qparam.param.device)
        levels = torch.empty(qparam.logit.numel(), self.group_size,
                             dtype=torch.short, device=qparam.param.device)
        for idx, packed_levels in enumerate(packed_all_levels):
            bit = idx + 1
            if packed_levels is None:
                continue
            sub_levels = levels[bits == bit]
            levels[bits == bit] = unpack_fn(
                packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels)
        return (levels, scales, bits)

    def detach(self):
        super().detach()
        for qparam in self._qparams:
            delattr(qparam.module, qparam.name + self.suffix)

    def __repr__(self):
        return simple_repr(self)

Ancestors

Methods

def check_unused(self)
Expand source code
def check_unused(self):
    for qparam in self._qparams:
        if qparam.other is not None:
            continue
        grad = qparam.param.grad
        if grad is None or (grad == 0).all():
            if qparam.logit.grad is not None:
                qparam.logit.grad.data.zero_()
def clear_optimizer(self, optimizer: torch.optim.optimizer.Optimizer)
Expand source code
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
    params = [qp.logit for qp in self._qparams]

    for group in optimizer.param_groups:
        new_params = []
        for q in list(group["params"]):
            matched = False
            for p in params:
                if p is q:
                    matched = True
            if not matched:
                new_params.append(q)
        group["params"][:] = new_params
def model_size(self, exact=False)

Differentiable estimate of the model size. The size is returned in MB.

If exact is True, then the output is no longer differentiable but reflect exactly an achievable size, even without compression, i.e.same as returned by naive_model_size().

Expand source code
def model_size(self, exact=False):
    """
    Differentiable estimate of the model size.
    The size is returned in MB.

    If `exact` is True, then the output is no longer differentiable but
    reflect exactly an achievable size, even without compression,
    i.e.same as returned by `naive_model_size()`.
    """
    total = super().model_size()
    subtotal = 0
    for qparam in self._qparams:
        # only count the first appearance of a Parameter
        if qparam.other is not None:
            continue
        bits = self.extra_bits + self._get_bits(qparam.logit)
        if exact:
            bits = bits.round().clamp(1, 15)
        if self.group_size == 0:
            group_size = qparam.param.numel()
        else:
            group_size = self.group_size
        subtotal += group_size * bits.sum()
        subtotal += 2 * 32  # param scale

        # Number of bits to represent each number of bits
        bits_bits = math.ceil(math.log2(1 + (bits.max().round().item() - self.min_bits)))
        subtotal += 8  # 8 bits for bits_bits
        subtotal += bits_bits * bits.numel()

    subtotal /= 2 ** 20 * 8  # bits -> MegaBytes
    return total + subtotal
def no_optimizer(self)

Call this if you do not want to use an optimizer.

Expand source code
def no_optimizer(self):
    """
    Call this if you do not want to use an optimizer.
    """
    self._optimizer_setup = True
def setup_optimizer(self, optimizer: torch.optim.optimizer.Optimizer, lr: float = 0.001, **kwargs)

Setup the optimizer to tune the number of bits. In particular, this will deactivate weight decay for the bits parameters.

Args

optimizer : torch.Optimizer
optimizer to use.
lr : float
specific learning rate for the bits parameters. 1e-3 is perfect for Adam.,w
kwargs : dict
overrides for other optimization parameters for the bits.
Expand source code
def setup_optimizer(self, optimizer: torch.optim.Optimizer,
                    lr: float = 1e-3, **kwargs):
    """
    Setup the optimizer to tune the number of bits. In particular, this will deactivate
    weight decay for the bits parameters.

    Args:
        optimizer (torch.Optimizer): optimizer to use.
        lr (float): specific learning rate for the bits parameters. 1e-3
            is perfect for Adam.,w
        kwargs (dict): overrides for other optimization parameters for the bits.
    """
    assert not self._optimizer_setup
    self._optimizer_setup = True

    params = [qp.logit for qp in self._qparams]

    for group in optimizer.param_groups:
        for q in list(group["params"]):
            for p in params:
                if p is q:
                    raise RuntimeError("You should create the optimizer "
                                       "before the quantizer!")

    group = {"params": params, "lr": lr, "weight_decay": 0}
    group.update(kwargs)
    optimizer.add_param_group(group)
def true_model_size(self)

Naive model size without zlib compression.

Expand source code
def true_model_size(self):
    """
    Naive model size without zlib compression.
    """
    return self.model_size(exact=True).item()

Inherited members