minor change in comments notebook

This commit is contained in:
Tristan de Waard 2025-06-13 11:15:18 +02:00
parent 6629a163c6
commit de6ddfa29d

View File

@ -957,7 +957,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "7f54afc1",
"metadata": {},
"outputs": [],
@ -998,6 +998,7 @@
"\n",
" return np.array(data), np.array(labels).reshape(-1, 1)\n",
"\n",
"\n",
"# Reshape to have all features and timesteps ina single row\n",
"def reshaping(X):\n",
" reshaped_x = X.reshape(-1, X.shape[1] * X.shape[2])\n",
@ -1209,7 +1210,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"id": "cc4975ce",
"metadata": {},
"outputs": [
@ -1245,7 +1246,6 @@
}
],
"source": [
"\n",
"df_energy.index.freq = \"h\"\n",
"y_train = df_energy[\"price actual\"].iloc[:train_cutoff]\n",
"y_test = df_energy[\"price actual\"].iloc[train_cutoff:]\n",
@ -1387,21 +1387,39 @@
],
"source": [
"import warnings\n",
"\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"\n",
"\n",
"# history size (number of hours)\n",
"\n",
"\n",
"hist_size = 24\n",
"\n",
"\n",
"\n",
"# Generate training & test datasets\n",
"X_train, y_train, X_test, y_test, train_cutoff, scaler_y, scaler_X = get_train_test(hist_size=hist_size)\n",
"\n",
"\n",
"X_train, y_train, X_test, y_test, train_cutoff, scaler_y, scaler_X = get_train_test(\n",
" hist_size=hist_size\n",
")\n",
"\n",
"\n",
"\n",
"X_train_arima = reshaping(X_train)\n",
"\n",
"\n",
"X_test_arima = reshaping(X_test)\n",
"\n",
"\n",
"\n",
"# Fit AutoARIMA model (with exogenous variables)\n",
"autoarima_model = auto_arima(\n",
" y_train,\n",
"\n",
" # X=X_train_arima, # with extra features this takes a loooong time\n",
" start_p=0,\n",
" start_q=0,\n",
@ -1419,43 +1437,82 @@
" D=None,\n",
" trace=True,\n",
" error_action=\"ignore\",\n",
"\n",
" suppress_warnings=True,\n",
" stepwise=True,\n",
" n_jobs=-1,\n",
" method=\"nm\",\n",
" maxiter=100\n",
" maxiter=100,\n",
")\n",
"\n",
"\n",
"\n",
"# TODO : Try out different values for m (seasonality) -- can influence comp. cost\n",
"\n",
"\n",
"# TODO : Perform feature engineering, e.g. rolling means\n",
"\n",
"\n",
"# TODO : Try out different history sizes\n",
"\n",
"\n",
"\n",
"# Forecast\n",
"y_pred = autoarima_model.predict(n_periods=len(y_test))\n",
"\n",
"\n",
"\n",
"# Inverse transform the predictions and actual values\n",
"\n",
"\n",
"y_pred_actual = y_pred.reshape(-1, 1)\n",
"\n",
"\n",
"y_test_inv = y_test.reshape(-1, 1)\n",
"\n",
"\n",
"\n",
"# Calculate MAE\n",
"\n",
"\n",
"mae = mean_absolute_error(y_pred_actual, y_test_inv)\n",
"\n",
"\n",
"\n",
"# Plot the results\n",
"\n",
"\n",
"fig, axes = plt.subplots(figsize=(14, 6))\n",
"\n",
"\n",
"axes.plot(\n",
" df_energy[\"price actual\"].iloc[train_cutoff:].index[: len(y_test_inv)],\n",
" y_test_inv,\n",
" \"k.-\",\n",
" label=\"Actual\",\n",
")\n",
"\n",
"\n",
"\n",
"axes.plot(\n",
" df_energy[\"price actual\"].iloc[train_cutoff:].index[: len(y_pred_actual)],\n",
" y_pred_actual,\n",
"\n",
" \"b.-\",\n",
" label=\"Predicted\",\n",
")\n",
"\n",
"\n",
"\n",
"axes.set(xlabel=\"Date\", ylabel=\"Price actual\")\n",
"\n",
"\n",
"axes.set_title(f\"SARIMAX Predictions vs Actual Prices (MAE={mae:.3f})\")\n",
"\n",
"\n",
"axes.legend()\n",
"\n",
"\n",
"plt.show()"
]
},
@ -1565,9 +1622,12 @@
" verbose=False,\n",
")\n",
"\n",
"# perform predictions\n",
"y_pred = xgb_model.predict(X_test_xgb)\n",
"y_pred_actual = scaler_y.inverse_transform(y_pred.reshape(-1, 1))\n",
"y_test_inv = scaler_y.inverse_transform(y_test)\n",
"\n",
"# Calculate MAE\n",
"mae = mean_absolute_error(y_pred_actual, y_test_inv)\n",
"\n",
"\n",
@ -1662,7 +1722,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": null,
"id": "f531a91a",
"metadata": {},
"outputs": [
@ -38792,7 +38852,6 @@
}
],
"source": [
"\n",
"# Initialize LIME explainer\n",
"lime_explainer = lime.lime_tabular.LimeTabularExplainer(\n",
" training_data=np.array(X_train_xgb),\n",
@ -38810,7 +38869,9 @@
" predict_fn=xgb_model.predict,\n",
" num_features=len(X_train_xgb.columns),\n",
" )\n",
" display(HTML(explanation.as_html()))"
" display(HTML(explanation.as_html()))\n",
" # Or, you can save the explanation to an HTML file:\n",
" # explanation.save_to_file(f\"lime_explanation_{i}.html\")"
]
},
{
@ -39033,7 +39094,7 @@
"n_features = X_train.shape[2] # Number of features in the input\n",
"\n",
"\n",
"# Build LSTM model\n",
"# Build a simple LSTM model\n",
"model = Sequential()\n",
"model.add(Input(shape=(hist_size, n_features)))\n",
"model.add(LSTM(50, activation=\"relu\"))\n",
@ -39046,7 +39107,7 @@
"# fit the model\n",
"model.fit(X_train, y_train, epochs=50)\n",
"\n",
"\n",
"# Predict on the test set\n",
"y_pred = model.predict(X_test)\n",
"\n",
"\n",
@ -39141,7 +39202,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": null,
"id": "7a3da2a1",
"metadata": {},
"outputs": [
@ -40179,7 +40240,7 @@
"source": [
"from ShapTime import ShapleyValues, TimeImportance, TimeHeatmap\n",
"\n",
"\n",
"# calculate ShapTime values -- beware this can take a long time\n",
"Tn = 10 # number of time steps\n",
"shap_values = ShapleyValues(model, X_train, Tn)\n",
"time_columns = [f\"Time Step {i+1}\" for i in range(Tn)]\n",
@ -40188,7 +40249,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": null,
"id": "a1c887c3",
"metadata": {},
"outputs": [
@ -40215,8 +40276,6 @@
],
"source": [
"TimeImportance(Tn, shap_values, time_columns)\n",
"# ax = TimeHeatmap(Tn, shap_values, time_columns)\n",
"\n",
"\n",
"# Normalize SHAP values for color mapping\n",
"norm = plt.Normalize(shap_values.min(), shap_values.max())\n",
@ -40435,7 +40494,9 @@
"# Prepare training data\n",
"df_train = df_energy.iloc[:train_cutoff].copy()\n",
"df_train = df_train.rename(columns={\"price actual\": \"y\"}) # Prophet requires 'y'\n",
"df_train[\"ds\"] = df_train.index.to_series().dt.tz_localize(None) # Prophet requires 'ds' without timezone\n",
"df_train[\"ds\"] = df_train.index.to_series().dt.tz_localize(\n",
" None\n",
") # Prophet requires 'ds' without timezone\n",
"\n",
"# Initialize and fit the Prophet model\n",
"prophet_model = prophet.Prophet()\n",
@ -40445,7 +40506,9 @@
"\n",
"# Create future dataframe for prediction\n",
"df_future = df_energy.iloc[train_cutoff:].copy()\n",
"df_future[\"ds\"] = df_future.index.to_series().dt.tz_localize(None) # Prophet requires 'ds' without timezone\n",
"df_future[\"ds\"] = df_future.index.to_series().dt.tz_localize(\n",
" None\n",
") # Prophet requires 'ds' without timezone\n",
"\n",
"# Make forecasts\n",
"df_forecast = prophet_model.predict(df_future)\n",
@ -40490,7 +40553,7 @@
"axes[1].legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n"
"plt.show()"
]
}
],