Commit 3305a09f authored by szymon's avatar szymon
Browse files

Bullet constraint solver as an option, not default due to memory problems

parent 2528b6ee
......@@ -11,10 +11,10 @@ typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Ma
class HexapodClosedLoopEnv : public HexapodEnv {
public:
explicit HexapodClosedLoopEnv(double reset_noise_scale, bool init_reset, bool observe_velocities = false,
explicit HexapodClosedLoopEnv(double reset_noise_scale, bool init_reset, bool use_bullet = false, bool observe_velocities = false,
float step_duration = 0.015,
float simulation_duration = 5, float min_action_value = -1, float max_action_value = 1) :
HexapodEnv(false, step_duration, simulation_duration, min_action_value, max_action_value),
HexapodEnv(false, use_bullet, step_duration, simulation_duration, min_action_value, max_action_value),
_reset_noise_scale{reset_noise_scale},
observe_velocities{observe_velocities},
observation_space_size{observe_velocities ? 36 : 18}{
......
......@@ -38,7 +38,7 @@ void load_and_init_robot2() {
class HexapodEnv : public Env
{
public:
explicit HexapodEnv(bool init_reset, float step_duration = 0.015, float simulation_duration = 5, float min_action_value = -1, float max_action_value = 1):
explicit HexapodEnv(bool init_reset, bool use_bullet = false, float step_duration = 0.015, float simulation_duration = 5, float min_action_value = -1, float max_action_value = 1):
Env(),
step_duration{step_duration},
simulation_duration{simulation_duration},
......@@ -55,8 +55,10 @@ public:
simulation->set_graphics(std::make_shared<robot_dart::graphics::Graphics>(simulation->world()));
std::static_pointer_cast<robot_dart::graphics::Graphics>(simulation->graphics())->look_at({0.5, 3., 0.75}, {0.5, 0., 0.2});
#endif
simulation->world()->getConstraintSolver()->setCollisionDetector(
dart::collision::BulletCollisionDetector::create());
if(use_bullet) {
simulation->world()->getConstraintSolver()->setCollisionDetector(
dart::collision::BulletCollisionDetector::create());
}
simulation->add_floor(20.);
......
......@@ -104,6 +104,8 @@ int main(int argc, char **argv)
args::Flag verbose(parser,"verbose", "output additional logs to the console",{'v',"verbose"});
args::Flag resume(parser,"resume", "flag signalling resuming",{'r',"resume"});
args::Flag use_bullet(parser,"use_bullet", "Replace default constraint solver with Bullet",{"bullet","use_bullet","bullet_solver"});
args::ValueFlag<double> duration(parser, "duration", "The total duration of played animation [seconds]", {"duration","du"},5.);
args::ValueFlag<int> threads(parser, "num threads", "Number of threads used in training", {'j',"jobs","threads","n_threads","num_threads","nt"},1);
......@@ -170,18 +172,18 @@ int main(int argc, char **argv)
for (int i =0; i<threads.Get(); ++i){
//TODO: environment selection should be recoverable from serialization as well
if(closed_loop){
envs.push_back(std::make_shared<HexapodClosedLoopEnv>(reset_noise_scale.Get(),!multi_env));
envs.push_back(std::make_shared<HexapodClosedLoopEnv>(reset_noise_scale.Get(),!multi_env, use_bullet));
} else {
envs.push_back(std::make_shared<HexapodEnv>(!multi_env));
envs.push_back(std::make_shared<HexapodEnv>(!multi_env, use_bullet));
}
}
wrapped_env = std::make_unique<VecEnv>(envs);
} else {
//TODO: environment selection should be recoverable from serialization as well
if(closed_loop){
wrapped_env = std::make_unique<HexapodClosedLoopEnv>(reset_noise_scale.Get(),!multi_env);
wrapped_env = std::make_unique<HexapodClosedLoopEnv>(reset_noise_scale.Get(),!multi_env, use_bullet);
} else {
wrapped_env = std::make_unique<HexapodEnv>(!multi_env);
wrapped_env = std::make_unique<HexapodEnv>(!multi_env, use_bullet);
}
}
......
......@@ -48,6 +48,7 @@ void simulate_steps(const int steps, const int num_threads){
}
}
//TODO: fix seed
//this test is non-deterministic
//if it fails, there is a problem somewhere
//if it succeeds it doesn't give a guarantee of correctness
......
......@@ -36,11 +36,9 @@ def build(bld):
bld.env.INCLUDES_PROTOBUF = '/workspace/include/'
bld.program(features = 'cxx',
#source = 'cpp/tf_exp.cpp',
source = 'ppo2.cpp',
includes = './cpp . ../../',
uselib = 'ROBOTDART ABSL TBB BOOST EIGEN PTHREAD MPI DART DART_GRAPHIC PROTOBUF TF',
#use = 'sferes2',
defines = [], # defines = ['GRAPHIC'],
target = 'ppo_cpp')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment