diff --git a/src/fastapi_cli/cli.py b/src/fastapi_cli/cli.py index 28afa297..73a94f61 100644 --- a/src/fastapi_cli/cli.py +++ b/src/fastapi_cli/cli.py @@ -126,6 +126,9 @@ def _run( module_data = import_data.module_data import_string = import_data.import_string + openapi_url = import_data.openapi_url + docs_url = import_data.docs_url + redoc_url = import_data.redoc_url toolkit.print(f"Importing from {module_data.extra_sys_path}") toolkit.print_line() @@ -152,15 +155,30 @@ def _run( ) url = f"http://{host}:{port}" - url_docs = f"{url}/docs" + docs_str = "" + + if openapi_url and (docs_url or redoc_url): + if docs_url: + docs_str += f"[link={url}{docs_url}]{url}{docs_url}[/]" + + if docs_url and redoc_url: + docs_str += " or " + + if redoc_url: + docs_str += f"[link={url}{redoc_url}]{url}{redoc_url}[/]" toolkit.print_line() toolkit.print( f"Server started at [link={url}]{url}[/]", - f"Documentation at [link={url_docs}]{url_docs}[/]", tag="server", ) + if docs_str: + toolkit.print( + f"Documentation at {docs_str}", + tag="server", + ) + if command == "dev": toolkit.print_line() toolkit.print( diff --git a/src/fastapi_cli/discover.py b/src/fastapi_cli/discover.py index b174f8fb..bb7d2f2b 100644 --- a/src/fastapi_cli/discover.py +++ b/src/fastapi_cli/discover.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from logging import getLogger from pathlib import Path -from typing import List, Union +from typing import List, Tuple, Union from fastapi_cli.exceptions import FastAPICLIException @@ -45,12 +45,16 @@ class ModuleData: def get_module_data_from_path(path: Path) -> ModuleData: use_path = path.resolve() module_path = use_path + if use_path.is_file() and use_path.stem == "__init__": module_path = use_path.parent + module_paths = [module_path] extra_sys_path = module_path.parent + for parent in module_path.parents: init_path = parent / "__init__.py" + if init_path.is_file(): module_paths.insert(0, parent) extra_sys_path = parent.parent @@ -58,6 +62,7 @@ def get_module_data_from_path(path: Path) -> ModuleData: break module_str = ".".join(p.stem for p in module_paths) + return ModuleData( module_import_str=module_str, extra_sys_path=extra_sys_path.resolve(), @@ -65,7 +70,9 @@ def get_module_data_from_path(path: Path) -> ModuleData: ) -def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) -> str: +def get_app_infos( + *, mod_data: ModuleData, app_name: Union[str, None] = None +) -> Tuple[str, Union[str, None], Union[str, None], Union[str, None]]: try: mod = importlib.import_module(mod_data.module_import_str) except (ImportError, ValueError) as e: @@ -74,32 +81,41 @@ def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) -> "Ensure all the package directories have an [blue]__init__.py[/blue] file" ) raise + if not FastAPI: # type: ignore[truthy-function] raise FastAPICLIException( "Could not import FastAPI, try running 'pip install fastapi'" ) from None + object_names = dir(mod) object_names_set = set(object_names) + if app_name: if app_name not in object_names_set: raise FastAPICLIException( f"Could not find app name {app_name} in {mod_data.module_import_str}" ) + app = getattr(mod, app_name) + if not isinstance(app, FastAPI): raise FastAPICLIException( f"The app name {app_name} in {mod_data.module_import_str} doesn't seem to be a FastAPI app" ) - return app_name + + return app_name, app.openapi_url, app.docs_url, app.redoc_url + for preferred_name in ["app", "api"]: if preferred_name in object_names_set: obj = getattr(mod, preferred_name) if isinstance(obj, FastAPI): - return preferred_name + return preferred_name, obj.openapi_url, obj.docs_url, obj.redoc_url + for name in object_names: obj = getattr(mod, name) if isinstance(obj, FastAPI): - return name + return name, obj.openapi_url, obj.docs_url, obj.redoc_url + raise FastAPICLIException("Could not find FastAPI app in module, try using --app") @@ -108,6 +124,9 @@ class ImportData: app_name: str module_data: ModuleData import_string: str + openapi_url: Union[str, None] = None + docs_url: Union[str, None] = None + redoc_url: Union[str, None] = None def get_import_data( @@ -121,14 +140,22 @@ def get_import_data( if not path.exists(): raise FastAPICLIException(f"Path does not exist {path}") + mod_data = get_module_data_from_path(path) sys.path.insert(0, str(mod_data.extra_sys_path)) - use_app_name = get_app_name(mod_data=mod_data, app_name=app_name) + use_app_name, openapi_url, docs_url, redoc_url = get_app_infos( + mod_data=mod_data, app_name=app_name + ) import_string = f"{mod_data.module_import_str}:{use_app_name}" return ImportData( - app_name=use_app_name, module_data=mod_data, import_string=import_string + app_name=use_app_name, + module_data=mod_data, + import_string=import_string, + openapi_url=openapi_url, + docs_url=docs_url, + redoc_url=redoc_url, ) @@ -144,12 +171,21 @@ def get_import_data_from_import_string(import_string: str) -> ImportData: sys.path.insert(0, str(here)) + module_data = ModuleData( + module_import_str=module_str, + extra_sys_path=here, + module_paths=[], + ) + + _, openapi_url, docs_url, redoc_url = get_app_infos( + mod_data=module_data, app_name=app_name + ) + return ImportData( app_name=app_name, - module_data=ModuleData( - module_import_str=module_str, - extra_sys_path=here, - module_paths=[], - ), + module_data=module_data, import_string=import_string, + openapi_url=openapi_url, + docs_url=docs_url, + redoc_url=redoc_url, ) diff --git a/tests/assets/single_file_docs.py b/tests/assets/single_file_docs.py new file mode 100644 index 00000000..d074804a --- /dev/null +++ b/tests/assets/single_file_docs.py @@ -0,0 +1,48 @@ +from fastapi import FastAPI + +no_openapi = FastAPI(openapi_url=None) + + +@no_openapi.get("/") +def no_openapi_root(): + return {"message": "single file no_openapi"} + + +none_docs = FastAPI(docs_url=None, redoc_url=None) + + +@none_docs.get("/") +def none_docs_root(): + return {"message": "single file none_docs"} + + +no_docs = FastAPI(docs_url=None) + + +@no_docs.get("/") +def no_docs_root(): + return {"message": "single file no_docs"} + + +no_redoc = FastAPI(redoc_url=None) + + +@no_redoc.get("/") +def no_redoc_root(): + return {"message": "single file no_redoc"} + + +full_docs = FastAPI() + + +@full_docs.get("/") +def full_docs_root(): + return {"message": "single file full_docs"} + + +custom_docs = FastAPI(docs_url="/custom-docs-url", redoc_url="/custom-redoc-url") + + +@custom_docs.get("/") +def custom_docs_root(): + return {"message": "single file custom_docs"} diff --git a/tests/test_cli.py b/tests/test_cli.py index b87a811a..c457c602 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -349,6 +349,84 @@ def test_run_env_vars_and_args() -> None: assert "Documentation at http://0.0.0.0:8080/docs" in result.output +def test_no_openapi() -> None: + with changing_dir(assets_path): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke( + app, ["dev", "single_file_docs.py", "--app", "no_openapi"] + ) + assert result.exit_code == 0, result.output + assert mock_run.called + + assert "http://127.0.0.1:8000/docs" not in result.output + assert "http://127.0.0.1:8000/redoc" not in result.output + + +def test_none_docs() -> None: + with changing_dir(assets_path): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke( + app, ["dev", "single_file_docs.py", "--app", "none_docs"] + ) + assert result.exit_code == 0, result.output + assert mock_run.called + + assert "http://127.0.0.1:8000/docs" not in result.output + assert "http://127.0.0.1:8000/redoc" not in result.output + + +def test_no_docs() -> None: + with changing_dir(assets_path): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke( + app, ["dev", "single_file_docs.py", "--app", "no_docs"] + ) + assert result.exit_code == 0, result.output + assert mock_run.called + + assert "http://127.0.0.1:8000/redoc" in result.output + assert "http://127.0.0.1:8000/docs" not in result.output + + +def test_no_redoc() -> None: + with changing_dir(assets_path): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke( + app, ["dev", "single_file_docs.py", "--app", "no_redoc"] + ) + assert result.exit_code == 0, result.output + assert mock_run.called + + assert "http://127.0.0.1:8000/docs" in result.output + assert "http://127.0.0.1:8000/redocs" not in result.output + + +def test_full_docs() -> None: + with changing_dir(assets_path): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke( + app, ["dev", "single_file_docs.py", "--app", "full_docs"] + ) + assert result.exit_code == 0, result.output + assert mock_run.called + + assert "http://127.0.0.1:8000/docs" in result.output + assert "http://127.0.0.1:8000/redoc" in result.output + + +def test_custom_docs() -> None: + with changing_dir(assets_path): + with patch.object(uvicorn, "run") as mock_run: + result = runner.invoke( + app, ["dev", "single_file_docs.py", "--app", "custom_docs"] + ) + assert result.exit_code == 0, result.output + assert mock_run.called + + assert "http://127.0.0.1:8000/custom-docs-url" in result.output + assert "http://127.0.0.1:8000/custom-redoc-url" in result.output + + def test_run_error() -> None: with changing_dir(assets_path): result = runner.invoke(app, ["run", "non_existing_file.py"]) diff --git a/tests/test_discover.py b/tests/test_discover.py index b1052050..99b5a4fe 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -7,18 +7,20 @@ get_import_data_from_import_string, ) from fastapi_cli.exceptions import FastAPICLIException +from tests.utils import changing_dir assets_path = Path(__file__).parent / "assets" def test_get_import_data_from_import_string_valid() -> None: - result = get_import_data_from_import_string("module.submodule:app") + with changing_dir(assets_path): + result = get_import_data_from_import_string("package.mod.app:app") assert isinstance(result, ImportData) assert result.app_name == "app" - assert result.import_string == "module.submodule:app" - assert result.module_data.module_import_str == "module.submodule" - assert result.module_data.extra_sys_path == Path(".").resolve() + assert result.import_string == "package.mod.app:app" + assert result.module_data.module_import_str == "package.mod.app" + assert result.module_data.extra_sys_path == Path(assets_path).resolve() assert result.module_data.module_paths == []