Skip to content

Inconsistency between VAE Decoder SDF output and paper description #189

@zhozhh

Description

@zhozhh

Hi @HuiwenShi @Zeqiang-Lai,

According to the paper, the VAE should generate SDF values based on grid point queries, which are then used for mesh extraction via Marching Cubes. This implies that the SDF should represent a continuous distance field where values vary proportionally to the distance from the object surface. The script hy3dshape/tools/watertight/watertight_and_sample.py follows this expected behavior.

However, the released pre-trained VAE models seem to exhibit inconsistent behavior. When performing a VAE reconstruction on a standard cube, the SDF values sampled from space are predominantly clamped at 1.0 or -1.0. Values between these two only appear in areas extremely close to the object surface.

What is the definition of the model output, is it actually a TSDF with scale transformation?

Here is a visualization of SDF decoded from [-0.7, 0.7] cube (looks like binary):

Image

The following code snippet demonstrates this behavior. In the majority of the sampled volume, the output is nearly binary (1.0 or -1.0).

import torch
from hy3dshape.surface_loaders import SharpEdgeSurfaceLoader
from hy3dshape.models.autoencoders import ShapeVAE
from hy3dshape.pipelines import export_to_trimesh
import numpy as np
import trimesh
from typing import Callable, Tuple, List, Union
from tqdm import tqdm
from einops import repeat

def generate_dense_grid_points(
    bbox_min: np.ndarray,
    bbox_max: np.ndarray,
    octree_resolution: int,
    indexing: str = "ij",
):
    length = bbox_max - bbox_min
    num_cells = octree_resolution

    x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
    y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
    z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
    [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
    xyz = np.stack((xs, ys, zs), axis=-1)
    grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]

    return xyz, grid_size, length


class SaveVanillaVolumeDecoder:
    @torch.no_grad()
    def __call__(
        self,
        latents: torch.FloatTensor,
        geo_decoder: Callable,
        bounds: Union[Tuple[float], List[float], float] = 1.01,
        num_chunks: int = 10000,
        octree_resolution: int = None,
        enable_pbar: bool = True,
        **kwargs,
    ):
        device = latents.device
        dtype = latents.dtype
        batch_size = latents.shape[0]

        # 1. generate query points
        if isinstance(bounds, float):
            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]

        bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
        xyz_samples, grid_size, length = generate_dense_grid_points(
            bbox_min=bbox_min,
            bbox_max=bbox_max,
            octree_resolution=octree_resolution,
            indexing="ij"
        )
        xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
        EXTRANUM = 1024
        xs, ys, zs = torch.meshgrid(torch.tensor(0.0), torch.tensor(0.0), torch.linspace(-1.0, 1.0, EXTRANUM))
        
        xyz_samples_new = torch.stack((xs, ys, zs), dim=-1).reshape(-1,3).to(device, dtype=dtype)
        xyz_samples = torch.cat([xyz_samples, xyz_samples_new], dim=0)

        # 2. latents to 3d volume
        batch_logits = []
        for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
                          disable=not enable_pbar):
            chunk_queries = xyz_samples[start: start + num_chunks, :]
            chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
            logits = geo_decoder(queries=chunk_queries, latents=latents)
            batch_logits.append(logits)

        grid_all = torch.cat(batch_logits, dim=1)
        print(grid_all.shape)
        grid_logits = grid_all[:,:-EXTRANUM, :]
        grid = grid_all[0,-EXTRANUM:, :]
        print(grid.shape)

        strtotxt = ""
        for i in range(0, EXTRANUM):
            strtotxt += f"{i}, {2.0*i/EXTRANUM - 1.0}, {grid[i].item()}\n"

        with open("grid_logits.txt", "w") as f:
            f.write(strtotxt)

        grid_logits = grid_logits.view((batch_size, *grid_size)).float()

        return grid_logits


vae = ShapeVAE.from_pretrained(
    'tencent/Hunyuan3D-2.1',
    use_safetensors=False,
    variant='fp16',
)


loader = SharpEdgeSurfaceLoader(
    num_sharp_points=0,
    num_uniform_points=81920,
)
vae.volume_decoder = SaveVanillaVolumeDecoder()

mesh_demo = trimesh.creation.box(extents=[2.0, 2.0, 2.0])
surface = loader(mesh_demo).to('cuda', dtype=torch.float16)
surface[:, :, :3] = surface[:, :, :3] * 0.8 # normalize the cube to [-0.8, 0.8]

with torch.no_grad():
    latents = vae.encode(surface, sample_posterior=True)
    latents = vae.decode(latents)
    mesh = vae.latents2mesh(
        latents,
        output_type='trimesh',
        bounds=1.01,
        mc_level=0.0,
        num_chunks=20000,
        octree_resolution=256,
        mc_algo='mc',
        enable_pbar=True
    )
    mesh = export_to_trimesh(mesh)[0]
    print(mesh.bounds)
    mesh.export("cube.obj")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions