diff --git a/simulation/play.py b/simulation/play.py new file mode 100644 index 0000000..6800ed6 --- /dev/null +++ b/simulation/play.py @@ -0,0 +1,39 @@ +import time +import numpy as np + +from src.robot.kinematics import Kinematics, BodyStateT +from src.robot.gait import GaitController, GaitStateT +from src.envs.quadruped_env import QuadrupedEnv + +env = QuadrupedEnv() + +leg_order = [3, 0, 4, 1, 5, 2] + +kinematics = Kinematics(0.605, 0.01, 1.112, 1.185, 2.075, 0.78) + +standby = kinematics.get_default_feet_pos() + +body_state = BodyStateT(omega=0, phi=0, psi=0, xm=0, ym=0, + zm=0, px=0, py=0, pz=0, feet=standby, default_feet=standby) + +gait_state = GaitStateT(step_height=30, step_x=0, step_z=0, + step_angle=0, step_velocity=1, step_depth=0.002) + +gait = GaitController(standby) + +dt = 1.0 / 240 +while True: + env.gui.update_gait_state(gait_state) + env.gui.update_body_state(body_state) + env.gui.update() + + gait.step(gait_state, body_state, dt) + angles = kinematics.inverse_kinematics(body_state) + print(angles) + joints = angles # angles.reshape(4, 3)[leg_order].flatten() + + _, _, done, truncated, _ = env.step(joints) + # if done or truncated: + # env.reset() + + time.sleep(dt) diff --git a/simulation/simulation/environment.py b/simulation/simulation/environment.py deleted file mode 100644 index 66a773f..0000000 --- a/simulation/simulation/environment.py +++ /dev/null @@ -1,56 +0,0 @@ -import pybullet as p -import pybullet_data -import numpy as np -from simulation.robot import QuadrupedRobot -from simulation.gui import GUI - - -roll_pitch_reward_weight = 0.1 - - -class QuadrupedEnv: - def __init__(self, urdf_path): - p.connect(p.GUI) - p.setAdditionalSearchPath(pybullet_data.getDataPath()) - p.setGravity(0, 0, -9.8) - p.setTimeStep(1 / 240) - self.urdf_path = urdf_path - self.numSolverIterations = 50 - - self.setupWorld() - self.gui = GUI(self.robot.robot_id) - - self.envStartState = p.saveState() - - def setupWorld(self): - p.resetSimulation() - p.setPhysicsEngineParameter(numSolverIterations=self.numSolverIterations) - p.setGravity(0, 0, -9.8) - - self.plane_id = p.loadURDF("plane.urdf") - - self.robot = QuadrupedRobot(self.urdf_path) - - def reset(self): - p.restoreState(self.envStartState) - return self.robot.get_observation() - - def step(self, action): - self.gui.update() - self.robot.apply_action(action) - p.stepSimulation() - obs = self.robot.get_observation() - reward = self.calculate_reward(obs) - done = self.is_done(obs) - return obs, reward, done - - def calculate_reward(self, observation): - reward = 0 - reward += ( - -(abs(observation[0]) + abs(observation[1])) * roll_pitch_reward_weight - ) - return reward - - def is_done(self, observation): - roll, pitch = observation[0:2] - return abs(roll) > 0.5 or abs(pitch) > 0.5 diff --git a/simulation/simulation/kinematic.py b/simulation/simulation/kinematic.py deleted file mode 100644 index 5415365..0000000 --- a/simulation/simulation/kinematic.py +++ /dev/null @@ -1,148 +0,0 @@ -import numpy as np - - -class Kinematic: - def __init__(self) -> None: - - self.l1 = 60.5 - self.l2 = 10 - self.l3 = 100.7 - self.l4 = 118.5 - - self.L = 207.5 - self.W = 78 - - def calculate_inverse_kinematics(self, omega, phi, psi, x, y, z, feet): - Tlf, Trf, Tlb, Trb = self.bodyIK(omega, phi, psi, x, y, z) - - Q = np.linalg.inv(Tlf).dot(feet[0])[:3] - - IK = self.legIK(*Q) - LF = ( - np.rad2deg(np.pi / 2 - IK[0]), - np.rad2deg(np.pi / 3 - IK[1]), - np.rad2deg(np.pi - IK[2]), - ) - - Q = self.Ix.dot(np.linalg.inv(Trf)).dot(feet[1])[:3] - - IK = self.legIK(*Q) - RF = ( - np.rad2deg(np.pi / 2 + IK[0]), - np.rad2deg(2 * np.pi / 3 + IK[1]), - np.rad2deg(IK[2]), - ) - - Q = np.linalg.inv(Tlb).dot(feet[2])[:3] - - IK = self.legIK(*Q) - LB = ( - np.rad2deg(np.pi / 2 + (IK[0])), - np.rad2deg(np.pi / 3 - IK[1]), - np.rad2deg(np.pi - IK[2]), - ) - - Q = self.Ix.dot(np.linalg.inv(Trb)).dot(feet[3])[:3] - - IK = self.legIK(*Q) - RB = ( - np.rad2deg(np.pi / 2 - IK[0]), - np.rad2deg(2 * np.pi / 3 + IK[1]), - np.rad2deg(IK[2]), - ) - return (LF, RF, LB, RB) - - def bodyIK(self, omega, phi, psi, xm, ym, zm): - sHp = np.sin(np.pi / 2) - cHp = np.cos(np.pi / 2) - Rx = np.array( - [ - [1, 0, 0, 0], - [0, np.cos(omega), -np.sin(omega), 0], - [0, np.sin(omega), np.cos(omega), 0], - [0, 0, 0, 1], - ] - ) - Ry = np.array( - [ - [np.cos(phi), 0, np.sin(phi), 0], - [0, 1, 0, 0], - [-np.sin(phi), 0, np.cos(phi), 0], - [0, 0, 0, 1], - ] - ) - Rz = np.array( - [ - [np.cos(psi), -np.sin(psi), 0, 0], - [np.sin(psi), np.cos(psi), 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1], - ] - ) - Rxyz = Rx @ Ry @ Rz - - T = np.array([[0, 0, 0, xm], [0, 0, 0, ym], [0, 0, 0, zm], [0, 0, 0, 0]]) - Tm = T + Rxyz - - return [ - Tm - @ np.array( - [ - [cHp, 0, sHp, self.L / 2], - [0, 1, 0, 0], - [-sHp, 0, cHp, self.W / 2], - [0, 0, 0, 1], - ] - ), - Tm - @ np.array( - [ - [cHp, 0, sHp, self.L / 2], - [0, 1, 0, 0], - [-sHp, 0, cHp, -self.W / 2], - [0, 0, 0, 1], - ] - ), - Tm - @ np.array( - [ - [cHp, 0, sHp, -self.L / 2], - [0, 1, 0, 0], - [-sHp, 0, cHp, self.W / 2], - [0, 0, 0, 1], - ] - ), - Tm - @ np.array( - [ - [cHp, 0, sHp, -self.L / 2], - [0, 1, 0, 0], - [-sHp, 0, cHp, -self.W / 2], - [0, 0, 0, 1], - ] - ), - ] - - def legIK(self, x, y, z): - """ - x/y/z=Position of the Foot in Leg-Space - - F=Length of shoulder-point to target-point on x/y only - G=length we need to reach to the point on x/y - H=3-Dimensional length we need to reach - """ - - F = np.sqrt(x**2 + y**2 - self.l1**2) - G = F - self.l2 - H = np.sqrt(G**2 + z**2) - - theta1 = -np.atan2(y, x) - np.atan2(F, -self.l1) - - D = (H**2 - self.l3**2 - self.l4**2) / (2 * self.l3 * self.l4) - theta3 = np.acos(D) - - theta2 = np.atan2(z, G) - np.atan2( - self.l4 * np.sin(theta3), self.l3 + self.l4 * np.cos(theta3) - ) - - return (theta1, theta2, theta3) diff --git a/simulation/src/envs/__init__.py b/simulation/src/envs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/simulation/src/envs/quadruped_env.py b/simulation/src/envs/quadruped_env.py new file mode 100644 index 0000000..01fa1f5 --- /dev/null +++ b/simulation/src/envs/quadruped_env.py @@ -0,0 +1,174 @@ +import gymnasium as gym +import pybullet as p +import pybullet_data +import numpy as np +from enum import Enum + +from src.utils.gui import GUI + + +class TerrainType(Enum): + FLAT = "flat" + PLANAR_REFLECTION = "planar_reflection" + TERRAIN = "terrain" + MAZE = "maze" + + +class QuadrupedRobot: + def __init__(self, urdf_path, position=[0, 0, 0.3], orientation=[0, 0, 0], use_fixed_base=False): + q_orientation = p.getQuaternionFromEuler(orientation) + self.robot_id = p.loadURDF( + urdf_path, position, q_orientation, useFixedBase=use_fixed_base) + + def get_observation(self): + position, orientation = p.getBasePositionAndOrientation(self.robot_id) + orientation = p.getEulerFromQuaternion(orientation) + velocity, angular_velocity = p.getBaseVelocity(self.robot_id) + joint_states = p.getJointStates( + self.robot_id, range(p.getNumJoints(self.robot_id))) + joint_positions = [state[0] for state in joint_states] + joint_velocities = [state[1] for state in joint_states] + return np.concatenate( + [ + position, + orientation, + velocity, + angular_velocity, + joint_positions, + joint_velocities, + ] + ) + + def apply_action(self, action): + for i, position in enumerate(action): + p.setJointMotorControl2( + bodyIndex=self.robot_id, + jointIndex=i, + controlMode=p.POSITION_CONTROL, + targetPosition=position, + force=50, # 343 # / 100 for newtons - Fix mass + positionGain=0.5, + maxVelocity=13.09, + ) + + +class QuadrupedEnv(gym.Env): + def __init__(self, terrain_type: TerrainType = TerrainType.FLAT, render_mode: str = "human"): + super().__init__() + if render_mode == "human": + p.connect(p.GUI) + else: + p.connect(p.DIRECT) + + self.observation_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(48,)) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(18,)) + p.setAdditionalSearchPath(pybullet_data.getDataPath()) + + self.terrain_type = terrain_type + self.render_mode = render_mode + self.target_velocity = 0.5 + self.max_steps = float("inf") + self.current_step = 0 + + self._setup_world() + if render_mode == "human": + self.env_start_state = p.saveState() + + # env parameters + self._distance_limit = float("inf") + + def _setup_world(self): + self.robot = QuadrupedRobot("src/resources/spot.urdf") + self._load_terrain(self.terrain_type) + p.setGravity(0, 0, -9.8) + p.setTimeStep(1 / 240) + if self.render_mode == "human": + self.gui = GUI(self.robot.robot_id) + else: + self.gui = None + + def _load_terrain(self, terrain_type: TerrainType): + if terrain_type == TerrainType.FLAT: + self.terrain = p.loadURDF("plane.urdf") + elif terrain_type == TerrainType.PLANAR_REFLECTION: + p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1) + p.configureDebugVisualizer(p.COV_ENABLE_PLANAR_REFLECTION, 1) + p.configureDebugVisualizer(p.COV_ENABLE_TINY_RENDERER, 0) + self.terrain = p.loadURDF( + "plane_transparent.urdf", useMaximalCoordinates=True) + elif terrain_type == TerrainType.TERRAIN: + terrainShape = p.createCollisionShape( + shapeType=p.GEOM_HEIGHTFIELD, meshScale=[0.1, 0.1, 24], fileName="heightmaps/wm_height_out.png" + ) + textureId = p.loadTexture("heightmaps/gimp_overlay_out.png") + self.terrain = p.createMultiBody(0, terrainShape) + p.changeVisualShape(self.terrain, -1, textureUniqueId=textureId) + elif terrain_type == TerrainType.MAZE: + terrainShape = p.createCollisionShape( + shapeType=p.GEOM_HEIGHTFIELD, meshScale=[1, 1, 3], fileName="heightmaps/Maze.png" + ) + textureId = p.loadTexture("heightmaps/Maze.png") + maze = p.createMultiBody(0, terrainShape) + self.terrain = [p.loadURDF("plane.urdf"), maze] + p.changeVisualShape(self.terrain[1], -1, textureUniqueId=textureId) + + def reset(self, *, seed: int | None = None): + super().reset(seed=seed) + if self.render_mode == "human": + p.restoreState(self.env_start_state) + else: + p.resetSimulation() + self._setup_world() + self.current_step = 0 + return self.robot.get_observation(), {} + + def step(self, action): + self.current_step += 1 + if self.gui: + self.gui.update() + self.robot.apply_action(action) + p.stepSimulation() + + obs = self.robot.get_observation() + reward = self.calculate_reward(obs) + done = self.is_done(obs) + truncated = self.current_step >= self.max_steps + + return obs, reward, done, truncated, {} + + def close(self): + pass + # p.disconnect() + + def calculate_reward(self, obs): + position = obs[:3] + velocity = obs[6:9] + angular_velocity = obs[9:12] + + forward_velocity = velocity[0] + velocity_reward = -abs(forward_velocity - self.target_velocity) + + height_penalty = -abs(position[2] - 0.3) + + angular_penalty = -np.sum(np.square(angular_velocity)) + + total_reward = velocity_reward + 0.1 * height_penalty + 0.01 * angular_penalty + return total_reward + + def is_done(self, obs): + position = obs[:3] + orientation = obs[3:6] + return self._is_fallen(orientation) or self._is_distance_limit_exceeded(position) + + def _is_distance_limit_exceeded(self, position): + distance = np.hypot(position[0], position[1]) + return distance > self._distance_limit + + def _is_fallen(self, orientation): + # orientation = self.spot.GetBaseOrientation() + # rot_mat = self._pybullet_client.getMatrixFromQuaternion(orientation) + # local_up = rot_mat[6:] + # pos = self.spot.GetBasePosition() + # return (np.dot(np.asarray([0, 0, 1]), np.asarray(local_up)) < 0.55) + return abs(orientation[0]) > 0.85 or abs(orientation[1]) > 0.85 diff --git a/simulation/simulation/robot.py b/simulation/src/envs/robot.py similarity index 100% rename from simulation/simulation/robot.py rename to simulation/src/envs/robot.py diff --git a/simulation/resources/__init__.py b/simulation/src/resources/__init__.py similarity index 100% rename from simulation/resources/__init__.py rename to simulation/src/resources/__init__.py diff --git a/simulation/resources/spot.urdf b/simulation/src/resources/spot.urdf similarity index 100% rename from simulation/resources/spot.urdf rename to simulation/src/resources/spot.urdf diff --git a/simulation/resources/stl/backpart.stl b/simulation/src/resources/stl/backpart.stl similarity index 100% rename from simulation/resources/stl/backpart.stl rename to simulation/src/resources/stl/backpart.stl diff --git a/simulation/resources/stl/foot.stl b/simulation/src/resources/stl/foot.stl similarity index 100% rename from simulation/resources/stl/foot.stl rename to simulation/src/resources/stl/foot.stl diff --git a/simulation/resources/stl/frontpart.stl b/simulation/src/resources/stl/frontpart.stl similarity index 100% rename from simulation/resources/stl/frontpart.stl rename to simulation/src/resources/stl/frontpart.stl diff --git a/simulation/resources/stl/larm.stl b/simulation/src/resources/stl/larm.stl similarity index 100% rename from simulation/resources/stl/larm.stl rename to simulation/src/resources/stl/larm.stl diff --git a/simulation/resources/stl/larm_cover.stl b/simulation/src/resources/stl/larm_cover.stl similarity index 100% rename from simulation/resources/stl/larm_cover.stl rename to simulation/src/resources/stl/larm_cover.stl diff --git a/simulation/resources/stl/lfoot.stl b/simulation/src/resources/stl/lfoot.stl similarity index 100% rename from simulation/resources/stl/lfoot.stl rename to simulation/src/resources/stl/lfoot.stl diff --git a/simulation/resources/stl/lshoulder.stl b/simulation/src/resources/stl/lshoulder.stl similarity index 100% rename from simulation/resources/stl/lshoulder.stl rename to simulation/src/resources/stl/lshoulder.stl diff --git a/simulation/resources/stl/mainbody.stl b/simulation/src/resources/stl/mainbody.stl similarity index 100% rename from simulation/resources/stl/mainbody.stl rename to simulation/src/resources/stl/mainbody.stl diff --git a/simulation/resources/stl/rarm.stl b/simulation/src/resources/stl/rarm.stl similarity index 100% rename from simulation/resources/stl/rarm.stl rename to simulation/src/resources/stl/rarm.stl diff --git a/simulation/resources/stl/rarm_cover.stl b/simulation/src/resources/stl/rarm_cover.stl similarity index 100% rename from simulation/resources/stl/rarm_cover.stl rename to simulation/src/resources/stl/rarm_cover.stl diff --git a/simulation/resources/stl/rfoot.stl b/simulation/src/resources/stl/rfoot.stl similarity index 100% rename from simulation/resources/stl/rfoot.stl rename to simulation/src/resources/stl/rfoot.stl diff --git a/simulation/resources/stl/rplidar_main.STL b/simulation/src/resources/stl/rplidar_main.STL similarity index 100% rename from simulation/resources/stl/rplidar_main.STL rename to simulation/src/resources/stl/rplidar_main.STL diff --git a/simulation/resources/stl/rshoulder.stl b/simulation/src/resources/stl/rshoulder.stl similarity index 100% rename from simulation/resources/stl/rshoulder.stl rename to simulation/src/resources/stl/rshoulder.stl diff --git a/simulation/src/robot/__init__.py b/simulation/src/robot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/simulation/src/robot/gait.py b/simulation/src/robot/gait.py new file mode 100644 index 0000000..7599920 --- /dev/null +++ b/simulation/src/robot/gait.py @@ -0,0 +1,120 @@ +import math +import numpy as np +from typing import TypedDict +from enum import Enum + +from src.robot.kinematics import BodyStateT + + +class GaitType(Enum): + TROT_GATE = 0 + CRAWL_GATE = 1 + + +default_offset = { + GaitType.TROT_GATE: [0, 0.5, 0.5, 0], + GaitType.CRAWL_GATE: [0, 1 / 4, 2 / 4, 3 / 4], +} + +default_stand_frac = { + GaitType.TROT_GATE: 3 / 4, + GaitType.CRAWL_GATE: 3 / 4, +} + + +class GaitStateT(TypedDict): + step_height: float + step_x: float + step_z: float + step_angle: float + step_velocity: float + step_depth: float + stand_frac: float + offset: list[float] + gait_type: GaitType + + +length_multipliers = np.array( + [-1.4, -1.0, -1.5, -1.5, -1.5, 0.0, 0.0, 0.0, 1.5, 1.5, 1.4, 1.0]) +height_profile = np.array( + [0.0, 0.0, 0.9, 0.9, 0.9, 0.9, 0.9, 1.1, 1.1, 1.1, 0.0, 0.0]) + + +def sine_curve(length, angle, height, phase): + x, z = length * (1 - 2 * phase) * np.cos(angle), length * \ + (1 - 2 * phase) * np.sin(angle) + y = height * np.cos(np.pi * (x + z) / (2 * length)) if length else 0 + return np.array([x, z, y]) + + +def yaw_arc(feet, current): + return ( + np.pi / 2 + + np.arctan2(feet[1], feet[0]) + + np.arctan2(np.linalg.norm(current[:2] - feet[:2]), np.linalg.norm(feet[:2])) + ) + + +def get_control_points(length, angle, height): + x_polar, z_polar = np.cos(angle), np.sin(angle) + + x = length * length_multipliers * x_polar + z = length * length_multipliers * z_polar + y = height * height_profile + return np.stack([x, z, y], axis=1) + + +def bezier_curve(length, angle, height, phase): + ctrl = get_control_points(length, angle, height) + n = len(ctrl) - 1 + coeffs = np.array([math.comb(n, i) * (phase**i) * + ((1 - phase) ** (n - i)) for i in range(n + 1)]) + return np.sum(ctrl * coeffs[:, None], axis=0) + + +class GaitController: + def __init__(self, default_position: np.ndarray): + self.default_position = default_position + self.phase = 0.0 + + def step(self, gait: GaitStateT, body: BodyStateT, dt: float): + step_x, step_z, angle = gait["step_x"], gait["step_z"], gait["step_angle"] + if not any((step_x, step_z, angle)): + body["feet"] = body["feet"] + \ + (self.default_position - body["feet"]) * dt * 10 + self.phase = 0.0 + return + + self._advance_phase(dt, gait["step_velocity"]) + + stand_fraction = gait["stand_frac"] + depth = gait["step_depth"] + height = gait["step_height"] + offsets = gait["offset"] + + length = np.hypot(step_x, step_z) + if step_x < 0: + length = -length + turn_amplitude = np.arctan2(step_z, length) * 2 + + new_feet = np.zeros_like(self.default_position) + + for i, (default_foot, current_foot) in enumerate(zip(self.default_position, body["feet"])): + phase = (self.phase + offsets[i]) % 1 + ph_norm, curve_fn, amp = self._phase_params( + phase, stand_fraction, depth, height) + delta_pos = curve_fn(length / 2, turn_amplitude, amp, ph_norm) + delta_rot = curve_fn(np.rad2deg(angle), yaw_arc( + default_foot, current_foot), amp, ph_norm) + new_feet[i][:2] = default_foot[:2] + delta_pos + delta_rot + # new_feet[i][3] = 1 + + body["feet"] = new_feet + + def _advance_phase(self, dt: float, velocity: float): + self.phase = (self.phase + dt * velocity) % 1 + + def _phase_params(self, phase: float, stand_frac: float, depth: float, height: float): + if phase < stand_frac: + return phase / stand_frac, sine_curve, -depth + return (phase - stand_frac) / (1 - stand_frac), bezier_curve, height diff --git a/simulation/src/robot/kinematics.py b/simulation/src/robot/kinematics.py new file mode 100644 index 0000000..5e31003 --- /dev/null +++ b/simulation/src/robot/kinematics.py @@ -0,0 +1,130 @@ +import numpy as np +from typing import TypedDict, List + +import config + + +class BodyStateT(TypedDict): + omega: float + phi: float + psi: float + xm: float + ym: float + zm: float + feet: List[List[float]] + default_feet: List[List[float]] + px: float + py: float + pz: float + + +def rot_x(theta): + c = np.cos(theta) + s = np.sin(theta) + return np.array([[1, 0, 0, 0], [0, c, -s, 0], [0, s, c, 0], [0, 0, 0, 1]]) + + +def rot_y(theta): + c = np.cos(theta) + s = np.sin(theta) + return np.array([[c, 0, s, 0], [0, 1, 0, 0], [-s, 0, c, 0], [0, 0, 0, 1]]) + + +def rot_z(theta): + c = np.cos(theta) + s = np.sin(theta) + return np.array([[c, -s, 0, 0], [s, c, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + + +def rot(omega, phi, psi): + return rot_z(psi) @ rot_y(phi) @ rot_x(omega) + + +def translation(x, y, z): + return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]) + + +def transformation(omega, phi, psi, x, y, z): + return rot(omega, phi, psi) @ translation(x, y, z) + + +def get_transformation_matrix(body_state): + omega, phi, psi = body_state["omega"], body_state["phi"], body_state["psi"] + xm, ym, zm = body_state["xm"], body_state["ym"], body_state["zm"] + + return transformation(omega, phi, psi, xm, ym, zm) + + +class Kinematics: + def __init__(self, l1, l2, l3, l4, length, width): + self.l1 = float(l1) + self.l2 = float(l2) + self.l3 = float(l3) + self.l4 = float(l4) + self.length = float(length) + self.width = float(width) + self.deg2rad = np.pi / 180 + + self.mount_offsets = np.array([ + [self.length / 2, 0, self.width / 2], + [self.length / 2, 0, -self.width / 2], + [-self.length / 2, 0, self.width / 2], + [-self.length / 2, 0, -self.width / 2] + ]) + + self.inv_mount_rot = np.array([ + [0, 0, -1], + [0, 1, 0], + [1, 0, 0] + ]) + + def get_default_feet_pos(self): + feet = self.mount_offsets.copy() + feet[:, 1] = -1 + feet[:, 2] += np.array([self.l1, -self.l1, self.l1, -self.l1]) + return feet + + def inverse_kinematics(self, body_state): + roll, pitch, yaw = np.deg2rad(body_state["omega"]), np.deg2rad( + body_state["phi"]), np.deg2rad(body_state["psi"]) + xm, ym, zm = body_state["xm"], body_state["ym"], body_state["zm"] + + rot = self._rotation_matrix(roll, pitch, yaw) + inv_rot = rot.T + inv_tr = - \ + inv_rot @ np.array([xm, ym, zm]) + + angles = [] + for idx, foot_world in enumerate(body_state["feet"]): + foot_body = inv_rot @ foot_world + inv_tr + foot_local = self.inv_mount_rot @ (foot_body - + self.mount_offsets[idx]) + x_local = -foot_local[0] if idx % 2 else foot_local[0] + angles.extend(self._leg_ik(x_local, foot_local[1], foot_local[2])) + return angles + + def _leg_ik(self, x, y, z): + f = np.sqrt(max(0.0, x*x + y*y - self.l1*self.l1)) + g = f - self.l2 + h = np.sqrt(g*g + z*z) + + t1 = -np.arctan2(y, x) - np.arctan2(f, -self.l1) + + d = (h*h - self.l3*self.l3 - self.l4*self.l4) / (2*self.l3*self.l4) + d = max(-1.0, min(1.0, d)) + + t3 = np.arccos(d) + t2 = np.arctan2(z, g) - np.arctan2(self.l4*np.sin(t3), + self.l3 + self.l4*np.cos(t3)) + + return t1, t2, t3 + + def _rotation_matrix(self, roll, pitch, yaw): + cr, sr = np.cos(roll), np.sin(roll) + cp, sp = np.cos(pitch), np.sin(pitch) + cy, sy = np.cos(yaw), np.sin(yaw) + return np.array([ + [cp*cy, -cp*sy, sp], + [sr*sp*cy + cy*cr, -sr*sp*sy + cr*cy, -sr*cp], + [sr*sy - sp*cr*cy, sr*cy + sp*sy*cr, cr*cp] + ]) diff --git a/simulation/src/utils/__init__.py b/simulation/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/simulation/src/utils/gui.py b/simulation/src/utils/gui.py new file mode 100644 index 0000000..d753a9c --- /dev/null +++ b/simulation/src/utils/gui.py @@ -0,0 +1,120 @@ +import pybullet as p +import numpy as np + +from src.robot.kinematics import BodyStateT +from src.robot.gait import GaitStateT, GaitType, default_stand_frac, default_offset + + +class GUI: + def __init__(self, bot): + self.robot = bot + self.c_yaw = 10 + self.c_pitch = -17 + self.c_distance = 5 + + self.x_slider = p.addUserDebugParameter("x", -50, 50, 0) + self.y_slider = p.addUserDebugParameter("y", -50, 50, 0) + self.z_slider = p.addUserDebugParameter("z", -50, 50, 0) + self.yaw_slider = p.addUserDebugParameter( + "yaw", -np.pi / 4, np.pi / 4, 0) + self.pitch_slider = p.addUserDebugParameter( + "pitch", -np.pi / 4, np.pi / 4, 0) + self.roll_slider = p.addUserDebugParameter( + "roll", -np.pi / 4, np.pi / 4, 0) + + self.pivot_x_slider = p.addUserDebugParameter("pivot x", -50, 50, 0) + self.pivot_y_slider = p.addUserDebugParameter("pivot y", -50, 50, 0) + self.pivot_z_slider = p.addUserDebugParameter("pivot z", -50, 50, 0) + + self.step_x_slider = p.addUserDebugParameter("Step x", -50, 50, 0) + self.step_z_slider = p.addUserDebugParameter("Step z", -50, 50, 0) + self.angle_slider = p.addUserDebugParameter( + "Angle", -np.pi / 4, np.pi / 4, 0) + self.step_height_slider = p.addUserDebugParameter( + "Step height", 0, 50, 15) + self.step_depth_slider = p.addUserDebugParameter( + "Step depth", 0, 0.01, 0.002) + self.speed_slider = p.addUserDebugParameter("Speed", 0, 2, 1) + self.stand_frac_slider = p.addUserDebugParameter( + "Stand frac", 0, 1, 0.5) + + self.gait_type_slider = p.addUserDebugParameter( + "Gait Type", 0, len(GaitType) - 1, 0) + + # self.gait_type_slider = p.addUserDebugParameter("Gait Type", 0, len(GaitType) - 1, 0, paramType=p.GUI_ENUM, + # enumNames=[g.value for g in GaitType]) + self.last_gait_type = GaitType.TROT_GATE + + def update_gait_state(self, gait_state: GaitStateT): + gait_state["step_x"] = p.readUserDebugParameter(self.step_x_slider) + gait_state["step_z"] = p.readUserDebugParameter(self.step_z_slider) + gait_state["step_angle"] = p.readUserDebugParameter(self.angle_slider) + gait_state["step_height"] = p.readUserDebugParameter( + self.step_height_slider) + gait_state["step_depth"] = p.readUserDebugParameter( + self.step_depth_slider) + gait_state["step_velocity"] = p.readUserDebugParameter( + self.speed_slider) + gait_state["stand_frac"] = p.readUserDebugParameter( + self.stand_frac_slider) + gait_state["offset"] = default_offset[self.last_gait_type] + + def update_body_state(self, body_state: BodyStateT): + body_state["xm"] = p.readUserDebugParameter(self.x_slider) + body_state["ym"] = p.readUserDebugParameter(self.y_slider) + body_state["zm"] = p.readUserDebugParameter(self.z_slider) + body_state["omega"] = p.readUserDebugParameter(self.roll_slider) + body_state["phi"] = p.readUserDebugParameter(self.pitch_slider) + body_state["psi"] = p.readUserDebugParameter(self.yaw_slider) + body_state["px"] = p.readUserDebugParameter(self.pivot_x_slider) + body_state["py"] = p.readUserDebugParameter(self.pivot_y_slider) + body_state["pz"] = p.readUserDebugParameter(self.pivot_z_slider) + + def update(self): + gait_type = GaitType( + int(p.readUserDebugParameter(self.gait_type_slider))) + if gait_type != self.last_gait_type: + self.last_gait_type = gait_type + p.removeUserDebugItem(self.stand_frac_slider) + self.stand_frac_slider = p.addUserDebugParameter( + "Stand frac", 0, 1, default_stand_frac[gait_type]) + + quadruped_pos, _ = p.getBasePositionAndOrientation(self.robot) + p.resetDebugVisualizerCamera( + cameraDistance=self.c_distance, + cameraYaw=self.c_yaw, + cameraPitch=self.c_pitch, + cameraTargetPosition=quadruped_pos, + ) + + keys = p.getKeyboardEvents() + if keys.get(ord("j")): + self.c_yaw += 0.1 + if keys.get(ord("k")): + self.c_yaw -= 0.1 + if keys.get(ord("m")): + self.c_pitch += 0.1 + if keys.get(ord("i")): + self.c_pitch -= 0.1 + + if keys.get(ord("q")) or keys.get(27): + p.disconnect() + exit() + + self.position = np.array( + [ + p.readUserDebugParameter(self.x_slider), + p.readUserDebugParameter(self.y_slider), + p.readUserDebugParameter(self.z_slider), + ] + ) + + self.orientation = np.array( + [ + p.readUserDebugParameter(self.roll_slider), + p.readUserDebugParameter(self.pitch_slider), + p.readUserDebugParameter(self.yaw_slider), + ] + ) + + return self.position, self.orientation diff --git a/simulation/src/utils/xacro.py b/simulation/src/utils/xacro.py new file mode 100644 index 0000000..688d66e --- /dev/null +++ b/simulation/src/utils/xacro.py @@ -0,0 +1,638 @@ +#! /usr/bin/env python +# Copyright (c) 2008, Willow Garage, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the Willow Garage, Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +# Author: Stuart Glaser +# Modified by Saul Reynolds-Haertle Oct 14 2012 to remove ROS dependencies. + + +import os.path +import sys +import os +import getopt +import subprocess +from xml.dom.minidom import parse, parseString +import xml.dom +import re +import string + + +class XacroException(Exception): + pass + + +def isnumber(x): + return hasattr(x, '__int__') + +# Better pretty printing of xml +# Taken from http://ronrothman.com/public/leftbraned/xml-dom-minidom-toprettyxml-and-silly-whitespace/ + + +def fixed_writexml(self, writer, indent="", addindent="", newl=""): + # indent = current indentation + # addindent = indentation to add to higher levels + # newl = newline string + writer.write(indent+"<" + self.tagName) + + attrs = self._get_attributes() + a_names = attrs.keys() + # a_names.sort() + sorted(a_names) + + for a_name in a_names: + writer.write(" %s=\"" % a_name) + xml.dom.minidom._write_data(writer, attrs[a_name].value) + writer.write("\"") + if self.childNodes: + if len(self.childNodes) == 1 \ + and self.childNodes[0].nodeType == xml.dom.minidom.Node.TEXT_NODE: + writer.write(">") + self.childNodes[0].writexml(writer, "", "", "") + writer.write("%s" % (self.tagName, newl)) + return + writer.write(">%s" % (newl)) + for node in self.childNodes: + if node.nodeType is not xml.dom.minidom.Node.TEXT_NODE: # 3: + node.writexml(writer, indent+addindent, addindent, newl) + # node.writexml(writer,indent+addindent,addindent,newl) + writer.write("%s%s" % (indent, self.tagName, newl)) + else: + writer.write("/>%s" % (newl)) + + +# replace minidom's function with ours +xml.dom.minidom.Element.writexml = fixed_writexml + + +class Table: + def __init__(self, parent=None): + self.parent = parent + self.table = {} + + def __getitem__(self, key): + if key in self.table: + return self.table[key] + elif self.parent: + return self.parent[key] + else: + raise KeyError(key) + + def __setitem__(self, key, value): + self.table[key] = value + + def __contains__(self, key): + return \ + key in self.table or \ + (self.parent and key in self.parent) + + +class QuickLexer(object): + def __init__(self, **res): + self.str = "" + self.top = None + self.res = [] + for k, v in res.items(): + self.__setattr__(k, len(self.res)) + self.res.append(v) + + def lex(self, str): + self.str = str + self.top = None + self.next() + + def peek(self): + return self.top + + def next(self): + result = self.top + self.top = None + for i in range(len(self.res)): + m = re.match(self.res[i], self.str) + if m: + self.top = (i, m.group(0)) + self.str = self.str[m.end():] + break + return result + + +def first_child_element(elt): + c = elt.firstChild + while c: + if c.nodeType == xml.dom.Node.ELEMENT_NODE: + return c + c = c.nextSibling + return None + + +def next_sibling_element(elt): + c = elt.nextSibling + while c: + if c.nodeType == xml.dom.Node.ELEMENT_NODE: + return c + c = c.nextSibling + return None + +# Pre-order traversal of the elements + + +def next_element(elt): + child = first_child_element(elt) + if child: + return child + while elt and elt.nodeType == xml.dom.Node.ELEMENT_NODE: + next = next_sibling_element(elt) + if next: + return next + elt = elt.parentNode + return None + +# Pre-order traversal of all the nodes + + +def next_node(node): + if node.firstChild: + return node.firstChild + while node: + if node.nextSibling: + return node.nextSibling + node = node.parentNode + return None + + +def child_elements(elt): + c = elt.firstChild + while c: + if c.nodeType == xml.dom.Node.ELEMENT_NODE: + yield c + c = c.nextSibling + + +all_includes = [] +# @throws XacroException if a parsing error occurs with an included document + + +def process_includes(doc, base_dir): + namespaces = {} + previous = doc.documentElement + elt = next_element(previous) + while elt: + if elt.tagName == 'include' or elt.tagName == 'xacro:include': + # print("elt.getAttribute('filename'):", elt.getAttribute('filename')) + filename = eval_text(elt.getAttribute('filename'), {}) + # print("filename:",filename) + if not os.path.isabs(filename): + filename = os.path.join(base_dir, filename) + f = None + try: + try: + f = open(filename) + except IOError as e: + print(elt) + raise XacroException( + "included file \"%s\" could not be opened: %s" % (filename, str(e))) + try: + global all_includes + all_includes.append(filename) + included = parse(f) + except Exception as e: + raise XacroException( + "included file [%s] generated an error during XML parsing: %s" % (filename, str(e))) + finally: + if f: + f.close() + + # Replaces the include tag with the elements of the included file + for c in child_elements(included.documentElement): + elt.parentNode.insertBefore(c.cloneNode(1), elt) + elt.parentNode.removeChild(elt) + elt = None + + # Grabs all the declared namespaces of the included document + for name, value in included.documentElement.attributes.items(): + if name.startswith('xmlns:'): + namespaces[name] = value + else: + previous = elt + + elt = next_element(previous) + + # Makes sure the final document declares all the namespaces of the included documents. + for k, v in namespaces.items(): + doc.documentElement.setAttribute(k, v) + +# Returns a dictionary: { macro_name => macro_xml_block } + + +def grab_macros(doc): + macros = {} + + previous = doc.documentElement + elt = next_element(previous) + while elt: + if elt.tagName == 'macro' or elt.tagName == 'xacro:macro': + name = elt.getAttribute('name') + + macros[name] = elt + macros['xacro:' + name] = elt + + elt.parentNode.removeChild(elt) + elt = None + else: + previous = elt + + elt = next_element(previous) + return macros + +# Returns a Table of the properties + + +def grab_properties(doc): + table = Table() + + previous = doc.documentElement + elt = next_element(previous) + while elt: + if elt.tagName == 'property' or elt.tagName == 'xacro:property': + name = elt.getAttribute('name') + value = None + + if elt.hasAttribute('value'): + value = elt.getAttribute('value') + else: + name = '**' + name + value = elt # debug + + bad = string.whitespace + "${}" + has_bad = False + for b in bad: + if b in name: + has_bad = True + break + + if has_bad: + sys.stderr.write('Property names may not have whitespace, ' + + '"{", "}", or "$" : "' + name + '"') + else: + table[name] = value + + elt.parentNode.removeChild(elt) + elt = None + else: + previous = elt + + elt = next_element(previous) + return table + + +def eat_ignore(lex): + while lex.peek() and lex.peek()[0] == lex.IGNORE: + lex.next() + + +def eval_lit(lex, symbols): + eat_ignore(lex) + if lex.peek()[0] == lex.NUMBER: + return float(lex.next()[1]) + if lex.peek()[0] == lex.SYMBOL: + try: + value = symbols[lex.next()[1]] + except KeyError as ex: + # sys.stderr.write("Could not find symbol: %s\n" % str(ex)) + raise XacroException("Property wasn't defined: %s" % str(ex)) + if not (isnumber(value) or isinstance(value, (str, str))): + print([value], isinstance(value, str), type(value)) + raise XacroException("WTF2") + try: + return int(value) + except: + try: + return float(value) + except: + return value + raise XacroException("Bad literal") + + +def eval_factor(lex, symbols): + eat_ignore(lex) + + neg = 1 + if lex.peek()[1] == '-': + lex.next() + neg = -1 + + if lex.peek()[0] in [lex.NUMBER, lex.SYMBOL]: + return neg * eval_lit(lex, symbols) + if lex.peek()[0] == lex.LPAREN: + lex.next() + eat_ignore(lex) + result = eval_expr(lex, symbols) + eat_ignore(lex) + if lex.next()[0] != lex.RPAREN: + raise XacroException("Unmatched left paren") + eat_ignore(lex) + return neg * result + + raise XacroException("Misplaced operator") + + +def eval_term(lex, symbols): + eat_ignore(lex) + + result = 0 + if lex.peek()[0] in [lex.NUMBER, lex.SYMBOL, lex.LPAREN] \ + or lex.peek()[1] == '-': + result = eval_factor(lex, symbols) + + eat_ignore(lex) + while lex.peek() and lex.peek()[1] in ['*', '/']: + op = lex.next()[1] + n = eval_factor(lex, symbols) + + if op == '*': + result = float(result) * float(n) + elif op == '/': + result = float(result) / float(n) + else: + raise XacroException("WTF") + eat_ignore(lex) + return result + + +def eval_expr(lex, symbols): + eat_ignore(lex) + + op = None + if lex.peek()[0] == lex.OP: + op = lex.next()[1] + if not op in ['+', '-']: + raise XacroException("Invalid operation. Must be '+' or '-'") + + result = eval_term(lex, symbols) + if op == '-': + result = -float(result) + + eat_ignore(lex) + while lex.peek() and lex.peek()[1] in ['+', '-']: + op = lex.next()[1] + n = eval_term(lex, symbols) + + if op == '+': + result = float(result) + float(n) + if op == '-': + result = float(result) - float(n) + eat_ignore(lex) + return result + + +def eval_extension(s): + # if s == '$(cwd)': + return os.getcwd() + # try: + # return substitution_args.resolve_args(s, context=substitution_args_context, resolve_anon=False) + # except substitution_args.ArgException as e: + # raise XacroException("Undefined substitution argument", exc=e) + # except ResourceNotFound as e: + # raise XacroException("resource not found:", exc=e) + + +def eval_text(text, symbols): + def handle_expr(s): + lex = QuickLexer(IGNORE=r"\s+", + NUMBER=r"(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?", + SYMBOL=r"[a-zA-Z_]\w*", + OP=r"[\+\-\*/^]", + LPAREN=r"\(", + RPAREN=r"\)") + lex.lex(s) + return eval_expr(lex, symbols) + + def handle_extension(s): + # print("handle_extension", s) + return eval_extension("$(%s)" % s) + + results = [] + lex = QuickLexer(DOLLAR_DOLLAR_BRACE=r"\$\$+\{", + EXPR=r"\$\{[^\}]*\}", + EXTENSION=r"\$\([^\)]*\)", + TEXT=r"([^\$]|\$[^{(]|\$$)+") + lex.lex(text) + while lex.peek(): + if lex.peek()[0] == lex.EXPR: + results.append(handle_expr(lex.next()[1][2:-1])) + # print("1:", results) + elif lex.peek()[0] == lex.EXTENSION: + results.append(handle_extension(lex.next()[1][2:-1])) + # print("2:",results) + elif lex.peek()[0] == lex.TEXT: + results.append(lex.next()[1]) + # print("3:",results) + elif lex.peek()[0] == lex.DOLLAR_DOLLAR_BRACE: + results.append(lex.next()[1][1:]) + # print("4:",results) + # print(results) + return ''.join(map(str, results)) + +# Expands macros, replaces properties, and evaluates expressions + + +def eval_all(root, macros, symbols): + # Evaluates the attributes for the root node + for at in root.attributes.items(): + result = eval_text(at[1], symbols) + root.setAttribute(at[0], result) + + previous = root + node = next_node(previous) + while node: + if node.nodeType == xml.dom.Node.ELEMENT_NODE: + if node.tagName in macros: + body = macros[node.tagName].cloneNode(deep=True) + params = body.getAttribute('params').split() + + # Expands the macro + scoped = Table(symbols) + for name, value in node.attributes.items(): + if not name in params: + raise XacroException("Invalid parameter \"%s\" while expanding macro \"%s\"" % + (str(name), str(node.tagName))) + params.remove(name) + scoped[name] = eval_text(value, symbols) + + # Pulls out the block arguments, in order + cloned = node.cloneNode(deep=True) + eval_all(cloned, macros, symbols) + block = cloned.firstChild + for param in params[:]: + if param[0] == '*': + while block and block.nodeType != xml.dom.Node.ELEMENT_NODE: + block = block.nextSibling + if not block: + raise XacroException( + "Not enough blocks while evaluating macro %s" % str(node.tagName)) + params.remove(param) + scoped[param] = block + block = block.nextSibling + + if params: + raise XacroException("Some parameters were not set for macro %s" % + str(node.tagName)) + eval_all(body, macros, scoped) + + # Replaces the macro node with the expansion + for e in list(child_elements(body)): # Ew + node.parentNode.insertBefore(e, node) + node.parentNode.removeChild(node) + + node = None + elif node.tagName == 'insert_block' or node.tagName == 'xacro:insert_block': + name = node.getAttribute('name') + + if ("**" + name) in symbols: + # Multi-block + block = symbols['**' + name] + + for e in list(child_elements(block)): + node.parentNode.insertBefore( + e.cloneNode(deep=True), node) + node.parentNode.removeChild(node) + elif ("*" + name) in symbols: + # Single block + block = symbols['*' + name] + + node.parentNode.insertBefore( + block.cloneNode(deep=True), node) + node.parentNode.removeChild(node) + else: + raise XacroException( + "Block \"%s\" was never declared" % name) + + node = None + else: + # Evals the attributes + for at in node.attributes.items(): + result = eval_text(at[1], symbols) + node.setAttribute(at[0], result) + previous = node + elif node.nodeType == xml.dom.Node.TEXT_NODE: + node.data = eval_text(node.data, symbols) + previous = node + else: + previous = node + + node = next_node(previous) + return macros + +# Expands everything except includes + + +def eval_self_contained(doc): + macros = grab_macros(doc) + symbols = grab_properties(doc) + eval_all(doc.documentElement, macros, symbols) + + +def print_usage(exit_code=0): + print("Usage: %s [-o ] " % 'xacro.py') + print(" %s --deps Prints dependencies" % 'xacro.py') + print(" %s --includes Only evalutes includes" % 'xacro.py') + sys.exit(exit_code) + + +def main(): + # print("dir:",os.path.dirname(sys.argv[2])) + # sys.exit(0) + try: + opts, args = getopt.gnu_getopt( + sys.argv[1:], "ho:", ['deps', 'includes']) + except getopt.GetoptError as err: + print(str(err)) + print_usage(2) + + just_deps = False + just_includes = False + + output = sys.stdout + for o, a in opts: + if o == '-h': + print_usage(0) + elif o == '-o': + output = open(a, 'w') + elif o == '--deps': + just_deps = True + elif o == '--includes': + just_includes = True + + if len(args) < 1: + print("No input given") + print_usage(2) + + f = open(args[0]) + # print(args[0]) + # sys.exit(0) + doc = None + try: + doc = parse(f) + except xml.parsers.expat.ExpatError: + sys.stderr.write("Expat parsing error. Check that:\n") + sys.stderr.write(" - Your XML is correctly formed\n") + sys.stderr.write(" - You have the xacro xmlns declaration: " + + "xmlns:xacro=\"http://www.ros.org/wiki/xacro\"\n") + sys.stderr.write("\n") + raise + finally: + f.close() + + # print(type(doc)) + # print(opts, args) + # print("dir:",os.path.dirname(sys.argv[2])) + # sys.exit(0) + process_includes(doc, os.path.dirname(sys.argv[2])) + if just_deps: + for inc in all_includes: + sys.stdout.write(inc + " ") + sys.stdout.write("\n") + elif just_includes: + doc.writexml(output) + print() + else: + eval_self_contained(doc) + banner = [xml.dom.minidom.Comment(c) for c in + [" %s " % ('='*83), + " | This document was autogenerated by xacro from %-30s | " % args[0], + " | EDITING THIS FILE BY HAND IS NOT RECOMMENDED %-30s | " % "", + " %s " % ('='*83)]] + first = doc.firstChild + for comment in banner: + doc.insertBefore(comment, first) + + output.write(doc.toprettyxml(indent=' ')) + # doc.writexml(output, newl = "\n") + print() + + +if __name__ == "__main__": + main() diff --git a/simulation/test.py b/simulation/test.py index 1c39574..89dfe25 100644 --- a/simulation/test.py +++ b/simulation/test.py @@ -1,4 +1,4 @@ -from simulation.environment import QuadrupedEnv +from src.envs.environment import QuadrupedEnv from training.model import SimpleNN import resources as resources diff --git a/simulation/train.py b/simulation/train.py index dbfc985..83f7cb6 100644 --- a/simulation/train.py +++ b/simulation/train.py @@ -1,10 +1,11 @@ -from simulation.environment import QuadrupedEnv +from src.envs.environment import QuadrupedEnv from training.trainer import Trainer import resources as resources render = True + def main(): env = QuadrupedEnv(resources.getDataPath() + "/spot.urdf") trainer = Trainer(env, render)