Module diffq.lsq
Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153.
Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Learnt-Stepsize quantizer from [Esser et al. 2019] https://arxiv.org/abs/1902.08153.
"""
from dataclasses import dataclass
import typing as tp
import torch
from .base import BaseQuantizer
from .utils import capture_init, simple_repr
class LSQ(BaseQuantizer):
"""Implements weight only quantization based on [Esser et al. 2019].
https://arxiv.org/abs/1902.08153
"""
@dataclass
class _QuantizedParam(BaseQuantizer._QuantizedParam):
scale: torch.nn.Parameter
@capture_init
def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01,
float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True):
assert 0 < bits <= 15
self.suffix = suffix
self._optimizer_setup = False
self.bits = bits
for name, _ in model.named_parameters():
if name.endswith(suffix):
raise RuntimeError("The model already has some noise scales parameters, "
"maybe you used twice a LSQ on the same model?.")
super().__init__(model, min_size, float16, exclude, detect_bound)
def _register_param(self, name, param, module, other):
if other is not None:
return self.__class__._QuantizedParam(
name=name, param=param, module=module, scale=other.scale, other=other)
# we want the initial number of bits to be init_bits.
scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5
scale = torch.nn.Parameter(scale)
module.register_parameter(name + self.suffix, scale)
return self.__class__._QuantizedParam(
name=name, param=param, module=module, scale=scale, other=None)
def clear_optimizer(self, optimizer: torch.optim.Optimizer):
params = [qp.scale for qp in self._qparams]
for group in optimizer.param_groups:
new_params = []
for q in list(group["params"]):
matched = False
for p in params:
if p is q:
matched = True
if not matched:
new_params.append(q)
group["params"][:] = new_params
def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
"""
Setup the optimizer to tune the scale parameter.
Following [Esser et al. 2019], we use the same LR and weight decay
as the base optimizer, unless specified otherwise.
Args:
optimizer (torch.Optimizer): optimizer to use.
kwargs (dict): overrides for optimization parameters
"""
assert not self._optimizer_setup
self._optimizer_setup = True
params = [qp.scale for qp in self._qparams]
for group in optimizer.param_groups:
for q in list(group["params"]):
for p in params:
if p is q:
raise RuntimeError("You should create the optimizer "
"before the quantizer!")
group = {"params": params}
group.update(kwargs)
optimizer.add_param_group(group)
def no_optimizer(self):
"""
Call this if you do not want to use an optimizer.
"""
self._optimizer_setup = True
def model_size(self, exact=False):
"""
Differentiable estimate of the model size.
The size is returned in MB.
If `exact` is True, then the output is no longer differentiable but
reflect exactly an achievable size, even without compression,
i.e.same as returned by `naive_model_size()`.
"""
total = super().model_size()
subtotal = 0
for qparam in self._qparams:
# only count the first appearance of a Parameter
if qparam.other is not None:
continue
bits = qparam.param.numel() * self.bits
subtotal += bits
subtotal += 1 * 32 # param scale
subtotal /= 2 ** 20 * 8 # bits -> MegaBytes
return total + subtotal
def true_model_size(self):
"""
Naive model size without zlib compression.
"""
return self.model_size(exact=True).item()
def _pre_forward_train(self):
if not self._optimizer_setup:
raise RuntimeError("You must call `setup_optimizer()` on your optimizer "
"before starting training.")
for qparam in self._qparams:
scale = qparam.scale
quant, _ = quantize(qparam.param, scale, self.bits)
# We bypass the checks by PyTorch on parameters being leafs
qparam.module._parameters[qparam.name] = quant
return True
def _post_forward_train(self):
for qparam in self._qparams:
qparam.module._parameters[qparam.name] = qparam.param
return True
def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any:
_, index = quantize(qparam.param, qparam.scale, self.bits)
assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max()
assert (index >= (-2 ** (self.bits - 1))).all(), index.min()
return index.detach().short(), qparam.scale.detach()
def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor:
index, scale = quantized
return index.float() * scale
def _bit_pack_param(self, qparam, quantized, pack_fn):
levels, scale = quantized
packed = pack_fn(levels + 2 ** (self.bits - 1))
return (packed, scale)
def _bit_unpack_param(self, qparam, packed, unpack_fn):
"""Unpack bitpacked representation. Should be overriden
"""
packed_levels, scale = packed
levels = unpack_fn(
packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param)
levels -= 2 ** (self.bits - 1)
return (levels, scale)
def detach(self):
super().detach()
for qparam in self._qparams:
delattr(qparam.module, qparam.name + self.suffix)
def __repr__(self):
return simple_repr(self)
def roundpass(x):
return (x.round() - x).detach() + x
def gradscale(x, scale):
return (x - x * scale).detach() + x * scale
def quantize(tensor, scale, bits):
low = - 2 ** (bits - 1)
high = 2 ** (bits - 1) - 1
scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5)
index = tensor / scale
index = index.clamp(low, high)
index = roundpass(index)
return index * scale, index
Functions
def gradscale(x, scale)
-
Expand source code
def gradscale(x, scale): return (x - x * scale).detach() + x * scale
def quantize(tensor, scale, bits)
-
Expand source code
def quantize(tensor, scale, bits): low = - 2 ** (bits - 1) high = 2 ** (bits - 1) - 1 scale = gradscale(scale, 1 / (tensor.numel() * high)**0.5) index = tensor / scale index = index.clamp(low, high) index = roundpass(index) return index * scale, index
def roundpass(x)
-
Expand source code
def roundpass(x): return (x.round() - x).detach() + x
Classes
class LSQ (model: torch.nn.modules.module.Module, bits: int = 8, min_size: float = 0.01, float16: bool = False, suffix: str = '_lsq', exclude=[], detect_bound=True)
-
Implements weight only quantization based on [Esser et al. 2019]. https://arxiv.org/abs/1902.08153
Expand source code
class LSQ(BaseQuantizer): """Implements weight only quantization based on [Esser et al. 2019]. https://arxiv.org/abs/1902.08153 """ @dataclass class _QuantizedParam(BaseQuantizer._QuantizedParam): scale: torch.nn.Parameter @capture_init def __init__(self, model: torch.nn.Module, bits: int = 8, min_size: float = 0.01, float16: bool = False, suffix: str = "_lsq", exclude=[], detect_bound=True): assert 0 < bits <= 15 self.suffix = suffix self._optimizer_setup = False self.bits = bits for name, _ in model.named_parameters(): if name.endswith(suffix): raise RuntimeError("The model already has some noise scales parameters, " "maybe you used twice a LSQ on the same model?.") super().__init__(model, min_size, float16, exclude, detect_bound) def _register_param(self, name, param, module, other): if other is not None: return self.__class__._QuantizedParam( name=name, param=param, module=module, scale=other.scale, other=other) # we want the initial number of bits to be init_bits. scale = 2 * param.data.abs().mean() / (2 ** (self.bits - 1))**0.5 scale = torch.nn.Parameter(scale) module.register_parameter(name + self.suffix, scale) return self.__class__._QuantizedParam( name=name, param=param, module=module, scale=scale, other=None) def clear_optimizer(self, optimizer: torch.optim.Optimizer): params = [qp.scale for qp in self._qparams] for group in optimizer.param_groups: new_params = [] for q in list(group["params"]): matched = False for p in params: if p is q: matched = True if not matched: new_params.append(q) group["params"][:] = new_params def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs): """ Setup the optimizer to tune the scale parameter. Following [Esser et al. 2019], we use the same LR and weight decay as the base optimizer, unless specified otherwise. Args: optimizer (torch.Optimizer): optimizer to use. kwargs (dict): overrides for optimization parameters """ assert not self._optimizer_setup self._optimizer_setup = True params = [qp.scale for qp in self._qparams] for group in optimizer.param_groups: for q in list(group["params"]): for p in params: if p is q: raise RuntimeError("You should create the optimizer " "before the quantizer!") group = {"params": params} group.update(kwargs) optimizer.add_param_group(group) def no_optimizer(self): """ Call this if you do not want to use an optimizer. """ self._optimizer_setup = True def model_size(self, exact=False): """ Differentiable estimate of the model size. The size is returned in MB. If `exact` is True, then the output is no longer differentiable but reflect exactly an achievable size, even without compression, i.e.same as returned by `naive_model_size()`. """ total = super().model_size() subtotal = 0 for qparam in self._qparams: # only count the first appearance of a Parameter if qparam.other is not None: continue bits = qparam.param.numel() * self.bits subtotal += bits subtotal += 1 * 32 # param scale subtotal /= 2 ** 20 * 8 # bits -> MegaBytes return total + subtotal def true_model_size(self): """ Naive model size without zlib compression. """ return self.model_size(exact=True).item() def _pre_forward_train(self): if not self._optimizer_setup: raise RuntimeError("You must call `setup_optimizer()` on your optimizer " "before starting training.") for qparam in self._qparams: scale = qparam.scale quant, _ = quantize(qparam.param, scale, self.bits) # We bypass the checks by PyTorch on parameters being leafs qparam.module._parameters[qparam.name] = quant return True def _post_forward_train(self): for qparam in self._qparams: qparam.module._parameters[qparam.name] = qparam.param return True def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: _, index = quantize(qparam.param, qparam.scale, self.bits) assert (index <= (2 ** (self.bits - 1) - 1)).all(), index.max() assert (index >= (-2 ** (self.bits - 1))).all(), index.min() return index.detach().short(), qparam.scale.detach() def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: index, scale = quantized return index.float() * scale def _bit_pack_param(self, qparam, quantized, pack_fn): levels, scale = quantized packed = pack_fn(levels + 2 ** (self.bits - 1)) return (packed, scale) def _bit_unpack_param(self, qparam, packed, unpack_fn): """Unpack bitpacked representation. Should be overriden """ packed_levels, scale = packed levels = unpack_fn( packed_levels, qparam.param.numel()).to(qparam.param.device).view_as(qparam.param) levels -= 2 ** (self.bits - 1) return (levels, scale) def detach(self): super().detach() for qparam in self._qparams: delattr(qparam.module, qparam.name + self.suffix) def __repr__(self): return simple_repr(self)
Ancestors
Methods
def clear_optimizer(self, optimizer: torch.optim.optimizer.Optimizer)
-
Expand source code
def clear_optimizer(self, optimizer: torch.optim.Optimizer): params = [qp.scale for qp in self._qparams] for group in optimizer.param_groups: new_params = [] for q in list(group["params"]): matched = False for p in params: if p is q: matched = True if not matched: new_params.append(q) group["params"][:] = new_params
def model_size(self, exact=False)
-
Differentiable estimate of the model size. The size is returned in MB.
If
exact
is True, then the output is no longer differentiable but reflect exactly an achievable size, even without compression, i.e.same as returned bynaive_model_size()
.Expand source code
def model_size(self, exact=False): """ Differentiable estimate of the model size. The size is returned in MB. If `exact` is True, then the output is no longer differentiable but reflect exactly an achievable size, even without compression, i.e.same as returned by `naive_model_size()`. """ total = super().model_size() subtotal = 0 for qparam in self._qparams: # only count the first appearance of a Parameter if qparam.other is not None: continue bits = qparam.param.numel() * self.bits subtotal += bits subtotal += 1 * 32 # param scale subtotal /= 2 ** 20 * 8 # bits -> MegaBytes return total + subtotal
def no_optimizer(self)
-
Call this if you do not want to use an optimizer.
Expand source code
def no_optimizer(self): """ Call this if you do not want to use an optimizer. """ self._optimizer_setup = True
def setup_optimizer(self, optimizer: torch.optim.optimizer.Optimizer, **kwargs)
-
Setup the optimizer to tune the scale parameter. Following [Esser et al. 2019], we use the same LR and weight decay as the base optimizer, unless specified otherwise.
Args
optimizer
:torch.Optimizer
- optimizer to use.
kwargs
:dict
- overrides for optimization parameters
Expand source code
def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs): """ Setup the optimizer to tune the scale parameter. Following [Esser et al. 2019], we use the same LR and weight decay as the base optimizer, unless specified otherwise. Args: optimizer (torch.Optimizer): optimizer to use. kwargs (dict): overrides for optimization parameters """ assert not self._optimizer_setup self._optimizer_setup = True params = [qp.scale for qp in self._qparams] for group in optimizer.param_groups: for q in list(group["params"]): for p in params: if p is q: raise RuntimeError("You should create the optimizer " "before the quantizer!") group = {"params": params} group.update(kwargs) optimizer.add_param_group(group)
def true_model_size(self)
-
Naive model size without zlib compression.
Expand source code
def true_model_size(self): """ Naive model size without zlib compression. """ return self.model_size(exact=True).item()
Inherited members