西湖大学赵世钰老师的强化学习课程代码复现。
复现1-8章算法,第9章算法无法收敛,止步于此
运行环境:
arguments5x5.py: 环境参数 (参考官方代码)
grid_world.py: 算法,画图代码(参考官方代码)
network.py: 基于numpy手写的神经网络框架(用于第八章 DQN 以及第九章中)
复现代码 :
main.ipynb
核心库 numpy 和 matplotlib,其他库可见每一章 grid_world.py
env = GridWorld() # 生成环境
env.grid_plot() # 画格子
env.plot_max_policy(policy_matrix) # 在格子上面画出策略的箭头
文件夹 Chapter8 中没有 DQN 算法,单独放在 Chapter8_DQN 中。
from grid_world import GridWorld
import numpy as np
import matplotlib.pyplot as plt
env = GridWorld() # 生成环境
# 运行策略
policy_matrix,episodes_len,total_rewards,errors = \
env.TD7_2_sarsa(epsilon=0.1,isExpectedSarsa=False,start_state=0, iterations=2000, gamma=0.9, Alpha=0.01)
# 画图
fig = plt.figure(num=1,figsize=(10,5)) # 画布 1
axs = fig.subplots(1,2) # 创建画布和轴
env.grid_plot(fig=fig,axs=axs[0])
env.plot_max_policy(policy_matrix)
axs[0].set_title('max_policy',y=1.1) # y 参数设置高度
line=axs[1].plot(errors, label='err')
axs[1].legend()
axs[1].set_xlabel('Iteration')
axs[1].set_ylabel('Convergence Error')
axs[1].set_title('Convergence Error')
axs[1].grid(True)
fig.tight_layout()
fig = plt.figure(num=2,figsize=(8,5)) # 画布 1
plt.subplot(2,1,1)
plt.plot(total_rewards, label='reward')
plt.ylabel('total rewards')
plt.grid(True)
plt.subplot(2,1,2)
plt.plot(episodes_len, label='len')
plt.legend()
plt.ylabel('episode length')
plt.xlabel('Episode index')
plt.grid(True)
plt.tight_layout()
plt.show()
