Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature. add feat to modify metadata via dataset api #13116

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions api/controllers/service_api/dataset/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
DocumentIndexingError,
InvalidMetadataError,
)
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.errors.error import ProviderTokenNotInitError
Expand Down Expand Up @@ -50,6 +51,9 @@ def post(self, tenant_id, dataset_id):
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")

args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
Expand All @@ -61,6 +65,28 @@ def post(self, tenant_id, dataset_id):
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")

# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")

if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)

if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")

# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}

text = args.get("text")
name = args.get("name")
if text is None or name is None:
Expand Down Expand Up @@ -107,6 +133,8 @@ def post(self, tenant_id, dataset_id, document_id):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
Expand All @@ -115,6 +143,29 @@ def post(self, tenant_id, dataset_id, document_id):
if not dataset:
raise ValueError("Dataset is not exist.")

# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")

if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)

if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")

# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")

# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}

if args["text"]:
text = args.get("text")
name = args.get("name")
Expand Down Expand Up @@ -161,6 +212,30 @@ def post(self, tenant_id, dataset_id):
args["doc_form"] = "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"

# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")

if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)

if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")

# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")

# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}

# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
Expand Down Expand Up @@ -228,6 +303,29 @@ def post(self, tenant_id, dataset_id, document_id):
if "doc_language" not in args:
args["doc_language"] = "English"

# Validate metadata if provided
if args.get("doc_type") or args.get("doc_metadata"):
if not args.get("doc_type") or not args.get("doc_metadata"):
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")

if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise InvalidMetadataError(
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
)

if not isinstance(args["doc_metadata"], dict):
raise InvalidMetadataError("doc_metadata must be a dictionary")

# Validate metadata schema based on doc_type
if args["doc_type"] != "others":
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
for key, value in args["doc_metadata"].items():
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
raise InvalidMetadataError(f"Invalid type for metadata field {key}")

# set to MetaDataConfig
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}

# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
Expand Down
15 changes: 15 additions & 0 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
MetaDataConfig,
RerankingModel,
RetrievalModel,
SegmentUpdateArgs,
Expand Down Expand Up @@ -894,6 +895,9 @@ def save_document_with_dataset_id(
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
if knowledge_config.metadata:
document.doc_type = knowledge_config.metadata.doc_type
document.metadata = knowledge_config.metadata.doc_metadata
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
Expand All @@ -910,6 +914,7 @@ def save_document_with_dataset_id(
account,
file_name,
batch,
knowledge_config.metadata,
)
db.session.add(document)
db.session.flush()
Expand Down Expand Up @@ -965,6 +970,7 @@ def save_document_with_dataset_id(
account,
page.page_name,
batch,
knowledge_config.metadata,
)
db.session.add(document)
db.session.flush()
Expand Down Expand Up @@ -1005,6 +1011,7 @@ def save_document_with_dataset_id(
account,
document_name,
batch,
knowledge_config.metadata,
)
db.session.add(document)
db.session.flush()
Expand Down Expand Up @@ -1042,6 +1049,7 @@ def build_document(
account: Account,
name: str,
batch: str,
metadata: Optional[MetaDataConfig] = None,
):
document = Document(
tenant_id=dataset.tenant_id,
Expand All @@ -1057,6 +1065,9 @@ def build_document(
doc_form=document_form,
doc_language=document_language,
)
if metadata is not None:
document.doc_metadata = metadata.doc_metadata
document.doc_type = metadata.doc_type
return document

@staticmethod
Expand Down Expand Up @@ -1169,6 +1180,10 @@ def update_document_with_dataset_id(
# update document name
if document_data.name:
document.name = document_data.name
# update doc_type and doc_metadata if provided
if document_data.metadata is not None:
document.doc_metadata = document_data.metadata.doc_type
document.doc_type = document_data.metadata.doc_type
# update document to be waiting
document.indexing_status = "waiting"
document.completed_at = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class RetrievalModel(BaseModel):
score_threshold: Optional[float] = None


class MetaDataConfig(BaseModel):
doc_type: str
doc_metadata: dict


class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
Expand All @@ -105,6 +110,7 @@ class KnowledgeConfig(BaseModel):
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
metadata: Optional[MetaDataConfig] = None


class SegmentUpdateArgs(BaseModel):
Expand Down
Loading
Loading