🤖 Adds plane
This commit is contained in:
@@ -4,6 +4,8 @@ import numpy as np
|
|||||||
from simulation.robot import QuadrupedRobot
|
from simulation.robot import QuadrupedRobot
|
||||||
|
|
||||||
|
|
||||||
|
roll_pitch_reward_weight = 0.1
|
||||||
|
|
||||||
class QuadrupedEnv:
|
class QuadrupedEnv:
|
||||||
def __init__(self, urdf_path):
|
def __init__(self, urdf_path):
|
||||||
p.connect(p.GUI)
|
p.connect(p.GUI)
|
||||||
@@ -11,11 +13,15 @@ class QuadrupedEnv:
|
|||||||
p.setGravity(0, 0, -9.8)
|
p.setGravity(0, 0, -9.8)
|
||||||
p.setTimeStep(1 / 240)
|
p.setTimeStep(1 / 240)
|
||||||
|
|
||||||
|
self.plane_id = p.loadURDF("plane.urdf")
|
||||||
|
|
||||||
self.robot = QuadrupedRobot(urdf_path)
|
self.robot = QuadrupedRobot(urdf_path)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
p.resetSimulation()
|
p.resetSimulation()
|
||||||
|
p.setGravity(0, 0, -9.8)
|
||||||
|
self.plane_id = p.loadURDF("plane.urdf")
|
||||||
self.robot.load()
|
self.robot.load()
|
||||||
return self.robot.get_observation()
|
return self.robot.get_observation()
|
||||||
|
|
||||||
@@ -28,8 +34,11 @@ class QuadrupedEnv:
|
|||||||
return obs, reward, done
|
return obs, reward, done
|
||||||
|
|
||||||
def calculate_reward(self, observation):
|
def calculate_reward(self, observation):
|
||||||
# Define your reward function here
|
reward = 0
|
||||||
return 0
|
reward += (
|
||||||
|
-(abs(observation[0]) + abs(observation[1])) * roll_pitch_reward_weight
|
||||||
|
)
|
||||||
|
return reward
|
||||||
|
|
||||||
def is_done(self, observation):
|
def is_done(self, observation):
|
||||||
roll, pitch = observation[0:2]
|
roll, pitch = observation[0:2]
|
||||||
|
|||||||
+2
-1
@@ -3,10 +3,11 @@ from training.trainer import Trainer
|
|||||||
|
|
||||||
import resources as resources
|
import resources as resources
|
||||||
|
|
||||||
|
render = True
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = QuadrupedEnv(resources.getDataPath() + "/spot.urdf")
|
env = QuadrupedEnv(resources.getDataPath() + "/spot.urdf")
|
||||||
trainer = Trainer(env)
|
trainer = Trainer(env, render)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from time import sleep
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import pybullet as p
|
import pybullet as p
|
||||||
@@ -11,8 +12,10 @@ Experience = namedtuple("Experience", ["observation", "action", "reward", "log_p
|
|||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(self, env):
|
|
||||||
|
def __init__(self, env, render):
|
||||||
self.env = env
|
self.env = env
|
||||||
|
self.should_render = render
|
||||||
self.model = SimpleNN(
|
self.model = SimpleNN(
|
||||||
input_size=env.robot.get_observation().shape[0],
|
input_size=env.robot.get_observation().shape[0],
|
||||||
output_size=p.getNumJoints(env.robot.robot_id),
|
output_size=p.getNumJoints(env.robot.robot_id),
|
||||||
@@ -30,6 +33,9 @@ class Trainer:
|
|||||||
observation, reward, done = self.env.step(action)
|
observation, reward, done = self.env.step(action)
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
|
|
||||||
|
if self.should_render:
|
||||||
|
sleep(0.005)
|
||||||
|
|
||||||
# Train the neural network
|
# Train the neural network
|
||||||
# loss = self.compute_loss(observation, action, reward)
|
# loss = self.compute_loss(observation, action, reward)
|
||||||
# self.optimizer.zero_grad()
|
# self.optimizer.zero_grad()
|
||||||
@@ -42,9 +48,7 @@ class Trainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
observation_tensor = torch.tensor(observation, dtype=torch.float32)
|
observation_tensor = torch.tensor(observation, dtype=torch.float32)
|
||||||
action = self.model(observation_tensor)
|
action = self.model(observation_tensor)
|
||||||
return np.array(
|
return action.numpy()
|
||||||
[-0.4, -1.5, 6, 0.4, -1.5, 6, -0.4, -1.5, 6, 0.4, -1.5, 6]
|
|
||||||
) # action.numpy()
|
|
||||||
|
|
||||||
def compute_loss(self, observation, action, reward):
|
def compute_loss(self, observation, action, reward):
|
||||||
# Define your loss function here
|
# Define your loss function here
|
||||||
|
|||||||
Reference in New Issue
Block a user