Source code for kats.models.ensemble.ensemble

#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Ensemble techniques for forecasting

This implements a set of ensemble techniques including weighted averaging, median ensemble
and STL-based ensembling method. This is the parent class for all ensemble models.
"""

from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from multiprocessing import Pool, cpu_count
from typing import List

import pandas as pd
from kats.consts import Params, TimeSeriesData
from kats.models import (
    model,
    arima,
    holtwinters,
    linear_model,
    prophet,
    quadratic_model,
    sarima,
)


BASE_MODELS = {
    "arima": arima.ARIMAModel,
    "holtwinters": holtwinters.HoltWintersModel,
    "sarima": sarima.SARIMAModel,
    "prophet": prophet.ProphetModel,
    "linear": linear_model.LinearModel,
    "quadratic": quadratic_model.QuadraticModel,
}


[docs]class BaseModelParams: """Ensemble parameter class This class contains three attributes: Attributes: model: model names (str) model_params: model_param is defined in base models """ def __init__(self, model_name: str, model_params: model.Model, **kwargs) -> None: self.model_name = model_name self.model_params = model_params logging.debug( "Initialized Base Model parameters: " "Model name:{model_name}," "model_params:{model_params}".format( model_name=model_name, model_params=model_params ) ) def validate_params(self): logging.info("Method validate_params() is not implemented.") pass
class EnsembleParams: __slots__ = ["models"] def __init__(self, models: List[BaseModelParams]) -> None: self.models = models
[docs]class BaseEnsemble: """Base ensemble class Implement parent class for ensemble. """ def __init__(self, data: TimeSeriesData, params: EnsembleParams) -> None: self.data = data self.params = params for m in params.models: if m.model_name not in BASE_MODELS: msg = "Model {model_name} is not supported.\ Only support{models}.".format( model_name=m.model_name, models=BASE_MODELS.keys() ) logging.error(msg) raise ValueError(msg) if not isinstance(self.data.value, pd.Series): msg = "Only support univariate time series, but get {type}.".format( type=type(self.data.value) ) logging.error(msg) raise ValueError(msg)
[docs] def fit(self,): """Fit method for ensemble model This method fits each individual model for ensembling and create a dict of model and fitted obj, such as {'m1': fitted_m1_obj, 'm2': fitted_m2_obj} """ num_process = min(len(BASE_MODELS.keys()), (cpu_count() - 1) // 2) if num_process < 1: num_process = 1 pool = Pool(processes=(num_process), maxtasksperchild=1000) fitted_models = {} for m in self.params.models: fitted_models[m.model_name] = pool.apply_async( self._fit_single, args=(BASE_MODELS[m.model_name.lower()], m.model_params), ) pool.close() pool.join() self.fitted = {model: res.get() for model, res in fitted_models.items()}
def _fit_single(self, model_func: model.Model, model_param: Params): """Private method to fit individual model Args: model_func: the callable model function model_param: the Kats model parameter class Returns: None """ # get the model function call # pyre-fixme[29]: `Model` is not a function. m = model_func(params=model_param, data=self.data) m.fit() return m def _predict_all(self, steps: int, **kwargs): """Private method to fit all individual models Args: steps: the length of forecasting horizon Returns: None """ predicted = {} # pyre-fixme[16]: `BaseEnsemble` has no attribute `fitted`. for model_name, model_fitted in self.fitted.items(): predicted[model_name] = model_fitted.predict(steps, **kwargs) return predicted
[docs] def plot(self): """Plot method for ensemble model (not implemented yet) """ pass
def __str__(self): """Get the class name as a string Args: None Returns: Model name as a string """ return "Ensemble"