-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestingModelwith1Query.py
46 lines (34 loc) · 1.22 KB
/
TestingModelwith1Query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 22 14:41:05 2024
@author: sohail
"""
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from gymnasium.wrappers import FlattenObservation
from tentrisEnv import TentrisEnv
# Path to the SQLite database
db_path = "/home/sohail/CLionProjects/tentris/query_data.db"
# Load the trained PPO model
model = PPO.load("ppo_tentris")
# Initialize the environment
env = DummyVecEnv([lambda: FlattenObservation(TentrisEnv(db_path, reset_on_init=False))])
# Reset the environment to start a new episode (i.e., load a new query)
obs = env.reset()
# Initialize variables to track the episode
done = False
total_reward = 0
# Loop until the query plan is fully generated
while not done:
# Let the model select an action based on the current observation
action, _states = model.predict(obs)
# Step the environment using the selected action
obs, reward, done, _ = env.step(action)
# Accumulate the total reward
total_reward += reward
# Print the step details
#print(f"Action: {action}, Reward: {reward}, Done: {done}")
# Episode is finished, print the total reward
print(f"Total Reward: {total_reward}")