⚡️ Makes training parallelized
This commit is contained in:
+35
-8
@@ -3,7 +3,7 @@ import os
|
||||
import gymnasium as gym
|
||||
from stable_baselines3 import PPO, SAC
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -37,13 +37,15 @@ def train_ppo(
|
||||
eval_freq=10000,
|
||||
save_freq=50000,
|
||||
terrain_type=TerrainType.FLAT,
|
||||
n_envs=8,
|
||||
use_gpu=True,
|
||||
):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
os.makedirs(f"{log_dir}/eval", exist_ok=True)
|
||||
|
||||
print("Creating training environment...")
|
||||
env = DummyVecEnv([make_env(terrain_type=terrain_type)])
|
||||
print(f"Creating {n_envs} parallel training environments...")
|
||||
env = SubprocVecEnv([make_env(terrain_type=terrain_type) for _ in range(n_envs)])
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
||||
|
||||
print("Creating evaluation environment...")
|
||||
@@ -66,7 +68,8 @@ def train_ppo(
|
||||
render=False,
|
||||
)
|
||||
|
||||
print("Creating PPO model...")
|
||||
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
||||
print(f"Creating PPO model on device: {device}")
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
@@ -82,6 +85,7 @@ def train_ppo(
|
||||
max_grad_norm=max_grad_norm,
|
||||
verbose=1,
|
||||
tensorboard_log=log_dir,
|
||||
device=device,
|
||||
policy_kwargs=dict(
|
||||
net_arch=[dict(pi=[256, 256], vf=[256, 256])],
|
||||
activation_fn=torch.nn.ReLU,
|
||||
@@ -119,13 +123,15 @@ def train_sac(
|
||||
eval_freq=10000,
|
||||
save_freq=50000,
|
||||
terrain_type=TerrainType.FLAT,
|
||||
n_envs=8,
|
||||
use_gpu=True,
|
||||
):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
os.makedirs(f"{log_dir}/eval", exist_ok=True)
|
||||
|
||||
print("Creating training environment...")
|
||||
env = DummyVecEnv([make_env(terrain_type=terrain_type)])
|
||||
print(f"Creating {n_envs} parallel training environments...")
|
||||
env = SubprocVecEnv([make_env(terrain_type=terrain_type) for _ in range(n_envs)])
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
||||
|
||||
print("Creating evaluation environment...")
|
||||
@@ -148,7 +154,8 @@ def train_sac(
|
||||
render=False,
|
||||
)
|
||||
|
||||
print("Creating SAC model...")
|
||||
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
||||
print(f"Creating SAC model on device: {device}")
|
||||
model = SAC(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
@@ -163,6 +170,7 @@ def train_sac(
|
||||
ent_coef=ent_coef,
|
||||
verbose=1,
|
||||
tensorboard_log=log_dir,
|
||||
device=device,
|
||||
policy_kwargs=dict(
|
||||
net_arch=dict(pi=[256, 256], qf=[256, 256]),
|
||||
activation_fn=torch.nn.ReLU,
|
||||
@@ -224,6 +232,17 @@ def main():
|
||||
default="logs",
|
||||
help="Directory to save logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-envs",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of parallel environments (default: 8, max: 16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu-only",
|
||||
action="store_true",
|
||||
help="Force CPU training even if GPU is available",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -235,13 +254,17 @@ def main():
|
||||
}
|
||||
terrain_type = terrain_map[args.terrain]
|
||||
|
||||
use_gpu = not args.cpu_only and torch.cuda.is_available()
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Training Configuration:")
|
||||
print(f" Algorithm: {args.algo}")
|
||||
print(f" Total timesteps: {args.timesteps:,}")
|
||||
print(f" Learning rate: {args.learning_rate}")
|
||||
print(f" Terrain: {args.terrain}")
|
||||
print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
|
||||
print(f" Parallel environments: {args.n_envs}")
|
||||
print(f" Device: {'CUDA (GPU)' if use_gpu else 'CPU'}")
|
||||
print(f" CPU cores available: {os.cpu_count()}")
|
||||
print(f"{'='*50}\n")
|
||||
|
||||
if args.algo == "ppo" or args.algo == "both":
|
||||
@@ -252,6 +275,8 @@ def main():
|
||||
save_dir=f"{args.save_dir}/ppo",
|
||||
log_dir=f"{args.log_dir}/ppo",
|
||||
terrain_type=terrain_type,
|
||||
n_envs=args.n_envs,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
|
||||
if args.algo == "sac" or args.algo == "both":
|
||||
@@ -262,6 +287,8 @@ def main():
|
||||
save_dir=f"{args.save_dir}/sac",
|
||||
log_dir=f"{args.log_dir}/sac",
|
||||
terrain_type=terrain_type,
|
||||
n_envs=args.n_envs,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user