toddlerbot.locomotion package¶
Submodules¶
toddlerbot.locomotion.cartwheel_env module¶
Cartwheel locomotion environment for ToddlerBot.
This module provides the CartwheelEnv class for training ToddlerBot in cartwheel movements. The environment extends MJXEnv with cartwheel-specific motion references and command sampling.
toddlerbot.locomotion.crawl_env module¶
Crawling locomotion environment for ToddlerBot.
This module provides the CrawlEnv class for training ToddlerBot in crawling movements. The environment extends MJXEnv with crawl-specific motion references and command sampling.
toddlerbot.locomotion.mjx_config module¶
Configuration classes for MJX environments.
This module defines configuration dataclasses for MJX-based locomotion environments, including simulation parameters, terrain settings, observation configurations, rewards, actions, domain randomization, and noise settings.
- class toddlerbot.locomotion.mjx_config.MJXConfig¶
Bases:
object
Configuration class for the MJX environment.
- class ActionConfig(action_parts: List[str] = <factory>, action_scale: float = 0.25, contact_force_threshold: float = 1.0, n_steps_delay: int = 1, n_frames: int = 4, cycle_time: float = 0.72)¶
Bases:
object
- action_parts: List[str]¶
- action_scale: float = 0.25¶
- contact_force_threshold: float = 1.0¶
- cycle_time: float = 0.72¶
- n_frames: int = 4¶
- n_steps_delay: int = 1¶
- class CommandsConfig(resample_time: float = 3.0, zero_chance: float = 0.2, turn_chance: float = 0.2, command_obs_indices: List[int] = <factory>, command_range: List[List[float]] = <factory>, deadzone: List[float] = <factory>)¶
Bases:
object
- command_obs_indices: List[int]¶
- command_range: List[List[float]]¶
- deadzone: List[float]¶
- resample_time: float = 3.0¶
- turn_chance: float = 0.2¶
- zero_chance: float = 0.2¶
- class DomainRandConfig(add_domain_rand: bool = True, rand_init_state_indices: List[int] = <factory>, add_push: bool = False, add_head_pose: bool = False, backlash_activation: float = 0.1, backlash_range: List[float] = <factory>, torso_roll_range: List[float] = <factory>, torso_pitch_range: List[float] = <factory>, arm_joint_pos_range: List[float] = <factory>, friction_range: List[float] = <factory>, damping_range: List[float] = <factory>, armature_range: List[float] = <factory>, frictionloss_range: List[float] = <factory>, body_mass_range: List[float] = <factory>, hand_mass_range: List[float] = <factory>, other_mass_range: List[float] = <factory>, kp_range: List[float] = <factory>, kd_range: List[float] = <factory>, tau_max_range: List[float] = <factory>, q_dot_tau_max_range: List[float] = <factory>, q_dot_max_range: List[float] = <factory>, kd_min_range: List[float] = <factory>, tau_brake_max_range: List[float] = <factory>, tau_q_dot_max_range: List[float] = <factory>, passive_active_ratio_range: List[float] = <factory>, push_interval_s: float = 2.0, push_duration_s: float = 0.2, push_torso_range: List[float] = <factory>, push_other_range: List[float] = <factory>)¶
Bases:
object
- add_domain_rand: bool = True¶
- add_head_pose: bool = False¶
- add_push: bool = False¶
- arm_joint_pos_range: List[float]¶
- armature_range: List[float]¶
- backlash_activation: float = 0.1¶
- backlash_range: List[float]¶
- body_mass_range: List[float]¶
- damping_range: List[float]¶
- friction_range: List[float]¶
- frictionloss_range: List[float]¶
- hand_mass_range: List[float]¶
- kd_min_range: List[float]¶
- kd_range: List[float]¶
- kp_range: List[float]¶
- other_mass_range: List[float]¶
- passive_active_ratio_range: List[float]¶
- push_duration_s: float = 0.2¶
- push_interval_s: float = 2.0¶
- push_other_range: List[float]¶
- push_torso_range: List[float]¶
- q_dot_max_range: List[float]¶
- q_dot_tau_max_range: List[float]¶
- rand_init_state_indices: List[int]¶
- tau_brake_max_range: List[float]¶
- tau_max_range: List[float]¶
- tau_q_dot_max_range: List[float]¶
- torso_pitch_range: List[float]¶
- torso_roll_range: List[float]¶
- class NoiseConfig(level: float = 0.05, dof_pos: float = 1.0, dof_vel: float = 2.0, gyro_fc: float = 0.35, gyro_std: float = 0.25, gyro_bias_walk_std: float = 0.0002, gyro_white_std: float = 0.0, quat_fc: float = 0.25, quat_std: float = 0.1, quat_bias_walk_std: float = 0.0001, quat_white_std: float = 0.0, gyro_amp_min: float = 0.8, gyro_amp_max: float = 2.0, quat_amp_min: float = 0.8, quat_amp_max: float = 1.2)¶
Bases:
object
- dof_pos: float = 1.0¶
- dof_vel: float = 2.0¶
- gyro_amp_max: float = 2.0¶
- gyro_amp_min: float = 0.8¶
- gyro_bias_walk_std: float = 0.0002¶
- gyro_fc: float = 0.35¶
- gyro_std: float = 0.25¶
- gyro_white_std: float = 0.0¶
- level: float = 0.05¶
- quat_amp_max: float = 1.2¶
- quat_amp_min: float = 0.8¶
- quat_bias_walk_std: float = 0.0001¶
- quat_fc: float = 0.25¶
- quat_std: float = 0.1¶
- quat_white_std: float = 0.0¶
- class ObsConfig(frame_stack: int = 15, c_frame_stack: int = 15, num_single_obs: int = 84, num_single_privileged_obs: int = 151)¶
Bases:
object
- c_frame_stack: int = 15¶
- frame_stack: int = 15¶
- num_single_obs: int = 84¶
- num_single_privileged_obs: int = 151¶
- class ObsScales(lin_vel: float = 2.0, ang_vel: float = 1.0, dof_pos: float = 1.0, dof_vel: float = 0.05, quat: float = 1.0, actuator_force: float = 0.1)¶
Bases:
object
- actuator_force: float = 0.1¶
- ang_vel: float = 1.0¶
- dof_pos: float = 1.0¶
- dof_vel: float = 0.05¶
- lin_vel: float = 2.0¶
- quat: float = 1.0¶
- class RewardScales(torso_pos_xy: float | Dict[str, float] = 0.0, torso_pos_z: float | Dict[str, float] = 0.0, torso_quat: float | Dict[str, float] = 0.0, torso_roll: float | Dict[str, float] = 0.0, torso_pitch: float | Dict[str, float] = 0.0, lin_vel_xy: float | Dict[str, float] = 0.0, lin_vel_z: float | Dict[str, float] = 0.0, ang_vel_xy: float | Dict[str, float] = 0.0, ang_vel_z: float | Dict[str, float] = 0.0, motor_pos: float | Dict[str, float] = 0.0, motor_torque: float | Dict[str, float] = 0.0, energy: float | Dict[str, float] = 0.0, action_rate: float | Dict[str, float] = 0.0, feet_contact: float | Dict[str, float] = 0.0, contact_number: float | Dict[str, float] = 0.0, collision: float | Dict[str, float] = 0.0, survival: float | Dict[str, float] = 0.0, feet_air_time: float | Dict[str, float] = 0.0, feet_distance: float | Dict[str, float] = 0.0, feet_slip: float | Dict[str, float] = 0.0, feet_clearance: float | Dict[str, float] = 0.0, stand_still: float | Dict[str, float] = 0.0, align_ground: float | Dict[str, float] = 0.0, body_quat: float | Dict[str, float] = 0.0, site_pos: float | Dict[str, float] = 0.0, body_lin_vel: float | Dict[str, float] = 0.0, body_ang_vel: float | Dict[str, float] = 0.0)¶
Bases:
object
- action_rate: float | Dict[str, float] = 0.0¶
- align_ground: float | Dict[str, float] = 0.0¶
- ang_vel_xy: float | Dict[str, float] = 0.0¶
- ang_vel_z: float | Dict[str, float] = 0.0¶
- body_ang_vel: float | Dict[str, float] = 0.0¶
- body_lin_vel: float | Dict[str, float] = 0.0¶
- body_quat: float | Dict[str, float] = 0.0¶
- collision: float | Dict[str, float] = 0.0¶
- contact_number: float | Dict[str, float] = 0.0¶
- energy: float | Dict[str, float] = 0.0¶
- feet_air_time: float | Dict[str, float] = 0.0¶
- feet_clearance: float | Dict[str, float] = 0.0¶
- feet_contact: float | Dict[str, float] = 0.0¶
- feet_distance: float | Dict[str, float] = 0.0¶
- feet_slip: float | Dict[str, float] = 0.0¶
- lin_vel_xy: float | Dict[str, float] = 0.0¶
- lin_vel_z: float | Dict[str, float] = 0.0¶
- motor_pos: float | Dict[str, float] = 0.0¶
- motor_torque: float | Dict[str, float] = 0.0¶
- reset()¶
Reset all reward scales to zero.
- site_pos: float | Dict[str, float] = 0.0¶
- stand_still: float | Dict[str, float] = 0.0¶
- survival: float | Dict[str, float] = 0.0¶
- torso_pitch: float | Dict[str, float] = 0.0¶
- torso_pos_xy: float | Dict[str, float] = 0.0¶
- torso_pos_z: float | Dict[str, float] = 0.0¶
- torso_quat: float | Dict[str, float] = 0.0¶
- torso_roll: float | Dict[str, float] = 0.0¶
- class RewardsConfig(healthy_z_range: List[float] = <factory>, pos_tracking_sigma: float = 200.0, rot_tracking_sigma: float = 20.0, lin_vel_tracking_sigma: float = 200.0, ang_vel_tracking_sigma: float = 0.5, min_feet_y_dist: float = 0.07, max_feet_y_dist: float = 0.13, torso_roll_range: List[float] = <factory>, torso_pitch_range: List[float] = <factory>, add_regularization: bool = True, use_exp_reward: bool = True, leg_weight: float = 1.0, arm_weight: float = 1.0, neck_weight: float = 0.2, waist_weight: float = 0.2)¶
Bases:
object
- add_regularization: bool = True¶
- ang_vel_tracking_sigma: float = 0.5¶
- arm_weight: float = 1.0¶
- healthy_z_range: List[float]¶
- leg_weight: float = 1.0¶
- lin_vel_tracking_sigma: float = 200.0¶
- max_feet_y_dist: float = 0.13¶
- min_feet_y_dist: float = 0.07¶
- neck_weight: float = 0.2¶
- pos_tracking_sigma: float = 200.0¶
- rot_tracking_sigma: float = 20.0¶
- torso_pitch_range: List[float]¶
- torso_roll_range: List[float]¶
- use_exp_reward: bool = True¶
- waist_weight: float = 0.2¶
- class SimConfig(timestep: float = 0.005, solver: int = 2, self_contact_pairs: List[List[str]] | None = None)¶
Bases:
object
- self_contact_pairs: List[List[str]] | None = None¶
- solver: int = 2¶
- timestep: float = 0.005¶
- class TerrainConfig(tile_width: float = 4.0, tile_length: float = 4.0, resolution_per_meter: int = 16, random_spawn: bool = False, manual_map: List[List[str]] = <factory>, robot_collision_geom_names: Optional[List[str]] = <factory>)¶
Bases:
object
- manual_map: List[List[str]]¶
- random_spawn: bool = False¶
- resolution_per_meter: int = 16¶
- robot_collision_geom_names: List[str] | None¶
- tile_length: float = 4.0¶
- tile_width: float = 4.0¶
toddlerbot.locomotion.mjx_env module¶
Base MJX environment for ToddlerBot locomotion tasks.
This module provides the core MJXEnv class that serves as the base for all locomotion environments in ToddlerBot. It handles physics simulation, observation generation, reward computation, and environment management using MuJoCo and JAX.
- class toddlerbot.locomotion.mjx_env.MJXEnv(name: str, robot: Robot, cfg: MJXConfig, motion_ref: MotionReference, fixed_base: bool = False, add_domain_rand: bool = True, **kwargs: Any)¶
Bases:
PipelineEnv
Base MuJoCo-JAX environment for ToddlerBot locomotion tasks.
- property action_size: int¶
Returns the number of possible actions.
Overrides the default action size to provide the specific number of actions available.
- Returns:
The number of possible actions.
- Return type:
int
- pipeline_step(state: State, action: Array) State ¶
Executes a pipeline step by applying a control action to the system state.
This function iteratively applies a control action to the system’s state over a specified number of frames. It uses a controller to compute control signals based on the current state and action, and updates the pipeline state accordingly.
- Parameters:
state (State) – The current state of the system, containing information required for control computations.
action (jax.Array) – The control action to be applied to the system.
- Returns:
The updated state of the system after applying the control action over the specified number of frames.
- Return type:
base.State
- render(states: List[State], height: int = 240, width: int = 320, camera: str | None = None)¶
Renders a trajectory using the MuJoCo renderer.
- reset(rng: Array) State ¶
Resets the environment state and initializes various components for a new episode.
This function splits the input random number generator (RNG) into multiple streams for different components, initializes the state information dictionary, and sets up the initial positions, velocities, and commands for the environment. It also applies domain randomization if enabled and prepares observation histories.
- Parameters:
rng (jax.Array) – The random number generator state for initializing the environment.
- Returns:
The initialized state of the environment, including pipeline state, observations, rewards, and other relevant information.
- Return type:
State
- step(state: State, action: Array) State ¶
Advances the simulation by one time step, updating the state based on the given action.
This function updates the state of the simulation by processing the given action, applying filters, and incorporating domain randomization if enabled. It computes the motor targets, updates the pipeline state, and checks for termination conditions. Additionally, it calculates rewards and updates various state information, including contact forces, stance masks, and command resampling.
- Parameters:
state (State) – The current state of the simulation, containing information about the system’s dynamics and metadata.
action (jax.Array) – The action to be applied at this time step, influencing the system’s behavior.
- Returns:
The updated state after applying the action and advancing the simulation by one step.
- Return type:
State
- visualize_force_arrow(renderer: Renderer, state: State, push_id: int, push_force: Array, vis_scale: float = 0.05)¶
- visualize_world_axes(renderer: Renderer, origin: ndarray = array([0., 0., 0.]), axis_len: float = 20.0, alpha: float = 0.5)¶
- toddlerbot.locomotion.mjx_env.get_env_class(env_name: str) Type[MJXEnv] ¶
Returns the environment class associated with the given environment name.
- Parameters:
env_name (str) – The name of the environment to retrieve.
- Returns:
The class of the specified environment.
- Return type:
Type[MJXEnv]
- Raises:
ValueError – If the environment name is not found in the registry.
- toddlerbot.locomotion.mjx_env.get_env_config(env: str)¶
Retrieves and parses the configuration for a specified environment.
- Parameters:
env (str) – The name of the environment for which to retrieve the configuration.
- Returns:
An instance of MJXConfig initialized with the parsed configuration.
- Return type:
- Raises:
FileNotFoundError – If the configuration file for the specified environment does not exist.
toddlerbot.locomotion.on_policy_runner module¶
On-policy reinforcement learning runner for ToddlerBot training.
This module provides the OnPolicyRunner class for training and evaluating policies using on-policy algorithms like PPO. It manages the training loop, checkpointing, logging, and policy evaluation with support for distributed training.
- class toddlerbot.locomotion.on_policy_runner.OnPolicyRunner(env: VecEnv, train_cfg: dict, run_name: str | None = None, progress_fn: callable | None = None, render_fn: callable | None = None, restore_params=None, device='cpu')¶
Bases:
object
On-policy runner for training and evaluation.
- eval_mode()¶
Set all models to evaluation mode.
- get_inference_policy(device=None)¶
Get inference policy for evaluation with optional normalization.
- learn(num_learning_iterations: int, init_at_random_ep_len: bool = False)¶
- load(loaded_dict: dict, load_optimizer: bool = True)¶
Load model checkpoint and optionally resume training state.
- save(path: str, infos=None)¶
Save model checkpoint including policy, optimizer, and training state.
- train_mode()¶
Set all models to training mode.
- update_episode_metrics(metrics, dones, info_loss)¶
Update episode metrics and log training progress.
toddlerbot.locomotion.ppo_config module¶
PPO configuration settings for ToddlerBot training.
This module defines the PPOConfig dataclass containing hyperparameters and configuration settings for Proximal Policy Optimization (PPO) training.
- class toddlerbot.locomotion.ppo_config.PPOConfig(wandb_project: str = 'ToddlerBot', wandb_entity: str = 'toddlerbot', policy_hidden_layer_sizes: Tuple[int, ...] = (512, 256, 128), value_hidden_layer_sizes: Tuple[int, ...] = (512, 256, 128), use_rnn: bool = False, rnn_type: str = 'lstm', rnn_hidden_size: int = 512, rnn_num_layers: int = 1, activation: str = 'elu', distribution_type: str = 'normal', noise_std_type: str = 'log', init_noise_std: float = 0.5, num_timesteps: int = 500000000, num_evals: int = 100, episode_length: int = 1000, unroll_length: int = 20, num_updates_per_batch: int = 4, discounting: float = 0.97, gae_lambda: float = 0.95, max_grad_norm: float = 1.0, normalize_advantage: bool = True, normalize_observation: bool = False, learning_rate: float = 3e-05, entropy_cost: float = 0.001, clipping_epsilon: float = 0.2, num_envs: int = 1024, render_nums: int = 20, batch_size: int = 256, num_minibatches: int = 4, seed: int = 0)¶
Bases:
object
Data class for storing PPO hyperparameters.
- activation: str = 'elu'¶
- batch_size: int = 256¶
- clipping_epsilon: float = 0.2¶
- discounting: float = 0.97¶
- distribution_type: str = 'normal'¶
- entropy_cost: float = 0.001¶
- episode_length: int = 1000¶
- gae_lambda: float = 0.95¶
- init_noise_std: float = 0.5¶
- learning_rate: float = 3e-05¶
- max_grad_norm: float = 1.0¶
- noise_std_type: str = 'log'¶
- normalize_advantage: bool = True¶
- normalize_observation: bool = False¶
- num_envs: int = 1024¶
- num_evals: int = 100¶
- num_minibatches: int = 4¶
- num_timesteps: int = 500000000¶
- num_updates_per_batch: int = 4¶
- render_nums: int = 20¶
- rnn_num_layers: int = 1¶
- rnn_type: str = 'lstm'¶
- seed: int = 0¶
- unroll_length: int = 20¶
- use_rnn: bool = False¶
- wandb_entity: str = 'toddlerbot'¶
- wandb_project: str = 'ToddlerBot'¶
toddlerbot.locomotion.rsl_rl_wrapper module¶
RSL-RL wrapper for MJX environments.
This module provides the RSLRLWrapper class that adapts MJX environments for use with the RSL-RL reinforcement learning framework, handling JAX-PyTorch tensor conversions and environment interface compatibility.
- class toddlerbot.locomotion.rsl_rl_wrapper.RSLRLWrapper(env: MJXEnv, device: device, train_cfg: PPOConfig, eval: bool = False)¶
Bases:
VecEnv
Wrapper to adapt MJX environments for RSL-RL training framework.
- get_observations() tuple[Tensor, dict] ¶
Get current observations and convert from JAX to PyTorch format.
- reset()¶
Reset all environments and return initial states.
- step(actions: Tensor) tuple[Tensor, Tensor, Tensor, dict] ¶
Execute actions in environment and return step results.
toddlerbot.locomotion.train_mjx module¶
Training script for ToddlerBot locomotion policies using MJX.
This module provides training functionality for ToddlerBot using both JAX (Brax) and PyTorch (RSL-RL) backends. It supports various locomotion tasks including walking, crawling, and cartwheel movements with configurable environments.
- class toddlerbot.locomotion.train_mjx.Tee(log_path)¶
Bases:
object
Custom stdout/stderr redirection class for logging output to both console and file.
- close()¶
- fileno()¶
- flush()¶
- isatty()¶
- write(message)¶
Write message to both terminal and log file.
- toddlerbot.locomotion.train_mjx.domain_randomize(sys: System, rng: Array, friction_range: List[float], damping_range: List[float], armature_range: List[float], frictionloss_range: List[float], body_mass_attr_range: Dict[str, Array | ndarray[Any, dtype[float32]]] | None) Tuple[System, System] ¶
Randomizes the physical parameters of a system within specified ranges.
- Parameters:
sys (base.System) – The system whose parameters are to be randomized.
rng (jax.Array) – Random number generator state.
friction_range (List[float]) – Range for randomizing friction values.
damping_range (List[float]) – Range for randomizing damping values.
armature_range (List[float]) – Range for randomizing armature values.
frictionloss_range (List[float]) – Range for randomizing friction loss values.
body_mass_attr_range (Optional[Dict[str, jax.Array | npt.NDArray[np.float32]]]) – Optional dictionary specifying ranges for body mass attributes.
- Returns:
A tuple containing the randomized system and the in_axes configuration for JAX transformations.
- Return type:
Tuple[base.System, base.System]
- toddlerbot.locomotion.train_mjx.dynamic_import_envs(env_package: str)¶
Import all modules from a specified package for environment registration.
- toddlerbot.locomotion.train_mjx.evaluate(env: MJXEnv, train_cfg: PPOConfig, policy_path: str, args: Namespace)¶
Evaluates a policy in a given environment using a specified network factory and logs the results.
- Parameters:
env (MJXEnv) – The environment in which the policy is evaluated.
make_networks_factory (Any) – A factory function to create network architectures for the policy.
run_name (str) – The name of the run, used for saving and loading policy parameters.
- toddlerbot.locomotion.train_mjx.get_body_mass_attr_range(robot: Robot, body_mass_range: List[float], hand_mass_range: List[float], other_mass_range: List[float], num_envs: int)¶
Generates a range of body mass attributes for a robot across multiple environments.
This function modifies the body mass and inertia of a robot model based on specified ranges for different body parts (torso, end-effector, and others) and returns a dictionary containing the updated attributes for each environment.
- Parameters:
robot (Robot) – The robot object containing configuration and name.
body_mass_range (List[float]) – The range of mass deltas for the torso.
hand_mass_range (List[float]) – The range of mass deltas for the end-effector.
other_mass_range (List[float]) – The range of mass deltas for other body parts.
num_envs (int) – The number of environments to generate.
- Returns:
A dictionary with keys representing different body mass attributes and values as JAX arrays or NumPy arrays containing the attribute values across all environments.
- Return type:
Dict[str, jax.Array | npt.NDArray[np.float32]]
- toddlerbot.locomotion.train_mjx.load_jax_ckpt_to_torch(jax_params)¶
Convert JAX model parameters to PyTorch format for cross-framework compatibility.
- toddlerbot.locomotion.train_mjx.load_runner_config(train_cfg: PPOConfig)¶
Load and configure RSL-RL runner settings from PPO configuration.
- toddlerbot.locomotion.train_mjx.log_metrics(metrics: Dict[str, Any], num_steps: int, defined_metrics: List[str])¶
Process and log training metrics to Weights & Biases.
- toddlerbot.locomotion.train_mjx.main(args=None)¶
Trains or evaluates a policy for a specified robot and environment using PPO.
This function sets up the training or evaluation of a policy for a robot in a specified environment. It parses command-line arguments to configure the robot, environment, evaluation settings, and other parameters. It then loads configuration files, binds any overridden parameters, and initializes the environment and robot. Depending on the arguments, it either trains a new policy or evaluates an existing one.
- Parameters:
args (list, optional) – List of command-line arguments. If None, arguments are parsed from sys.argv.
- Raises:
FileNotFoundError – If a specified gin configuration file or evaluation run is not found.
- toddlerbot.locomotion.train_mjx.print_metrics(metrics: Dict[str, Any], time_elapsed: float, num_steps: int, num_total_steps: int, notes: str = '', width: int = 80, pad: int = 35)¶
Logs and formats metrics for display, including elapsed time and optional step information.
- Parameters:
metrics (Dict[str, Any]) – A dictionary containing metric names and their corresponding values.
time_elapsed (float) – The time elapsed since the start of the process.
num_steps (int, optional) – The current number of steps completed. Defaults to -1.
num_total_steps (int, optional) – The total number of steps to be completed. Defaults to -1.
width (int, optional) – The width of the log display. Defaults to 80.
pad (int, optional) – The padding for metric names in the log display. Defaults to 35.
- Returns:
A dictionary containing the logged data, including time elapsed and processed metrics.
- Return type:
Dict[str, Any]
- toddlerbot.locomotion.train_mjx.render_video(env: MJXEnv, states: List[Any], video_dir: str, video_name: str, cameras: List[str] = ['perspective'], render_every: int = 2, height: int = 360, width: int = 640)¶
Renders and saves a video of the environment from multiple camera angles.
- Parameters:
env (MJXEnv) – The environment to render.
rollout (List[Any]) – A list of environment states or actions to render.
run_name (str) – The name of the run, used to organize output files.
render_every (int, optional) – Interval at which frames are rendered from the rollout. Defaults to 2.
height (int, optional) – The height of the rendered video frames. Defaults to 360.
width (int, optional) – The width of the rendered video frames. Defaults to 640.
- Creates:
A video file for each camera angle (‘perspective’, ‘side’, ‘top’, ‘front’) and a final concatenated video in a 2x2 grid layout, saved in the ‘results’ directory under the specified run name.
- toddlerbot.locomotion.train_mjx.rollout(jit_reset, jit_step, inference_fn, train_cfg, use_torch, use_batch, rng)¶
Execute policy rollout for evaluation and video generation.
- toddlerbot.locomotion.train_mjx.train(env: MJXEnv, eval_env: MJXEnv, test_env: MJXEnv, train_cfg: PPOConfig, run_name: str, args: Namespace)¶
Trains a reinforcement learning agent using the Proximal Policy Optimization (PPO) algorithm.
This function sets up the training environment, initializes configurations, and manages the training process, including saving configurations, logging metrics, and handling checkpoints.
- Parameters:
env (MJXEnv) – The training environment.
eval_env (MJXEnv) – The evaluation environment.
make_networks_factory (Any) – Factory function to create neural network models.
train_cfg (PPOConfig) – Configuration settings for the PPO training process.
run_name (str) – Name of the training run, used for organizing results.
restore_path (str) – Path to restore a previous checkpoint, if any.
toddlerbot.locomotion.walk_env module¶
Walking locomotion environment for ToddlerBot.
This module provides the WalkEnv class for training ToddlerBot in bipedal walking. The environment includes specialized reward functions for stability, gait patterns, and command following using ZMP (Zero Moment Point) reference trajectories.
- class toddlerbot.locomotion.walk_env.WalkEnv(name: str, robot: Robot, cfg: MJXConfig, ref_motion_type: str = 'zmp', fixed_base: bool = False, add_domain_rand: bool = True, **kwargs: Any)¶
Bases:
MJXEnv
Walk environment with ToddlerBot.
- render(states: List[State], height: int = 240, width: int = 320, camera: str | None = None)¶
Render environment states with path visualization and force arrows.
- visualize_path_frame(renderer, pos, rot, axis_len=0.2, alpha=0.5)¶
Visualize coordinate frame axes in the renderer for debugging.
Module contents¶
Locomotion and movement control for ToddlerBot.
This package contains implementations for various locomotion strategies and movement patterns for the ToddlerBot humanoid robot, including:
Walking gaits using reinforcement learning (PPO) and classical control
Cartwheel and acrobatic movement patterns
Crawling and low-profile locomotion
Balance and stability control
Environment interaction and terrain adaptation
The locomotion modules support both MuJoCo-based training environments and real robot deployment, with configurable parameters for different robot configurations and movement requirements.
Key components include environment definitions, training configurations, reward functions, and policy interfaces for various locomotion behaviors.