-
Notifications
You must be signed in to change notification settings - Fork 177
/
value_iteration.py
118 lines (103 loc) · 4.43 KB
/
value_iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
#MIT License
#Copyright (c) 2017 Massimiliano Patacchiola
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE.
#Implementation of the Value Iteration algorithm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
def return_state_utility(v, T, u, reward, gamma):
"""Return the utility of a single state.
This is an implementation of the Bellman equation.
"""
action_array = np.zeros(4)
for action in range(0, 4):
action_array[action] = np.sum(np.multiply(u, np.dot(v, T[:,:,action])))
return reward + gamma * np.max(action_array)
def generate_graph(utility_list):
"""Given a list of utility arrays (one for each iteration)
it generates a matplotlib graph and save it as 'output.jpg'
"""
name_list = ('(1,3)', '(2,3)', '(3,3)', '+1', '(1,2)', '#', '(3,2)', '-1', '(1,1)', '(2,1)', '(3,1)', '(4,1)')
color_list = ('cyan', 'teal', 'blue', 'green', 'magenta', 'black', 'yellow', 'red', 'brown', 'pink', 'gray', 'sienna')
counter = 0
index_vector = np.arange(len(utility_list))
for state in range(12):
state_list = list()
for utility_array in utility_list:
state_list.append(utility_array[state])
plt.plot(index_vector, state_list, color=color_list[state], label=name_list[state])
counter += 1
#Adjust the legend and the axis
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 0.4), ncol=3, fancybox=True, shadow=True)
plt.ylim((-1.1, +1.1))
plt.xlim((1, len(utility_list)-1))
plt.ylabel('Utility', fontsize=15)
plt.xlabel('Iterations', fontsize=15)
plt.savefig("./output.jpg", dpi=500)
def main():
#Change as you want
tot_states = 12
gamma = 0.999 #Discount factor
iteration = 0 #Iteration counter
epsilon = 0.01 #Stopping criteria small value
#List containing the data for each iteation
graph_list = list()
#Transition matrix loaded from file (It is too big to write here)
T = np.load("T.npy")
#Reward vector
r = np.array([-0.04, -0.04, -0.04, +1.0,
-0.04, 0.0, -0.04, -1.0,
-0.04, -0.04, -0.04, -0.04])
#Utility vectors
u = np.array([0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0])
u1 = np.array([0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0])
while True:
delta = 0
u = u1.copy()
iteration += 1
graph_list.append(u)
for s in range(tot_states):
reward = r[s]
v = np.zeros((1,tot_states))
v[0,s] = 1.0
u1[s] = return_state_utility(v, T, u, reward, gamma)
delta = max(delta, np.abs(u1[s] - u[s]))
#Stopping criteria
if delta < epsilon * (1 - gamma) / gamma:
print("=================== FINAL RESULT ==================")
print("Iterations: " + str(iteration))
print("Delta: " + str(delta))
print("Gamma: " + str(gamma))
print("Epsilon: " + str(epsilon))
print("===================================================")
print(u[0:4])
print(u[4:8])
print(u[8:12])
print("===================================================")
break
generate_graph(graph_list)
if __name__ == "__main__":
main()