Skip to content

Reproduction of Zhao Shiyu of Reinforcement Learning Code

Notifications You must be signed in to change notification settings

TOMjacksmith/RL_code_study

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 

Repository files navigation

RL CODE 🎉️

简介

西湖大学赵世钰老师的强化学习课程代码复现。

赵世钰老师课程b站地址 💌

官方仓库

工作

复现1-8章算法,第9章算法无法收敛,止步于此

代码结构

运行环境:

arguments5x5.py:    环境参数 (参考官方代码)

grid_world.py:  算法,画图代码(参考官方代码)

network.py:  基于numpy手写的神经网络框架(用于第八章 DQN 以及第九章中)

复现代码 :

main.ipynb

说明

安装python库

核心库 numpy 和 matplotlib,其他库可见每一章 grid_world.py

画图

env = GridWorld() # 生成环境

env.grid_plot() # 画格子

env.plot_max_policy(policy_matrix) # 在格子上面画出策略的箭头

其他

文件夹 Chapter8 中没有 DQN 算法,单独放在 Chapter8_DQN 中。

代码示例

from grid_world import GridWorld
import numpy as np
import matplotlib.pyplot as plt

env = GridWorld() # 生成环境

# 运行策略
policy_matrix,episodes_len,total_rewards,errors = \
    env.TD7_2_sarsa(epsilon=0.1,isExpectedSarsa=False,start_state=0, iterations=2000, gamma=0.9, Alpha=0.01)

# 画图
fig = plt.figure(num=1,figsize=(10,5))  # 画布 1
axs = fig.subplots(1,2)   # 创建画布和轴
env.grid_plot(fig=fig,axs=axs[0])
env.plot_max_policy(policy_matrix)
axs[0].set_title('max_policy',y=1.1) # y 参数设置高度
line=axs[1].plot(errors, label='err')
axs[1].legend()
axs[1].set_xlabel('Iteration')
axs[1].set_ylabel('Convergence Error')
axs[1].set_title('Convergence Error')
axs[1].grid(True)
fig.tight_layout()

fig = plt.figure(num=2,figsize=(8,5))  # 画布 1
plt.subplot(2,1,1)
plt.plot(total_rewards, label='reward')
plt.ylabel('total rewards')
plt.grid(True)
plt.subplot(2,1,2)
plt.plot(episodes_len, label='len')
plt.legend()
plt.ylabel('episode length')
plt.xlabel('Episode index')
plt.grid(True)
plt.tight_layout()
plt.show()

运行结果: alt text

About

Reproduction of Zhao Shiyu of Reinforcement Learning Code

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published