Module diffq.ts_export

TorchScript export support. We have to do a lot of black magic for TorchScript to be happy because we cannot dynamically allocate new weights when loading the model.

Here is how it works: - we generate code in a temporary python file for the given model that explicitely override all the weights on the first forward from their packed version. This is because TorchScript does not let us iterate over parameters in a generic manner. - we zero out all the original weights. We cannot simply remove those weights because TorchScript won't let us recreate them. - A TorchScript file is just a zip file, but stored without compression. In order to remove the cost of storing the zeroed out weights, we unzip the file, and zip it again with compression.

Expand source code
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""TorchScript export support.
We have to do a lot of black magic for TorchScript to be happy
because we cannot dynamically allocate new weights when loading the model.

Here is how it works:
- we generate code in a temporary python file for the given model that explicitely
    override all the weights on the first forward from their packed version.
    This is because TorchScript does not let us iterate over parameters in a generic manner.
- we zero out all the original weights. We cannot simply remove those weights
    because TorchScript won't let us recreate them.
- A TorchScript file is just a zip file, but stored without compression.
    In order to remove the cost of storing the zeroed out weights, we unzip the file,
    and zip it again with compression.
"""
import importlib
import os
from pathlib import Path
import random
import sys
import typing as tp
import tempfile
import zipfile

import torch
from torch import jit

from .diffq import DiffQuantizer
from .uniform import uniform_unquantize
from .torch_pack import unpack

_DiffQPacked = tp.Tuple[
    tp.List[tp.Optional[torch.Tensor]], tp.Tuple[float, float],
    torch.Tensor, tp.List[int]]

# This is the template for the generated class.
TEMPLATE = '''
import typing as tp
import torch
from torch import jit

from diffq.ts_export import _unpack_param, _DiffQPacked

from {module} import {klass}


class DiffQTSModel(torch.nn.Module):
    def __init__(self, model: {klass}, group_size: int, min_bits: int,
                 packed: tp.List[_DiffQPacked]):
        super().__init__()
        self.group_size = group_size
        self.min_bits = min_bits
        self.model = model
        self._unpacked = False
        self._packed = packed

    @jit.export
    def unpack(self):
        """
        Unpack the weights, automatically called on the first forward,
        or explicitely."""
        if self._unpacked:
            return
{unpack_assigns}
        self._unpacked = True

    def forward(self, x: torch.Tensor):
        self.unpack()
        return self.model.forward(x)
'''

# those are the assignments for each quantized weight.
UNPACK_ASSIGN = (' ' * 8) + ('self.model{full_name}.data[:] = '
                             '_unpack_param(self._packed[{index}], '
                             'group_size=self.group_size, min_bits=self.min_bits)')
UNPACK_ASSIGN_SAME = (' ' * 8) + 'self.model{full_name} = self.model{other_name}'


def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]):
    """Export the given quantized model to the given path.
    We must save the quantized model ourselves, as we need to recompress
    the zip archive afterwards.

    ..Warning:: This will completely destroy the model and the quantizer, so you probably
        want to call this only once at the end of training.
    """
    packed: tp.List[_DiffQPacked] = []
    uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)])
    with tempfile.TemporaryDirectory() as tmpdir:
        sys.path.insert(0, tmpdir)
        try:
            code = _codegen(quantizer)
            with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f:
                f.write(code)
            module = importlib.import_module(uniq_name)
            ts_klass = module.DiffQTSModel
            state = quantizer.get_quantized_state(packed=True, torch_pack=True)
            quantized = state["quantized"]
            for qparam in quantizer._qparams:
                if qparam.other is None:
                    levels, scales, bits = quantized.pop(0)
                    size = qparam.param.size()
                    packed.append((levels, scales, bits, list(size)))
                    qparam.param.data.zero_()
            quantizer.detach()
            ts_premodel = ts_klass(quantizer.model, quantizer.group_size,
                                   quantizer.min_bits, packed)
            ts_model = jit.script(ts_premodel)
            if path is not None:
                jit.save(ts_model, path)
                recompress(path)
        finally:
            sys.path.pop(0)

    return ts_model


def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor:
    """Function called from TorchScript on the first forward to decode the
    packed weights to FP32.
    """
    packed_all_levels, scales, packed_bits, shape = packed
    numel = 1
    for dim in shape:
        numel *= dim
    bits = unpack(packed_bits, numel // group_size) + min_bits
    levels = torch.empty(bits.numel(), group_size, dtype=torch.short)
    for idx, packed_levels in enumerate(packed_all_levels):
        bit = idx + 1
        if packed_levels is not None:
            sub_levels = levels[bits == bit]
            levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels)
    bits = bits[:, None]
    unquant = uniform_unquantize(levels, scales, bits)
    if len(shape) == 4:
        return unquant.view(shape[0], shape[1], shape[2], shape[3])
    elif len(shape) == 3:
        return unquant.view(shape[0], shape[1], shape[2])
    elif len(shape) == 2:
        return unquant.view(shape[0], shape[1])
    elif len(shape) == 1:
        return unquant.view(shape[0])
    else:
        raise RuntimeError("Invalid numbr of dim")


def recompress(path: tp.Union[str, Path]):
    """After having saved the torchscript file, this will recompress it
    to make sure all the zeroed out parameters don't actually take any space.
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        with zipfile.ZipFile(path) as zipin:
            zipin.extractall(tmpdir)
        with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED,
                             compresslevel=1) as zipout:
            for root, folders, files in os.walk(tmpdir):
                for file in files:
                    fp = Path(root) / file
                    name = fp.relative_to(tmpdir)
                    zipout.write(fp, name)


def _get_full_name_access(full_name):
    # When generating code, we need to handle attributes vs. indexing.
    parts = []
    for part in full_name.split("."):
        try:
            index = int(part)
        except ValueError:
            parts.append("." + part)
        else:
            parts.append(f"[{index}]")
    return "".join(parts)


def _codegen(quantizer: DiffQuantizer):
    # Generates the code for the given quantizer
    module = quantizer.model.__class__.__module__
    klass = quantizer.model.__class__.__name__
    model = quantizer.model

    assert not quantizer.float16
    names = {}
    for mod_name, mod in model.named_modules():
        names[mod] = mod_name
    unpack_assigns = []

    index = 0
    for qparam in quantizer._qparams:
        mod_name = names[qparam.module]
        if mod_name == '':
            full_name = qparam.name
        else:
            full_name = mod_name + '.' + qparam.name
        full_name = _get_full_name_access(full_name)
        if qparam.other is None:
            unpack_assigns.append(UNPACK_ASSIGN.format(full_name=full_name, index=index))
            index += 1
        else:
            other_name = names[(qparam.other.module, qparam.other.name)]
            other_name = _get_full_name_access(other_name)
            unpack_assigns.append(
                UNPACK_ASSIGN_SAME.format(full_name=full_name, other_name=other_name))

    return TEMPLATE.format(
        module=module,
        klass=klass,
        unpack_assigns='\n'.join(unpack_assigns))

Functions

def export(quantizer: DiffQuantizer, path: Union[str, pathlib.Path])

Export the given quantized model to the given path. We must save the quantized model ourselves, as we need to recompress the zip archive afterwards.

Warning: This will completely destroy the model and the quantizer, so you probably

want to call this only once at the end of training.

Expand source code
def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]):
    """Export the given quantized model to the given path.
    We must save the quantized model ourselves, as we need to recompress
    the zip archive afterwards.

    ..Warning:: This will completely destroy the model and the quantizer, so you probably
        want to call this only once at the end of training.
    """
    packed: tp.List[_DiffQPacked] = []
    uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)])
    with tempfile.TemporaryDirectory() as tmpdir:
        sys.path.insert(0, tmpdir)
        try:
            code = _codegen(quantizer)
            with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f:
                f.write(code)
            module = importlib.import_module(uniq_name)
            ts_klass = module.DiffQTSModel
            state = quantizer.get_quantized_state(packed=True, torch_pack=True)
            quantized = state["quantized"]
            for qparam in quantizer._qparams:
                if qparam.other is None:
                    levels, scales, bits = quantized.pop(0)
                    size = qparam.param.size()
                    packed.append((levels, scales, bits, list(size)))
                    qparam.param.data.zero_()
            quantizer.detach()
            ts_premodel = ts_klass(quantizer.model, quantizer.group_size,
                                   quantizer.min_bits, packed)
            ts_model = jit.script(ts_premodel)
            if path is not None:
                jit.save(ts_model, path)
                recompress(path)
        finally:
            sys.path.pop(0)

    return ts_model
def recompress(path: Union[str, pathlib.Path])

After having saved the torchscript file, this will recompress it to make sure all the zeroed out parameters don't actually take any space.

Expand source code
def recompress(path: tp.Union[str, Path]):
    """After having saved the torchscript file, this will recompress it
    to make sure all the zeroed out parameters don't actually take any space.
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        with zipfile.ZipFile(path) as zipin:
            zipin.extractall(tmpdir)
        with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED,
                             compresslevel=1) as zipout:
            for root, folders, files in os.walk(tmpdir):
                for file in files:
                    fp = Path(root) / file
                    name = fp.relative_to(tmpdir)
                    zipout.write(fp, name)