This repository has been archived by the owner on Aug 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_darts.py
74 lines (56 loc) · 2.59 KB
/
test_darts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import logging
from functools import lru_cache
import numpy as np
import torch
logging.getLogger().setLevel(logging.WARN)
from darts import TimeSeries
from darts.datasets import AirPassengersDataset
from darts.models import TransformerModel
from jackdaw_ml import loads, saves
from jackdaw_ml.artefact_decorator import artefacts
from jackdaw_ml.detectors.torch import TorchDetector, TorchSeqDetector
# As the class already exists, we can just add artefacts to it and rename it
TransformerModelWithArtefacts = artefacts(
{}, child_detectors=[TorchSeqDetector, TorchDetector]
)(TransformerModel)
@lru_cache(maxsize=1)
def example_data() -> TimeSeries:
return AirPassengersDataset().load()
@lru_cache(maxsize=1)
def similar_example_data() -> TimeSeries:
return AirPassengersDataset().load()[0:64]
def np_float_equivalence(a: np.ndarray, b: np.ndarray) -> bool:
# Are `a` and `b` within a reasonable distance of each other, accounting for internal machine error?
diff = np.sum(a - b)
return diff <= np.finfo(np.float32).eps
def assert_model_equivalence(
model: torch.nn.Module, model_two: torch.nn.Module
) -> bool:
for (key, value) in model.state_dict().items():
if key not in model_two.state_dict():
return False
if not torch.equal(value, model_two.state_dict()[key]):
return False
return True
def test_darts_torch_equivalence():
m1 = TransformerModelWithArtefacts(32, 32)
m2 = TransformerModelWithArtefacts(32, 32)
# DARTS' Transformer is a complex case for Jackdaw - while Jackdaw can find the correct slots to retrieve the
# model from, DARTS only generates the model that contained those slots when data is provided, as the Transformer
# model requires the data to exist to know the sizing for parameters.
# It may be that there's an easier way to initialise models on load without saving them via Pickle, but my
# experience with the library ends here.
m1.fit(example_data(), epochs=1)
model_id = saves(m1)
# Due to the dynamic load, we can 'trick' DARTS into loading a model with similar starting values, then override
# that model with our own parameters.
m2.fit(similar_example_data(), epochs=1)
# Prove that our unloaded/tricked model isn't the same as our saved model (yet)
assert not assert_model_equivalence(m1.model, m2.model)
# Load in the pre-existing model's artefacts
loads(m2, model_id)
assert assert_model_equivalence(m1.model, m2.model)
if __name__ == "__main__":
m1 = TransformerModelWithArtefacts(32, 32)
m1.fit(example_data(), epochs=1)
model_id = saves(m1)