diff --git a/simulation/simulation/environment.py b/simulation/simulation/environment.py index d5805f9..7927558 100644 --- a/simulation/simulation/environment.py +++ b/simulation/simulation/environment.py @@ -6,23 +6,29 @@ from simulation.robot import QuadrupedRobot roll_pitch_reward_weight = 0.1 + class QuadrupedEnv: def __init__(self, urdf_path): p.connect(p.GUI) p.setAdditionalSearchPath(pybullet_data.getDataPath()) p.setGravity(0, 0, -9.8) p.setTimeStep(1 / 240) + self.urdf_path = urdf_path - self.plane_id = p.loadURDF("plane.urdf") + self.setupWorld() - self.robot = QuadrupedRobot(urdf_path) - self.reset() + self.envStartState = p.saveState() - def reset(self): + def setupWorld(self): p.resetSimulation() p.setGravity(0, 0, -9.8) + self.plane_id = p.loadURDF("plane.urdf") - self.robot.load() + + self.robot = QuadrupedRobot(self.urdf_path) + + def reset(self): + p.restoreState(self.envStartState) return self.robot.get_observation() def step(self, action): diff --git a/simulation/simulation/robot.py b/simulation/simulation/robot.py index 8dee7a2..0be046a 100644 --- a/simulation/simulation/robot.py +++ b/simulation/simulation/robot.py @@ -5,7 +5,7 @@ import numpy as np class QuadrupedRobot: def __init__(self, urdf_path): self.urdf_path = urdf_path - self.robot_id = None + self.load() def load(self): position = [0, 0, 0.3]