Sarsa
This code uses the CMAC code discussed previously.
#!/usr/bin/python """ Author: Jeremy M. Stober Program: SARSA.PY Date: Thursday, April 10 2008 Description: TD(lambda) with CMAC function approximation. Includes a mountain car demonstration. """ import os, sys, getopt, pdb import random as prandom from numpy import * from tiles import CMAC def flip(p): if prandom.random() < p: return True else: return False class Sarsa(object): def __init__(self, nactions, nlevels, quantize, alpha, gamma, ld, epsilon): self.nactions = nactions self.gamma = gamma # discount self.epsilon = epsilon # random policy self.ld = ld # lambda # Keep the last few transitions. self.trace = [] self.threshold = 1e-2 # # Your favorite function approximation goes here. self.cmacs = [CMAC(nlevels,quantize,alpha) for action in range(nactions)] def argmax(self, statevec): values = array([cmac.eval(statevec) for cmac in self.cmacs]) maxindex = argmax(values) maxvalue = max(values) if len(values == maxvalue) > 1: # return a random choice among max value indices return prandom.choice(arange(len(values))[values == maxvalue]) else: return maxindex def softmax(self, statevec): values = array([cmac.eval(statevec) for cmac in self.cmacs]) # Calculate the softmax probabilities... normalizer = sum(exp(values / self.epsilon)) probs = exp(values / self.epsilon) / normalizer bins = add.accumulate(probs) return searchsorted(bins,prandom.random()) def greedy(self, statevec): amax = self.argmax(statevec) if (flip(self.epsilon)): return prandom.choice(range(self.nactions)) else: return amax def reset(self): self.trace = [] def save(self,filep): import pickle pickle.dump(self,filep,pickle.HIGHEST_PROTOCOL) def train(self, statevec, reward): caction = self.greedy(statevec) # or softmax qval = self.cmacs[caction].eval(statevec) if len(self.trace) > 0: # the previous state and action are already in the trace (pstate,paction,pweight) = self.trace[0] delta = reward + (self.gamma * qval) - self.cmacs[paction].eval(pstate, quantized = True) # iterate through the trace and update the cmac using attributed temporal difference for (indx,(coords,action,weight)) in enumerate(self.trace): self.cmacs[action].difference(coords, delta * weight, quantized = True) self.trace[indx][2] *= (self.gamma * self.ld) # remove trace values that have minimal difference error if self.trace[-1][2] < self.threshold: self.trace.pop() # add the new state and action to the top of the queue self.trace.insert(0,[self.cmacs[caction].quantize(statevec),caction,1.0]) return caction class MountainCar( object ): def __init__(self): self.position = -0.5 self.velocity = 0.0 self.maxv = 0.07 self.maxl = -1.5 self.goalr = 0.45 def action(self, force): # velocity + force + gravity self.velocity = self.velocity + (0.001 * force) + (-0.0025 * cos(3.0 * self.position)) self.veloctiy = self.velocity * 0.999 # friction if self.velocity > self.maxv: self.velocity = self.maxv elif self.velocity < -self.maxv: self.velocity = -self.maxv self.position = self.position + self.velocity if self.position < self.maxl: self.position = self.maxl self.velocity = 0.0 # hit the left wall def reward(self): if self.position > self.goalr: return 100 else: return -1 def reset(self): self.position = -0.5 self.velocity = 0.0 def main(): mc = MountainCar() sarsa = Sarsa(3, 9, 0.01, 0.1, 1.0, 0.9, 0.0) for i in range(100): counter = 0 while (mc.reward() < 0): action = sarsa.train(array([mc.position,mc.velocity]), mc.reward()) mc.action(action - 1) counter += 1 print i, counter mc.reset() sarsa.reset() if __name__ == "__main__": main()
Entries (RSS)
[...] finally got around to putting up my code for a sarsa agent. A couple of things to [...]