Reinforcement Learning Using - Tensorflow.js

06 Aug 2023  Amiya pattanaik  6 mins read.

Introduction

Reinforcement Learning (RL) is the science of decision making. It is about learning the optimal behavior in an environment to obtain maximum reward. This optimal behavior is learned through interactions with the environment and observations of how it responds, similar to children exploring the world around them and learning the actions that help them achieve a goal.

In RL, the data is accumulated from machine learning systems that use a trial-and-error method. Data is not part of the input that we would find in supervised or unsupervised machine learning.

RL uses algorithms that learn from outcomes and decide which action to take next. After each action, the algorithm receives feedback that helps it determine whether the choice it made was correct, neutral or incorrect. It is a good technique to use for automated systems that have to make a lot of small decisions without human guidance.

In this blog we’ll be using TensorFlow.js to implement a basic reinforcement learning algorithm and apply it to a simple Gridworld environment.

Prerequisites

  • Basic understanding of JavaScript and machine learning concepts.
  • Familiarity with TensorFlow.js library.

1. Setting Up the Environment:

First, we’ll need to include Tensorflow.js in our HTML file which is called as index.html. please note that the reinforcement-learning.js is referenced in the index.js file which contains the actual code of the reinforcement learning.

index.html

<!DOCTYPE html>
<html>
<head>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

  <style>
    /* Add any necessary styling here */
  </style>
</head>
<body>
  <canvas id="gridworld"></canvas>
  <script src="reinforcement-learning.js"></script>
</body>
</html>

2. Creating the Gridworld:

  • Define the grid, agent, and goal.
  • Set up the rendering function to visualize the environment.

3. Q-Learning Implementation:

  • Q-learning algorithm and how it’s applied:
    • Initialize Q-table and hyperparameters.
    • Implement the Q-learning update equation.

4. Interaction with the Environment:

  • Agent interacts with the environment and learns:
    • Implement the epsilon-greedy exploration strategy.
    • Define the takeAction function to execute actions.

5. Training the Agent:

  • Train the agent over multiple episodes:
    • Iterate through episodes and interactions.
    • Update exploration rate to balance exploration and exploitation.

6. Visualizing the Learning:

  • Visualize the agent’s learning progress:
    • Update the rendering function to show the agent’s movements and Q-values.

These are the steps we require for reinforcement-learning. Here is the complete code snippet for the steps that we mentioned above.

Code Snippet

reinforcement-learning.js

const canvas = document.getElementById('gridworld');
const ctx = canvas.getContext('2d');

const grid = [
  ['EMPTY', 'EMPTY', 'EMPTY', 'EMPTY', 'GOAL'],
  ['EMPTY', 'WALL', 'EMPTY', 'WALL', 'EMPTY'],
  ['EMPTY', 'EMPTY', 'EMPTY', 'EMPTY', 'EMPTY'],
  ['WALL', 'EMPTY', 'WALL', 'EMPTY', 'WALL'],
  ['AGENT', 'EMPTY', 'EMPTY', 'EMPTY', 'EMPTY'],
];

const agent = { x: 0, y: 4 };
const goal = { x: 4, y: 0 };

const numStates = grid.length * grid[0].length;
const numActions = 4; // Up, Down, Left, Right
let QTable = tf.zeros([numStates, numActions]);

const learningRate = 0.1;
const discountFactor = 0.9;
let explorationRate = 1.0;
const explorationDecay = 0.995;

function isTerminalState(state) {
  const row = Math.floor(state / grid[0].length);
  const col = state % grid[0].length;
  return (row === goal.y && col === goal.x);
}

// Function to transition to the next state and calculate reward
function takeAction(state, action) {
  const row = Math.floor(state / grid[0].length);
  const col = state % grid[0].length;

  let newRow = row;
  let newCol = col;

  if (action === 0 && row > 0) {
    newRow -= 1; // Up
  } else if (action === 1 && row < grid.length - 1) {
    newRow += 1; // Down
  } else if (action === 2 && col > 0) {
    newCol -= 1; // Left
  } else if (action === 3 && col < grid[0].length - 1) {
    newCol += 1; // Right
  }

  // Check if the new position is within the grid boundaries
  if (newRow >= 0 && newRow < grid.length && newCol >= 0 && newCol < grid[0].length) {
    const nextState = newRow * grid[0].length + newCol;
    let reward = -1; // Default reward for non-goal states

    if (isTerminalState(nextState)) {
      reward = 10; // Reward for reaching the goal
    } else if (grid[newRow][newCol] === 'WALL') {
      reward = -5; // Penalty for hitting a wall
    }

    return { nextState, reward };
  } else {
    // Invalid action, stay in the current state with a penalty
    return { nextState: state, reward: -10 };
  }
}

function qLearningUpdate(state, action, reward, nextState) {
  const currentQ = QTable.arraySync()[state][action];

  const maxNextQ = tf.max(QTable.gather(tf.tensor1d([nextState], 'int32'))).arraySync();
  const updatedQ = currentQ + learningRate * (reward + discountFactor * maxNextQ - currentQ);

  const qValues = QTable.arraySync();
  qValues[state][action] = updatedQ;

  // Create a new tensor with the updated Q-values
  QTable.dispose(); // Dispose of the old tensor
  QTable = tf.tensor2d(qValues); // Assign the new tensor to QTable
}

function epsilonGreedy(state) {
  if (Math.random() < explorationRate) {
    return Math.floor(Math.random() * numActions);
  } else {
    return tf.argMax(QTable.gather(tf.tensor1d([state], 'int32'))).arraySync();
  }
}

const numEpisodes = 500;
for (let episode = 0; episode < numEpisodes; episode++) {
  let state = agent.y * grid[0].length + agent.x;

  while (!isTerminalState(state)) {
    const action = epsilonGreedy(state);
    const { nextState, reward } = takeAction(state, action);
    qLearningUpdate(state, action, reward, nextState);
    state = nextState;
  }

  explorationRate *= explorationDecay;
}

// Rendering logic
function render() {
  ctx.clearRect(0, 0, canvas.width, canvas.height);

  const cellWidth = canvas.width / grid[0].length;
  const cellHeight = canvas.height / grid.length;

  // Render the grid, agent, and goal
  for (let row = 0; row < grid.length; row++) {
    for (let col = 0; col < grid[row].length; col++) {
      let color = 'white';
      if (grid[row][col] === 'WALL') {
        color = 'black';
      } else if (row === agent.y && col === agent.x) {
        color = 'blue';
      } else if (row === goal.y && col === goal.x) {
        color = 'green';
      }
      ctx.fillStyle = color;
      ctx.fillRect(col * cellWidth, row * cellHeight, cellWidth, cellHeight);
    }
  }

  // Render Q-values as text
  ctx.fillStyle = 'white';
  ctx.font = '12px Arial';
  for (let state = 0; state < numStates; state++) {
    const qValues = QTable.arraySync()[state];
    const row = Math.floor(state / grid[0].length);
    const col = state % grid[0].length;
    ctx.fillText(`Q(${row},${col}): ${qValues}`, col * cellWidth, row * cellHeight + cellHeight - 5);
  }

  requestAnimationFrame(render);
}

console.log(`Episode: ${episode}`);
render();

Conclusion

This will allow you to see the progress of the training in the browser / browser’s console. If you notice that the episode numbers are increasing, it indicates that the learning process is progressing. Additionally, you could try reducing the number of episodes further and see if that allows the learning process to complete faster and display the results in the browser.

Please note that the provided code is for educational purposes and might not be suitable for all environments or problem domains. You should customize and adjust the logic according to your specific requirements and problem space.

We encourage our readers to treat each other respectfully and constructively. Thank you for taking the time to read this blog post to the end. We look forward to your contributions. Let’s make something great together! What do you think? Please vote and post your comments.

Amiya Pattanaik
Amiya Pattanaik

Amiya is a Product Engineering Director focus on Product Development, Quality Engineering & User Experience. He writes his experiences here.