Module diffq.lsq

Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153.

Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153.
"""
from dataclasses import dataclass
import typing as tp

import torch

from .base import BaseQuantizer
from .utils import capture_init, simple_repr


class LSQ(BaseQuantizer):
    """Implements weight only quantization based on [Esser et al. 2019].
    https://arxiv.org/abs/1902.08153
    """
    @dataclass
    class _QuantizedParam(BaseQuantizer._QuantizedParam):
        scale: torch.nn.Parameter

    @capture_init
    def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01,
                 float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True):
        assert 0 < bits <= 15
        self.suffix = suffix
        self._optimizer_setup = False
        self.bits = bits

        for name, _ in model.named_parameters():
            if name.endswith(suffix):
                raise RuntimeError("The model already has some noise scales parameters, "
                                   "maybe you used twice a LSQ on the same model?.")

        super().__init__(model, min_size, float16, exclude, detect_bound)

    def _register_param(self, name, param, module, other):
        if other is not None:
            return self.__class__._QuantizedParam(
               name=name, param=param, module=module, scale=other.scale, other=other)
        # we want the initial number of bits to be init_bits.
        scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5
        scale = torch.nn.Parameter(scale)
        module.register_parameter(name + self.suffix, scale)
        return self.__class__._QuantizedParam(
           name=name, param=param, module=module, scale=scale, other=None)

    def clear_optimizer(self, optimizer: torch.optim.Optimizer):
        params = [qp.scale for qp in self._qparams]

        for group in optimizer.param_groups:
            new_params = []
            for q in list(group["params"]):
                matched = False
                for p in params:
                    if p is q:
                        matched = True
                if not matched:
                    new_params.append(q)
            group["params"][:] = new_params

    def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
        """
        Setup the optimizer to tune the scale parameter.
        Following [Esser et al. 2019], we use the same LR and weight decay
        as the base optimizer, unless specified otherwise.

        Args:
            optimizer (torch.Optimizer): optimizer to use.
            kwargs (dict): overrides for optimization parameters
        """
        assert not self._optimizer_setup
        self._optimizer_setup = True

        params = [qp.scale for qp in self._qparams]

        for group in optimizer.param_groups:
            for q in list(group["params"]):
                for p in params:
                    if p is q:
                        raise RuntimeError("You should create the optimizer "
                                           "before the quantizer!")

        group = {"params": params}
        group.update(kwargs)
        optimizer.add_param_group(group)

    def no_optimizer(self):
        """
        Call this if you do not want to use an optimizer.
        """
        self._optimizer_setup = True

    def model_size(self, exact=False):
        """
        Differentiable estimate of the model size.
        The size is returned in MB.

        If `exact` is True, then the output is no longer differentiable but
        reflect exactly an achievable size, even without compression,
        i.e.same as returned by `naive_model_size()`.
        """
        total = super().model_size()
        subtotal = 0
        for qparam in self._qparams:
            # only count the first appearance of a Parameter
            if qparam.other is not None:
                continue
            bits = qparam.param.numel() * self.bits
            subtotal += bits
            subtotal += 1 * 32  # param scale

        subtotal /= 2 ** 20 * 8  # bits -> MegaBytes
        return total + subtotal

    def true_model_size(self):
        """
        Naive model size without zlib compression.
        """
        return self.model_size(exact=True).item()

    def _pre_forward_train(self):
        if not self._optimizer_setup:
            raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
                               "before starting training.")
        for qparam in self._qparams:
            scale = qparam.scale
            quant, _ = quantize(qparam.param, scale, self.bits)
            # We bypass the checks by PyTorch on parameters being leafs
            qparam.module._parameters[qparam.name] = quant
        return True

    def _post_forward_train(self):
        for qparam in self._qparams:
            qparam.module._parameters[qparam.name] = qparam.param
        return True

    def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
        _, index = quantize(qparam.param, qparam.scale, self.bits)
        assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max()
        assert (index >= (-2 ** (self.bits - 1))).all(), index.min()
        return index.detach().short(), qparam.scale.detach()

    def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
        index, scale = quantized
        return index.float() * scale

    def _bit_pack_param(self, qparam, quantized, pack_fn):
        levels, scale = quantized
        packed = pack_fn(levels + 2 ** (self.bits - 1))
        return (packed, scale)

    def _bit_unpack_param(self, qparam, packed, unpack_fn):
        """Unpack bitpacked representation. Should be overriden
        """
        packed_levels, scale = packed
        levels = unpack_fn(
            packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
        levels -= 2 ** (self.bits - 1)
        return (levels, scale)

    def detach(self):
        super().detach()
        for qparam in self._qparams:
            delattr(qparam.module, qparam.name + self.suffix)

    def __repr__(self):
        return simple_repr(self)


def roundpass(x):
    return (x.round() - x).detach() + x


def gradscale(x, scale):
    return (x - x * scale).detach() + x * scale


def quantize(tensor, scale, bits):
    low = - 2 ** (bits - 1)
    high = 2 ** (bits - 1) - 1
    scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5)

    index = tensor / scale
    index = index.clamp(low, high)
    index = roundpass(index)
    return index * scale, index

Functions

def gradscale(x, scale)
Expand source code
def gradscale(x, scale):
    return (x - x * scale).detach() + x * scale
def quantize(tensor, scale, bits)
Expand source code
def quantize(tensor, scale, bits):
    low = - 2 ** (bits - 1)
    high = 2 ** (bits - 1) - 1
    scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5)

    index = tensor / scale
    index = index.clamp(low, high)
    index = roundpass(index)
    return index * scale, index
def roundpass(x)
Expand source code
def roundpass(x):
    return (x.round() - x).detach() + x

Classes

class LSQ (model: torch.nn.modules.module.Module, bits: int = 8, min_size: float = 0.01, float16: bool = False, suffix: str = '_lsq', exclude=[], detect_bound=True)

Implements weight only quantization based on [Esser et al. 2019]. https://arxiv.org/abs/1902.08153

Expand source code
class LSQ(BaseQuantizer):
    """Implements weight only quantization based on [Esser et al. 2019].
    https://arxiv.org/abs/1902.08153
    """
    @dataclass
    class _QuantizedParam(BaseQuantizer._QuantizedParam):
        scale: torch.nn.Parameter

    @capture_init
    def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01,
                 float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True):
        assert 0 < bits <= 15
        self.suffix = suffix
        self._optimizer_setup = False
        self.bits = bits

        for name, _ in model.named_parameters():
            if name.endswith(suffix):
                raise RuntimeError("The model already has some noise scales parameters, "
                                   "maybe you used twice a LSQ on the same model?.")

        super().__init__(model, min_size, float16, exclude, detect_bound)

    def _register_param(self, name, param, module, other):
        if other is not None:
            return self.__class__._QuantizedParam(
               name=name, param=param, module=module, scale=other.scale, other=other)
        # we want the initial number of bits to be init_bits.
        scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5
        scale = torch.nn.Parameter(scale)
        module.register_parameter(name + self.suffix, scale)
        return self.__class__._QuantizedParam(
           name=name, param=param, module=module, scale=scale, other=None)

    def clear_optimizer(self, optimizer: torch.optim.Optimizer):
        params = [qp.scale for qp in self._qparams]

        for group in optimizer.param_groups:
            new_params = []
            for q in list(group["params"]):
                matched = False
                for p in params:
                    if p is q:
                        matched = True
                if not matched:
                    new_params.append(q)
            group["params"][:] = new_params

    def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
        """
        Setup the optimizer to tune the scale parameter.
        Following [Esser et al. 2019], we use the same LR and weight decay
        as the base optimizer, unless specified otherwise.

        Args:
            optimizer (torch.Optimizer): optimizer to use.
            kwargs (dict): overrides for optimization parameters
        """
        assert not self._optimizer_setup
        self._optimizer_setup = True

        params = [qp.scale for qp in self._qparams]

        for group in optimizer.param_groups:
            for q in list(group["params"]):
                for p in params:
                    if p is q:
                        raise RuntimeError("You should create the optimizer "
                                           "before the quantizer!")

        group = {"params": params}
        group.update(kwargs)
        optimizer.add_param_group(group)

    def no_optimizer(self):
        """
        Call this if you do not want to use an optimizer.
        """
        self._optimizer_setup = True

    def model_size(self, exact=False):
        """
        Differentiable estimate of the model size.
        The size is returned in MB.

        If `exact` is True, then the output is no longer differentiable but
        reflect exactly an achievable size, even without compression,
        i.e.same as returned by `naive_model_size()`.
        """
        total = super().model_size()
        subtotal = 0
        for qparam in self._qparams:
            # only count the first appearance of a Parameter
            if qparam.other is not None:
                continue
            bits = qparam.param.numel() * self.bits
            subtotal += bits
            subtotal += 1 * 32  # param scale

        subtotal /= 2 ** 20 * 8  # bits -> MegaBytes
        return total + subtotal

    def true_model_size(self):
        """
        Naive model size without zlib compression.
        """
        return self.model_size(exact=True).item()

    def _pre_forward_train(self):
        if not self._optimizer_setup:
            raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
                               "before starting training.")
        for qparam in self._qparams:
            scale = qparam.scale
            quant, _ = quantize(qparam.param, scale, self.bits)
            # We bypass the checks by PyTorch on parameters being leafs
            qparam.module._parameters[qparam.name] = quant
        return True

    def _post_forward_train(self):
        for qparam in self._qparams:
            qparam.module._parameters[qparam.name] = qparam.param
        return True

    def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
        _, index = quantize(qparam.param, qparam.scale, self.bits)
        assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max()
        assert (index >= (-2 ** (self.bits - 1))).all(), index.min()
        return index.detach().short(), qparam.scale.detach()

    def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
        index, scale = quantized
        return index.float() * scale

    def _bit_pack_param(self, qparam, quantized, pack_fn):
        levels, scale = quantized
        packed = pack_fn(levels + 2 ** (self.bits - 1))
        return (packed, scale)

    def _bit_unpack_param(self, qparam, packed, unpack_fn):
        """Unpack bitpacked representation. Should be overriden
        """
        packed_levels, scale = packed
        levels = unpack_fn(
            packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
        levels -= 2 ** (self.bits - 1)
        return (levels, scale)

    def detach(self):
        super().detach()
        for qparam in self._qparams:
            delattr(qparam.module, qparam.name + self.suffix)

    def __repr__(self):
        return simple_repr(self)

Ancestors

Methods

def clear_optimizer(self, optimizer: torch.optim.optimizer.Optimizer)
Expand source code
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
    params = [qp.scale for qp in self._qparams]

    for group in optimizer.param_groups:
        new_params = []
        for q in list(group["params"]):
            matched = False
            for p in params:
                if p is q:
                    matched = True
            if not matched:
                new_params.append(q)
        group["params"][:] = new_params
def model_size(self, exact=False)

Differentiable estimate of the model size. The size is returned in MB.

If exact is True, then the output is no longer differentiable but reflect exactly an achievable size, even without compression, i.e.same as returned by naive_model_size().

Expand source code
def model_size(self, exact=False):
    """
    Differentiable estimate of the model size.
    The size is returned in MB.

    If `exact` is True, then the output is no longer differentiable but
    reflect exactly an achievable size, even without compression,
    i.e.same as returned by `naive_model_size()`.
    """
    total = super().model_size()
    subtotal = 0
    for qparam in self._qparams:
        # only count the first appearance of a Parameter
        if qparam.other is not None:
            continue
        bits = qparam.param.numel() * self.bits
        subtotal += bits
        subtotal += 1 * 32  # param scale

    subtotal /= 2 ** 20 * 8  # bits -> MegaBytes
    return total + subtotal
def no_optimizer(self)

Call this if you do not want to use an optimizer.

Expand source code
def no_optimizer(self):
    """
    Call this if you do not want to use an optimizer.
    """
    self._optimizer_setup = True
def setup_optimizer(self, optimizer: torch.optim.optimizer.Optimizer, **kwargs)

Setup the optimizer to tune the scale parameter. Following [Esser et al. 2019], we use the same LR and weight decay as the base optimizer, unless specified otherwise.

Args

optimizer : torch.Optimizer
optimizer to use.
kwargs : dict
overrides for optimization parameters
Expand source code
def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
    """
    Setup the optimizer to tune the scale parameter.
    Following [Esser et al. 2019], we use the same LR and weight decay
    as the base optimizer, unless specified otherwise.

    Args:
        optimizer (torch.Optimizer): optimizer to use.
        kwargs (dict): overrides for optimization parameters
    """
    assert not self._optimizer_setup
    self._optimizer_setup = True

    params = [qp.scale for qp in self._qparams]

    for group in optimizer.param_groups:
        for q in list(group["params"]):
            for p in params:
                if p is q:
                    raise RuntimeError("You should create the optimizer "
                                       "before the quantizer!")

    group = {"params": params}
    group.update(kwargs)
    optimizer.add_param_group(group)
def true_model_size(self)

Naive model size without zlib compression.

Expand source code
def true_model_size(self):
    """
    Naive model size without zlib compression.
    """
    return self.model_size(exact=True).item()

Inherited members