diff --git a/README.md b/README.md index 27716ae..ebcaaff 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,11 @@ - - -
+

ComfyUI-DyPE

ComfyUI-DyPE Banner

- A ComfyUI custom node that implements DyPE (Dynamic Position Extrapolation), enabling Diffusion Transformers (like FLUX and Qwen Image) to generate ultra-high-resolution images (4K and beyond) with exceptional coherence and detail. + A ComfyUI custom node that implements DyPE (Dynamic Position Extrapolation), enabling Diffusion Transformers (like FLUX, Qwen Image, and Z-Image) to generate ultra-high-resolution images (4K and beyond) with exceptional coherence and detail.

Report Bug @@ -44,7 +42,7 @@ It works by taking advantage of the spectral progression inherent to the diffusi This node provides a seamless, "plug-and-play" integration of DyPE into your workflow. **✨ Key Features:** -* **Multi-Architecture Support:** Now supports **FLUX** (Standard), **Nunchaku** (Quantized Flux), and **Qwen Image**. +* **Multi-Architecture Support:** Supports **FLUX** (Standard), **Nunchaku** (Quantized Flux), **Qwen Image**, and **Z-Image** (Lumina 2). * **High-Resolution Generation:** Push models to 4096x4096 and beyond. * **Single-Node Integration:** Simply place the `DyPE for FLUX` node after your model loader to patch the model. No complex workflow changes required. * **Full Compatibility:** Works seamlessly with your existing ComfyUI workflows, samplers, schedulers, and other optimization nodes. @@ -84,7 +82,7 @@ Alternatively, to install manually: Using the node is straightforward and designed for minimal workflow disruption. -1. **Load Your Model:** Use your preferred loader (e.g., `Load Checkpoint` for Flux, `Nunchaku Flux DiT Loader`, or a Qwen loader). +1. **Load Your Model:** Use your preferred loader (e.g., `Load Checkpoint` for Flux, `Nunchaku Flux DiT Loader`, or `ZImage` loader). 2. **Add the DyPE Node:** Add the `DyPE for FLUX` node to your graph (found under `model_patches/unet`). 3. **Connect the Model:** Connect the `MODEL` output from your loader to the `model` input of the DyPE node. 4. **Set Resolution:** Set the `width` and `height` on the DyPE node to match the resolution of your `Empty Latent Image`. @@ -98,12 +96,13 @@ Using the node is straightforward and designed for minimal workflow disruption. #### 1. Model Configuration * **`model_type`**: - * **`auto`**: Attempts to automatically detect the model architecture (Flux, Nunchaku, or Qwen). Recommended. + * **`auto`**: Attempts to automatically detect the model architecture. Recommended. * **`flux`**: Forces Standard Flux logic. * **`nunchaku`**: Forces Nunchaku (Quantized Flux) logic. * **`qwen`**: Forces Qwen Image logic. + * **`zimage`**: Forces Z-Image (Lumina 2) logic. * **`base_resolution`**: The native resolution the model was trained on. - * Flux: `1024` + * Flux / Z-Image: `1024` * Qwen: `1328` (Recommended setting for Qwen models) #### 2. Method Selection @@ -118,6 +117,10 @@ Using the node is straightforward and designed for minimal workflow disruption. * **Anisotropic (High-Res):** Scales Height and Width independently. Can cause geometric stretching if the aspect ratio differs significantly from the training data. * **Isotropic (Stable Default):** Scales both dimensions based on the largest axis. . * *Note: `vision_yarn` automatically handles this balance internally, so this switch is ignored when `vision_yarn` is selected.* + +> [!TIP] +> **Z-Image Usage:** Z-Image models have a very low RoPE base frequency (`theta=256`). This makes anisotropic scaling unstable (vertical stretching). The node automatically detects this and forces isotropic behavior in `vision_yarn` mode for Z-Image. We recommend using `vision_yarn` or `ntk` for Z-Image. + #### 3. Dynamic Control * **`enable_dype`**: Enables or disables the **dynamic, time-aware** component of DyPE. * **Enabled (True):** Both the noise schedule and RoPE will be dynamically adjusted throughout sampling. This is the full DyPE algorithm. @@ -135,6 +138,9 @@ Using the node is straightforward and designed for minimal workflow disruption. ## Changelog +#### v2.2 +* **Z-Image Support:** Added experimental support for **Z-Image (Lumina 2)** architecture. + #### v2.1 * **New Architecture Support:** Added support for **Qwen Image** and **Nunchaku** (Quantized Flux) models. * **Modular Architecture:** Refactored codebase into a modular adapter pattern (`src/models/`) to ensure stability and easier updates for future models. diff --git a/__init__.py b/__init__.py index 642b7e7..853c3a2 100644 --- a/__init__.py +++ b/__init__.py @@ -31,7 +31,7 @@ def define_schema(cls) -> io.Schema: ), io.Combo.Input( "model_type", - options=["auto", "flux", "nunchaku", "qwen"], + options=["auto", "flux", "nunchaku", "qwen", "zimage", "z_image"], default="auto", tooltip="Specify the model architecture. 'auto' usually works", ), @@ -108,4 +108,4 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: return [DyPE_FLUX] async def comfy_entrypoint() -> DyPEExtension: - return DyPEExtension() \ No newline at end of file + return DyPEExtension() diff --git a/example_workflows/DyPE-ZIT-workflow.json b/example_workflows/DyPE-ZIT-workflow.json new file mode 100644 index 0000000..36b261f --- /dev/null +++ b/example_workflows/DyPE-ZIT-workflow.json @@ -0,0 +1,760 @@ +{ + "id": "9ae6082b-c7f4-433c-9971-7a8f65a3ea65", + "revision": 0, + "last_node_id": 49, + "last_link_id": 47, + "nodes": [ + { + "id": 35, + "type": "MarkdownNote", + "pos": [ + -390, + 270 + ], + "size": [ + 490, + 400 + ], + "flags": { + "collapsed": false + }, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [], + "title": "Model link", + "properties": {}, + "widgets_values": [ + "## Report issue\n\nIf you found any issues when running this workflow, [report template issue here](https://github.com/Comfy-Org/workflow_templates/issues)\n\n\n## Model links\n\n**text_encoders**\n\n- [qwen_3_4b.safetensors](https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors)\n\n**diffusion_models**\n\n- [z_image_turbo_bf16.safetensors](https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors)\n\n**vae**\n\n- [ae.safetensors](https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors)\n\n\nModel Storage Location\n\n```\nšŸ“‚ ComfyUI/\nā”œā”€ā”€ šŸ“‚ models/\n│ ā”œā”€ā”€ šŸ“‚ text_encoders/\n│ │ └── qwen_3_4b.safetensors\n│ ā”œā”€ā”€ šŸ“‚ diffusion_models/\n│ │ └── z_image_turbo_bf16.safetensors\n│ └── šŸ“‚ vae/\n│ └── ae.safetensors\n```\n\n" + ], + "color": "#432", + "bgcolor": "#653" + }, + { + "id": 44, + "type": "KSampler", + "pos": [ + 900.2638101844517, + 375.2545050421948 + ], + "size": [ + 317.95033878771574, + 474 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 40 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 41 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 42 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 43 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "slot_index": 0, + "links": [ + 38 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "KSampler", + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + 410513707389275, + "fixed", + 9, + 1, + "res_multistep", + "simple", + 1 + ] + }, + { + "id": 46, + "type": "UNETLoader", + "pos": [ + 130.2638101844517, + 305.2545050421948 + ], + "size": [ + 270, + 82 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 46 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "UNETLoader", + "models": [ + { + "name": "z_image_turbo_bf16.safetensors", + "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors", + "directory": "diffusion_models" + } + ], + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + "z_image_turbo_bf16.safetensors", + "default" + ] + }, + { + "id": 39, + "type": "CLIPLoader", + "pos": [ + 130.2638101844517, + 435.2545050421948 + ], + "size": [ + 270, + 106 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 44 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "CLIPLoader", + "models": [ + { + "name": "qwen_3_4b.safetensors", + "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors", + "directory": "text_encoders" + } + ], + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + "Qwen_3-4B-abliterated.safetensors", + "lumina2", + "default" + ] + }, + { + "id": 40, + "type": "VAELoader", + "pos": [ + 130.2638101844517, + 585.2545050421948 + ], + "size": [ + 270, + 58 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 39 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "VAELoader", + "models": [ + { + "name": "ae.safetensors", + "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors", + "directory": "vae" + } + ], + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + "FLUX1\\ae.safetensors" + ] + }, + { + "id": 9, + "type": "SaveImage", + "pos": [ + 1240, + 260 + ], + "size": [ + 389.6898473104816, + 590.9817412927072 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 45 + } + ], + "outputs": [], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "SaveImage", + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + "z-image" + ] + }, + { + "id": 41, + "type": "EmptySD3LatentImage", + "pos": [ + 137.14793402245462, + 741.155182617626 + ], + "size": [ + 260, + 110 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "slot_index": 0, + "links": [ + 43 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "EmptySD3LatentImage", + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + 2048, + 3072, + 1 + ] + }, + { + "id": 47, + "type": "ModelSamplingAuraFlow", + "pos": [ + 900.2638101844517, + 265.2545050421948 + ], + "size": [ + 310, + 60 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 47 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "slot_index": 0, + "links": [ + 40 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "ModelSamplingAuraFlow", + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + 3 + ] + }, + { + "id": 45, + "type": "CLIPTextEncode", + "pos": [ + 450.2638101844517, + 305.2545050421948 + ], + "size": [ + 407.0496612122845, + 203.79758162535745 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 44 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 36, + 41 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "CLIPTextEncode", + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [ + "Latina female with thick wavy hair, harbor boats and pastel houses behind. Breezy seaside light, warm tones, cinematic close-up." + ], + "color": "#232", + "bgcolor": "#353" + }, + { + "id": 42, + "type": "ConditioningZeroOut", + "pos": [ + 648.4624550335893, + 557.0851941424083 + ], + "size": [ + 197.712890625, + 26 + ], + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 36 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 42 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "ConditioningZeroOut", + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [] + }, + { + "id": 49, + "type": "MarkdownNote", + "pos": [ + -383.3530660677599, + 714.9195027807318 + ], + "size": [ + 481.8661447405933, + 328.1129300408643 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [], + "outputs": [], + "title": "DyPE", + "properties": { + "ue_properties": { + "widget_ue_connectable": {}, + "version": "7.1", + "input_ue_unconnectable": {} + } + }, + "widgets_values": [ + "### Node Inputs\n\n* **`method`**: The extrapolation strategy.\n * **`vision_yarn`:** Best for 4K+ and non-square images. Fixes stretching and artifacts automatically.\n * **`yarn` / `ntk`**: Legacy methods. Use `yarn` with `yarn_alt_scaling` for manual control.\n* **`enable_dype`**: Toggles the dynamic (time-aware) algorithm. Keep enabled for best quality.\n* **`dype_scale`**: Magnitude of the frequency shift. Default: `2.0`.\n* **`dype_exponent`**: Aggressiveness of the dynamic schedule.\n * `2.0`: Best for **4K**.\n * `1.0`: Best for **2K**.\n * `0.5`: Best for **~1.5K**.\n* **`yarn_alt_scaling`**: *(Legacy `yarn` only)* Switch between Anisotropic (Sharper, potential stretching) and Isotropic (Softer, accurate geometry).\n* **`base_shift` / `max_shift`**: Advanced noise schedule controls. Defaults (`0.5` / `1.15`) are optimized for FLUX.\n\n\n## Join\n\n### [TokenDiffusion](https://t.me/TokenDiff) - AI for every home, creativity for every mind!\n\n### [TokenDiff Community Hub](https://t.me/TokenDiff_hub) - Questions, help, and thoughtful discussion. " + ], + "color": "#322", + "bgcolor": "#533" + }, + { + "id": 43, + "type": "VAEDecode", + "pos": [ + 1006.7658843684605, + 897.0716563819947 + ], + "size": [ + 210, + 46 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 38 + }, + { + "name": "vae", + "type": "VAE", + "link": 39 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "slot_index": 0, + "links": [ + 45 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "VAEDecode", + "enableTabs": false, + "tabWidth": 65, + "tabXOffset": 10, + "hasSecondTab": false, + "secondTabText": "Send Back", + "secondTabOffset": 80, + "secondTabWidth": 65 + }, + "widgets_values": [] + }, + { + "id": 48, + "type": "DyPE_FLUX", + "pos": [ + 487.4055670233965, + 706.7051732742663 + ], + "size": [ + 342.5881686404282, + 326.75987991084776 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 46 + } + ], + "outputs": [ + { + "name": "Patched Model", + "type": "MODEL", + "links": [ + 47 + ] + } + ], + "properties": { + "cnr_id": "ComfyUI-DyPE", + "ver": "7dbea3b189ef5d75ead0f7bb5c237eee180ac19a", + "Node name for S&R": "DyPE_FLUX" + }, + "widgets_values": [ + 2048, + 2048, + "auto", + "ntk", + true, + true, + 1024, + 1, + 2, + 2, + 0.5, + 1.15 + ], + "color": "#332922", + "bgcolor": "#593930" + } + ], + "links": [ + [ + 36, + 45, + 0, + 42, + 0, + "CONDITIONING" + ], + [ + 38, + 44, + 0, + 43, + 0, + "LATENT" + ], + [ + 39, + 40, + 0, + 43, + 1, + "VAE" + ], + [ + 40, + 47, + 0, + 44, + 0, + "MODEL" + ], + [ + 41, + 45, + 0, + 44, + 1, + "CONDITIONING" + ], + [ + 42, + 42, + 0, + 44, + 2, + "CONDITIONING" + ], + [ + 43, + 41, + 0, + 44, + 3, + "LATENT" + ], + [ + 44, + 39, + 0, + 45, + 0, + "CLIP" + ], + [ + 45, + 43, + 0, + 9, + 0, + "IMAGE" + ], + [ + 46, + 46, + 0, + 48, + 0, + "MODEL" + ], + [ + 47, + 48, + 0, + 47, + 0, + "MODEL" + ] + ], + "groups": [ + { + "id": 2, + "title": "Step2 - Image size", + "bounding": [ + 120.2638101844517, + 665.2545050421948, + 290, + 200 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 3, + "title": "Step3 - Prompt", + "bounding": [ + 430.2638101844517, + 235.2545050421948, + 450, + 363.9631189996389 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 4, + "title": "Step1 - Load models", + "bounding": [ + 120.2638101844517, + 235.2545050421948, + 290, + 413.6 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 5, + "title": "Model Patching", + "bounding": [ + 436.9001236929169, + 612.8672841136505, + 443.02476814406367, + 434.8989256271392 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + } + ], + "config": {}, + "extra": { + "ds": { + "scale": 0.8403573356722741, + "offset": [ + 481.6702101560164, + -113.59117086923438 + ] + }, + "frontendVersion": "1.33.10", + "VHS_latentpreview": false, + "VHS_latentpreviewrate": 0, + "VHS_MetadataImage": true, + "VHS_KeepIntermediate": true, + "workflowRendererVersion": "LG" + }, + "version": 0.4 +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 592955a..0c1d70c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "ComfyUI-DyPE" -description = "DyPE. Artifact-free 4K+ image generation. Qwen, Flux, Nunchacku" -version = "2.1.0" +description = "DyPE. Artifact-free 4K+ image generation. Qwen, Flux, Nunchacku, Z-Image" +version = "2.2.0" license = {file = "LICENSE"} dependencies = ["torch"] diff --git a/src/base.py b/src/base.py index 8fa9443..72b68eb 100644 --- a/src/base.py +++ b/src/base.py @@ -9,7 +9,7 @@ class DyPEBasePosEmbed(nn.Module): Handles the calculation of DyPE scaling factors and raw (cos, sin) components. Subclasses must implement `forward` to format the output for specific model architectures. """ - def __init__(self, theta: int, axes_dim: list[int], method: str = 'yarn', yarn_alt_scaling: bool = False, dype: bool = True, dype_scale: float = 2.0, dype_exponent: float = 2.0, base_resolution: int = 1024, dype_start_sigma: float = 1.0): + def __init__(self, theta: int, axes_dim: list[int], method: str = 'yarn', yarn_alt_scaling: bool = False, dype: bool = True, dype_scale: float = 2.0, dype_exponent: float = 2.0, base_resolution: int = 1024, dype_start_sigma: float = 1.0, base_patches: int | None = None): super().__init__() self.theta = theta self.axes_dim = axes_dim @@ -19,14 +19,18 @@ def __init__(self, theta: int, axes_dim: list[int], method: str = 'yarn', yarn_a self.dype_scale = dype_scale self.dype_exponent = dype_exponent self.base_resolution = base_resolution - self.dype_start_sigma = max(0.001, min(1.0, dype_start_sigma)) # Clamp 0.001-1.0 + + is_z_image = self.__class__.__name__.lower().endswith("zimage") + if is_z_image: + self.dype_start_sigma = 1.0 + else: + self.dype_start_sigma = max(0.001, min(1.0, dype_start_sigma)) # Clamp 0.001-1.0 self.current_timestep = 1.0 - # Dynamic Base Patches: (Resolution // 8) // 2 - # Flux (1024) -> 128 -> 64 - # Qwen (1328) -> 166 -> 83 - self.base_patches = (self.base_resolution // 8) // 2 + # Dynamic Base Patches: configurable per-model to align with native patch grids. + # Flux/Qwen default: (Resolution // 8) // 2 + self.base_patches = base_patches if base_patches is not None else (self.base_resolution // 8) // 2 def set_timestep(self, timestep: float): self.current_timestep = timestep @@ -61,6 +65,9 @@ def _calc_vision_yarn_components(self, pos: torch.Tensor, freqs_dtype: torch.dty t_factor = math.pow(t_norm, self.dype_exponent) current_mscale = mscale_end + (mscale_start - mscale_end) * t_factor + # Low Theta Heuristic (Z-Image / Lumina) + force_isotropic = self.theta < 1000.0 + for i in range(n_axes): axis_pos = pos[..., i] axis_dim = self.axes_dim[i] @@ -86,7 +93,12 @@ def _calc_vision_yarn_components(self, pos: torch.Tensor, freqs_dtype: torch.dty if i > 0: scale_local = max(1.0, current_patches / self.base_patches) - dype_kwargs['linear_scale'] = scale_local + + # Apply Low Theta protection + if force_isotropic: + dype_kwargs['linear_scale'] = 1.0 + else: + dype_kwargs['linear_scale'] = scale_local if scale_global > 1.0: cos, sin = get_1d_dype_yarn_pos_embed( @@ -120,7 +132,10 @@ def _calc_yarn_components(self, pos: torch.Tensor, freqs_dtype: torch.dtype): needs_extrapolation = (max_current_patches > self.base_patches) - if needs_extrapolation and self.yarn_alt_scaling: + force_isotropic = self.theta < 1000.0 + use_anisotropic = self.yarn_alt_scaling and not force_isotropic + + if needs_extrapolation and use_anisotropic: for i in range(n_axes): axis_pos = pos[..., i] axis_dim = self.axes_dim[i] @@ -208,7 +223,7 @@ def get_components(self, pos: torch.Tensor, freqs_dtype: torch.dtype): return self._calc_vision_yarn_components(pos, freqs_dtype) elif self.method == 'yarn': return self._calc_yarn_components(pos, freqs_dtype) - else: + else: # 'ntk' or 'base' return self._calc_ntk_components(pos, freqs_dtype) def forward(self, ids: torch.Tensor) -> torch.Tensor: diff --git a/src/models/nunchaku.py b/src/models/nunchaku.py index a5306f5..8c6e980 100644 --- a/src/models/nunchaku.py +++ b/src/models/nunchaku.py @@ -60,7 +60,9 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: rope_i = self._axis_rope_from_cos_sin(cos, sin) emb_parts.append(rope_i) - emb = torch.cat(emb_parts, dim=-3) + # Concatenate along axis dimension: dim = -3 + # shape: (B, M, D_total//2, 1, 2) + emb = torch.cat(emb_parts, dim=-3) out = emb.unsqueeze(1).to(ids.device) return out \ No newline at end of file diff --git a/src/models/z_image.py b/src/models/z_image.py new file mode 100644 index 0000000..5dbde9c --- /dev/null +++ b/src/models/z_image.py @@ -0,0 +1,26 @@ +import torch +from ..base import DyPEBasePosEmbed + +class PosEmbedZImage(DyPEBasePosEmbed): + """ + DyPE Implementation for Z-Image / NextDiT models. + + Output Format matches `EmbedND`: (B, 1, L, D/2, 2, 2) + """ + def forward(self, ids: torch.Tensor) -> torch.Tensor: + pos = ids.float() + freqs_dtype = torch.bfloat16 if pos.device.type == 'cuda' else torch.float32 + + components = self.get_components(pos, freqs_dtype) + + emb_parts = [] + for cos, sin in components: + cos_reshaped = cos.view(*cos.shape[:-1], -1, 2)[..., :1] + sin_reshaped = sin.view(*sin.shape[:-1], -1, 2)[..., :1] + row1 = torch.cat([cos_reshaped, -sin_reshaped], dim=-1) + row2 = torch.cat([sin_reshaped, cos_reshaped], dim=-1) + matrix = torch.stack([row1, row2], dim=-2) + emb_parts.append(matrix) + + emb = torch.cat(emb_parts, dim=-3) + return emb.unsqueeze(1).to(ids.device) diff --git a/src/patch_utils.py b/src/patch_utils.py index 25259cb..a5bf647 100644 --- a/src/patch_utils.py +++ b/src/patch_utils.py @@ -1,87 +1,113 @@ import math +import types +import torch +import torch.nn.functional as F +import comfy from comfy.model_patcher import ModelPatcher from comfy import model_sampling from .models.flux import PosEmbedFlux from .models.nunchaku import PosEmbedNunchaku from .models.qwen import PosEmbedQwen +from .models.z_image import PosEmbedZImage + def apply_dype_to_model(model: ModelPatcher, model_type: str, width: int, height: int, method: str, yarn_alt_scaling: bool, enable_dype: bool, dype_scale: float, dype_exponent: float, base_shift: float, max_shift: float, base_resolution: int = 1024, dype_start_sigma: float = 1.0) -> ModelPatcher: m = model.clone() - + is_nunchaku = False is_qwen = False - - if model_type == "nunchaku": + is_zimage = False + + normalized_model_type = model_type.replace("_", "").lower() + + if normalized_model_type == "nunchaku": is_nunchaku = True - elif model_type == "qwen": + elif normalized_model_type == "qwen": is_qwen = True + elif normalized_model_type == "zimage": + is_zimage = True elif model_type == "flux": - pass + pass # defaults false else: # auto if hasattr(m.model, "diffusion_model"): dm = m.model.diffusion_model model_class_name = dm.__class__.__name__ - # ToDo: add normal logging if "QwenImage" in model_class_name: is_qwen = True - # print("[DyPE] Auto-detected Qwen Image model.") + elif "NextDiT" in model_class_name or hasattr(dm, "rope_embedder"): + is_zimage = True elif hasattr(dm, "model") and hasattr(dm.model, "pos_embed"): is_nunchaku = True - # print("[DyPE] Auto-detected Nunchaku Flux model.") elif hasattr(dm, "pe_embedder"): - # print("[DyPE] Auto-detected Standard Flux model.") pass else: - # print("[DyPE] Warning: Could not auto-detect model type. Assuming Standard Flux.") pass else: - raise ValueError("The provided model is not a compatible model.") + raise ValueError("The provided model is not a compatible model.") + + new_dype_params = (width, height, base_shift, max_shift, method, yarn_alt_scaling, base_resolution, dype_start_sigma, is_nunchaku, is_qwen, is_zimage) - new_dype_params = (width, height, base_shift, max_shift, method, yarn_alt_scaling, base_resolution, dype_start_sigma, is_nunchaku, is_qwen) - should_patch_schedule = True if hasattr(m.model, "_dype_params"): if m.model._dype_params == new_dype_params: should_patch_schedule = False - - if enable_dype and should_patch_schedule: - patch_size = 2 # Default Flux/Qwen - try: - if is_nunchaku: - patch_size = m.model.diffusion_model.model.config.patch_size - else: - patch_size = m.model.diffusion_model.patch_size - except: + else: pass + base_patch_h_tokens = None + base_patch_w_tokens = None + default_base_patches = (base_resolution // 8) // 2 + default_base_seq_len = default_base_patches * default_base_patches + + if is_zimage: + axes_lens = getattr(m.model.diffusion_model, "axes_lens", None) + if isinstance(axes_lens, (list, tuple)) and len(axes_lens) >= 3: + base_patch_h_tokens = int(axes_lens[1]) + base_patch_w_tokens = int(axes_lens[2]) + + patch_size = 2 # Default Flux/Qwen + try: + if is_nunchaku: + patch_size = m.model.diffusion_model.model.config.patch_size + else: + patch_size = m.model.diffusion_model.patch_size + except: + pass + + if base_patch_h_tokens is not None and base_patch_w_tokens is not None: + derived_base_patches = max(base_patch_h_tokens, base_patch_w_tokens) + derived_base_seq_len = base_patch_h_tokens * base_patch_w_tokens + else: + derived_base_patches = default_base_patches + derived_base_seq_len = default_base_seq_len + + if enable_dype and should_patch_schedule: try: - if isinstance(m.model.model_sampling, model_sampling.ModelSamplingFlux) or is_qwen: + if isinstance(m.model.model_sampling, model_sampling.ModelSamplingFlux) or is_qwen or is_zimage: latent_h, latent_w = height // 8, width // 8 padded_h, padded_w = math.ceil(latent_h / patch_size) * patch_size, math.ceil(latent_w / patch_size) * patch_size image_seq_len = (padded_h // patch_size) * (padded_w // patch_size) - - base_patches = (base_resolution // 8) // 2 - base_seq_len = base_patches * base_patches + + base_seq_len = derived_base_seq_len max_seq_len = image_seq_len if max_seq_len <= base_seq_len: - dype_shift = base_shift + dype_shift = base_shift else: slope = (max_shift - base_shift) / (max_seq_len - base_seq_len) intercept = base_shift - slope * base_seq_len dype_shift = image_seq_len * slope + intercept - + dype_shift = max(0.0, dype_shift) - # print(f"[DyPE DEBUG] Calculated dype_shift (mu): {dype_shift:.4f} for resolution {width}x{height} (Base: {base_resolution})") class DypeModelSamplingFlux(model_sampling.ModelSamplingFlux, model_sampling.CONST): pass new_model_sampler = DypeModelSamplingFlux(m.model.model_config) new_model_sampler.set_parameters(shift=dype_shift) - + m.add_object_patch("model_sampling", new_model_sampler) m.model._dype_params = new_dype_params except: @@ -98,41 +124,173 @@ class DefaultModelSamplingFlux(model_sampling.ModelSamplingFlux, model_sampling. if is_nunchaku: orig_embedder = m.model.diffusion_model.model.pos_embed target_patch_path = "diffusion_model.model.pos_embed" + elif is_zimage: + orig_embedder = m.model.diffusion_model.rope_embedder + target_patch_path = "diffusion_model.rope_embedder" else: orig_embedder = m.model.diffusion_model.pe_embedder target_patch_path = "diffusion_model.pe_embedder" - + theta, axes_dim = orig_embedder.theta, orig_embedder.axes_dim except AttributeError: - raise ValueError("The provided model is not a compatible FLUX/Qwen model structure.") + raise ValueError("The provided model is not a compatible FLUX/Qwen/Z-Image model structure.") embedder_cls = PosEmbedFlux if is_nunchaku: embedder_cls = PosEmbedNunchaku elif is_qwen: embedder_cls = PosEmbedQwen + elif is_zimage: + embedder_cls = PosEmbedZImage + + embedder_base_patches = derived_base_patches if is_zimage else None new_pe_embedder = embedder_cls( - theta, axes_dim, method, yarn_alt_scaling, enable_dype, - dype_scale, dype_exponent, base_resolution, dype_start_sigma + theta, axes_dim, method, yarn_alt_scaling, enable_dype, + dype_scale, dype_exponent, base_resolution, dype_start_sigma, embedder_base_patches ) - + m.add_object_patch(target_patch_path, new_pe_embedder) - + + if is_zimage: + original_patchify_and_embed = getattr(m.model.diffusion_model, "patchify_and_embed", None) + if original_patchify_and_embed is not None: + m.model.diffusion_model._dype_original_patchify_and_embed = original_patchify_and_embed + + base_hw_override = None + if base_patch_h_tokens is not None and base_patch_w_tokens is not None: + base_hw_override = (base_patch_h_tokens, base_patch_w_tokens) + elif derived_base_patches is not None: + base_hw_override = (derived_base_patches, derived_base_patches) + + if base_hw_override is not None: + m.model.diffusion_model._dype_base_hw = base_hw_override + + def dype_patchify_and_embed(self, x, cap_feats, cap_mask, t, num_tokens, transformer_options={}): + pH = pW = self.patch_size + x_tensor = x if isinstance(x, torch.Tensor) else torch.stack(list(x), dim=0) + device = x_tensor.device + + B, C, H, W = x_tensor.shape + if (H % pH != 0) or (W % pW != 0): + x_tensor = comfy.ldm.common_dit.pad_to_patch_size(x_tensor, (pH, pW)) + B, C, H, W = x_tensor.shape + + bsz = x_tensor.shape[0] + + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + if pad_extra > 0: + cap_pad = self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1) + cap_feats = torch.cat((cap_feats, cap_pad), dim=1) + + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + + x_emb = self.x_embedder(x_tensor.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + + rope_options = transformer_options.get("rope_options", None) + base_hw = getattr(self, "_dype_base_hw", None) + default_h_scale = 1.0 + default_w_scale = 1.0 + + rope_embedder = getattr(self, "rope_embedder", None) + dype_blend_factor = None + if rope_embedder is not None: + dype_start_sigma = getattr(rope_embedder, "dype_start_sigma", None) + dype_exponent = getattr(rope_embedder, "dype_exponent", None) + current_timestep = getattr(rope_embedder, "current_timestep", None) + + if all(value is not None for value in (dype_start_sigma, dype_exponent, current_timestep)) and dype_start_sigma > 0: + if current_timestep > dype_start_sigma: + t_norm = 1.0 + else: + t_norm = current_timestep / dype_start_sigma + + dype_blend_factor = math.pow(t_norm, dype_exponent) + + H_tokens, W_tokens = H // pH, W // pW + if base_hw is not None and len(base_hw) == 2 and base_hw[0] > 0 and base_hw[1] > 0: + default_h_scale = H_tokens / base_hw[0] + default_w_scale = W_tokens / base_hw[1] + + def _blend_scale(default_scale: float) -> float: + if dype_blend_factor is None: + return default_scale + return 1.0 + (default_scale - 1.0) * dype_blend_factor + + h_scale = rope_options.get("scale_y", _blend_scale(default_h_scale)) if rope_options is not None else _blend_scale(default_h_scale) + w_scale = rope_options.get("scale_x", _blend_scale(default_w_scale)) if rope_options is not None else _blend_scale(default_w_scale) + + h_start = rope_options.get("shift_y", 0.0) if rope_options is not None else 0.0 + w_start = rope_options.get("shift_x", 0.0) if rope_options is not None else 0.0 + + x_pos_ids = torch.zeros((bsz, x_emb.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() + + if self.pad_tokens_multiple is not None: + pad_extra = (-x_emb.shape[1]) % self.pad_tokens_multiple + if pad_extra > 0: + pad_token = self.x_pad_token.to(device=x_emb.device, dtype=x_emb.dtype, copy=True).unsqueeze(0).repeat(x_emb.shape[0], pad_extra, 1) + x_emb = torch.cat((x_emb, pad_token), dim=1) + x_pos_ids = F.pad(x_pos_ids, (0, 0, 0, pad_extra)) + + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) + + padded_img_mask = None + for layer in self.noise_refiner: + x_emb = layer(x_emb, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + + padded_full_embed = torch.cat((cap_feats, x_emb), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz + return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + + m.add_object_patch("diffusion_model.patchify_and_embed", types.MethodType(dype_patchify_and_embed, m.model.diffusion_model)) + + if original_patchify_and_embed is not None: + m.model._dype_zimage_override_active = True + m.model._dype_zimage_step_count = 0 + sigma_max = m.model.model_sampling.sigma_max.item() - + def dype_wrapper_function(model_function, args_dict): + current_sigma = None timestep_tensor = args_dict.get("timestep") if timestep_tensor is not None and timestep_tensor.numel() > 0: current_sigma = timestep_tensor.flatten()[0].item() - + if sigma_max > 0: normalized_timestep = min(max(current_sigma / sigma_max, 0.0), 1.0) new_pe_embedder.set_timestep(normalized_timestep) - + input_x, c = args_dict.get("input"), args_dict.get("c", {}) - return model_function(input_x, args_dict.get("timestep"), **c) + output = model_function(input_x, args_dict.get("timestep"), **c) + + if getattr(m.model, "_dype_zimage_override_active", False): + current_step = getattr(m.model, "_dype_zimage_step_count", 0) + 1 + m.model._dype_zimage_step_count = current_step + + if current_sigma is not None and current_sigma <= dype_start_sigma: + original_fn = getattr(m.model.diffusion_model, "_dype_original_patchify_and_embed", None) + if original_fn is not None: + m.model.diffusion_model.patchify_and_embed = original_fn + + if hasattr(m.model.diffusion_model, "_dype_base_hw"): + delattr(m.model.diffusion_model, "_dype_base_hw") + + new_pe_embedder.base_patches = default_base_patches + + m.model._dype_zimage_override_active = False + + return output m.set_model_unet_function_wrapper(dype_wrapper_function) - - return m \ No newline at end of file + + return m