Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add gpustack image tools #13042

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/core/tools/provider/_position.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- cogview
- comfyui
- getimgai
- gpustack
- siliconflow
- spark
- stepfun
Expand Down
14 changes: 14 additions & 0 deletions api/core/tools/provider/builtin/gpustack/_assets/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 26 additions & 0 deletions api/core/tools/provider/builtin/gpustack/gpustack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import requests

from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class GPUStackProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
base_url = credentials.get("base_url", "").removesuffix("/").removesuffix("/v1-openai")
api_key = credentials.get("api_key", "")
tls_verify = credentials.get("tls_verify", True)

if not base_url:
raise ToolProviderCredentialValidationError("GPUStack base_url is required")
if not api_key:
raise ToolProviderCredentialValidationError("GPUStack api_key is required")
headers = {
"accept": "application/json",
"authorization": f"Bearer {api_key}",
}

response = requests.get(f"{base_url}/v1-openai/models", headers=headers, verify=tls_verify)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(
f"Failed to validate GPUStack API key, status code: {response.status_code}-{response.text}"
)
44 changes: 44 additions & 0 deletions api/core/tools/provider/builtin/gpustack/gpustack.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
identity:
author: gpustack
name: gpustack
label:
en_US: GPUStack
zh_Hans: GPUStack
description:
en_US: GPUStack is an open-source GPU cluster manager for running AI models, providing efficient resource management and model deployment capabilities.
zh_Hans: GPUStack 是一款开源的 GPU 集群管理工具,专为 AI 模型部署和运行而设计,提供高效的资源管理和模型部署能力。
icon: icon.svg
tags:
- image
credentials_for_provider:
base_url:
type: text-input
required: true
label:
en_US: Server URL
zh_Hans: 服务器 URL
placeholder:
en_US: http://your-server-address.com
help:
en_US: Please input GPUStack server's URL
zh_Hans: 请输入 GPUStack 服务器的 URL
api_key:
type: secret-input
required: true
label:
en_US: API Key
zh_Hans: API Key
placeholder:
en_US: Please input your GPUStack API Key
zh_Hans: 请输入你的 GPUStack API Key
url: https://docs.gpustack.ai/latest/user-guide/api-key-management/
tls_verify:
type: boolean
required: false
label:
en_US: TLS Verify
zh_Hans: 证书验证
help:
en_US: Whether to verify the TLS certificate of the GPUStack server.
zh_Hans: 是否验证 GPUStack 服务器的 TLS 证书。
default: true
47 changes: 47 additions & 0 deletions api/core/tools/provider/builtin/gpustack/tools/image_edit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import io
from typing import Any, Union

import requests

from core.file.enums import FileType
from core.file.file_manager import download
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

from .utils import get_base_url, get_common_params, handle_api_error, handle_image_response


class ImageEditTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
image = tool_parameters.get("image")
if image.type != FileType.IMAGE:
return [self.create_text_message("Not a valid image file")]

try:
params = get_common_params(tool_parameters)
params["strength"] = tool_parameters.get("strength", 0.75)

image_binary = io.BytesIO(download(image))
files = {"image": ("image.png", image_binary, "image/png")}

base_url = get_base_url(self.runtime.credentials["base_url"])
response = requests.post(
f"{base_url}/v1-openai/images/edits",
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
data=params,
files=files,
verify=self.runtime.credentials.get("tls_verify", True),
)

if not response.ok:
return self.create_text_message(handle_api_error(response))

result = []
return handle_image_response(result, response, self)

except ValueError as e:
return self.create_text_message(str(e))
except Exception as e:
return self.create_text_message(f"An error occurred: {str(e)}")
181 changes: 181 additions & 0 deletions api/core/tools/provider/builtin/gpustack/tools/image_edit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
identity:
name: image_edit
author: gpustack
label:
en_US: Image Edit
zh_Hans: 图片编辑
icon: icon.svg
description:
human:
en_US: Edit images with GPUStack's image editing model.
zh_Hans: 使用 GPUStack 的图像编辑模型编辑图片。
llm: This tool is used to edit image.
parameters:
- name: image
type: file
required: true
label:
en_US: Image
zh_Hans: 图片
human_description:
en_US: The image to be edited.
zh_Hans: 要编辑的图片。
llm_description: The image to be edited.
form: llm
- name: prompt
type: string
required: true
label:
en_US: prompt
zh_Hans: 提示词
human_description:
en_US: The text prompt used to edit the image.
zh_Hans: 用于编辑图片的文字提示词
llm_description: this prompt text will be used to edit image.
form: llm
- name: model
type: string
required: true
label:
en_US: Model
zh_Hans: 模型
human_description:
en_US: image model name that running in GPUStack.
zh_Hans: 在 GPUStack 上运行的图像模型名称。
form: form
- name: cfg_scale
type: number
required: false
default: 4.5
label:
en_US: CFG Scale
human_description:
en_US: Classifier-free guidance scale, affecting the image's adherence to the prompt.
zh_Hans: 无分类器引导比例,影响图片的对 Prompt 的贴合度。
form: form
- name: n
type: number
required: false
default: 1
label:
en_US: Number
zh_Hans: 数量
human_description:
en_US: Number of images to generate.
zh_Hans: 生成图片数量。
form: form
- name: size
type: string
required: true
default: "512x512"
label:
en_US: Image Size
zh_Hans: 图片尺寸
human_description:
en_US: The maximum size of the generated image is controlled by the deployment parameters of the model.
zh_Hans: 图片生成的最大尺寸受控于模型的部署参数。
form: form
- name: sample_method
type: select
required: true
default: euler
options:
- value: euler_a
label:
en_US: euler_a
- value: euler
label:
en_US: euler
- value: heun
label:
en_US: heun
- value: dpm2
label:
en_US: dpm2
- value: dpm++2s_a
label:
en_US: dpm++2s_a
- value: dpm++2m
label:
en_US: dpm++2m
- value: dpm++2mv2
label:
en_US: dpm++2mv2
- value: ipndm
label:
en_US: ipndm
- value: ipndm_v
label:
en_US: ipndm_v
- value: icm
label:
en_US: icm
label:
en_US: Sample Method
zh_Hans: 采样方法
human_description:
en_US: The sample method for the image generation model.
zh_Hans: 图像生成模型的采样方法。
form: form
- name: sampling_steps
type: number
required: false
default: 20
label:
en_US: Sampling Steps
zh_Hans: 采样步数
human_description:
en_US: Number of sampling steps to generate the image.
zh_Hans: 生成图片所需的采样步数。
form: form
- name: guidance
type: number
required: false
default: 4.5
label:
en_US: Guidance
human_description:
en_US: Guidance scale, affecting the quality and diversity of the image.
zh_Hans: 引导比例,影响图片的质量和多样性
form: form
- name: schedule_method
type: select
required: true
default: discrete
options:
- value: discrete
label:
en_US: discrete
- value: karras
label:
en_US: karras
- value: exponential
label:
en_US: exponential
- value: ays
label:
en_US: ays
- value: gits
label:
en_US: gits
label:
en_US: Schedule Method
zh_Hans: 调度方法
form: form
- name: strength
type: number
required: false
default: 0.75
label:
en_US: Strength
zh_Hans: 强度
human_description:
en_US: The higher the value, the greater the modification to the original image.
zh_Hans: 值越高,它对原图的修改越大。
form: form
- name: seed
type: number
required: false
label:
en_US: Seed
form: form
34 changes: 34 additions & 0 deletions api/core/tools/provider/builtin/gpustack/tools/text2image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Union

import requests

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

from .utils import get_base_url, get_common_params, handle_api_error, handle_image_response


class TextToImageTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
try:
params = get_common_params(tool_parameters)
base_url = get_base_url(self.runtime.credentials["base_url"])
response = requests.post(
f"{base_url}/v1-openai/images/generations",
headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
json=params,
verify=self.runtime.credentials.get("tls_verify", True),
)

if not response.ok:
return self.create_text_message(handle_api_error(response))

result = []
return handle_image_response(result, response, self)

except ValueError as e:
return self.create_text_message(str(e))
except Exception as e:
return self.create_text_message(f"An error occurred: {str(e)}")
Loading
Loading