-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
28 lines (21 loc) · 930 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
from classification_module import ClassificationLightningModule
import hydra
from omegaconf import DictConfig
import pytorch_lightning as pl
from pytorch_lightning.logging import TensorBoardLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from argparse import Namespace
import yaml
import time
@hydra.main(config_path="config/config.yaml")
def test_app(cfg):
print(cfg.pretty())
hparams = Namespace(**cfg)
#model = ClassificationLightningModule.load_from_checkpoint(os.environ["DATASET_FOLDER"]+hparams.model_dir+"/_ckpt_epoch_103.ckpt")
model = ClassificationLightningModule.load_from_checkpoint(os.environ["DATASET_FOLDER"]+"ntu_reindexing_testreindexing/cross_subject/models_high/efficient/augment_translate/transform_300/_ckpt_epoch_103.ckpt")
trainer = Trainer(gpus=-1)
trainer.test(model)
if __name__ == "__main__":
test_app()