🎨 Update observation space to match real world
This commit is contained in:
+22
-4
@@ -15,7 +15,11 @@ from src.controllers import Controller, GUIController, WebSocketController
|
||||
|
||||
class SpotMicroSimulation:
|
||||
def __init__(
|
||||
self, controller: Controller, env: Optional[QuadrupedEnv] = None, terrain_type: TerrainType = TerrainType.FLAT
|
||||
self,
|
||||
controller: Controller,
|
||||
env: Optional[QuadrupedEnv] = None,
|
||||
terrain_type: TerrainType = TerrainType.FLAT,
|
||||
dt: float = 1.0 / 240,
|
||||
):
|
||||
print("Initializing Spot Micro simulation...")
|
||||
try:
|
||||
@@ -23,7 +27,7 @@ class SpotMicroSimulation:
|
||||
self.env = env
|
||||
print("Using existing environment")
|
||||
else:
|
||||
self.env = QuadrupedEnv(terrain_type=terrain_type)
|
||||
self.env = QuadrupedEnv(terrain_type=terrain_type, dt=dt)
|
||||
print("Environment created successfully")
|
||||
|
||||
print(f"Robot ID: {self.env.robot.robot_id}")
|
||||
@@ -78,7 +82,7 @@ class SpotMicroSimulation:
|
||||
)
|
||||
|
||||
self.gait = GaitController(standby)
|
||||
self.dt = 1.0 / 240
|
||||
self.dt = dt
|
||||
|
||||
def step(self):
|
||||
self.controller.update(self.body_state, self.gait_state, self.dt)
|
||||
@@ -87,10 +91,24 @@ class SpotMicroSimulation:
|
||||
joints = self.kinematics.inverse_kinematics(self.body_state)
|
||||
joints = joints * self.joint_directions
|
||||
|
||||
_, _, done, truncated, _ = self.env.step(joints)
|
||||
obs, _, done, truncated, _ = self.env.step(joints)
|
||||
|
||||
self._print_mpu6050_data(obs)
|
||||
|
||||
return joints, done, truncated
|
||||
|
||||
def _print_mpu6050_data(self, obs):
|
||||
accel = obs[0:3]
|
||||
gyro = obs[3:6]
|
||||
heading = obs[6]
|
||||
altitude = obs[7]
|
||||
|
||||
print(
|
||||
f"MPU6050: Accel({accel[0]:8.3f}, {accel[1]:8.3f}, {accel[2]:8.3f}) "
|
||||
f"Gyro({gyro[0]:8.3f}, {gyro[1]:8.3f}, {gyro[2]:8.3f}) "
|
||||
f"Mag({heading:8.3f}) Baro({altitude:8.3f})"
|
||||
)
|
||||
|
||||
def run_sync(self):
|
||||
try:
|
||||
while self.controller.is_running():
|
||||
|
||||
@@ -38,22 +38,31 @@ class QuadrupedRobot:
|
||||
return [p.getJointInfo(self.robot_id, idx)[1].decode("utf-8") for idx in self.movable_joint_indices]
|
||||
|
||||
def get_observation(self):
|
||||
position, orientation = p.getBasePositionAndOrientation(self.robot_id)
|
||||
orientation = p.getEulerFromQuaternion(orientation)
|
||||
velocity, angular_velocity = p.getBaseVelocity(self.robot_id)
|
||||
joint_states = p.getJointStates(self.robot_id, self.movable_joint_indices)
|
||||
joint_positions = [state[0] for state in joint_states]
|
||||
joint_velocities = [state[1] for state in joint_states]
|
||||
return np.concatenate(
|
||||
[
|
||||
position,
|
||||
orientation,
|
||||
velocity,
|
||||
angular_velocity,
|
||||
joint_positions,
|
||||
joint_velocities,
|
||||
]
|
||||
).astype(np.float32)
|
||||
pos_w, quat_wb = p.getBasePositionAndOrientation(self.robot_id)
|
||||
v_w, w_w = p.getBaseVelocity(self.robot_id)
|
||||
|
||||
R = np.array(p.getMatrixFromQuaternion(quat_wb), dtype=np.float32).reshape(3, 3)
|
||||
|
||||
if hasattr(self, "prev_velocity") and self.prev_velocity is not None:
|
||||
dt = 1.0 / 240.0
|
||||
accel_world = (v_w - self.prev_velocity) / dt
|
||||
else:
|
||||
accel_world = np.array([0.0, 0.0, 0.0])
|
||||
|
||||
accel_body = R.T @ np.asarray(accel_world, dtype=np.float32)
|
||||
gravity_body = R.T @ np.array([0, 0, -9.81], dtype=np.float32)
|
||||
accel_body += gravity_body
|
||||
|
||||
gyro_body = np.degrees(R.T @ np.asarray(w_w, dtype=np.float32))
|
||||
|
||||
euler = p.getEulerFromQuaternion(quat_wb)
|
||||
heading = np.degrees(euler[2])
|
||||
|
||||
altitude = np.array([pos_w[2]], dtype=np.float32)
|
||||
|
||||
self.prev_velocity = np.array(v_w)
|
||||
|
||||
return np.concatenate([accel_body, gyro_body, [heading], altitude]).astype(np.float32)
|
||||
|
||||
def apply_action(self, action):
|
||||
for i, position in enumerate(action):
|
||||
@@ -78,6 +87,7 @@ class QuadrupedEnv(gym.Env):
|
||||
target_velocity: float = 0.5,
|
||||
max_steps: int = 1000,
|
||||
distance_limit: float = 10.0,
|
||||
dt: float = 1.0 / 240,
|
||||
):
|
||||
super().__init__()
|
||||
if render_mode == "human":
|
||||
@@ -85,7 +95,7 @@ class QuadrupedEnv(gym.Env):
|
||||
else:
|
||||
p.connect(p.DIRECT)
|
||||
|
||||
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(36,), dtype=np.float32)
|
||||
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(12,), dtype=np.float32)
|
||||
p.setAdditionalSearchPath(pybullet_data.getDataPath())
|
||||
|
||||
@@ -93,9 +103,10 @@ class QuadrupedEnv(gym.Env):
|
||||
self.render_mode = render_mode
|
||||
self.target_velocity = target_velocity
|
||||
self.max_steps = max_steps
|
||||
self.prev_accel = np.zeros(3)
|
||||
self.estimated_velocity = np.zeros(3)
|
||||
self.current_step = 0
|
||||
|
||||
self.prev_velocity = None
|
||||
self.dt = dt
|
||||
|
||||
self._setup_world()
|
||||
if render_mode == "human":
|
||||
@@ -107,7 +118,7 @@ class QuadrupedEnv(gym.Env):
|
||||
self.robot = QuadrupedRobot("src/resources/spot.urdf")
|
||||
self._load_terrain(self.terrain_type)
|
||||
p.setGravity(0, 0, -9.8)
|
||||
p.setTimeStep(1 / 240)
|
||||
p.setTimeStep(self.dt)
|
||||
if self.render_mode == "human":
|
||||
self.gui = GUI(self.robot.robot_id)
|
||||
else:
|
||||
@@ -157,7 +168,7 @@ class QuadrupedEnv(gym.Env):
|
||||
|
||||
obs = self.robot.get_observation()
|
||||
reward = self.calculate_reward(obs)
|
||||
done = self.is_done(obs)
|
||||
done = self.is_done()
|
||||
truncated = self.current_step >= self.max_steps
|
||||
|
||||
return obs, reward, done, truncated, {}
|
||||
@@ -167,48 +178,42 @@ class QuadrupedEnv(gym.Env):
|
||||
# p.disconnect()
|
||||
|
||||
def calculate_reward(self, obs):
|
||||
position = obs[:3]
|
||||
orientation = obs[3:6]
|
||||
velocity = obs[6:9]
|
||||
angular_velocity = obs[9:12]
|
||||
accel = obs[0:3]
|
||||
gyro = obs[3:6]
|
||||
heading = obs[6]
|
||||
altitude = obs[7]
|
||||
|
||||
forward_velocity = velocity[0]
|
||||
self.estimated_velocity = self.estimated_velocity + self.prev_accel * self.dt
|
||||
|
||||
self.prev_accel = accel.copy()
|
||||
|
||||
forward_velocity = self.estimated_velocity[0]
|
||||
velocity_reward = -abs(forward_velocity - self.target_velocity)
|
||||
|
||||
height_penalty = -abs(position[2] - 0.3) * 0.5
|
||||
height_penalty = -abs(altitude - 0.3) * 0.5
|
||||
|
||||
roll, pitch, yaw = orientation
|
||||
orientation_penalty = -(abs(roll) + abs(pitch)) * 1.0
|
||||
orientation_penalty = -(abs(gyro[0]) + abs(gyro[1])) * 0.1
|
||||
|
||||
angular_penalty = -np.sum(np.square(angular_velocity)) * 0.05
|
||||
angular_penalty = -np.sum(np.square(gyro)) * 0.01
|
||||
|
||||
sideways_velocity_penalty = -abs(velocity[1]) * 0.3
|
||||
lateral_acc_penalty = -abs(accel[1]) * 0.01
|
||||
|
||||
if self.prev_velocity is not None:
|
||||
dt = 1.0 / 240.0
|
||||
acceleration = (velocity - self.prev_velocity) / dt
|
||||
lateral_acc_penalty = -abs(acceleration[1]) * 0.01
|
||||
vertical_acc_penalty = -abs(acceleration[2]) * 0.01
|
||||
else:
|
||||
lateral_acc_penalty = 0
|
||||
vertical_acc_penalty = 0
|
||||
|
||||
self.prev_velocity = velocity.copy()
|
||||
vertical_acc_penalty = -abs(accel[2] + 9.81) * 0.01
|
||||
|
||||
total_reward = (
|
||||
velocity_reward
|
||||
+ height_penalty
|
||||
+ orientation_penalty
|
||||
+ angular_penalty
|
||||
+ sideways_velocity_penalty
|
||||
+ lateral_acc_penalty
|
||||
+ vertical_acc_penalty
|
||||
)
|
||||
return total_reward
|
||||
|
||||
def is_done(self, obs):
|
||||
position = obs[:3]
|
||||
orientation = obs[3:6]
|
||||
def is_done(self):
|
||||
position, orientation = p.getBasePositionAndOrientation(self.robot.robot_id)
|
||||
orientation = p.getEulerFromQuaternion(orientation)
|
||||
|
||||
return self._is_fallen(orientation) or self._is_distance_limit_exceeded(position)
|
||||
|
||||
def _is_distance_limit_exceeded(self, position):
|
||||
@@ -216,9 +221,4 @@ class QuadrupedEnv(gym.Env):
|
||||
return distance > self._distance_limit
|
||||
|
||||
def _is_fallen(self, orientation):
|
||||
# orientation = self.spot.GetBaseOrientation()
|
||||
# rot_mat = self._pybullet_client.getMatrixFromQuaternion(orientation)
|
||||
# local_up = rot_mat[6:]
|
||||
# pos = self.spot.GetBasePosition()
|
||||
# return (np.dot(np.asarray([0, 0, 1]), np.asarray(local_up)) < 0.55)
|
||||
return abs(orientation[0]) > 0.85 or abs(orientation[1]) > 0.85
|
||||
|
||||
Reference in New Issue
Block a user