Documentation Index
Fetch the complete documentation index at: https://wb-21fd5541-john-wbdocs-2044-rename-serverless-products.mintlify.app/llms.txt
Use this file to discover all available pages before exploring further.
Stable Baselines 3 (SB3) est un ensemble d’implémentations fiables d’algorithmes d’apprentissage par renforcement en PyTorch. L’intégration SB3 de W&B :
- Enregistre des métriques telles que les pertes et les retours d’épisode.
- Téléverse des vidéos d’agents en train de jouer.
- Enregistre le modèle entraîné.
- Journalise les hyperparamètres du modèle.
- Journalise les histogrammes de gradients du modèle.
Consultez un exemple de run d’entraînement SB3.
Journalisez vos expériences SB3
from wandb.integration.sb3 import WandbCallback
model.learn(..., callback=WandbCallback())
Arguments de WandbCallback
| Argument | Utilisation |
|---|
verbose | Niveau de verbosité de la sortie de sb3 |
model_save_path | Chemin vers le dossier où le modèle sera enregistré. La valeur par défaut est None, donc le modèle n’est pas enregistré |
model_save_freq | Fréquence d’enregistrement du modèle |
gradient_save_freq | Fréquence d’enregistrement des gradients. La valeur par défaut est 0, donc les gradients ne sont pas enregistrés |
L’intégration W&B SB3 utilise les journaux générés par TensorBoard pour journaliser vos métriques
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_name": "CartPole-v1",
}
run = wandb.init(
project="sb3",
config=config,
sync_tensorboard=True, # téléverse automatiquement les métriques TensorBoard de sb3
monitor_gym=True, # téléverse automatiquement les vidéos des agents pendant la partie
save_code=True, # facultatif
)
def make_env():
env = gym.make(config["env_name"])
env = Monitor(env) # enregistre des statistiques comme les retours cumulés
return env
env = DummyVecEnv([make_env])
env = VecVideoRecorder(
env,
f"videos/{run.id}",
record_video_trigger=lambda x: x % 2000 == 0,
video_length=200,
)
model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
gradient_save_freq=100,
model_save_path=f"models/{run.id}",
verbose=2,
),
)
run.finish()