Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(improve-api-endpoints): Added Datasets and Annotation APIs #12237

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions api/controllers/service_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
api = ExternalApi(bp)

from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, segment, upload_file
from .app import annotation, app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, segment
from .workspace import models
107 changes: 107 additions & 0 deletions api/controllers/service_api/app/annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden

from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
)
from libs.login import current_user
from models.model import App, EndUser
from services.annotation_service import AppAnnotationService


class AnnotationReplyActionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, action):
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200


class AnnotationReplyActionStatusApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, job_id, action):
job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job is not exist.")

job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()

return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200


class AnnotationListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)

annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation


class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()

annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def delete(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()

annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 200


api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>")
api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
api.add_resource(AnnotationListApi, "/apps/annotations")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/<uuid:annotation_id>")
150 changes: 148 additions & 2 deletions api/controllers/service_api/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask import request
from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import Forbidden, NotFound

import services.dataset_service
from controllers.service_api import api
Expand All @@ -11,7 +11,7 @@
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetService
from services.dataset_service import DatasetPermissionService, DatasetService


def _validate_name(name):
Expand All @@ -20,6 +20,12 @@ def _validate_name(name):
return name


def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description


class DatasetListApi(DatasetApiResource):
"""Resource for datasets."""

Expand Down Expand Up @@ -135,6 +141,145 @@ def post(self, tenant_id):
class DatasetApi(DatasetApiResource):
"""Resource for dataset."""

def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})

# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)

embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)

model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")

if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
else:
data["embedding_available"] = False
else:
data["embedding_available"] = True

if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})

return data, 200

def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")

parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")

parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)

parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)

parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()

# check embedding model setting
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)

# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)

dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)

if dataset is None:
raise NotFound("Dataset not found.")

result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id

if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)

partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})

return result_data, 200

def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
Expand All @@ -155,6 +300,7 @@ def delete(self, _, dataset_id):

try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204
else:
raise NotFound("Dataset not found.")
Expand Down
21 changes: 21 additions & 0 deletions api/controllers/service_api/workspace/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore

from controllers.service_api import api
from controllers.service_api.wraps import validate_dataset_token
from core.model_runtime.utils.encoders import jsonable_encoder
from services.model_provider_service import ModelProviderService


class ModelProviderAvailableModelApi(Resource):
@validate_dataset_token
def get(self, _, model_type):
tenant_id = current_user.current_tenant_id

model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

return jsonable_encoder({"data": models})


api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
21 changes: 21 additions & 0 deletions api/controllers/service_api/wraps.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ def decorated_view(*args, **kwargs):
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")

tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")

kwargs["app_model"] = app_model

if fetch_user_arg:
Expand Down
Loading