Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion assets/requirements/macos.compiled
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file was autogenerated by uv via the following command:
# uv pip compile assets/ComfyUI/requirements.txt assets/ComfyUI/manager_requirements.txt --emit-index-annotation --emit-index-url --index-strategy unsafe-best-match --python-platform aarch64-apple-darwin --python-version 3.12 --override assets/override.txt --index-url https://pypi.org/simple -o assets/requirements/macos.compiled
--index-url https://pypi.org/simple
--extra-index-url https://download.pytorch.org/whl/cu129

aiohappyeyeballs==2.4.3
# via aiohttp
Expand Down
79 changes: 67 additions & 12 deletions src/virtualEnvironment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type TorchPackageVersions = Record<TorchPackageName, string | undefined>;

const TORCH_PACKAGE_NAMES: TorchPackageName[] = ['torch', 'torchaudio', 'torchvision'];

const TORCH_MIRROR_HOSTNAME = 'download.pytorch.org';
const TORCH_MIRROR_PATHS = [/^\/whl\/cu\d+$/, /^\/whl\/nightly\/cu\d+$/, /^\/whl\/nightly\/cpu$/];

export function getPipInstallArgs(config: PipInstallConfig): string[] {
const installArgs = ['pip', 'install'];

Expand Down Expand Up @@ -117,6 +120,55 @@ function fixDeviceMirrorMismatch(device: TorchDeviceType, mirror: string | undef
return mirror;
}

export function getTorchInstallConfig({
packages,
torchMirror,
pypiMirror,
fallbackIndexUrls,
upgradePackages,
}: {
packages: string[];
torchMirror: string;
pypiMirror?: string;
fallbackIndexUrls?: string[];
upgradePackages?: boolean;
}): PipInstallConfig {
const prerelease = torchMirror.includes('nightly');

if (!isTorchIndexUrl(torchMirror)) {
return {
packages,
indexUrl: torchMirror,
prerelease,
upgradePackages,
};
}

const primaryIndex = pypiMirror ?? TorchMirrorUrl.Default;
const extraIndexUrls = [torchMirror, ...(fallbackIndexUrls ?? [])].filter(
(url, index, urls) => url !== primaryIndex && urls.indexOf(url) === index
);

return {
packages,
indexUrl: primaryIndex,
extraIndexUrls: extraIndexUrls.length > 0 ? extraIndexUrls : undefined,
prerelease,
upgradePackages,
};
}

function isTorchIndexUrl(mirrorUrl: string): boolean {
try {
const parsed = new URL(mirrorUrl);
if (parsed.hostname !== TORCH_MIRROR_HOSTNAME) return false;
const normalizedPath = parsed.pathname.replace(/\/+$/, '');
return TORCH_MIRROR_PATHS.some((pattern) => pattern.test(normalizedPath));
} catch {
return false;
}
}

/**
* Manages a virtual Python environment using uv.
*
Expand Down Expand Up @@ -639,11 +691,12 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor {
}

const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice);
const config: PipInstallConfig = {
const config = getTorchInstallConfig({
packages: ['torch', 'torchvision', 'torchaudio'],
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
};
torchMirror,
pypiMirror: this.pypiMirror,
fallbackIndexUrls: this.getPypiFallbackIndexUrls(),
});

const installArgs = getPipInstallArgs(config);

Expand All @@ -669,11 +722,12 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor {
}

const torchMirror = this.torchMirror || getDefaultTorchMirror(this.selectedDevice);
const config: PipInstallConfig = {
const config = getTorchInstallConfig({
packages: NVIDIA_TORCH_PACKAGES,
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
};
torchMirror,
pypiMirror: this.pypiMirror,
fallbackIndexUrls: this.getPypiFallbackIndexUrls(),
});

const installArgs = getPipInstallArgs(config);
log.info('Installing recommended NVIDIA PyTorch packages.', { installedVersions });
Expand All @@ -685,12 +739,13 @@ export class VirtualEnvironment implements HasTelemetry, PythonExecutor {
exitCode: pinnedExitCode,
});

const fallbackConfig: PipInstallConfig = {
const fallbackConfig = getTorchInstallConfig({
packages: ['torch', 'torchvision', 'torchaudio'],
indexUrl: torchMirror,
prerelease: torchMirror.includes('nightly'),
torchMirror,
pypiMirror: this.pypiMirror,
fallbackIndexUrls: this.getPypiFallbackIndexUrls(),
upgradePackages: true,
};
});
const fallbackArgs = getPipInstallArgs(fallbackConfig);
const { exitCode: fallbackExitCode } = await this.runUvCommandAsync(fallbackArgs, callbacks);
if (fallbackExitCode !== 0) {
Expand Down
72 changes: 71 additions & 1 deletion tests/unit/virtualEnvironment.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { test as baseTest, describe, expect, vi } from 'vitest';

import { TorchMirrorUrl } from '@/constants';
import type { ITelemetry } from '@/services/telemetry';
import { VirtualEnvironment, getPipInstallArgs } from '@/virtualEnvironment';
import { VirtualEnvironment, getPipInstallArgs, getTorchInstallConfig } from '@/virtualEnvironment';

vi.mock('@sentry/electron/main', () => ({
init: vi.fn(),
Expand Down Expand Up @@ -179,6 +179,76 @@ describe('VirtualEnvironment', () => {
});
});

describe('getTorchInstallConfig', () => {
test('uses PyPI as the primary index with torch mirrors as extra indexes', () => {
const config = getTorchInstallConfig({
packages: ['torch'],
torchMirror: TorchMirrorUrl.Cuda,
pypiMirror: TorchMirrorUrl.Default,
fallbackIndexUrls: ['https://mirror.example/simple/', TorchMirrorUrl.Default],
});

expect(config).toEqual({
packages: ['torch'],
indexUrl: TorchMirrorUrl.Default,
extraIndexUrls: [TorchMirrorUrl.Cuda, 'https://mirror.example/simple/'],
prerelease: false,
upgradePackages: undefined,
});
});

test('uses custom mirror as the primary index when it is not a known torch mirror', () => {
const config = getTorchInstallConfig({
packages: ['torch'],
torchMirror: 'https://custom.example/simple/',
pypiMirror: TorchMirrorUrl.Default,
fallbackIndexUrls: ['https://mirror.example/simple/'],
});

expect(config).toEqual({
packages: ['torch'],
indexUrl: 'https://custom.example/simple/',
extraIndexUrls: undefined,
prerelease: false,
upgradePackages: undefined,
});
});

test('marks nightly mirrors as prerelease and keeps PyPI primary', () => {
const config = getTorchInstallConfig({
packages: ['torch'],
torchMirror: TorchMirrorUrl.NightlyCpu,
pypiMirror: 'https://pypi.example/simple/',
fallbackIndexUrls: ['https://mirror.example/simple/'],
});

expect(config).toEqual({
packages: ['torch'],
indexUrl: 'https://pypi.example/simple/',
extraIndexUrls: [TorchMirrorUrl.NightlyCpu, 'https://mirror.example/simple/'],
prerelease: true,
upgradePackages: undefined,
});
});

test('treats older CUDA mirrors as torch indexes', () => {
const config = getTorchInstallConfig({
packages: ['torch'],
torchMirror: 'https://download.pytorch.org/whl/cu118',
pypiMirror: TorchMirrorUrl.Default,
fallbackIndexUrls: ['https://mirror.example/simple/'],
});

expect(config).toEqual({
packages: ['torch'],
indexUrl: TorchMirrorUrl.Default,
extraIndexUrls: ['https://download.pytorch.org/whl/cu118', 'https://mirror.example/simple/'],
prerelease: false,
upgradePackages: undefined,
});
});
});

describe('hasRequirements', () => {
test('returns OK when all packages are installed', async ({ virtualEnv }) => {
mockSpawnOutputOnce('Would make no changes\n');
Expand Down
Loading