【强化学习】SARAS代码实现

发布时间:2023年12月31日

前言

SARAS,假设环境状态和动作状态都是离散的。利用动作价值矩阵来进行行为的预测。其主要就是利用时序差分的思想,对动作价值矩阵进行更新。

代码实现

import gymnasium as gym
import numpy as np

class sarsa():
    def __init__(self, states_n, action_n, greedy_e=0.1):
        self.Q = np.zeros((states_n, action_n)) #动作价值矩阵
        self.greedy_e = greedy_e #随机探索的概率
        self.states_n = states_n #环境状态个数
        self.action_n = action_n #行动状态个数
        self.gamma=0.9 #价值衰减值
        self.lr=0.1 #学习率

    def predict(self, states):
        action_list=self.Q[states]#先拿出对应的行
        #再取出对应价值最大的行为,如果有重复则在重复项中随机选取,返回索引
        action=np.random.choice(np.flatnonzero(action_list==action_list.max()))
        return action
    def act(self, states):
        '''
            由对应环境产生对应的行动
            @param states: 当前环境
            @return: 行动动作
        '''
        if np.random.uniform() < self.greedy_e:#是否采取随即探索
            action = np.random.choice(np.arange(self.action_n))#随机探索
        else:
            action = self.predict(states) # 根据行动价值矩阵进行预测
        return action
    def learning(self,state,action,reward,next_state,next_action,does):
        '''
            学习更新参数
            @param state: 环境状态
            @param action: 采取的行动
            @param reward: 回报
            @param next_state: 采取行动后的下一个环境状态
            @param next_action: 下一个环境状态对应的行为
            @param does: 游戏是否结束
            @return:
        '''
        current_q=self.Q[state,action] #取出对应的行动价值
        if does: #查看是否已经完成游戏,完成则直接将当前回报作为下一个行动价值
            next_q=reward
        else:
            # 计算当前回报和下一个环境状态和下一个行动对应的价值,加和
            next_q=reward+self.gamma*self.Q[next_state,next_action]
        self.Q[state,action]+=self.lr*(next_q-current_q) #时序差分,更新行动价值矩阵


def train():
    env = gym.make("FrozenLake-v1", render_mode="human")#初始化游戏环境
    obs,info=env.reset()#重置位置
    agent=sarsa(env.observation_space.n,env.action_space.n)#初始化模型
    action = agent.act(obs)#预测行为
    num=0
    while True:
        num+=1
        # 由行为产生回报和下一个环境状态
        next_obs, reward, done, truncated, info = env.step(action)
        #预测下一个动作
        next_action=agent.act(obs)
        # 更新参数
        agent.learning(obs,action,reward,next_obs,next_action,done)
        obs=next_obs
        action=next_action
        # 判断游戏是否结束或者中断,是则重置游戏
        if done or truncated:
            obs, info = env.reset()
        if num % 100 == 0 :
            env.close()


if __name__ == '__main__':
    train()
文章来源:https://blog.csdn.net/sdksdf/article/details/135319571
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。