From 1fbddd483c1bffb74a0a814fd838c9e8e9362a71 Mon Sep 17 00:00:00 2001 From: Rune Harlyk Date: Fri, 10 Oct 2025 21:14:59 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Adds=20option=20to=20control=20sim?= =?UTF-8?q?=20using=20web=20app?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- simulation/play.py | 192 ++++++++++++++++++++++---------- simulation/src/controllers.py | 199 ++++++++++++++++++++++++++++++++++ 2 files changed, 331 insertions(+), 60 deletions(-) create mode 100644 simulation/src/controllers.py diff --git a/simulation/play.py b/simulation/play.py index d73ba3a..34dc553 100644 --- a/simulation/play.py +++ b/simulation/play.py @@ -1,81 +1,153 @@ +#!/usr/bin/env python3 import time import numpy as np import pybullet as p +import asyncio +import argparse +import sys +from typing import Optional from src.robot.kinematics import Kinematics, BodyStateT, KinConfig from src.robot.gait import GaitController, GaitStateT, GaitType, default_offset, default_stand_frac from src.envs.quadruped_env import QuadrupedEnv, TerrainType +from src.controllers import Controller, GUIController, WebSocketController -print("Initializing Spot Micro simulation...") -try: - env = QuadrupedEnv(terrain_type=TerrainType.FLAT) - print("Environment created successfully") - print(f"Robot ID: {env.robot.robot_id}") - print(f"Number of joints: {env.robot.get_observation().shape[0]}") - # Print joint names - print("\nJoint names:") - num_joints = p.getNumJoints(env.robot.robot_id) - for i in range(num_joints): - joint_info = p.getJointInfo(env.robot.robot_id, i) - joint_name = joint_info[1].decode("utf-8") - joint_type = joint_info[2] - print(f"Joint {i}: {joint_name} (type: {joint_type})") +class SpotMicroSimulation: + def __init__( + self, controller: Controller, env: Optional[QuadrupedEnv] = None, terrain_type: TerrainType = TerrainType.FLAT + ): + print("Initializing Spot Micro simulation...") + try: + if env is not None: + self.env = env + print("Using existing environment") + else: + self.env = QuadrupedEnv(terrain_type=terrain_type) + print("Environment created successfully") - print("Simulation ready! Use the GUI sliders to control the robot.") -except Exception as e: - print(f"Error creating environment: {e}") - import traceback + print(f"Robot ID: {self.env.robot.robot_id}") + print(f"Number of joints: {self.env.robot.get_observation().shape[0]}") - traceback.print_exc() - exit(1) + num_joints = p.getNumJoints(self.env.robot.robot_id) + print("\nJoint names:") + for i in range(num_joints): + joint_info = p.getJointInfo(self.env.robot.robot_id, i) + joint_name = joint_info[1].decode("utf-8") + joint_type = joint_info[2] + print(f"Joint {i}: {joint_name} (type: {joint_type})") -joint_directions = np.array([-1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1]) + print("Simulation ready!") + except Exception as e: + print(f"Error creating environment: {e}") + import traceback -kinematics = Kinematics() + traceback.print_exc() + sys.exit(1) -standby = KinConfig.default_feet_positions[:4, :3] + self.controller = controller + self.joint_directions = np.array([-1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1]) + self.kinematics = Kinematics() -body_state = BodyStateT( - omega=0, - phi=0, - psi=0, - xm=0, - ym=KinConfig.default_body_height, - zm=0, - px=0, - py=0, - pz=0, - feet=standby, - default_feet=standby, -) + standby = KinConfig.default_feet_positions[:4, :3] -gait_state = GaitStateT( - step_height=KinConfig.default_step_height, - step_x=0, - step_z=0, - step_angle=0, - step_velocity=1, - step_depth=KinConfig.default_step_depth, - stand_frac=default_stand_frac[GaitType.TROT_GATE], - offset=default_offset[GaitType.TROT_GATE], - gait_type=GaitType.TROT_GATE, -) + self.body_state = BodyStateT( + omega=0, + phi=0, + psi=0, + xm=0, + ym=KinConfig.default_body_height, + zm=0, + px=0, + py=0, + pz=0, + feet=standby, + default_feet=standby, + ) -gait = GaitController(standby) + self.gait_state = GaitStateT( + step_height=KinConfig.default_step_height, + step_x=0, + step_z=0, + step_angle=0, + step_velocity=1, + step_depth=KinConfig.default_step_depth, + stand_frac=default_stand_frac[GaitType.TROT_GATE], + offset=default_offset[GaitType.TROT_GATE], + gait_type=GaitType.TROT_GATE, + ) -dt = 1.0 / 240 -while True: - env.gui.update_gait_state(gait_state) - env.gui.update_body_state(body_state) - env.gui.update() + self.gait = GaitController(standby) + self.dt = 1.0 / 240 - gait.step(gait_state, body_state, dt) - joints = kinematics.inverse_kinematics(body_state) - joints = joints * joint_directions + def step(self): + self.controller.update(self.body_state, self.gait_state, self.dt) - _, _, done, truncated, _ = env.step(joints) - # if done or truncated: - # env.reset() + self.gait.step(self.gait_state, self.body_state, self.dt) + joints = self.kinematics.inverse_kinematics(self.body_state) + joints = joints * self.joint_directions - time.sleep(dt) + _, _, done, truncated, _ = self.env.step(joints) + + return joints, done, truncated + + def run_sync(self): + try: + while self.controller.is_running(): + joints, done, truncated = self.step() + time.sleep(self.dt) + except KeyboardInterrupt: + print("\n[*] Shutting down...") + + async def run_async(self): + try: + while self.controller.is_running(): + joints, done, truncated = self.step() + + if isinstance(self.controller, WebSocketController): + await self.controller.broadcast_angles(joints) + + await asyncio.sleep(self.dt) + except KeyboardInterrupt: + print("\n[*] Shutting down...") + + +def main(): + parser = argparse.ArgumentParser(description="Spot Micro Interactive Control Server") + parser.add_argument("--port", type=int, default=8765, help="WebSocket server port (default: 8765)") + parser.add_argument("--mode", choices=["gui", "websocket"], default="gui", help="Control mode (default: gui)") + parser.add_argument("--terrain", choices=["flat", "maze", "terrain"], default="flat", help="Terrain type") + + args = parser.parse_args() + + terrain_map = {"flat": TerrainType.FLAT, "maze": TerrainType.MAZE, "terrain": TerrainType.TERRAIN} + terrain_type = terrain_map.get(args.terrain, TerrainType.FLAT) + + if args.mode == "websocket": + controller = WebSocketController(port=args.port) + sim = SpotMicroSimulation(controller, terrain_type=terrain_type) + + async def run(): + server = await controller.start_server() + try: + await sim.run_async() + except KeyboardInterrupt: + print("\n[!] Shutting down server...") + controller.running = False + server.close() + await server.wait_closed() + print("[+] Server stopped") + + asyncio.run(run()) + else: + from src.envs.quadruped_env import QuadrupedEnv + + env = QuadrupedEnv(terrain_type=terrain_type) + controller = GUIController(env) + sim = SpotMicroSimulation(controller, env=env) + print("Use the GUI sliders to control the robot.") + sim.run_sync() + + +if __name__ == "__main__": + main() diff --git a/simulation/src/controllers.py b/simulation/src/controllers.py new file mode 100644 index 0000000..ca2ac97 --- /dev/null +++ b/simulation/src/controllers.py @@ -0,0 +1,199 @@ +import time +import asyncio +import websockets +import json +import numpy as np +from typing import Dict, Any +from abc import ABC, abstractmethod + +from src.robot.kinematics import BodyStateT, KinConfig +from src.robot.gait import GaitStateT, GaitType, default_offset, default_stand_frac + + +class Controller(ABC): + @abstractmethod + def update(self, body_state: BodyStateT, gait_state: GaitStateT, dt: float): + pass + + @abstractmethod + def is_running(self) -> bool: + pass + + +class GUIController(Controller): + def __init__(self, env): + self.env = env + + def update(self, body_state: BodyStateT, gait_state: GaitStateT, dt: float): + self.env.gui.update_gait_state(gait_state) + self.env.gui.update_body_state(body_state) + self.env.gui.update() + + def is_running(self) -> bool: + return True + + +class WebSocketController(Controller): + def __init__(self, port: int = 8765): + self.port = port + self.running = False + self.connected_clients = set() + + self.cmd_lx = 0.0 + self.cmd_ly = 0.0 + self.cmd_rx = 0.0 + self.cmd_ry = 0.0 + self.cmd_h = 0.0 + self.cmd_s = 0.0 + self.cmd_s1 = 0.0 + + self.motion_mode = "rest" + self.current_gait_type = GaitType.TROT_GATE + + self.last_broadcast_time = time.time() + + print(f"[*] WebSocket Controller initialized") + print(f" Port: {port}") + + def handle_input(self, input_data: list): + if len(input_data) >= 7: + self.cmd_lx = float(input_data[0]) + self.cmd_ly = float(input_data[1]) + self.cmd_rx = float(input_data[2]) + self.cmd_ry = float(input_data[3]) + self.cmd_h = float(input_data[4]) + self.cmd_s = float(input_data[5]) + self.cmd_s1 = float(input_data[6]) + + def handle_mode(self, mode: int): + modes = {0: "deactivated", 1: "idle", 2: "calibration", 3: "rest", 4: "stand", 5: "walk"} + if mode in modes: + self.motion_mode = modes[mode] + print(f"[*] Mode changed to: {self.motion_mode}") + else: + print(f"[!] Invalid mode: {mode}") + + def handle_walk_gait(self, gait: int): + if gait == 0: + self.current_gait_type = GaitType.TROT_GATE + print(f"[*] Gait changed to: TROT") + return default_offset[GaitType.TROT_GATE], default_stand_frac[GaitType.TROT_GATE] + elif gait == 1: + self.current_gait_type = GaitType.CRAWL_GATE + print(f"[*] Gait changed to: CRAWL") + return default_offset[GaitType.CRAWL_GATE], default_stand_frac[GaitType.CRAWL_GATE] + else: + print(f"[!] Invalid gait: {gait}") + return None, None + + def update(self, body_state: BodyStateT, gait_state: GaitStateT, dt: float): + if self.motion_mode == "walk": + body_state["ym"] = KinConfig.min_body_height + self.cmd_h * KinConfig.body_height_range + body_state["psi"] = self.cmd_ry * KinConfig.max_pitch + + gait_state["step_height"] = ( + self.cmd_s1 * KinConfig.max_step_height if self.cmd_s1 != 0 else KinConfig.default_step_height + ) + gait_state["step_x"] = self.cmd_ly * KinConfig.max_step_length + gait_state["step_z"] = -self.cmd_lx * KinConfig.max_step_length + gait_state["step_velocity"] = self.cmd_s if self.cmd_s != 0 else 1.0 + gait_state["step_angle"] = self.cmd_rx + gait_state["step_depth"] = KinConfig.default_step_depth + + elif self.motion_mode == "stand": + body_state["ym"] = KinConfig.min_body_height + self.cmd_h * KinConfig.body_height_range + body_state["xm"] = self.cmd_ly * KinConfig.max_body_shift_x + body_state["zm"] = self.cmd_lx * KinConfig.max_body_shift_z + body_state["phi"] = self.cmd_rx * KinConfig.max_roll + body_state["psi"] = self.cmd_ry * KinConfig.max_pitch + + elif self.motion_mode in ["rest", "idle", "calibration"]: + gait_state["step_x"] = 0.0 + gait_state["step_z"] = 0.0 + gait_state["step_angle"] = 0.0 + + async def handle_client(self, websocket, path): + client_addr = websocket.remote_address + print(f"[+] Client connected: {client_addr}") + self.connected_clients.add(websocket) + + try: + async for message in websocket: + try: + data = json.loads(message) + + if not isinstance(data, list) or len(data) < 1: + continue + + msg_type = data[0] + + if msg_type == 0: + if len(data) >= 2: + event = data[1] + print(f"[*] Client subscribed to: {event}") + + elif msg_type == 1: + if len(data) >= 2: + event = data[1] + print(f"[*] Client unsubscribed from: {event}") + + elif msg_type == 2: + if len(data) >= 3: + event = data[1] + payload = data[2] + await self.handle_event(websocket, event, payload) + + elif msg_type == 4: + await websocket.send(json.dumps([4])) + + except json.JSONDecodeError: + print(f"[!] Invalid JSON from {client_addr}") + except Exception as e: + print(f"[!] Error handling message: {e}") + + except websockets.exceptions.ConnectionClosed: + print(f"[-] Client disconnected: {client_addr}") + finally: + self.connected_clients.discard(websocket) + + async def handle_event(self, websocket, event: str, data: Any): + if event == "input": + if isinstance(data, list) and len(data) >= 7: + self.handle_input(data) + + elif event == "mode": + self.handle_mode(data) + + elif event == "walk_gait": + self.handle_walk_gait(data) + + async def broadcast_angles(self, joint_angles: np.ndarray): + if self.connected_clients and time.time() - self.last_broadcast_time > 0.1: + state_message = json.dumps([2, "angles", joint_angles.tolist()]) + + disconnected = set() + for client in self.connected_clients: + try: + await client.send(state_message) + except websockets.exceptions.ConnectionClosed: + disconnected.add(client) + + self.connected_clients -= disconnected + self.last_broadcast_time = time.time() + + async def start_server(self): + print(f"[*] Starting WebSocket server on port {self.port}") + self.running = True + + server = await websockets.serve(self.handle_client, "0.0.0.0", self.port, ping_interval=20, ping_timeout=10) + + print(f"[+] Server running on ws://0.0.0.0:{self.port}") + print(f"[*] Connect from same PC: ws://localhost:{self.port}") + print(f"[*] Connect from network: ws://:{self.port}") + print("[*] Ready for controller connections!") + print("[*] Use the controller app to connect and control the robot") + + return server + + def is_running(self) -> bool: + return self.running