diff --git a/simulation/requirements.txt b/simulation/requirements.txt index 436d9c2..59c5e12 100644 --- a/simulation/requirements.txt +++ b/simulation/requirements.txt @@ -7,4 +7,8 @@ matplotlib pybullet websockets msgpack -asyncio \ No newline at end of file +asyncio +gymnasium +stable-baselines3[extra]>=2.0.0 +tensorboard +tqdm \ No newline at end of file diff --git a/simulation/src/envs/quadruped_env.py b/simulation/src/envs/quadruped_env.py index bfc81c9..e6f7657 100644 --- a/simulation/src/envs/quadruped_env.py +++ b/simulation/src/envs/quadruped_env.py @@ -41,7 +41,7 @@ class QuadrupedRobot: position, orientation = p.getBasePositionAndOrientation(self.robot_id) orientation = p.getEulerFromQuaternion(orientation) velocity, angular_velocity = p.getBaseVelocity(self.robot_id) - joint_states = p.getJointStates(self.robot_id, range(p.getNumJoints(self.robot_id))) + joint_states = p.getJointStates(self.robot_id, self.movable_joint_indices) joint_positions = [state[0] for state in joint_states] joint_velocities = [state[1] for state in joint_states] return np.concatenate( @@ -53,7 +53,7 @@ class QuadrupedRobot: joint_positions, joint_velocities, ] - ) + ).astype(np.float32) def apply_action(self, action): for i, position in enumerate(action): @@ -71,29 +71,35 @@ class QuadrupedRobot: class QuadrupedEnv(gym.Env): - def __init__(self, terrain_type: TerrainType = TerrainType.FLAT, render_mode: str = "human"): + def __init__( + self, + terrain_type: TerrainType = TerrainType.FLAT, + render_mode: str = "human", + target_velocity: float = 0.5, + max_steps: int = 1000, + distance_limit: float = 10.0, + ): super().__init__() if render_mode == "human": p.connect(p.GUI) else: p.connect(p.DIRECT) - self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(48,)) - self.action_space = gym.spaces.Box(low=-1, high=1, shape=(18,)) + self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(36,), dtype=np.float32) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(12,), dtype=np.float32) p.setAdditionalSearchPath(pybullet_data.getDataPath()) self.terrain_type = terrain_type self.render_mode = render_mode - self.target_velocity = 0.5 - self.max_steps = float("inf") + self.target_velocity = target_velocity + self.max_steps = max_steps self.current_step = 0 self._setup_world() if render_mode == "human": self.env_start_state = p.saveState() - # env parameters - self._distance_limit = float("inf") + self._distance_limit = distance_limit def _setup_world(self): self.robot = QuadrupedRobot("src/resources/spot.urdf")