#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""The Holt Winters model is a time series forecast model that applies exponential smoothing three times, it serves as the extension of the simple exponential smoothing forecast model.
More details about the different exponential smoothing model can be found here:
https://en.wikipedia.org/wiki/Exponential_smoothing
In this module we adopt the Holt Winters model implementation from the statsmodels package, full details can be found as follows:
https://www.statsmodels.org/dev/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html
We rewrite the corresponding API to accommodate the Kats development style
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
from typing import Dict, List, Any
import kats.models.model as m
import pandas as pd
from kats.consts import Params, TimeSeriesData
from kats.utils.emp_confidence_int import EmpConfidenceInt
from kats.utils.parameter_tuning_utils import (
get_default_holtwinters_parameter_search_space,
)
from statsmodels.tsa.holtwinters import ExponentialSmoothing as HoltWinters
[docs]class HoltWintersParams(Params):
"""Parameter class for the HoltWinters model
Not all parameters from the statsmodels API have been implemented here, the full list of the parameter can be found:
https://www.statsmodels.org/dev/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html
Attributes:
trend: Optional; A string that specifies the type of trend component. Can be 'add' and 'mul' or 'additive' and 'multiplicative'. Default is 'add'
damped: Optional; A boolean indicates whether the trend should be damped or not. Default is False
seasonal: Optional; A string that specifies the type of seasonal component Can be 'add' and 'mul' or 'additive' and 'multiplicative'. Default is None
seasonal_periods: Optional; An integer that specifies the period for the seasonal component, e.g. 4 for quarterly data and 7 for weekly seasonality of daily data. Default is None
"""
__slots__ = ["trend", "damped", "seasonal", "seasonal_periods"]
def __init__(
self,
trend: str = "add",
damped: bool = False,
# pyre-fixme[9]: seasonal has type `str`; used as `None`.
seasonal: str = None,
# pyre-fixme[9]: seasonal_periods has type `int`; used as `None`.
seasonal_periods: int = None,
) -> None:
super().__init__()
self.trend = trend
self.damped = damped
self.seasonal = seasonal
self.seasonal_periods = seasonal_periods
self.validate_params()
logging.debug(
"Initialized HoltWintersParams with parameters. "
"trend:{trend},\
damped:{damped},\
seasonal:{seasonal},\
seasonal_periods{seasonal_periods}".format(
trend=trend,
damped=damped,
seasonal=seasonal,
seasonal_periods=seasonal_periods,
)
)
[docs] def validate_params(self):
"""Validate the types and values of the input parameters
Args:
None
Returns:
None
"""
if self.trend not in ["add", "mul", "additive", "multiplicative", None]:
msg = "trend parameter is not valid!\
use 'add' or 'mul' instead!"
logging.error(msg)
raise ValueError(msg)
if self.seasonal not in ["add", "mul", "additive", "multiplicative", None]:
msg = "seasonal parameter is not valid!\
use 'add' or 'mul' instead!"
logging.error(msg)
raise ValueError(msg)
[docs]class HoltWintersModel(m.Model):
"""Model class for the HoltWinters model
Attributes:
data: :class:`kats.consts.TimeSeriesData`, the input historical time series data from TimeSeriesData
params: The HoltWinters model parameters from HoltWintersParams
"""
def __init__(self, data: TimeSeriesData, params: HoltWintersParams) -> None:
super().__init__(data, params)
if not isinstance(self.data.value, pd.Series):
msg = "Only support univariate time series, but get {type}.".format(
type=type(self.data.value)
)
logging.error(msg)
raise ValueError(msg)
[docs] def fit(self, **kwargs) -> None:
"""Fit the model with the specified input parameters
"""
logging.debug("Call fit() with parameters:{kwargs}".format(kwargs=kwargs))
holtwinters = HoltWinters(
self.data.value,
trend=self.params.trend,
damped=self.params.damped,
seasonal=self.params.seasonal,
seasonal_periods=self.params.seasonal_periods,
)
# pyre-fixme[16]: `HoltWintersModel` has no attribute `model`.
self.model = holtwinters.fit()
logging.info("Fitted HoltWinters.")
# pyre-fixme[14]: `predict` overrides method defined in `Model` inconsistently.
[docs] def predict(self, steps: int, include_history: bool = False, **kwargs) -> pd.DataFrame:
"""Predict with fitted HoltWinters model
If the alpha keyword argument is specified, an empirical confidence interval is computed through a K-fold cross validation and a linear regression model, the forecast outcome will include a confidence interval there; otherwise no confidence interval is included in the final forecast. Please refer to the 'emp_confidence_int' module for full detailed implementation of the empirical confidence interval computation
Args:
steps:
include_history:
Returns:
A pd.DataFrame with the forecast and confidence interval (if empirical confidence interval calculation is triggered)
"""
logging.debug(
"Call predict() with parameters. "
"steps:{steps}, kwargs:{kwargs}".format(steps=steps, kwargs=kwargs)
)
if "freq" not in kwargs:
# pyre-fixme[16]: `HoltWintersModel` has no attribute `freq`.
# pyre-fixme[16]: `HoltWintersModel` has no attribute `data`.
self.freq = pd.infer_freq(self.data.time)
else:
self.freq = kwargs["freq"]
last_date = self.data.time.max()
dates = pd.date_range(start=last_date, periods=steps + 1, freq=self.freq)
# pyre-fixme[16]: `HoltWintersModel` has no attribute `dates`.
self.dates = dates[dates != last_date] # Return correct number of periods
# pyre-fixme[16]: `HoltWintersModel` has no attribute `include_history`.
self.include_history = include_history
if "alpha" in kwargs:
# pyre-fixme[16]: `HoltWintersModel` has no attribute `alpha`.
self.alpha = kwargs["alpha"]
# build empirical CI
error_methods = kwargs.get("error_methods", ["mape"])
train_percentage = kwargs.get("train_percentage", 70)
test_percentage = kwargs.get("test_percentage", 10)
sliding_steps = kwargs.get("sliding_steps", len(self.data) // 5)
multi = kwargs.get("multi", True)
eci = EmpConfidenceInt(
error_methods=error_methods,
data=self.data,
params=self.params,
train_percentage=train_percentage,
test_percentage=test_percentage,
sliding_steps=sliding_steps,
model_class=HoltWintersModel,
confidence_level=1 - self.alpha,
multi=False,
)
logging.debug(
f"""Use EmpConfidenceInt for CI with parameters: error_methods = {error_methods}, train_percentage = {train_percentage},
test_percentage = {test_percentage}, sliding_steps = {sliding_steps}, confidence_level = {1-self.alpha}, multi={multi}."""
)
fcst = eci.get_eci(steps=steps)
# pyre-fixme[16]: `HoltWintersModel` has no attribute `y_fcst`.
self.y_fcst = fcst["fcst"]
else:
# pyre-fixme[16]: `HoltWintersModel` has no attribute `model`.
fcst = self.model.forecast(steps)
self.y_fcst = fcst
fcst = pd.DataFrame({"time": self.dates, "fcst": fcst})
logging.info("Generated forecast data from Holt-Winters model.")
logging.debug("Forecast data: {fcst}".format(fcst=fcst))
if include_history:
history_fcst = self.model.predict(start=0, end=len(self.data.time))
# pyre-fixme[16]: `HoltWintersModel` has no attribute `fcst_df`.
self.fcst_df = pd.concat(
[
pd.DataFrame(
{
"time": self.data.time,
"fcst": history_fcst,
}
),
fcst,
]
)
else:
self.fcst_df = fcst
logging.debug("Return forecast data: {fcst_df}".format(fcst_df=self.fcst_df))
return self.fcst_df
[docs] def plot(self):
"""Plot forecast results from the HoltWinters model
"""
logging.info("Generating chart for forecast result from arima model.")
m.Model.plot(self.data, self.fcst_df, include_history=self.include_history)
def __str__(self):
return "HoltWinters"
[docs] @staticmethod
def get_parameter_search_space() -> List[Dict[str, Any]]:
"""Get default HoltWinters parameter search space.
Args:
None
Returns:
A dictionary with the default HoltWinters parameter search space
"""
return get_default_holtwinters_parameter_search_space()