Add files via upload
This commit is contained in:
209
diffq/ts_export.py
Normal file
209
diffq/ts_export.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""TorchScript export support.
|
||||
We have to do a lot of black magic for TorchScript to be happy
|
||||
because we cannot dynamically allocate new weights when loading the model.
|
||||
|
||||
Here is how it works:
|
||||
- we generate code in a temporary python file for the given model that explicitely
|
||||
override all the weights on the first forward from their packed version.
|
||||
This is because TorchScript does not let us iterate over parameters in a generic manner.
|
||||
- we zero out all the original weights. We cannot simply remove those weights
|
||||
because TorchScript won't let us recreate them.
|
||||
- A TorchScript file is just a zip file, but stored without compression.
|
||||
In order to remove the cost of storing the zeroed out weights, we unzip the file,
|
||||
and zip it again with compression.
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
import sys
|
||||
import typing as tp
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from torch import jit
|
||||
|
||||
from .diffq import DiffQuantizer
|
||||
from .uniform import uniform_unquantize
|
||||
from .torch_pack import unpack
|
||||
|
||||
_DiffQPacked = tp.Tuple[
|
||||
tp.List[tp.Optional[torch.Tensor]], tp.Tuple[float, float],
|
||||
torch.Tensor, tp.List[int]]
|
||||
|
||||
# This is the template for the generated class.
|
||||
TEMPLATE = '''
|
||||
import typing as tp
|
||||
import torch
|
||||
from torch import jit
|
||||
|
||||
from diffq.ts_export import _unpack_param, _DiffQPacked
|
||||
|
||||
from {module} import {klass}
|
||||
|
||||
|
||||
class DiffQTSModel(torch.nn.Module):
|
||||
def __init__(self, model: {klass}, group_size: int, min_bits: int,
|
||||
packed: tp.List[_DiffQPacked]):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
self.min_bits = min_bits
|
||||
self.model = model
|
||||
self._unpacked = False
|
||||
self._packed = packed
|
||||
|
||||
@jit.export
|
||||
def unpack(self):
|
||||
"""
|
||||
Unpack the weights, automatically called on the first forward,
|
||||
or explicitely."""
|
||||
if self._unpacked:
|
||||
return
|
||||
{unpack_assigns}
|
||||
self._unpacked = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.unpack()
|
||||
return self.model.forward(x)
|
||||
'''
|
||||
|
||||
# those are the assignments for each quantized weight.
|
||||
UNPACK_ASSIGN = (' ' * 8) + ('self.model{full_name}.data[:] = '
|
||||
'_unpack_param(self._packed[{index}], '
|
||||
'group_size=self.group_size, min_bits=self.min_bits)')
|
||||
UNPACK_ASSIGN_SAME = (' ' * 8) + 'self.model{full_name} = self.model{other_name}'
|
||||
|
||||
|
||||
def export(quantizer: DiffQuantizer, path: tp.Union[str, Path]):
|
||||
"""Export the given quantized model to the given path.
|
||||
We must save the quantized model ourselves, as we need to recompress
|
||||
the zip archive afterwards.
|
||||
"""
|
||||
packed: tp.List[_DiffQPacked] = []
|
||||
uniq_name = ''.join([random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(12)])
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
sys.path.insert(0, tmpdir)
|
||||
try:
|
||||
code = _codegen(quantizer)
|
||||
with open(Path(tmpdir) / f'{uniq_name}.py', 'w') as f:
|
||||
f.write(code)
|
||||
module = importlib.import_module(uniq_name)
|
||||
ts_klass = module.DiffQTSModel
|
||||
state = quantizer.get_quantized_state(packed=True, torch_pack=True)
|
||||
quantized = state["quantized"]
|
||||
for qparam in quantizer._qparams:
|
||||
if qparam.other is None:
|
||||
levels, scales, bits = quantized.pop(0)
|
||||
size = qparam.param.size()
|
||||
packed.append((levels, scales, bits, list(size)))
|
||||
qparam.param.data.zero_()
|
||||
quantizer.detach()
|
||||
ts_premodel = ts_klass(quantizer.model, quantizer.group_size,
|
||||
quantizer.min_bits, packed)
|
||||
ts_model = jit.script(ts_premodel)
|
||||
if path is not None:
|
||||
jit.save(ts_model, path)
|
||||
recompress(path)
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
return ts_model
|
||||
|
||||
|
||||
def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor:
|
||||
"""Function called from TorchScript on the first forward to decode the
|
||||
packed weights to FP32.
|
||||
"""
|
||||
packed_all_levels, scales, packed_bits, shape = packed
|
||||
numel = 1
|
||||
for dim in shape:
|
||||
numel *= dim
|
||||
bits = unpack(packed_bits, numel // group_size) + min_bits
|
||||
levels = torch.empty(bits.numel(), group_size, dtype=torch.short)
|
||||
for idx, packed_levels in enumerate(packed_all_levels):
|
||||
bit = idx + 1
|
||||
if packed_levels is not None:
|
||||
sub_levels = levels[bits == bit]
|
||||
levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels)
|
||||
bits = bits[:, None]
|
||||
unquant = uniform_unquantize(levels, scales, bits)
|
||||
if len(shape) == 4:
|
||||
return unquant.view(shape[0], shape[1], shape[2], shape[3])
|
||||
elif len(shape) == 3:
|
||||
return unquant.view(shape[0], shape[1], shape[2])
|
||||
elif len(shape) == 2:
|
||||
return unquant.view(shape[0], shape[1])
|
||||
elif len(shape) == 1:
|
||||
return unquant.view(shape[0])
|
||||
else:
|
||||
raise RuntimeError("Invalid numbr of dim")
|
||||
|
||||
|
||||
def recompress(path: tp.Union[str, Path]):
|
||||
"""After having saved the torchscript file, this will recompress it
|
||||
to make sure all the zeroed out parameters don't actually take any space.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with zipfile.ZipFile(path) as zipin:
|
||||
zipin.extractall(tmpdir)
|
||||
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED,
|
||||
compresslevel=1) as zipout:
|
||||
for root, folders, files in os.walk(tmpdir):
|
||||
for file in files:
|
||||
fp = Path(root) / file
|
||||
name = fp.relative_to(tmpdir)
|
||||
zipout.write(fp, name)
|
||||
|
||||
|
||||
def _get_full_name_access(full_name):
|
||||
# When generating code, we need to handle attributes vs. indexing.
|
||||
parts = []
|
||||
for part in full_name.split("."):
|
||||
try:
|
||||
index = int(part)
|
||||
except ValueError:
|
||||
parts.append("." + part)
|
||||
else:
|
||||
parts.append(f"[{index}]")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _codegen(quantizer: DiffQuantizer):
|
||||
# Generates the code for the given quantizer
|
||||
module = quantizer.model.__class__.__module__
|
||||
klass = quantizer.model.__class__.__name__
|
||||
model = quantizer.model
|
||||
|
||||
assert not quantizer.float16
|
||||
names = {}
|
||||
for mod_name, mod in model.named_modules():
|
||||
names[mod] = mod_name
|
||||
unpack_assigns = []
|
||||
|
||||
index = 0
|
||||
for qparam in quantizer._qparams:
|
||||
mod_name = names[qparam.module]
|
||||
if mod_name == '':
|
||||
full_name = qparam.name
|
||||
else:
|
||||
full_name = mod_name + '.' + qparam.name
|
||||
full_name = _get_full_name_access(full_name)
|
||||
if qparam.other is None:
|
||||
unpack_assigns.append(UNPACK_ASSIGN.format(full_name=full_name, index=index))
|
||||
index += 1
|
||||
else:
|
||||
other_name = names[(qparam.other.module, qparam.other.name)]
|
||||
other_name = _get_full_name_access(other_name)
|
||||
unpack_assigns.append(
|
||||
UNPACK_ASSIGN_SAME.format(full_name=full_name, other_name=other_name))
|
||||
|
||||
return TEMPLATE.format(
|
||||
module=module,
|
||||
klass=klass,
|
||||
unpack_assigns='\n'.join(unpack_assigns))
|
||||
Reference in New Issue
Block a user