DEV Community

loading...
Cover image for Reinforcement Learning with TF2 and Gym (part I)

Reinforcement Learning with TF2 and Gym (part I)

jemaloQiu
Updated on ・3 min read

Today I start my exercise on Reinforcement Learning using Tensorflow 2.0. Same as many guys in this domain do, I use the famous Cartpole game (see the figure below) of OpenAI Gym for implementing and testing my RL algorithm. One can check the Official introduction of Cartpole on this page.

Alt Text

This post will show a raw framework of my implementation. I will present the main loop of the training process and the basic concepts of TF2, RL and Cartpole game. Details of the algorithmic work will be presented in next post.

package preparation

Make sure that tensorflow (>=2.0), numpy and gym have been installed.

Now import the important modules.

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"

import tensorflow as tf

if tf.__version__.startswith("1."):     
    raise RuntimeError("Error!! You are using tensorflow-v1")

import numpy as np
import gym
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
import tensorflow.keras.backend as K

Enter fullscreen mode Exit fullscreen mode

Policy Gradient network

I will use the Policy Gradient method for solving this RL problem. Thus I define firstly a class for this method. For the moment, this class is almost empty. I will fill it with all required functionalities later.

This class consists of a Sequential neural network. Input of this network should be vector of state observation and output a unique action index. It should provide some important functions for calculating gradients, updating model gradients, calculating rewards and save/reload models.

# a class of Policy Gradient Neural Network
class PolicyGradientNet:

    def __init__(self):

        self.model = None 
        self.gradients = []
        self.optimizer = None

    def create_model(self):     
        self.model = Sequential()
        ## TODO: fulfill the neural netwrok model 

        self.optimizer = keras.optimizers.Adam(self.model.variables, lr=0.01)
        pass

    def calc_grad(self):

        pass
    def update_gradient(self):
        ### apply new gradient
        pass

    def get_gradients(self):
        return  self.gradients

    def get_variable(self):
        return self.model.variables

    def calc_rewards(self):
        pass

    def save_model(self):
        pass

    def load_model(self):
        pass
Enter fullscreen mode Exit fullscreen mode

Cartpole preparation

This part of code will call gym environment and prints its parameters for getting info.


ENV_SEED = 1024  ## Reproducibility of the game
NP_SEED = 1024  ## Reproducibility of numpy random

env = gym.make('CartPole-v0')
env.seed(ENV_SEED)  
np.random.seed(NP_SEED)


### The Discrete space allows a fixed range of non-negative numbers, so in this case valid actions are either 0 or 1. 
### The Box space represents an n-dimensional box, so valid observations will be an array of 4 numbers. 
print(env.action_space)
print(env.observation_space)
### We can also check the Box’s bounds:
print(env.observation_space.high)
print(env.observation_space.low)

Enter fullscreen mode Exit fullscreen mode

Main loop

This is the main loop for training. Since the PolicyGradientNet class has not yet been finished, so here I leave TODO keywords for indicating the places where I need to implement RL algorithmic code.

In each episode, there are multiple interaction loops between env and agent. Given current state, my agent calculates the action (0 or 1) that is applied on the cart, then the environment will accordingly update thus a new state is observed, so my agent will calculate again the action to apply, and so on...

Below is a screenshot of the official presentation of the agent-env interaction:
Alt Text

Current code of this part:

# define firstly an instance of PolicyGradientNet 
agent = PolicyGradientNet()

update_step = 5   # number of episodes for updating the network's gradient
limit_train = 1000  # training episode limit for stopping
i = 0  # episode counter
max_step = 0  # record the maximum steps in all episodes

while i < limit_train:  
    step  = 0
    state_current = env.reset()  ## reset the game so it will play from its initial state

    while True:

        env.render()  ## refreshing of visual result
        step += 1

        ## calculate an action 
        ## TODO: use agent.model to calculate the optimal action
        a = np.random.choice([0, 1], p=[0.5, 0.5])

        ## env.step() shall return: observation(object), reward(float), done(boolean),info(dict)
        ## check more info at https://gym.openai.com/docs/
        state_obs, reward, done, info = env.step(a) 
        state_current = state_obs  ## update current state vector
        x, x_prime, theta, theta_prime = state_obs  ## state_obs  consists of : current x position, velocity, current pole orientation (about vertical axis) and current angular velocity 

        if done:  # done being True indicates the episode has terminated. 
            ## TODO: launch agent training
            if max_step < step:
                max_step = step
            print("Step: ", step)

            ## TODO: update gradients
            break
    if i % 100 == 0:
        print("Max step is {} until episode {}  ".format(max_step, i))

    i += 1
Enter fullscreen mode Exit fullscreen mode

Now the simulation is something like this:

Full program will be presented in next post.

Discussion (4)

Collapse
otumianempire profile image
Otu Michael
if tf.__version__.startswith("1."):    
    print("Error!! You are using tensorflow-v1")
Enter fullscreen mode Exit fullscreen mode

At this point the code will still run. Does the version of tensorflow affect the functionality of the app?

Collapse
jemaloqiu profile image
jemaloQiu Author

yes, I check only if the version requirement has been satisfied. If you are using tensorflow 1.x, this program will report errors and halt. In this case, this message can give a hint for debugging.

Collapse
otumianempire profile image
Otu Michael

but this doesn't halt.. throw an error??

Thread Thread
jemaloqiu profile image
jemaloQiu Author

Agree, I have updated my code.

Forem Open with the Forem app