Source code for kats.utils.decomposition

#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from kats.consts import TimeSeriesData
from statsmodels.tsa.seasonal import STL, seasonal_decompose


[docs]class TimeSeriesDecomposition: """Model class for Time Series Decomposition. This class provides utilities to decompose an input time series Attributes: data: the input time series data as `TimeSeriesData` decomposition: `additive` or `multiplicative` decomposition method: `STL decompostion` or `seasonal_decompose` Specific arguments to seasonal_decompose and STL functions can be passed via kwargs """ def __init__( self, data: TimeSeriesData, decomposition="additive", method="STL", **kwargs ) -> None: self.data = data if not isinstance(self.data.value, pd.Series): msg = "Only support univariate time series, but get {type}.".format( type=type(self.data.value) ) logging.error(msg) raise ValueError(msg) if decomposition in ("additive", "multiplicative"): self.decomposition = decomposition else: logging.info("Invalid decomposition setting specified") logging.info("Defaulting to Additive Decomposition") self.decomposition = "additive" if method in ("STL", "seasonal_decompose"): self.method = method else: logging.info("Invalid decomposition setting specified") logging.info("Possible Values: STL, seasonal_decompose") logging.info("Defaulting to STL") self.method = "STL" ## The following are params for the STL Module self.period = kwargs.get("period", None) self.seasonal = kwargs.get("seasonal", 7) self.trend = kwargs.get("trend", None) self.low_pass = kwargs.get("low_pass", None) self.seasonal_deg = kwargs.get("seasonal_deg", 1) self.trend_deg = kwargs.get("trend_deg", 1) self.low_pass_deg = kwargs.get("low_pass_deg", 1) self.robust = kwargs.get("robust", False) self.seasonal_jump = kwargs.get("seasonal_jump", 1) self.trend_jump = kwargs.get("trend_jump", 1) self.low_pass_jump = kwargs.get("low_pass_jump", 1) def __clean_ts(self): """Internal function to clean the time series. Internal function to interpolate time series and infer frequency of time series required for decomposition """ original = pd.DataFrame( list(self.data.value), index=self.data.time, columns=["y"] ) original.columns = ["y"] original.index = pd.to_datetime(original.index) if pd.infer_freq(original.index) is None: original = original.asfreq("D") logging.info("Setting frequency to Daily since it cannot be inferred") self.freq = pd.infer_freq(original.index) original = original.interpolate( method="polynomial", limit_direction="both", order=3 ) ## This is a hack since polynomial interpolation is not working here if sum((np.isnan(x) for x in original["y"])): original = original.interpolate(method="linear", limit_direction="both") return original def __decompose_seasonal(self, original): """Internal function to call seasonal_decompose to do the decomposition.""" if self.period is not None: result = seasonal_decompose( original, model=self.decomposition, period=self.period ) else: if "T" in self.freq: result = seasonal_decompose( original, model=self.decomposition, period=2 ) logging.warning( "Seasonal Decompose cannot handle sub day level granularity" ) logging.warning( "Please consider setting period yourself based on the input data" ) logging.warning("Defaulting to a period of 2") else: result = seasonal_decompose(original, model=self.decomposition) output = { "trend": result.trend, "seasonal": result.seasonal, "resid": result.resid, } return output def __decompose_STL(self, original): """Internal function to call STL to do the decomposition. The arguments to STL can be passed in the class via kwargs """ if "T" in self.freq and self.period is None: logging.warning("STL cannot handle sub day level granularity") logging.warning( "Please consider setting period yourself based on the input data" ) logging.warning("Defaulting to a period of 2") self.period = 2 if self.decomposition == "additive": result = STL( original, period=self.period, seasonal=self.seasonal, trend=self.trend, low_pass=self.low_pass, seasonal_deg=self.seasonal_deg, trend_deg=self.trend_deg, low_pass_deg=self.low_pass_deg, robust=self.robust, seasonal_jump=self.seasonal_jump, trend_jump=self.trend_jump, low_pass_jump=self.low_pass_jump, ).fit() output = { "trend": result.trend, "seasonal": result.seasonal, "resid": result.resid, } else: if np.any(original <= 0): logging.error( "Multiplicative seasonality is not appropriate " "for zero and negative values" ) original_transformed = np.log(original) result = STL( original_transformed, period=self.period, seasonal=self.seasonal, trend=self.trend, low_pass=self.low_pass, seasonal_deg=self.seasonal_deg, trend_deg=self.trend_deg, low_pass_deg=self.low_pass_deg, robust=self.robust, seasonal_jump=self.seasonal_jump, trend_jump=self.trend_jump, low_pass_jump=self.low_pass_jump, ).fit() output = { "trend": np.exp(result.trend), "seasonal": np.exp(result.seasonal), "resid": np.exp(result.resid), } return output def __decompose(self, original): if self.method == "STL": output = self.__decompose_STL(original) else: output = self.__decompose_seasonal(original) return { "trend": TimeSeriesData( output["trend"].reset_index(), time_col_name=self.data.time_col_name ), "seasonal": TimeSeriesData( output["seasonal"].reset_index(), time_col_name=self.data.time_col_name ), "rem": TimeSeriesData( output["resid"].reset_index(), time_col_name=self.data.time_col_name ), }
[docs] def decomposer(self): """Decompose the time series. Args: None. Returns: A dictionary with three time series for the three components: `trend` : Trend `seasonal` : Seasonality, and `rem` : Residual """ original = self.__clean_ts() self.results = self.__decompose(original) return self.results
[docs] def plot(self): """Plot the original time series and the three decomposed components.""" fig, ax = plt.subplots(nrows=4, ncols=1, figsize=(20, 10), sharex=True) ax[0].plot( self.data.time.values, self.data.value.values, linewidth=3, ) ax[0].set_title("Original Time Series") ax[1].plot( self.results["trend"].time.values, self.results["trend"].value.values, linewidth=3, ) ax[1].set_title("Trend") ax[2].plot( self.results["seasonal"].time.values, self.results["seasonal"].value.values, linewidth=3, ) ax[2].set_title("Seasonality") ax[3].plot( self.results["rem"].time.values, self.results["rem"].value.values, linewidth=3, ) ax[3].set_title("Residual") ax[3].set_xlabel("Time") plt.subplots_adjust(hspace=0.2) return fig, ax