Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions examples/hunyuanvideo-i2v/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ Here is the development plan of the project:

| MindSpore | Ascend Driver | Firmware | CANN toolkit/kernel |
|:---------:|:-------------:|:-----------:|:-------------------:|
| 2.5.0 | 24.1.RC2 | 7.5.0.2.220 | 8.0.RC3.beta1 |
| 2.6.0 | 24.1.RC3 | 7.5.T11.0.B088 | 8.1.RC1 |
| 2.7.0 | 24.1.RC3 | 7.5.T11.0.B088 | 8.2.RC1 |

</div>

1. Install
[CANN 8.0.RC3.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC3.beta1)
[8.1.RC1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.1.RC1) or [8.2.RC1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.2.RC1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add CANN prefix

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed,

and MindSpore according to the [official instructions](https://www.mindspore.cn/install).
2. Install requirements
```shell
Expand Down Expand Up @@ -119,6 +120,23 @@ If you want change to another prompt, please set `--prompt` to the new prompt.

To run image-to-video inference with LoRA weight, please refer to `scripts/hyvideo-i2v/run_sample_image2video_lora.sh`.

## Performance

### Inference

The following experiments are tested on Ascend Atlas 800T A2 machines with **mindspore 2.7.0 pynative mode**.

| model | cards | batch size | resolution | num of frames | num of steps | step time (sec) |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
| HYVideo-T/2 | 1 | 1 | 720p | 129 | 50 | 86.02 |

The following experiments are tested on Ascend Atlas 800T A2 machines with **mindspore 2.6.0 pynative mode**.

| model | cards | batch size | resolution | num of frames | num of steps | step time (sec) |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
| HYVideo-T/2 | 1 | 1 | 720p | 129 | 50 | 85.94 |



## Acknowledgements

Expand Down
4 changes: 2 additions & 2 deletions examples/hunyuanvideo-i2v/hyvideo/text_encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def encode(
if model_return_dict:
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
else:
last_hidden_state = outputs[2][-(hidden_state_skip_layer + 1)]
last_hidden_state = outputs[1][-(hidden_state_skip_layer + 1)]
Comment on lines 321 to +322
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The index 1 is used here to access the hidden states from the text encoder's output when model_return_dict is False. It's important to ensure that this index is correct for all text encoder types supported. Add a comment explaining why index 1 is chosen and what it represents in the context of the outputs tuple.

If the structure of the outputs tuple changes in future versions of the transformers library, this index might become invalid, leading to runtime errors. Consider adding a check to validate the length of the outputs tuple and raise an error if it doesn't match the expected length.

Suggested change
else:
last_hidden_state = outputs[2][-(hidden_state_skip_layer + 1)]
last_hidden_state = outputs[1][-(hidden_state_skip_layer + 1)]
last_hidden_state = outputs[1][-(hidden_state_skip_layer + 1)] # outputs[1] represents the hidden states

# last_hidden_state = outputs[0][-(hidden_state_skip_layer + 1)]
# Real last hidden state already has layer norm applied. So here we only apply it
# for intermediate layers.
Expand All @@ -331,7 +331,7 @@ def encode(
outputs_hidden_states = outputs.hidden_states
else:
last_hidden_state = outputs[self.key_idx]
outputs_hidden_states = outputs[2] if len(outputs) >= 3 else None # TODO: double-check if use t5
outputs_hidden_states = outputs[1] if len(outputs) >= 2 else None # TODO: double-check if use t5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, the index 1 is used here to access the hidden states. Add a comment explaining why index 1 is chosen and what it represents in the context of the outputs tuple. Also, consider adding a check to validate the length of the outputs tuple.

Suggested change
outputs_hidden_states = outputs[1] if len(outputs) >= 2 else None # TODO: double-check if use t5
outputs_hidden_states = outputs[1] if len(outputs) >= 2 else None # outputs[1] represents the hidden states


# Remove hidden states of instruction tokens, only keep prompt tokens.
if self.use_template:
Expand Down
78 changes: 78 additions & 0 deletions examples/hunyuanvideo-i2v/hyvideo/utils/modules_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import mindspore as ms
from mindspore import Parameter, Tensor, mint, nn
from mindspore.common.initializer import initializer

from mindone.diffusers.models.layers_compat import group_norm


class LayerNorm(nn.Cell):
Expand Down Expand Up @@ -36,3 +39,78 @@ def construct(self, x: Tensor):
x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype), self.eps
)
return x


class GroupNorm(nn.Cell):
r"""Applies Group Normalization over a mini-batch of inputs.

This layer implements the operation as described in
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__

.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

The input channels are separated into :attr:`num_groups` groups, each containing
``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
:attr:`num_groups`. The mean and standard-deviation are calculated
separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
per-channel affine transform parameter vectors of size :attr:`num_channels` if
:attr:`affine` is ``True``.

This layer uses statistics computed from input data in both training and
evaluation modes.

Args:
num_groups (int): number of groups to separate the channels into
num_channels (int): number of channels expected in input
eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: a boolean value that when set to ``True``, this module
has learnable per-channel affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.

Shape:
- Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
- Output: :math:`(N, C, *)` (same shape as input)

Examples::

>>> input = mint.randn(20, 6, 10, 10)
>>> # Separate 6 channels into 3 groups
>>> m = GroupNorm(3, 6)
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
>>> m = GroupNorm(6, 6)
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
>>> m = GroupNorm(1, 6)
>>> # Activating the module
>>> output = m(input)
"""
Comment on lines +44 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for the GroupNorm class includes example usage with mint.randn. It would be helpful to clarify that mint refers to mindspore.mint to avoid confusion.


num_groups: int
num_channels: int
eps: float
affine: bool

def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, dtype=ms.float32):
super().__init__()
if num_channels % num_groups != 0:
raise ValueError("num_channels must be divisible by num_groups")

self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
weight = initializer("ones", num_channels, dtype=dtype)
bias = initializer("zeros", num_channels, dtype=dtype)
if self.affine:
self.weight = Parameter(weight, name="weight")
self.bias = Parameter(bias, name="bias")
else:
self.weight = None
self.bias = None

def construct(self, x: Tensor):
if self.affine:
x = group_norm(x, self.num_groups, self.weight.to(x.dtype), self.bias.to(x.dtype), self.eps)
else:
x = group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
return x
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
# ==============================================================================
from typing import Optional, Tuple, Union

from hyvideo.utils.modules_utils import LayerNorm
from hyvideo.utils.modules_utils import GroupNorm, LayerNorm

import mindspore as ms
import mindspore.mint.nn.functional as F
from mindspore import mint, nn, ops

from mindone.diffusers.models.activations import get_activation
from mindone.diffusers.models.attention_processor import Attention, SpatialNorm
from mindone.diffusers.models.normalization import AdaGroupNorm, GroupNorm, RMSNorm
from mindone.diffusers.models.normalization import AdaGroupNorm, RMSNorm
from mindone.diffusers.utils import logging

logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -412,7 +412,7 @@ def __init__(
conv_3d_out_channels = conv_3d_out_channels or out_channels
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)

self.nonlinearity = get_activation(non_linearity)()
self.nonlinearity = get_activation(non_linearity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nonlinearity attribute is assigned the get_activation(non_linearity) function itself, rather than an instance of the activation function. This means the activation function is not being instantiated. Instantiate the activation function by calling get_activation(non_linearity)(). This could lead to unexpected behavior during the forward pass.

Suggested change
self.nonlinearity = get_activation(non_linearity)
self.nonlinearity = get_activation(non_linearity)()


self.upsample = self.downsample = None
if self.up:
Expand Down
4 changes: 2 additions & 2 deletions examples/hunyuanvideo-i2v/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ imageio
imageio-ffmpeg
safetensors
mindcv==0.3.0
tokenizers==0.20.3
transformers==4.46.3
tokenizers==0.21.4
transformers==4.50.0
gradio
albumentations>=2.0
ftfy
Expand Down