forked from dbt-labs/coalesce-2022-python-databricks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathforecast_train_py.py
42 lines (29 loc) · 940 Bytes
/
forecast_train_py.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pyspark.pandas as ps
from datetime import datetime
from prophet import Prophet
from prophet.serialize import model_to_json
def model(dbt, session):
# comment this out to enable the model
dbt.config(enabled=False)
# dbt configuration
dbt.config(materialized="incremental")
# use current time as index
trained_at = datetime.now()
# get upstream data
revenue = dbt.ref("revenue_weekly_by_location").pandas_api()
# rename to match prophet's expected column names
renames = {
"date_week": "ds",
"location_name": "location",
"revenue": "y",
}
revenue = revenue.rename(columns=renames)
# get list of unique locations dynamically
locations = sorted(list(revenue["location"].unique().to_numpy()))
# train the ML models per location
models = [
# TODO: fix this
]
# persist models
df = None # TODO: fix this
return df