Skip to content

support moore threads MUSA gpu #3841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmseg/apis/mmseg_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.device.utils import is_musa_available
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
Expand Down Expand Up @@ -81,7 +82,8 @@ def __init__(self,
super().__init__(
model=model, weights=weights, device=device, scope=scope)

if device == 'cpu' or not torch.cuda.is_available():
if device == 'cpu' or (not torch.cuda.is_available()
and not is_musa_available()):
self.model = revert_sync_batchnorm(self.model)

assert isinstance(self.visualizer, SegLocalVisualizer)
Expand Down
7 changes: 6 additions & 1 deletion mmseg/models/assigners/hungarian_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

import torch
from mmengine import ConfigDict
from mmengine.device.utils import is_musa_available
from mmengine.structures import InstanceData
from scipy.optimize import linear_sum_assignment
from torch.cuda.amp import autocast

if is_musa_available():
from torch_musa.core.amp import autocast
else:
from torch.cuda.amp import autocast

from mmseg.registry import TASK_UTILS
from .base_assigner import BaseAssigner
Expand Down
5 changes: 4 additions & 1 deletion mmseg/models/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from mmengine.device.utils import is_musa_available

from mmseg.registry import MODELS
from .utils import weight_reduce_loss
Expand Down Expand Up @@ -269,7 +270,9 @@ def forward(self,
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
num_classes = pred.size(1)
if torch.cuda.is_available() and pred.is_cuda:
if (torch.cuda.is_available()
and pred.is_cuda) or (is_musa_available()
and pred.device == 'musa'):
if target.dim() == 1:
one_hot_target = F.one_hot(
target, num_classes=num_classes + 1)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models/test_backbones/test_clip_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import Config
from mmengine.device.utils import is_musa_available
from mmengine.registry import init_default_scope

from mmseg.models.text_encoder import CLIPTextEncoder
Expand All @@ -23,6 +24,8 @@ def test_clip_text_encoder():
text_encoder = CLIPTextEncoder(**cfg)
if torch.cuda.is_available():
text_encoder = text_encoder.cuda()
elif is_musa_available():
text_encoder = text_encoder.musa()

with torch.no_grad():
class_embeds = text_encoder()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
import torch.nn as nn
from mmengine.device.utils import is_musa_available
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.structures import PixelData
Expand Down Expand Up @@ -190,6 +191,8 @@ def _test_encoder_decoder_forward(cfg_file):
# convert to cuda Tensor if applicable
if torch.cuda.is_available():
segmentor = segmentor.cuda()
elif is_musa_available():
segmentor = segmentor.musa()
else:
segmentor = revert_sync_batchnorm(segmentor)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_models/test_heads/test_ann_head.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.device.utils import is_musa_available

from mmseg.models.decode_heads import ANNHead
from .utils import to_cuda
from .utils import to_cuda, to_musa


def test_ann_head():
Expand All @@ -16,5 +17,7 @@ def test_ann_head():
project_channels=8)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 21, 21)
7 changes: 6 additions & 1 deletion tests/test_models/test_heads/test_apc_head.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.device.utils import is_musa_available

from mmseg.models.decode_heads import APCHead
from .utils import _conv_has_norm, to_cuda
from .utils import _conv_has_norm, to_cuda, to_musa


def test_apc_head():
Expand Down Expand Up @@ -34,6 +35,8 @@ def test_apc_head():
fusion=True)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.fusion is True
assert head.acm_modules[0].pool_scale == 1
assert head.acm_modules[1].pool_scale == 2
Expand All @@ -51,6 +54,8 @@ def test_apc_head():
fusion=False)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.fusion is False
assert head.acm_modules[0].pool_scale == 1
assert head.acm_modules[1].pool_scale == 2
Expand Down
9 changes: 8 additions & 1 deletion tests/test_models/test_heads/test_aspp_head.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.device.utils import is_musa_available

from mmseg.models.decode_heads import ASPPHead, DepthwiseSeparableASPPHead
from .utils import _conv_has_norm, to_cuda
from .utils import _conv_has_norm, to_cuda, to_musa


def test_aspp_head():
Expand All @@ -29,6 +30,8 @@ def test_aspp_head():
in_channels=8, channels=4, num_classes=19, dilations=(1, 12, 24))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.aspp_modules[0].conv.dilation == (1, 1)
assert head.aspp_modules[1].conv.dilation == (12, 12)
assert head.aspp_modules[2].conv.dilation == (24, 24)
Expand All @@ -49,6 +52,8 @@ def test_dw_aspp_head():
dilations=(1, 12, 24))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.c1_bottleneck is None
assert head.aspp_modules[0].conv.dilation == (1, 1)
assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12)
Expand All @@ -67,6 +72,8 @@ def test_dw_aspp_head():
dilations=(1, 12, 24))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.c1_bottleneck.in_channels == 4
assert head.c1_bottleneck.out_channels == 2
assert head.aspp_modules[0].conv.dilation == (1, 1)
Expand Down
17 changes: 16 additions & 1 deletion tests/test_models/test_heads/test_decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import pytest
import torch
from mmengine.device.utils import is_musa_available
from mmengine.structures import PixelData

from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.structures import SegDataSample
from .utils import to_cuda
from .utils import to_cuda, to_musa


@patch.multiple(BaseDecodeHead, __abstractmethods__=set())
Expand Down Expand Up @@ -70,6 +71,8 @@ def test_decode_head():
head = BaseDecodeHead(32, 16, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.in_channels == 32
assert head.input_transform is None
transformed_inputs = head._transform_inputs(inputs)
Expand All @@ -84,6 +87,8 @@ def test_decode_head():
input_transform='resize_concat')
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.in_channels == 48
assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs)
Expand All @@ -108,6 +113,8 @@ def test_decode_head():
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_ce' in loss
Expand All @@ -128,6 +135,8 @@ def test_decode_head():
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)

loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
Expand Down Expand Up @@ -155,6 +164,8 @@ def test_decode_head():
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_1' in loss
Expand All @@ -176,6 +187,8 @@ def test_decode_head():
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
loss_3 = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)

Expand All @@ -186,6 +199,8 @@ def test_decode_head():
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_ce' in loss
Expand Down
7 changes: 6 additions & 1 deletion tests/test_models/test_heads/test_dm_head.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.device.utils import is_musa_available

from mmseg.models.decode_heads import DMHead
from .utils import _conv_has_norm, to_cuda
from .utils import _conv_has_norm, to_cuda, to_musa


def test_dm_head():
Expand Down Expand Up @@ -34,6 +35,8 @@ def test_dm_head():
fusion=True)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.fusion is True
assert head.dcm_modules[0].filter_size == 1
assert head.dcm_modules[1].filter_size == 3
Expand All @@ -51,6 +54,8 @@ def test_dm_head():
fusion=False)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
assert head.fusion is False
assert head.dcm_modules[0].filter_size == 1
assert head.dcm_modules[1].filter_size == 3
Expand Down
11 changes: 10 additions & 1 deletion tests/test_models/test_heads/test_dnl_head.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.device.utils import is_musa_available

from mmseg.models.decode_heads import DNLHead
from .utils import to_cuda
from .utils import to_cuda, to_musa


def test_dnl_head():
Expand All @@ -14,6 +15,8 @@ def test_dnl_head():
inputs = [torch.randn(1, 8, 23, 23)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 23, 23)

Expand All @@ -23,6 +26,8 @@ def test_dnl_head():
inputs = [torch.randn(1, 8, 23, 23)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 23, 23)

Expand All @@ -31,6 +36,8 @@ def test_dnl_head():
inputs = [torch.randn(1, 8, 23, 23)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 23, 23)

Expand All @@ -40,5 +47,7 @@ def test_dnl_head():
inputs = [torch.randn(1, 8, 23, 23)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 23, 23)
5 changes: 4 additions & 1 deletion tests/test_models/test_heads/test_ema_head.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.device.utils import is_musa_available

from mmseg.models.decode_heads import EMAHead
from .utils import to_cuda
from .utils import to_cuda, to_musa


def test_emanet_head():
Expand All @@ -19,5 +20,7 @@ def test_emanet_head():
inputs = [torch.randn(1, 4, 23, 23)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
elif is_musa_available():
head, inputs = to_musa(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 23, 23)
Loading