Source code for kats.models.model
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Generic, Optional, TypeVar
import pandas as pd
from kats.consts import TimeSeriesData
from matplotlib import pyplot as plt
ParamsType = TypeVar("ParamsType")
[docs]class Model(Generic[ParamsType]):
    __slots__ = ["data"]
    """Base forecasting model
    This is the parent class for all forecasting models in Kats
    Attributes:
        data: `TimeSeriesData` object
        params: model parameters
        validate_frequency: validate the frequency of time series
        validate_dimension: validate the dimension of time series
    """
    def __init__(
        self,
        data: Optional[TimeSeriesData],
        params: ParamsType,
        validate_frequency: bool = False,
        validate_dimension: bool = False,
    ) -> None:
        self.data = data
        self.params = params
        self.__type__ = "model"
        if data is not None:
            self.data.validate_data(validate_frequency, validate_dimension)
[docs]    def setup_data(self):
        """abstract method to set up dataset
        This is a declaration for setup data method
        """
        pass 
[docs]    def fit(self):
        """abstract method to fit model
        This is a declaration for model fitting
        """
        pass 
[docs]    def predict(self, *_args, **_kwargs):
        """abstract method to predict
        This is a declaration for predict method
        """
        pass 
[docs]    @staticmethod
    def plot(
        data: TimeSeriesData,
        fcst: pd.DataFrame,
        include_history=False,
    ) -> None:
        """plot method for forecasting models
        This method provides the plotting functionality for all forecasting
        models.
        Args:
            data: `TimeSeriesData`, the historical time series data set
            fcst: forecasted results from forecasting models
            include_history: if True, include the historical data when plotting.
        """
        logging.info("Generating chart for forecast result.")
        fig = plt.figure(facecolor="w", figsize=(10, 6))
        ax = fig.add_subplot(111)
        ax.plot(pd.to_datetime(data.time), data.value, "k")
        last_date = data.time.max()
        steps = fcst.shape[0]
        freq = pd.infer_freq(data.time)
        dates = pd.date_range(start=last_date, periods=steps + 1, freq=freq)
        dates_to_plot = dates[dates != last_date]  # Return correct number of periods
        fcst_dates = dates_to_plot.to_pydatetime()
        if include_history:
            ax.plot(fcst.time, fcst.fcst, ls="-", c="#4267B2")
            if ("fcst_lower" in fcst.columns) and ("fcst_upper" in fcst.columns):
                ax.fill_between(
                    fcst.time,
                    fcst.fcst_lower,
                    fcst.fcst_upper,
                    color="#4267B2",
                    alpha=0.2,
                )
        else:
            ax.plot(fcst_dates, fcst.fcst, ls="-", c="#4267B2")
            if ("fcst_lower" in fcst.columns) and ("fcst_upper" in fcst.columns):
                ax.fill_between(
                    fcst_dates,
                    fcst.fcst_lower,
                    fcst.fcst_upper,
                    color="#4267B2",
                    alpha=0.2,
                )
        ax.grid(True, which="major", c="gray", ls="-", lw=1, alpha=0.2)
        ax.set_xlabel(xlabel="time")
        ax.set_ylabel(ylabel="y")
        fig.tight_layout() 
[docs]    @staticmethod
    def get_parameter_search_space():
        """method to query default parameter search space
        abstract method to be implemented by downstream forecasting models
        """
        pass