Source code for kats.models.nowcasting.model_io
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pickle
from typing import Any
class SimplePickleSerializer:
def _jdefault(self, o):
if isinstance(o, set):
return list(o)
if isinstance(o, bool):
return str(o).lower()
if isinstance(o, int):
return str(o)
if isinstance(o, float):
return str(o)
return o.__dict__
def serialize(self, obj: Any) -> bytes:
"""Performs model saving.
Args:
obj is an object to be saved. Usually it is an sklearn model.
Returns:
A bytes object which is the compressed model.
"""
if obj is None:
return b""
return pickle.dumps(obj)
def deserialize(self, serialized_data: bytes) -> Any:
"""Performs model decoding.
Args:
serialized_data is a bytes object to be decoded.
Returns:
A decompressed model. Usually a sklearn model.
"""
if serialized_data is None:
return None
decoded = serialized_data # .decode("utf-8")
if not decoded:
return None
return pickle.loads(decoded)
[docs]def serialize_for_zippy(input: Any) -> bytes:
"""Performs model compression.
Args:
Input is an sklearn model.
Returns:
A compressed version of the model.
"""
serializer = SimplePickleSerializer()
return serializer.serialize(input)
[docs]def deserialize_from_zippy(input: bytes, use_case_id=None) -> None:
"""Performs model serialization for Zippydb.
Args:
Input is an encoded sklearn model.
Returns:
A compressed version of the model.
"""
serializer = SimplePickleSerializer()
return serializer.deserialize(input)