kats.models.reconciliation.thm module¶
This module contains the class TemporalHierarchicalModel class.
- class kats.models.reconciliation.thm.TemporalHierarchicalModel(data: kats.consts.TimeSeriesData, baseModels: List[kats.models.reconciliation.base_models.BaseTHModel])[source]¶
Bases:
object
Temporal hierarchical model class.
This framework combines the base models of different temporal aggregation levels to generate reconciled forecasts. This class provides fit, get_S, get_W, predict and median_validation.
- data¶
A TimeSeriesData object storing the time series data for level 1 (i.e., the most disaggregate level).
- baseModels¶
A list BaseTHModel objects representing the base models for different levels.
- fit() → None[source]¶
Fit all base models.
If base model only has residuals and forecasts, store the information.
- get_S() → numpy.ndarray[source]¶
Calculate S matrix.
- Returns
A np.array representing the S matrix.
- get_W(method: str = 'struc', eps: float = 1e-05) → numpy.ndarray[source]¶
Calculate W matrix.
- Parameters
method – Reconciliation method for temporal hierarchical model. Valid methods include ‘struc’, ‘svar’, ‘hvar’, ‘mint_sample’, and ‘mint_shrink’.
eps – Epsilons added to W for numerical stability.
- Returns
W matrix. (If W is a diagnoal matrix, only returns its diagnoal elements).
- median_validation(steps, dist_metric: str = 'mae', threshold: float = 5.0) → List[int][source]¶
Filtering out bad fcsts based on median forecasts.
This function detects the levels whose forecasts are greatly deviate from median forecasts, which is a strong indication of bad forecasts.
- Parameters
steps – The number of forecasts needed for level 1 for validation.
dist_metric – The distance metric used to measure the distance between the base forecasts and the median forecasts.
threshold – The threshold for deviance. The forecast whose distance from the median forecast is greater than threshold*std is taken as bad forecasts. Default is 3.
- Returns
A list of integers representing the levels whose forecasts are bad.
- predict(steps: int, method='struc', freq: Optional[str] = None, origin_fcst: bool = False, fcst_levels: Optional[List[int]] = None, last_timestamp: Optional[pandas._libs.tslibs.timestamps.Timestamp] = None) → Dict[str, Dict[int, pandas.core.frame.DataFrame]][source]¶
Generate reconciled forecasts (with time index).
- Parameters
steps – The number of forecasts needed for level 1.
methd – The name of the reconciliation method. Can be ‘bu’ (bottom-up), ‘median’, ‘struc’ (structure-variance), ‘svar’, ‘hvar’, ‘mint_shrink’ or ‘mint_sample’.
freq – The frequency of the time series at level 1. If None, then we infer the frequency via ts.infer_freq_robust().
origin_fcst – Whether or not to return the forecasts of base models.
fcst_levels – The levels to generate forecasts for. Default is None, which generates forecasts for all the levels of the base models.
- Returns
A dictionary of forecasts, whose key is the level and the corresponding value is a np.array storing the forecasts.