lightning/pl_examples/domain_templates/reinforce_learn_Qnet.py

377 lines
12 KiB
Python
Raw Normal View History

Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
"""
Deep Reinforcement Learning: Deep Q-network (DQN)
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
This example is based on https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
Second-Edition/blob/master/Chapter06/02_dqn_pong.py
The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the
classic CartPole environment.
To run the template just run:
python reinforce_learn_Qnet.py
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up TensorBoard to
see the metrics:
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
tensorboard --logdir default
"""
import argparse
from collections import OrderedDict, deque, namedtuple
from typing import Tuple, List
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import Optimizer
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
import pytorch_lightning as pl
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
class DQN(nn.Module):
"""
Simple MLP network
Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
)
def forward(self, x):
return self.net(x.float())
# Named tuple for storing experience steps gathered in training
Experience = namedtuple(
'Experience', field_names=['state', 'action', 'reward',
'done', 'new_state'])
class ReplayBuffer:
"""
Replay Buffer for storing past experiences allowing the agent to learn from them
Args:
capacity: size of the buffer
"""
def __init__(self, capacity: int) -> None:
self.buffer = deque(maxlen=capacity)
def __len__(self) -> int:
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
return len(self.buffer)
def append(self, experience: Experience) -> None:
"""
Add experience to the buffer
Args:
experience: tuple (state, action, reward, done, new_state)
"""
self.buffer.append(experience)
def sample(self, batch_size: int) -> Tuple:
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))
class RLDataset(IterableDataset):
"""
Iterable Dataset containing the ExperienceBuffer
which will be updated with new experiences during training
Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
"""
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
self.buffer = buffer
self.sample_size = sample_size
def __iter__(self) -> Tuple:
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
for i in range(len(dones)):
yield states[i], actions[i], rewards[i], dones[i], new_states[i]
class Agent:
"""
2020-07-17 06:25:14 +00:00
Base Agent class handling the interaction with the environment
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
Args:
env: training environment
replay_buffer: replay buffer storing experiences
"""
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
self.env = env
self.replay_buffer = replay_buffer
self.reset()
self.state = self.env.reset()
def reset(self) -> None:
"""Resets the environment and updates the state"""
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
self.state = self.env.reset()
def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
"""
Using the given network, decide what action to carry out
using an epsilon-greedy policy
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
action
"""
if np.random.random() < epsilon:
action = self.env.action_space.sample()
else:
state = torch.tensor([self.state])
if device not in ['cpu']:
state = state.cuda(device)
q_values = net(state)
_, action = torch.max(q_values, dim=1)
action = int(action.item())
return action
@torch.no_grad()
def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') -> Tuple[float, bool]:
"""
Carries out a single interaction step between the agent and the environment
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
reward, done
"""
action = self.get_action(net, epsilon, device)
# do step in the environment
new_state, reward, done, _ = self.env.step(action)
exp = Experience(self.state, action, reward, done, new_state)
self.replay_buffer.append(exp)
self.state = new_state
if done:
self.reset()
return reward, done
class DQNLightning(pl.LightningModule):
""" Basic DQN Model """
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
def __init__(self,
replay_size,
warm_start_steps: int,
gamma: float,
eps_start: int,
eps_end: int,
eps_last_frame: int,
sync_rate,
lr: float,
episode_length,
batch_size, **kwargs) -> None:
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
super().__init__()
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
self.replay_size = replay_size
self.warm_start_steps = warm_start_steps
self.gamma = gamma
self.eps_start = eps_start
self.eps_end = eps_end
self.eps_last_frame = eps_last_frame
self.sync_rate = sync_rate
self.lr = lr
self.episode_length = episode_length
self.batch_size = batch_size
self.env = gym.make(self.env)
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
obs_size = self.env.observation_space.shape[0]
n_actions = self.env.action_space.n
self.net = DQN(obs_size, n_actions)
self.target_net = DQN(obs_size, n_actions)
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
self.buffer = ReplayBuffer(self.replay_size)
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
self.agent = Agent(self.env, self.buffer)
self.total_reward = 0
self.episode_reward = 0
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
self.populate(self.warm_start_steps)
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
def populate(self, steps: int = 1000) -> None:
"""
Carries out several random steps through the environment to initially fill
up the replay buffer with experiences
Args:
steps: number of random steps to populate the buffer with
"""
for i in range(steps):
self.agent.play_step(self.net, epsilon=1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Passes in a state `x` through the network and gets the `q_values` of each action as an output
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
Args:
x: environment state
Returns:
q values
"""
output = self.net(x)
return output
def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""
Calculates the mse loss using a mini batch from the replay buffer
Args:
batch: current mini batch of replay data
Returns:
loss
"""
states, actions, rewards, dones, next_states = batch
state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
expected_state_action_values = next_state_values * self.gamma + rewards
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
return nn.MSELoss()(state_action_values, expected_state_action_values)
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
"""
Carries out a single step through the environment to update the replay buffer.
2020-05-07 13:25:54 +00:00
Then calculates loss based on the minibatch received
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
Args:
batch: current mini batch of replay data
nb_batch: batch number
Returns:
Training loss and log metrics
"""
device = self.get_device(batch)
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
epsilon = max(self.eps_end, self.eps_start -
self.global_step + 1 / self.eps_last_frame)
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
# step through environment with agent
reward, done = self.agent.play_step(self.net, epsilon, device)
self.episode_reward += reward
# calculates training loss
loss = self.dqn_mse_loss(batch)
if done:
self.total_reward = self.episode_reward
self.episode_reward = 0
# Soft update of target network
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
if self.global_step % self.sync_rate == 0:
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
self.target_net.load_state_dict(self.net.state_dict())
log = {'total_reward': torch.tensor(self.total_reward).to(device),
'reward': torch.tensor(reward).to(device),
'steps': torch.tensor(self.global_step).to(device)}
return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log})
def configure_optimizers(self) -> List[Optimizer]:
"""Initialize Adam optimizer"""
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
return [optimizer]
def __dataloader(self) -> DataLoader:
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
dataset = RLDataset(self.buffer, self.episode_length)
dataloader = DataLoader(
dataset=dataset,
batch_size=self.batch_size,
sampler=None,
)
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
return dataloader
def train_dataloader(self) -> DataLoader:
"""Get train loader"""
return self.__dataloader()
def get_device(self, batch) -> str:
"""Retrieve device currently being used by minibatch"""
return batch[0].device.index if self.on_gpu else 'cpu'
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
def main(args) -> None:
model = DQNLightning(**vars(args))
Example: Simple RL example using DQN/Lightning (#1232) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * CI: split tests-examples (#990) * CI: split tests-examples * tests without template * comment depends * CircleCI typo * add doctest * update test req. * CI tests * setup macOS * longer train * lover pred acc * fix model * rename default model * lower tests acc * typo * imports * fix test optimizer * update calls * fix Win * lower Drone image * fix call * pytorch image * fix test * add dev image * add dev image * update image * drone volume * lint * update test notes * rename tests/models >> tests/base * group models * conftest * optim imports * typos * fix import * fix tests * install AMP * tests * fix import * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * updated example image * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename * Disable validation when val_percent_check=0 (#1251) * fix disable validation * add test * update changelog * update docs for val_percent_check * make "fast training" docs consistent * calling self.forward() -> self() (#1211) * self.forward() -> self() * update changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Fix requirements-extra.txt Trains package to release version (#1229) * Fix requirement-extra use released Trains package * Update README.md add Trains and links to the external Visualization section Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Remove unnecessary parameters to super() in documentation and source code (#1240) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update deprecation warning (#1258) * update docs for progress bat values (#1253) * lower timeouts for inactive issues (#1250) * update contrib list (#1241) Co-authored-by: William Falcon <waf2107@columbia.edu> * Fix outdated docs (#1227) * Fix typo (#1224) * drop unused Tox (#1242) * system info (#1234) * system info * update big info * test script * update config * rename script * import path * Changed smoothing in tqdm to decrease variability of time remaining between training / eval (#1194) * Example: Simple RL example using DQN/Lightning * DQN RL Agent using Lightning * Uses Iterable Dataset for Replay Buffer * Buffer is populated by agent as training is carried out, updating the dataset * Applied autopep8 fixes * * Updated line length from 120 to 110 * Update pl_examples/domain_templates/dqn.py simplify get_device method Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pl_examples/domain_templates/dqn.py Re-ordered imports Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Clean up * added module docstring * renamed variables to be more descriptive * Added missing docstrings and type annotations * Added gym to example requirements * Added note to changelog * update types * rename script * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * another rename Co-authored-by: Donal Byrne <Donal.Byrne@xperi.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Martin.B <51887611+bmartinn@users.noreply.github.com> Co-authored-by: Tyler Yep <tyep@stanford.edu> Co-authored-by: Shunta Komatsu <59395084+skmatz@users.noreply.github.com> Co-authored-by: Jack Pertschuk <jackpertschuk@gmail.com>
2020-03-28 20:10:53 +00:00
trainer = pl.Trainer(
gpus=1,
distributed_backend='dp',
early_stop_callback=False,
val_check_interval=100
)
trainer.fit(model)
if __name__ == '__main__':
torch.manual_seed(0)
np.random.seed(0)
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
parser.add_argument("--sync_rate", type=int, default=10,
help="how many frames do we update the target network")
parser.add_argument("--replay_size", type=int, default=1000,
help="capacity of the replay buffer")
parser.add_argument("--warm_start_size", type=int, default=1000,
help="how many samples do we use to fill our buffer at the start of training")
parser.add_argument("--eps_last_frame", type=int, default=1000,
help="what frame should epsilon stop decaying")
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
parser.add_argument("--max_episode_reward", type=int, default=200,
help="max episode reward in the environment")
parser.add_argument("--warm_start_steps", type=int, default=1000,
help="max episode reward in the environment")
args = parser.parse_args()
main(args)