diff --git a/mmseg/apis/mmseg_inferencer.py b/mmseg/apis/mmseg_inferencer.py index 02a198b516..ed09d3f0f7 100644 --- a/mmseg/apis/mmseg_inferencer.py +++ b/mmseg/apis/mmseg_inferencer.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from mmcv.transforms import Compose +from mmengine.device.utils import is_musa_available from mmengine.infer.infer import BaseInferencer, ModelType from mmengine.model import revert_sync_batchnorm from mmengine.registry import init_default_scope @@ -81,7 +82,8 @@ def __init__(self, super().__init__( model=model, weights=weights, device=device, scope=scope) - if device == 'cpu' or not torch.cuda.is_available(): + if device == 'cpu' or (not torch.cuda.is_available() + and not is_musa_available()): self.model = revert_sync_batchnorm(self.model) assert isinstance(self.visualizer, SegLocalVisualizer) diff --git a/mmseg/models/assigners/hungarian_assigner.py b/mmseg/models/assigners/hungarian_assigner.py index 28868f0a04..80c48fd8c4 100644 --- a/mmseg/models/assigners/hungarian_assigner.py +++ b/mmseg/models/assigners/hungarian_assigner.py @@ -3,9 +3,14 @@ import torch from mmengine import ConfigDict +from mmengine.device.utils import is_musa_available from mmengine.structures import InstanceData from scipy.optimize import linear_sum_assignment -from torch.cuda.amp import autocast + +if is_musa_available(): + from torch_musa.core.amp import autocast +else: + from torch.cuda.amp import autocast from mmseg.registry import TASK_UTILS from .base_assigner import BaseAssigner diff --git a/mmseg/models/losses/focal_loss.py b/mmseg/models/losses/focal_loss.py index 6507ed7a91..ca57cd7b06 100644 --- a/mmseg/models/losses/focal_loss.py +++ b/mmseg/models/losses/focal_loss.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss +from mmengine.device.utils import is_musa_available from mmseg.registry import MODELS from .utils import weight_reduce_loss @@ -269,7 +270,9 @@ def forward(self, reduction_override if reduction_override else self.reduction) if self.use_sigmoid: num_classes = pred.size(1) - if torch.cuda.is_available() and pred.is_cuda: + if (torch.cuda.is_available() + and pred.is_cuda) or (is_musa_available() + and pred.device == 'musa'): if target.dim() == 1: one_hot_target = F.one_hot( target, num_classes=num_classes + 1) diff --git a/tests/test_models/test_backbones/test_clip_text_encoder.py b/tests/test_models/test_backbones/test_clip_text_encoder.py index ea06c5b5b3..3b62df5a83 100644 --- a/tests/test_models/test_backbones/test_clip_text_encoder.py +++ b/tests/test_models/test_backbones/test_clip_text_encoder.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch from mmengine import Config +from mmengine.device.utils import is_musa_available from mmengine.registry import init_default_scope from mmseg.models.text_encoder import CLIPTextEncoder @@ -23,6 +24,8 @@ def test_clip_text_encoder(): text_encoder = CLIPTextEncoder(**cfg) if torch.cuda.is_available(): text_encoder = text_encoder.cuda() + elif is_musa_available(): + text_encoder = text_encoder.musa() with torch.no_grad(): class_embeds = text_encoder() diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index bb3967f8dd..471bf66108 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -8,6 +8,7 @@ import pytest import torch import torch.nn as nn +from mmengine.device.utils import is_musa_available from mmengine.model.utils import revert_sync_batchnorm from mmengine.registry import init_default_scope from mmengine.structures import PixelData @@ -190,6 +191,8 @@ def _test_encoder_decoder_forward(cfg_file): # convert to cuda Tensor if applicable if torch.cuda.is_available(): segmentor = segmentor.cuda() + elif is_musa_available(): + segmentor = segmentor.musa() else: segmentor = revert_sync_batchnorm(segmentor) diff --git a/tests/test_models/test_heads/test_ann_head.py b/tests/test_models/test_heads/test_ann_head.py index c1e44bc685..fb40e8d9ae 100644 --- a/tests/test_models/test_heads/test_ann_head.py +++ b/tests/test_models/test_heads/test_ann_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import ANNHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_ann_head(): @@ -16,5 +17,7 @@ def test_ann_head(): project_channels=8) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 21, 21) diff --git a/tests/test_models/test_heads/test_apc_head.py b/tests/test_models/test_heads/test_apc_head.py index dc55ccc1d5..2ac3702a5a 100644 --- a/tests/test_models/test_heads/test_apc_head.py +++ b/tests/test_models/test_heads/test_apc_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import APCHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa def test_apc_head(): @@ -34,6 +35,8 @@ def test_apc_head(): fusion=True) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.fusion is True assert head.acm_modules[0].pool_scale == 1 assert head.acm_modules[1].pool_scale == 2 @@ -51,6 +54,8 @@ def test_apc_head(): fusion=False) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.fusion is False assert head.acm_modules[0].pool_scale == 1 assert head.acm_modules[1].pool_scale == 2 diff --git a/tests/test_models/test_heads/test_aspp_head.py b/tests/test_models/test_heads/test_aspp_head.py index db9e89324f..8626f67f32 100644 --- a/tests/test_models/test_heads/test_aspp_head.py +++ b/tests/test_models/test_heads/test_aspp_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import ASPPHead, DepthwiseSeparableASPPHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa def test_aspp_head(): @@ -29,6 +30,8 @@ def test_aspp_head(): in_channels=8, channels=4, num_classes=19, dilations=(1, 12, 24)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.aspp_modules[0].conv.dilation == (1, 1) assert head.aspp_modules[1].conv.dilation == (12, 12) assert head.aspp_modules[2].conv.dilation == (24, 24) @@ -49,6 +52,8 @@ def test_dw_aspp_head(): dilations=(1, 12, 24)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.c1_bottleneck is None assert head.aspp_modules[0].conv.dilation == (1, 1) assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) @@ -67,6 +72,8 @@ def test_dw_aspp_head(): dilations=(1, 12, 24)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.c1_bottleneck.in_channels == 4 assert head.c1_bottleneck.out_channels == 2 assert head.aspp_modules[0].conv.dilation == (1, 1) diff --git a/tests/test_models/test_heads/test_decode_head.py b/tests/test_models/test_heads/test_decode_head.py index 88e6bed10f..5a2abcb270 100644 --- a/tests/test_models/test_heads/test_decode_head.py +++ b/tests/test_models/test_heads/test_decode_head.py @@ -3,11 +3,12 @@ import pytest import torch +from mmengine.device.utils import is_musa_available from mmengine.structures import PixelData from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.structures import SegDataSample -from .utils import to_cuda +from .utils import to_cuda, to_musa @patch.multiple(BaseDecodeHead, __abstractmethods__=set()) @@ -70,6 +71,8 @@ def test_decode_head(): head = BaseDecodeHead(32, 16, num_classes=19) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.in_channels == 32 assert head.input_transform is None transformed_inputs = head._transform_inputs(inputs) @@ -84,6 +87,8 @@ def test_decode_head(): input_transform='resize_concat') if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.in_channels == 48 assert head.input_transform == 'resize_concat' transformed_inputs = head._transform_inputs(inputs) @@ -108,6 +113,8 @@ def test_decode_head(): type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) loss = head.loss_by_feat( seg_logits=inputs, batch_data_samples=data_samples) assert 'loss_ce' in loss @@ -128,6 +135,8 @@ def test_decode_head(): ]) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) loss = head.loss_by_feat( seg_logits=inputs, batch_data_samples=data_samples) @@ -155,6 +164,8 @@ def test_decode_head(): dict(type='CrossEntropyLoss', loss_name='loss_3'))) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) loss = head.loss_by_feat( seg_logits=inputs, batch_data_samples=data_samples) assert 'loss_1' in loss @@ -176,6 +187,8 @@ def test_decode_head(): dict(type='CrossEntropyLoss', loss_name='loss_ce'))) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) loss_3 = head.loss_by_feat( seg_logits=inputs, batch_data_samples=data_samples) @@ -186,6 +199,8 @@ def test_decode_head(): loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'))) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) loss = head.loss_by_feat( seg_logits=inputs, batch_data_samples=data_samples) assert 'loss_ce' in loss diff --git a/tests/test_models/test_heads/test_dm_head.py b/tests/test_models/test_heads/test_dm_head.py index a922ff7295..d5279e2fc7 100644 --- a/tests/test_models/test_heads/test_dm_head.py +++ b/tests/test_models/test_heads/test_dm_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import DMHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa def test_dm_head(): @@ -34,6 +35,8 @@ def test_dm_head(): fusion=True) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.fusion is True assert head.dcm_modules[0].filter_size == 1 assert head.dcm_modules[1].filter_size == 3 @@ -51,6 +54,8 @@ def test_dm_head(): fusion=False) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.fusion is False assert head.dcm_modules[0].filter_size == 1 assert head.dcm_modules[1].filter_size == 3 diff --git a/tests/test_models/test_heads/test_dnl_head.py b/tests/test_models/test_heads/test_dnl_head.py index 720cb07fc6..2ce52ad60d 100644 --- a/tests/test_models/test_heads/test_dnl_head.py +++ b/tests/test_models/test_heads/test_dnl_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import DNLHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_dnl_head(): @@ -14,6 +15,8 @@ def test_dnl_head(): inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) @@ -23,6 +26,8 @@ def test_dnl_head(): inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) @@ -31,6 +36,8 @@ def test_dnl_head(): inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) @@ -40,5 +47,7 @@ def test_dnl_head(): inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) diff --git a/tests/test_models/test_heads/test_ema_head.py b/tests/test_models/test_heads/test_ema_head.py index 1811cd2bb2..dd8ae6759a 100644 --- a/tests/test_models/test_heads/test_ema_head.py +++ b/tests/test_models/test_heads/test_ema_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import EMAHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_emanet_head(): @@ -19,5 +20,7 @@ def test_emanet_head(): inputs = [torch.randn(1, 4, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) diff --git a/tests/test_models/test_heads/test_fcn_head.py b/tests/test_models/test_heads/test_fcn_head.py index 664b543e07..854723b04c 100644 --- a/tests/test_models/test_heads/test_fcn_head.py +++ b/tests/test_models/test_heads/test_fcn_head.py @@ -2,10 +2,11 @@ import pytest import torch from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.device.utils import is_musa_available from mmengine.utils.dl_utils.parrots_wrapper import SyncBatchNorm from mmseg.models.decode_heads import DepthwiseSeparableFCNHead, FCNHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_fcn_head(): @@ -36,6 +37,8 @@ def test_fcn_head(): in_channels=8, channels=4, num_classes=19, concat_input=False) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available: + head, inputs = to_musa(head, inputs) assert len(head.convs) == 2 assert not head.concat_input and not hasattr(head, 'conv_cat') outputs = head(inputs) @@ -47,6 +50,8 @@ def test_fcn_head(): in_channels=8, channels=4, num_classes=19, concat_input=True) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert len(head.convs) == 2 assert head.concat_input assert head.conv_cat.in_channels == 12 @@ -58,6 +63,8 @@ def test_fcn_head(): head = FCNHead(in_channels=8, channels=4, num_classes=19) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) for i in range(len(head.convs)): assert head.convs[i].kernel_size == (3, 3) assert head.convs[i].padding == 1 @@ -69,6 +76,8 @@ def test_fcn_head(): head = FCNHead(in_channels=8, channels=4, num_classes=19, kernel_size=1) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) for i in range(len(head.convs)): assert head.convs[i].kernel_size == (1, 1) assert head.convs[i].padding == 0 @@ -80,6 +89,8 @@ def test_fcn_head(): head = FCNHead(in_channels=8, channels=4, num_classes=19, num_convs=1) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert len(head.convs) == 1 outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) @@ -94,6 +105,8 @@ def test_fcn_head(): concat_input=False) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert isinstance(head.convs, torch.nn.Identity) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) diff --git a/tests/test_models/test_heads/test_gc_head.py b/tests/test_models/test_heads/test_gc_head.py index c62ac9ae74..f25aeea0b7 100644 --- a/tests/test_models/test_heads/test_gc_head.py +++ b/tests/test_models/test_heads/test_gc_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import GCHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_gc_head(): @@ -12,5 +13,7 @@ def test_gc_head(): inputs = [torch.randn(1, 4, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) diff --git a/tests/test_models/test_heads/test_ham_head.py b/tests/test_models/test_heads/test_ham_head.py index f802d2d8db..5ac9e5a00a 100644 --- a/tests/test_models/test_heads/test_ham_head.py +++ b/tests/test_models/test_heads/test_ham_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import LightHamHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) @@ -38,6 +39,8 @@ def test_ham_head(): ] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.in_channels == [16, 32, 64] assert head.hamburger.ham_in.in_channels == 64 outputs = head(inputs) diff --git a/tests/test_models/test_heads/test_isa_head.py b/tests/test_models/test_heads/test_isa_head.py index b177f6d23e..816284509b 100644 --- a/tests/test_models/test_heads/test_isa_head.py +++ b/tests/test_models/test_heads/test_isa_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import ISAHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_isa_head(): @@ -16,5 +17,7 @@ def test_isa_head(): down_factor=(8, 8)) if torch.cuda.is_available(): isa_head, inputs = to_cuda(isa_head, inputs) + elif is_musa_available(): + isa_head, inputs = to_musa(isa_head, inputs) output = isa_head(inputs) assert output.shape == (1, isa_head.num_classes, 23, 23) diff --git a/tests/test_models/test_heads/test_mask2former_head.py b/tests/test_models/test_heads/test_mask2former_head.py index 45b353d441..f8f3ff080e 100644 --- a/tests/test_models/test_heads/test_mask2former_head.py +++ b/tests/test_models/test_heads/test_mask2former_head.py @@ -1,12 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch from mmengine import Config +from mmengine.device.utils import is_musa_available from mmengine.structures import PixelData from mmseg.models.decode_heads import Mask2FormerHead from mmseg.structures import SegDataSample from mmseg.utils import SampleList -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_mask2former_head(): @@ -141,6 +142,10 @@ def test_mask2former_head(): head, inputs = to_cuda(head, inputs) for data_sample in data_samples: data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda() + elif is_musa_available(): + head, inputs = to_musa(head, inputs) + for data_sample in data_samples: + data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.musa() loss_dict = head.loss(inputs, data_samples, None) assert isinstance(loss_dict, dict) diff --git a/tests/test_models/test_heads/test_nl_head.py b/tests/test_models/test_heads/test_nl_head.py index d4ef0b9db3..73e98e95c8 100644 --- a/tests/test_models/test_heads/test_nl_head.py +++ b/tests/test_models/test_heads/test_nl_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import NLHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_nl_head(): @@ -12,5 +13,7 @@ def test_nl_head(): inputs = [torch.randn(1, 8, 23, 23)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 23, 23) diff --git a/tests/test_models/test_heads/test_ocr_head.py b/tests/test_models/test_heads/test_ocr_head.py index 5e5d669b14..cdb0b9f6d8 100644 --- a/tests/test_models/test_heads/test_ocr_head.py +++ b/tests/test_models/test_heads/test_ocr_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import FCNHead, OCRHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_ocr_head(): @@ -14,6 +15,9 @@ def test_ocr_head(): if torch.cuda.is_available(): head, inputs = to_cuda(ocr_head, inputs) head, inputs = to_cuda(fcn_head, inputs) + elif is_musa_available(): + head, inputs = to_musa(ocr_head, inputs) + head, inputs = to_musa(fcn_head, inputs) prev_output = fcn_head(inputs) output = ocr_head(inputs, prev_output) assert output.shape == (1, ocr_head.num_classes, 23, 23) diff --git a/tests/test_models/test_heads/test_psa_head.py b/tests/test_models/test_heads/test_psa_head.py index 34f592b026..0b627fe706 100644 --- a/tests/test_models/test_heads/test_psa_head.py +++ b/tests/test_models/test_heads/test_psa_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import PSAHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa def test_psa_head(): @@ -37,6 +38,8 @@ def test_psa_head(): in_channels=4, channels=2, num_classes=19, mask_size=(13, 13)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 13, 13) @@ -50,6 +53,8 @@ def test_psa_head(): shrink_factor=1) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 13, 13) @@ -63,6 +68,8 @@ def test_psa_head(): psa_softmax=True) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 13, 13) @@ -76,6 +83,8 @@ def test_psa_head(): psa_type='collect') if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 13, 13) @@ -90,6 +99,8 @@ def test_psa_head(): psa_type='collect') if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 13, 13) @@ -105,6 +116,8 @@ def test_psa_head(): compact=True) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 13, 13) @@ -118,5 +131,7 @@ def test_psa_head(): psa_type='distribute') if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 13, 13) diff --git a/tests/test_models/test_heads/test_psp_head.py b/tests/test_models/test_heads/test_psp_head.py index fde4087c8e..c5167d5a61 100644 --- a/tests/test_models/test_heads/test_psp_head.py +++ b/tests/test_models/test_heads/test_psp_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import PSPHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa def test_psp_head(): @@ -29,6 +30,8 @@ def test_psp_head(): in_channels=4, channels=2, num_classes=19, pool_scales=(1, 2, 3)) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) assert head.psp_modules[0][0].output_size == 1 assert head.psp_modules[1][0].output_size == 2 assert head.psp_modules[2][0].output_size == 3 diff --git a/tests/test_models/test_heads/test_san_head.py b/tests/test_models/test_heads/test_san_head.py index af85a6e2ca..b808f477ca 100644 --- a/tests/test_models/test_heads/test_san_head.py +++ b/tests/test_models/test_heads/test_san_head.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch from mmengine import Config +from mmengine.device.utils import is_musa_available from mmengine.structures import PixelData from mmseg.models.decode_heads import SideAdapterCLIPHead from mmseg.structures import SegDataSample -from .utils import list_to_cuda +from .utils import list_to_cuda, list_to_musa def test_san_head(): @@ -113,6 +114,11 @@ def test_san_head(): data = list_to_cuda([inputs, clip_feature, class_embed]) for data_sample in data_samples: data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda() + elif is_musa_available(): + head = head.musa() + data = list_to_musa([inputs, clip_feature, class_embed]) + for data_sample in data_samples: + data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.musa() else: data = [inputs, clip_feature, class_embed] diff --git a/tests/test_models/test_heads/test_segmenter_mask_head.py b/tests/test_models/test_heads/test_segmenter_mask_head.py index 7b681ac15c..4642951033 100644 --- a/tests/test_models/test_heads/test_segmenter_mask_head.py +++ b/tests/test_models/test_heads/test_segmenter_mask_head.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import SegmenterMaskTransformerHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa def test_segmenter_mask_transformer_head(): @@ -20,5 +21,7 @@ def test_segmenter_mask_transformer_head(): inputs = [torch.randn(1, 2, 32, 32)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 32, 32) diff --git a/tests/test_models/test_heads/test_setr_mla_head.py b/tests/test_models/test_heads/test_setr_mla_head.py index 301bc0bff4..c8f281dbc1 100644 --- a/tests/test_models/test_heads/test_setr_mla_head.py +++ b/tests/test_models/test_heads/test_setr_mla_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import SETRMLAHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_setr_mla_head(capsys): @@ -47,6 +48,8 @@ def test_setr_mla_head(capsys): ] if torch.cuda.is_available(): head, x = to_cuda(head, x) + elif is_musa_available(): + head, x = to_musa(head, x) out = head(x) assert out.shape == (1, head.num_classes, h * 4, w * 4) @@ -59,5 +62,7 @@ def test_setr_mla_head(capsys): ] if torch.cuda.is_available(): head, x = to_cuda(head, x) + elif is_musa_available: + head, x = to_musa(head, x) out = head(x) assert out.shape == (1, head.num_classes, h * 4, w * 8) diff --git a/tests/test_models/test_heads/test_setr_up_head.py b/tests/test_models/test_heads/test_setr_up_head.py index a05192229c..ab4f287701 100644 --- a/tests/test_models/test_heads/test_setr_up_head.py +++ b/tests/test_models/test_heads/test_setr_up_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import SETRUPHead -from .utils import to_cuda +from .utils import to_cuda, to_musa def test_setr_up_head(capsys): @@ -45,6 +46,8 @@ def test_setr_up_head(capsys): x = [torch.randn(1, 4, h, w)] if torch.cuda.is_available(): head, x = to_cuda(head, x) + elif is_musa_available(): + head, x = to_musa(head, x) out = head(x) assert out.shape == (1, head.num_classes, h * 4, w * 4) @@ -52,5 +55,7 @@ def test_setr_up_head(capsys): x = [torch.randn(1, 4, h, w * 2)] if torch.cuda.is_available(): head, x = to_cuda(head, x) + elif is_musa_available(): + head, x = to_musa(head, x) out = head(x) assert out.shape == (1, head.num_classes, h * 4, w * 8) diff --git a/tests/test_models/test_heads/test_uper_head.py b/tests/test_models/test_heads/test_uper_head.py index 09456a80c4..1afed44c3d 100644 --- a/tests/test_models/test_heads/test_uper_head.py +++ b/tests/test_models/test_heads/test_uper_head.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.decode_heads import UPerHead -from .utils import _conv_has_norm, to_cuda +from .utils import _conv_has_norm, to_cuda, to_musa def test_uper_head(): @@ -31,5 +32,7 @@ def test_uper_head(): in_channels=[4, 2], channels=2, num_classes=19, in_index=[-2, -1]) if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) + elif is_musa_available(): + head, inputs = to_musa(head, inputs) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 45, 45) diff --git a/tests/test_models/test_heads/utils.py b/tests/test_models/test_heads/utils.py index 7282340155..71fe313777 100644 --- a/tests/test_models/test_heads/utils.py +++ b/tests/test_models/test_heads/utils.py @@ -29,3 +29,20 @@ def list_to_cuda(data): return data else: return data.cuda() + + +def to_musa(module, data): + module = module.to('musa') + if isinstance(data, list): + for i in range(len(data)): + data[i] = data[i].to('musa') + return module, data + + +def list_to_musa(data): + if isinstance(data, list): + for i in range(len(data)): + data[i] = list_to_musa(data[i]) + return data + else: + return data.to('musa') diff --git a/tests/test_models/test_necks/test_ic_neck.py b/tests/test_models/test_necks/test_ic_neck.py index 3d13008b5f..3a920b779f 100644 --- a/tests/test_models/test_necks/test_ic_neck.py +++ b/tests/test_models/test_necks/test_ic_neck.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.device.utils import is_musa_available from mmseg.models.necks import ICNeck from mmseg.models.necks.ic_neck import CascadeFeatureFusion -from ..test_heads.utils import _conv_has_norm, to_cuda +from ..test_heads.utils import _conv_has_norm, to_cuda, to_musa def test_ic_neck(): @@ -28,6 +29,8 @@ def test_ic_neck(): align_corners=False) if torch.cuda.is_available(): neck, inputs = to_cuda(neck, inputs) + elif is_musa_available(): + neck, inputs = to_musa(neck, inputs) outputs = neck(inputs) assert outputs[0].shape == (1, 4, 16, 32) diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py index ac31e2b277..e97e2b67ba 100644 --- a/tests/test_models/test_segmentors/utils.py +++ b/tests/test_models/test_segmentors/utils.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmengine.device.utils import is_musa_available from mmengine.optim import OptimWrapper from mmengine.structures import PixelData from torch import nn @@ -114,6 +115,8 @@ def _segmentor_forward_train_test(segmentor): # convert to cuda Tensor if applicable if torch.cuda.is_available(): segmentor = segmentor.cuda() + elif is_musa_available(): + segmentor = segmentor.musa() # check data preprocessor if not hasattr(segmentor, @@ -164,6 +167,8 @@ def _segmentor_predict(segmentor): # convert to cuda Tensor if applicable if torch.cuda.is_available(): segmentor = segmentor.cuda() + elif is_musa_available(): + segmentor = segmentor.musa() # check data preprocessor if not hasattr(segmentor, diff --git a/tests/test_visualization/test_local_visualizer.py b/tests/test_visualization/test_local_visualizer.py index e3b2a88cfb..38afa7ffe4 100644 --- a/tests/test_visualization/test_local_visualizer.py +++ b/tests/test_visualization/test_local_visualizer.py @@ -8,6 +8,7 @@ import mmcv import numpy as np import torch +from mmengine.device.utils import is_musa_available from mmengine.structures import PixelData from mmseg.structures import SegDataSample @@ -73,6 +74,8 @@ def test_add_datasample_forward(gt_sem_seg): if torch.cuda.is_available(): test_add_datasample_forward(gt_sem_seg.cuda()) + elif is_musa_available(): + test_add_datasample_forward(gt_sem_seg.musa()) test_add_datasample_forward(gt_sem_seg) def test_cityscapes_add_datasample(self): @@ -149,6 +152,8 @@ def test_cityscapes_add_datasample_forward(gt_sem_seg): if torch.cuda.is_available(): test_cityscapes_add_datasample_forward(gt_sem_seg.cuda()) + elif is_musa_available(): + test_cityscapes_add_datasample_forward(gt_sem_seg.musa()) test_cityscapes_add_datasample_forward(gt_sem_seg) def _assert_image_and_shape(self, out_file, out_shape): @@ -210,4 +215,6 @@ def test_add_datasample_forward_depth(gt_depth_map): if torch.cuda.is_available(): test_add_datasample_forward_depth(gt_depth_map.cuda()) + elif is_musa_available(): + test_add_datasample_forward_depth(gt_depth_map.musa()) test_add_datasample_forward_depth(gt_depth_map)