From 174d77a9fd99473146169e9039a6d056bc0bcb50 Mon Sep 17 00:00:00 2001 From: Rune Harlyk Date: Fri, 10 Oct 2025 20:02:18 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Updates=20training=20script=20with?= =?UTF-8?q?=20stablebaseline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- simulation/train.py | 267 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 260 insertions(+), 7 deletions(-) diff --git a/simulation/train.py b/simulation/train.py index 83f7cb6..88594aa 100644 --- a/simulation/train.py +++ b/simulation/train.py @@ -1,15 +1,268 @@ -from src.envs.environment import QuadrupedEnv -from training.trainer import Trainer +import argparse +import os +import gymnasium as gym +from stable_baselines3 import PPO, SAC +from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize +from stable_baselines3.common.monitor import Monitor +import numpy as np +import torch -import resources as resources +from src.envs.quadruped_env import QuadrupedEnv, TerrainType -render = True + +def make_env(terrain_type=TerrainType.FLAT, render_mode="rgb_array"): + def _init(): + env = QuadrupedEnv(terrain_type=terrain_type, render_mode=render_mode) + env = Monitor(env) + return env + + return _init + + +def train_ppo( + total_timesteps=1_000_000, + learning_rate=3e-4, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + ent_coef=0.01, + vf_coef=0.5, + max_grad_norm=0.5, + save_dir="models/ppo", + log_dir="logs/ppo", + eval_freq=10000, + save_freq=50000, + terrain_type=TerrainType.FLAT, +): + os.makedirs(save_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + os.makedirs(f"{log_dir}/eval", exist_ok=True) + + print("Creating training environment...") + env = DummyVecEnv([make_env(terrain_type=terrain_type)]) + env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0) + + print("Creating evaluation environment...") + eval_env = DummyVecEnv([make_env(terrain_type=terrain_type)]) + eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10.0) + + checkpoint_callback = CheckpointCallback( + save_freq=save_freq, + save_path=save_dir, + name_prefix="ppo_checkpoint", + save_vecnormalize=True, + ) + + eval_callback = EvalCallback( + eval_env, + best_model_save_path=f"{save_dir}/best", + log_path=f"{log_dir}/eval", + eval_freq=eval_freq, + deterministic=True, + render=False, + ) + + print("Creating PPO model...") + model = PPO( + "MlpPolicy", + env, + learning_rate=learning_rate, + n_steps=n_steps, + batch_size=batch_size, + n_epochs=n_epochs, + gamma=gamma, + gae_lambda=gae_lambda, + clip_range=clip_range, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + verbose=1, + tensorboard_log=log_dir, + policy_kwargs=dict( + net_arch=[dict(pi=[256, 256], vf=[256, 256])], + activation_fn=torch.nn.ReLU, + ), + ) + + print(f"Starting PPO training for {total_timesteps} timesteps...") + model.learn( + total_timesteps=total_timesteps, + callback=[checkpoint_callback, eval_callback], + progress_bar=True, + ) + + print(f"Saving final model to {save_dir}/ppo_final") + model.save(f"{save_dir}/ppo_final") + env.save(f"{save_dir}/ppo_final_vecnormalize.pkl") + + print("Training complete!") + return model, env + + +def train_sac( + total_timesteps=1_000_000, + learning_rate=3e-4, + buffer_size=300000, + learning_starts=10000, + batch_size=256, + tau=0.005, + gamma=0.99, + train_freq=1, + gradient_steps=1, + ent_coef="auto", + save_dir="models/sac", + log_dir="logs/sac", + eval_freq=10000, + save_freq=50000, + terrain_type=TerrainType.FLAT, +): + os.makedirs(save_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + os.makedirs(f"{log_dir}/eval", exist_ok=True) + + print("Creating training environment...") + env = DummyVecEnv([make_env(terrain_type=terrain_type)]) + env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0) + + print("Creating evaluation environment...") + eval_env = DummyVecEnv([make_env(terrain_type=terrain_type)]) + eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10.0) + + checkpoint_callback = CheckpointCallback( + save_freq=save_freq, + save_path=save_dir, + name_prefix="sac_checkpoint", + save_vecnormalize=True, + ) + + eval_callback = EvalCallback( + eval_env, + best_model_save_path=f"{save_dir}/best", + log_path=f"{log_dir}/eval", + eval_freq=eval_freq, + deterministic=True, + render=False, + ) + + print("Creating SAC model...") + model = SAC( + "MlpPolicy", + env, + learning_rate=learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + train_freq=train_freq, + gradient_steps=gradient_steps, + ent_coef=ent_coef, + verbose=1, + tensorboard_log=log_dir, + policy_kwargs=dict( + net_arch=dict(pi=[256, 256], qf=[256, 256]), + activation_fn=torch.nn.ReLU, + ), + ) + + print(f"Starting SAC training for {total_timesteps} timesteps...") + model.learn( + total_timesteps=total_timesteps, + callback=[checkpoint_callback, eval_callback], + progress_bar=True, + ) + + print(f"Saving final model to {save_dir}/sac_final") + model.save(f"{save_dir}/sac_final") + env.save(f"{save_dir}/sac_final_vecnormalize.pkl") + + print("Training complete!") + return model, env def main(): - env = QuadrupedEnv(resources.getDataPath() + "/spot.urdf") - trainer = Trainer(env, render) - trainer.train() + parser = argparse.ArgumentParser(description="Train quadruped robot with RL") + parser.add_argument( + "--algo", + type=str, + default="ppo", + choices=["ppo", "sac", "both"], + help="Algorithm to use (ppo, sac, or both)", + ) + parser.add_argument( + "--timesteps", + type=int, + default=1_000_000, + help="Total timesteps for training", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=3e-4, + help="Learning rate", + ) + parser.add_argument( + "--terrain", + type=str, + default="flat", + choices=["flat", "planar_reflection", "terrain", "maze"], + help="Terrain type", + ) + parser.add_argument( + "--save-dir", + type=str, + default="models", + help="Directory to save models", + ) + parser.add_argument( + "--log-dir", + type=str, + default="logs", + help="Directory to save logs", + ) + + args = parser.parse_args() + + terrain_map = { + "flat": TerrainType.FLAT, + "planar_reflection": TerrainType.PLANAR_REFLECTION, + "terrain": TerrainType.TERRAIN, + "maze": TerrainType.MAZE, + } + terrain_type = terrain_map[args.terrain] + + print(f"\n{'='*50}") + print(f"Training Configuration:") + print(f" Algorithm: {args.algo}") + print(f" Total timesteps: {args.timesteps:,}") + print(f" Learning rate: {args.learning_rate}") + print(f" Terrain: {args.terrain}") + print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") + print(f"{'='*50}\n") + + if args.algo == "ppo" or args.algo == "both": + print("\n=== Training PPO ===\n") + train_ppo( + total_timesteps=args.timesteps, + learning_rate=args.learning_rate, + save_dir=f"{args.save_dir}/ppo", + log_dir=f"{args.log_dir}/ppo", + terrain_type=terrain_type, + ) + + if args.algo == "sac" or args.algo == "both": + print("\n=== Training SAC ===\n") + train_sac( + total_timesteps=args.timesteps, + learning_rate=args.learning_rate, + save_dir=f"{args.save_dir}/sac", + log_dir=f"{args.log_dir}/sac", + terrain_type=terrain_type, + ) if __name__ == "__main__":