Module diffq.uniform

Classic uniform quantization over n bits.

Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Classic uniform quantization over n bits.
"""
from typing import Tuple
import torch

from .base import BaseQuantizer
from .utils import capture_init, simple_repr


def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
    """
    Quantize the given weights over `bits` bits.

    Returns:
        - quantized levels
        - (min, max) range.

    """
    assert (bits >= 1).all() and (bits <= 15).all()
    num_levels = (2 ** bits.float()).long()
    mn = p.min().item()
    mx = p.max().item()
    p = (p - mn) / (mx - mn)  # put p in [0, 1]
    unit = 1 / (num_levels - 1)  # quantization unit
    levels = (p / unit).round()
    if (bits <= 8).all():
        levels = levels.byte()
    else:
        levels = levels.short()
    return levels, (mn, mx)


def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
                       bits: torch.Tensor = torch.tensor(8.)):
    """
    Unquantize the weights from the levels and scale. Return a float32 tensor.
    """
    mn, mx = scales
    num_levels = 2 ** bits.float()
    unit = 1 / (num_levels - 1)
    levels = levels.float()
    p = levels * unit  # in [0, 1]
    return p * (mx - mn) + mn


class UniformQuantizer(BaseQuantizer):
    @capture_init
    def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
                 float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
        """
        Args:
            model (torch.nn.Module): model to quantize
            bits (float): number of bits to quantize over.
            min_size (float): minimum size in MB of a parameter to be quantized.
            float16 (bool): if a layer is smaller than min_size, should we still do float16?
            qat (bool): perform quantized aware training.
            exclude (list[str]): list of patterns used to match parameters to exclude.
                For instance `['bias']` to exclude all bias terms.
            detect_bound (bool): if True, will detect bound parameters and reuse
                the same quantized tensor for both.
        """
        self.bits = float(bits)
        self.qat = qat

        super().__init__(model, min_size, float16, exclude, detect_bound)

    def __repr__(self):
        return simple_repr(self, )

    def _pre_forward_train(self):
        if self.qat:
            for qparam in self._qparams:
                if qparam.other is not None:
                    new_param = qparam.other.module._parameters[qparam.other.name]
                else:
                    quantized = self._quantize_param(qparam)
                    qvalue = self._unquantize_param(qparam, quantized)
                    new_param = qparam.param + (qvalue - qparam.param).detach()
                qparam.module._parameters[qparam.name] = new_param
            return True
        return False

    def _post_forward_train(self):
        if self.qat:
            for qparam in self._qparams:
                qparam.module._parameters[qparam.name] = qparam.param
            return True
        return False

    def _quantize_param(self, qparam):
        levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
        return (levels, scales)

    def _unquantize_param(self, qparam, quantized):
        levels, scales = quantized
        return uniform_unquantize(levels, scales, torch.tensor(self.bits))

    def _bit_pack_param(self, qparam, quantized, pack_fn):
        levels, scales = quantized
        packed = pack_fn(levels, self.bits)
        return (packed, scales)

    def _bit_unpack_param(self, qparam, packed, unpack_fn):
        """Unpack bitpacked representation. Should be overriden
        """
        packed_levels, scales = packed
        levels = unpack_fn(
            packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
        return (levels, scales)

    def model_size(self):
        """
        Non differentiable model size in MB.
        """
        total = super().model_size()
        subtotal = 0
        for qparam in self._qparams:
            if qparam.other is None:  # if parameter is bound, count only one copy.
                subtotal += self.bits * qparam.param.numel() + 64  # 2 float for the overall scales
        subtotal /= 2**20 * 8  # bits to MegaBytes
        return total + subtotal

    def true_model_size(self):
        """
        Return the true quantized model size, in MB, without extra
        compression.
        """
        return self.model_size().item()

Functions

def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = tensor(8.))

Quantize the given weights over bits bits.

Returns

  • quantized levels
  • (min, max) range.
Expand source code
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
    """
    Quantize the given weights over `bits` bits.

    Returns:
        - quantized levels
        - (min, max) range.

    """
    assert (bits >= 1).all() and (bits <= 15).all()
    num_levels = (2 ** bits.float()).long()
    mn = p.min().item()
    mx = p.max().item()
    p = (p - mn) / (mx - mn)  # put p in [0, 1]
    unit = 1 / (num_levels - 1)  # quantization unit
    levels = (p / unit).round()
    if (bits <= 8).all():
        levels = levels.byte()
    else:
        levels = levels.short()
    return levels, (mn, mx)
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float], bits: torch.Tensor = tensor(8.))

Unquantize the weights from the levels and scale. Return a float32 tensor.

Expand source code
def uniform_unquantize(levels: torch.Tensor, scales: Tuple[float, float],
                       bits: torch.Tensor = torch.tensor(8.)):
    """
    Unquantize the weights from the levels and scale. Return a float32 tensor.
    """
    mn, mx = scales
    num_levels = 2 ** bits.float()
    unit = 1 / (num_levels - 1)
    levels = levels.float()
    p = levels * unit  # in [0, 1]
    return p * (mx - mn) + mn

Classes

class UniformQuantizer (model: torch.nn.modules.module.Module, bits: float = 8.0, min_size: float = 0.01, float16: bool = False, qat: bool = False, exclude=[], detect_bound=True)

Args

model : torch.nn.Module
model to quantize
bits : float
number of bits to quantize over.
min_size : float
minimum size in MB of a parameter to be quantized.
float16 : bool
if a layer is smaller than min_size, should we still do float16?
qat : bool
perform quantized aware training.
exclude : list[str]
list of patterns used to match parameters to exclude. For instance ['bias'] to exclude all bias terms.
detect_bound : bool
if True, will detect bound parameters and reuse the same quantized tensor for both.
Expand source code
class UniformQuantizer(BaseQuantizer):
    @capture_init
    def __init__(self, model: torch.nn.Module, bits: float = 8., min_size: float = 0.01,
                 float16: bool = False, qat: bool = False, exclude=[], detect_bound=True):
        """
        Args:
            model (torch.nn.Module): model to quantize
            bits (float): number of bits to quantize over.
            min_size (float): minimum size in MB of a parameter to be quantized.
            float16 (bool): if a layer is smaller than min_size, should we still do float16?
            qat (bool): perform quantized aware training.
            exclude (list[str]): list of patterns used to match parameters to exclude.
                For instance `['bias']` to exclude all bias terms.
            detect_bound (bool): if True, will detect bound parameters and reuse
                the same quantized tensor for both.
        """
        self.bits = float(bits)
        self.qat = qat

        super().__init__(model, min_size, float16, exclude, detect_bound)

    def __repr__(self):
        return simple_repr(self, )

    def _pre_forward_train(self):
        if self.qat:
            for qparam in self._qparams:
                if qparam.other is not None:
                    new_param = qparam.other.module._parameters[qparam.other.name]
                else:
                    quantized = self._quantize_param(qparam)
                    qvalue = self._unquantize_param(qparam, quantized)
                    new_param = qparam.param + (qvalue - qparam.param).detach()
                qparam.module._parameters[qparam.name] = new_param
            return True
        return False

    def _post_forward_train(self):
        if self.qat:
            for qparam in self._qparams:
                qparam.module._parameters[qparam.name] = qparam.param
            return True
        return False

    def _quantize_param(self, qparam):
        levels, scales = uniform_quantize(qparam.param.data, torch.tensor(self.bits))
        return (levels, scales)

    def _unquantize_param(self, qparam, quantized):
        levels, scales = quantized
        return uniform_unquantize(levels, scales, torch.tensor(self.bits))

    def _bit_pack_param(self, qparam, quantized, pack_fn):
        levels, scales = quantized
        packed = pack_fn(levels, self.bits)
        return (packed, scales)

    def _bit_unpack_param(self, qparam, packed, unpack_fn):
        """Unpack bitpacked representation. Should be overriden
        """
        packed_levels, scales = packed
        levels = unpack_fn(
            packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
        return (levels, scales)

    def model_size(self):
        """
        Non differentiable model size in MB.
        """
        total = super().model_size()
        subtotal = 0
        for qparam in self._qparams:
            if qparam.other is None:  # if parameter is bound, count only one copy.
                subtotal += self.bits * qparam.param.numel() + 64  # 2 float for the overall scales
        subtotal /= 2**20 * 8  # bits to MegaBytes
        return total + subtotal

    def true_model_size(self):
        """
        Return the true quantized model size, in MB, without extra
        compression.
        """
        return self.model_size().item()

Ancestors

Methods

def model_size(self)

Non differentiable model size in MB.

Expand source code
def model_size(self):
    """
    Non differentiable model size in MB.
    """
    total = super().model_size()
    subtotal = 0
    for qparam in self._qparams:
        if qparam.other is None:  # if parameter is bound, count only one copy.
            subtotal += self.bits * qparam.param.numel() + 64  # 2 float for the overall scales
    subtotal /= 2**20 * 8  # bits to MegaBytes
    return total + subtotal

Inherited members