|  | 
|  | 1 | +from __future__ import annotations | 
|  | 2 | + | 
|  | 3 | +import numpy as np | 
|  | 4 | +import torch | 
|  | 5 | +from einops import rearrange | 
|  | 6 | + | 
|  | 7 | +import comfy.model_management | 
|  | 8 | +import nodes | 
|  | 9 | +from comfy_api.v3 import io | 
|  | 10 | + | 
|  | 11 | +CAMERA_DICT = { | 
|  | 12 | +    "base_T_norm": 1.5, | 
|  | 13 | +    "base_angle": np.pi / 3, | 
|  | 14 | +    "Static": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 0.0]}, | 
|  | 15 | +    "Pan Up": {"angle": [0.0, 0.0, 0.0], "T": [0.0, -1.0, 0.0]}, | 
|  | 16 | +    "Pan Down": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 1.0, 0.0]}, | 
|  | 17 | +    "Pan Left": {"angle": [0.0, 0.0, 0.0], "T": [-1.0, 0.0, 0.0]}, | 
|  | 18 | +    "Pan Right": {"angle": [0.0, 0.0, 0.0], "T": [1.0, 0.0, 0.0]}, | 
|  | 19 | +    "Zoom In": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, 2.0]}, | 
|  | 20 | +    "Zoom Out": {"angle": [0.0, 0.0, 0.0], "T": [0.0, 0.0, -2.0]}, | 
|  | 21 | +    "Anti Clockwise (ACW)": {"angle": [0.0, 0.0, -1.0], "T": [0.0, 0.0, 0.0]}, | 
|  | 22 | +    "ClockWise (CW)": {"angle": [0.0, 0.0, 1.0], "T": [0.0, 0.0, 0.0]}, | 
|  | 23 | +} | 
|  | 24 | + | 
|  | 25 | + | 
|  | 26 | +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device="cpu"): | 
|  | 27 | +    def get_relative_pose(cam_params): | 
|  | 28 | +        """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" | 
|  | 29 | +        abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | 
|  | 30 | +        abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | 
|  | 31 | +        cam_to_origin = 0 | 
|  | 32 | +        target_cam_c2w = np.array([[1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1]]) | 
|  | 33 | +        abs2rel = target_cam_c2w @ abs_w2cs[0] | 
|  | 34 | +        ret_poses = [target_cam_c2w] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | 
|  | 35 | +        return np.array(ret_poses, dtype=np.float32) | 
|  | 36 | + | 
|  | 37 | +    """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" | 
|  | 38 | +    cam_params = [Camera(cam_param) for cam_param in cam_params] | 
|  | 39 | + | 
|  | 40 | +    sample_wh_ratio = width / height | 
|  | 41 | +    pose_wh_ratio = original_pose_width / original_pose_height  # Assuming placeholder ratios, change as needed | 
|  | 42 | + | 
|  | 43 | +    if pose_wh_ratio > sample_wh_ratio: | 
|  | 44 | +        resized_ori_w = height * pose_wh_ratio | 
|  | 45 | +        for cam_param in cam_params: | 
|  | 46 | +            cam_param.fx = resized_ori_w * cam_param.fx / width | 
|  | 47 | +    else: | 
|  | 48 | +        resized_ori_h = width / pose_wh_ratio | 
|  | 49 | +        for cam_param in cam_params: | 
|  | 50 | +            cam_param.fy = resized_ori_h * cam_param.fy / height | 
|  | 51 | + | 
|  | 52 | +    intrinsic = np.asarray( | 
|  | 53 | +        [[cam_param.fx * width, cam_param.fy * height, cam_param.cx * width, cam_param.cy * height] for cam_param in cam_params], | 
|  | 54 | +        dtype=np.float32, | 
|  | 55 | +    ) | 
|  | 56 | + | 
|  | 57 | +    K = torch.as_tensor(intrinsic)[None]  # [1, 1, 4] | 
|  | 58 | +    c2ws = get_relative_pose(cam_params)  # Assuming this function is defined elsewhere | 
|  | 59 | +    c2ws = torch.as_tensor(c2ws)[None]  # [1, n_frame, 4, 4] | 
|  | 60 | +    plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous()  # V, 6, H, W | 
|  | 61 | +    plucker_embedding = plucker_embedding[None] | 
|  | 62 | +    return rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] | 
|  | 63 | + | 
|  | 64 | + | 
|  | 65 | +class Camera: | 
|  | 66 | +    """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" | 
|  | 67 | + | 
|  | 68 | +    def __init__(self, entry): | 
|  | 69 | +        fx, fy, cx, cy = entry[1:5] | 
|  | 70 | +        self.fx = fx | 
|  | 71 | +        self.fy = fy | 
|  | 72 | +        self.cx = cx | 
|  | 73 | +        self.cy = cy | 
|  | 74 | +        c2w_mat = np.array(entry[7:]).reshape(4, 4) | 
|  | 75 | +        self.c2w_mat = c2w_mat | 
|  | 76 | +        self.w2c_mat = np.linalg.inv(c2w_mat) | 
|  | 77 | + | 
|  | 78 | + | 
|  | 79 | +def ray_condition(K, c2w, H, W, device): | 
|  | 80 | +    """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py""" | 
|  | 81 | +    # c2w: B, V, 4, 4 | 
|  | 82 | +    # K: B, V, 4 | 
|  | 83 | + | 
|  | 84 | +    B = K.shape[0] | 
|  | 85 | + | 
|  | 86 | +    j, i = torch.meshgrid( | 
|  | 87 | +        torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | 
|  | 88 | +        torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), | 
|  | 89 | +        indexing="ij", | 
|  | 90 | +    ) | 
|  | 91 | +    i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5  # [B, HxW] | 
|  | 92 | +    j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5  # [B, HxW] | 
|  | 93 | + | 
|  | 94 | +    fx, fy, cx, cy = K.chunk(4, dim=-1)  # B,V, 1 | 
|  | 95 | + | 
|  | 96 | +    zs = torch.ones_like(i)  # [B, HxW] | 
|  | 97 | +    xs = (i - cx) / fx * zs | 
|  | 98 | +    ys = (j - cy) / fy * zs | 
|  | 99 | +    zs = zs.expand_as(ys) | 
|  | 100 | + | 
|  | 101 | +    directions = torch.stack((xs, ys, zs), dim=-1)  # B, V, HW, 3 | 
|  | 102 | +    directions = directions / directions.norm(dim=-1, keepdim=True)  # B, V, HW, 3 | 
|  | 103 | + | 
|  | 104 | +    rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)  # B, V, 3, HW | 
|  | 105 | +    rays_o = c2w[..., :3, 3]  # B, V, 3 | 
|  | 106 | +    rays_o = rays_o[:, :, None].expand_as(rays_d)  # B, V, 3, HW | 
|  | 107 | +    # c2w @ dirctions | 
|  | 108 | +    rays_dxo = torch.cross(rays_o, rays_d) | 
|  | 109 | +    plucker = torch.cat([rays_dxo, rays_d], dim=-1) | 
|  | 110 | +    plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)  # B, V, H, W, 6 | 
|  | 111 | +    # plucker = plucker.permute(0, 1, 4, 2, 3) | 
|  | 112 | +    return plucker | 
|  | 113 | + | 
|  | 114 | + | 
|  | 115 | +def get_camera_motion(angle, T, speed, n=81): | 
|  | 116 | +    def compute_R_form_rad_angle(angles): | 
|  | 117 | +        theta_x, theta_y, theta_z = angles | 
|  | 118 | +        Rx = np.array([[1, 0, 0], [0, np.cos(theta_x), -np.sin(theta_x)], [0, np.sin(theta_x), np.cos(theta_x)]]) | 
|  | 119 | + | 
|  | 120 | +        Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], [0, 1, 0], [-np.sin(theta_y), 0, np.cos(theta_y)]]) | 
|  | 121 | + | 
|  | 122 | +        Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], [np.sin(theta_z), np.cos(theta_z), 0], [0, 0, 1]]) | 
|  | 123 | + | 
|  | 124 | +        R = np.dot(Rz, np.dot(Ry, Rx)) | 
|  | 125 | +        return R | 
|  | 126 | + | 
|  | 127 | +    RT = [] | 
|  | 128 | +    for i in range(n): | 
|  | 129 | +        _angle = (i / n) * speed * (CAMERA_DICT["base_angle"]) * angle | 
|  | 130 | +        R = compute_R_form_rad_angle(_angle) | 
|  | 131 | +        _T = (i / n) * speed * (CAMERA_DICT["base_T_norm"]) * (T.reshape(3, 1)) | 
|  | 132 | +        _RT = np.concatenate([R, _T], axis=1) | 
|  | 133 | +        RT.append(_RT) | 
|  | 134 | +    RT = np.stack(RT) | 
|  | 135 | +    return RT | 
|  | 136 | + | 
|  | 137 | + | 
|  | 138 | +class WanCameraEmbedding(io.ComfyNodeV3): | 
|  | 139 | +    @classmethod | 
|  | 140 | +    def define_schema(cls): | 
|  | 141 | +        return io.SchemaV3( | 
|  | 142 | +            node_id="WanCameraEmbedding_V3", | 
|  | 143 | +            category="camera", | 
|  | 144 | +            inputs=[ | 
|  | 145 | +                io.Combo.Input( | 
|  | 146 | +                    "camera_pose", | 
|  | 147 | +                    options=[ | 
|  | 148 | +                        "Static", | 
|  | 149 | +                        "Pan Up", | 
|  | 150 | +                        "Pan Down", | 
|  | 151 | +                        "Pan Left", | 
|  | 152 | +                        "Pan Right", | 
|  | 153 | +                        "Zoom In", | 
|  | 154 | +                        "Zoom Out", | 
|  | 155 | +                        "Anti Clockwise (ACW)", | 
|  | 156 | +                        "ClockWise (CW)", | 
|  | 157 | +                    ], | 
|  | 158 | +                    default="Static", | 
|  | 159 | +                ), | 
|  | 160 | +                io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), | 
|  | 161 | +                io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), | 
|  | 162 | +                io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), | 
|  | 163 | +                io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True), | 
|  | 164 | +                io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True), | 
|  | 165 | +                io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True), | 
|  | 166 | +                io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True), | 
|  | 167 | +                io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True), | 
|  | 168 | +            ], | 
|  | 169 | +            outputs=[ | 
|  | 170 | +                io.WanCameraEmbedding.Output(display_name="camera_embedding"), | 
|  | 171 | +                io.Int.Output(display_name="width"), | 
|  | 172 | +                io.Int.Output(display_name="height"), | 
|  | 173 | +                io.Int.Output(display_name="length"), | 
|  | 174 | +            ], | 
|  | 175 | +        ) | 
|  | 176 | + | 
|  | 177 | +    @classmethod | 
|  | 178 | +    def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput: | 
|  | 179 | +        """ | 
|  | 180 | +        Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) | 
|  | 181 | +        Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py | 
|  | 182 | +        """ | 
|  | 183 | +        motion_list = [camera_pose] | 
|  | 184 | +        speed = speed | 
|  | 185 | +        angle = np.array(CAMERA_DICT[motion_list[0]]["angle"]) | 
|  | 186 | +        T = np.array(CAMERA_DICT[motion_list[0]]["T"]) | 
|  | 187 | +        RT = get_camera_motion(angle, T, speed, length) | 
|  | 188 | + | 
|  | 189 | +        trajs = [] | 
|  | 190 | +        for cp in RT.tolist(): | 
|  | 191 | +            traj = [fx, fy, cx, cy, 0, 0] | 
|  | 192 | +            traj.extend(cp[0]) | 
|  | 193 | +            traj.extend(cp[1]) | 
|  | 194 | +            traj.extend(cp[2]) | 
|  | 195 | +            traj.extend([0, 0, 0, 1]) | 
|  | 196 | +            trajs.append(traj) | 
|  | 197 | + | 
|  | 198 | +        cam_params = np.array([[float(x) for x in pose] for pose in trajs]) | 
|  | 199 | +        cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) | 
|  | 200 | +        control_camera_video = process_pose_params(cam_params, width=width, height=height) | 
|  | 201 | +        control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device()) | 
|  | 202 | + | 
|  | 203 | +        control_camera_video = torch.concat( | 
|  | 204 | +            [torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), control_camera_video[:, :, 1:]], dim=2 | 
|  | 205 | +        ).transpose(1, 2) | 
|  | 206 | + | 
|  | 207 | +        # Reshape, transpose, and view into desired shape | 
|  | 208 | +        b, f, c, h, w = control_camera_video.shape | 
|  | 209 | +        control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) | 
|  | 210 | +        control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) | 
|  | 211 | + | 
|  | 212 | +        return io.NodeOutput(control_camera_video, width, height, length) | 
|  | 213 | + | 
|  | 214 | + | 
|  | 215 | +NODES_LIST = [ | 
|  | 216 | +    WanCameraEmbedding, | 
|  | 217 | +] | 
0 commit comments