diff --git a/src/pytorch_kinematics/mjcf.py b/src/pytorch_kinematics/mjcf.py index 2825e01..3dc1857 100644 --- a/src/pytorch_kinematics/mjcf.py +++ b/src/pytorch_kinematics/mjcf.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional, Dict import mujoco from mujoco._structs import _MjModelBodyViews as MjModelBodyViews @@ -57,7 +57,7 @@ def _build_chain_recurse(m, parent_frame, parent_body): parent_frame.children = parent_frame.children + [site_frame, ] -def build_chain_from_mjcf(data, body: Union[None, str, int] = None): +def build_chain_from_mjcf(data, body: Union[None, str, int] = None, assets:Optional[Dict[str,bytes]]=None): """ Build a Chain object from MJCF data. @@ -73,7 +73,7 @@ def build_chain_from_mjcf(data, body: Union[None, str, int] = None): chain.Chain Chain object created from MJCF. """ - m = mujoco.MjModel.from_xml_string(data) + m = mujoco.MjModel.from_xml_string(data, assets=assets) if body is None: root_body = m.body(0) else: