-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Fine-tune Qwen3-Embedding for code embeddings using Amazon SageMaker #4846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: default
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# inference.py - Custom inference script for embedding operations | ||
import json | ||
import torch | ||
from sentence_transformers import SentenceTransformer | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
import logging | ||
|
||
# Configure logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class EmbeddingInferenceHandler: | ||
""" | ||
Custom inference handler that supports both encoding and similarity operations. | ||
""" | ||
|
||
def __init__(self): | ||
self.model = None | ||
self.device = None | ||
|
||
def model_fn(self, model_dir): | ||
""" | ||
Load the fine-tuned model from the model directory. | ||
""" | ||
logger.info(f"Loading model from: {model_dir}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html |
||
|
||
# Detect device | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
logger.info(f"Using device: {self.device}") | ||
|
||
# Load the sentence transformer model | ||
self.model = SentenceTransformer(model_dir, device=self.device) | ||
|
||
return self.model | ||
|
||
def input_fn(self, request_body, content_type="application/json"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. This code relies on a client-controlled input (e.g., cookies, URL parameters, or headers) to determine user roles, which is vulnerable to manipulation. An attacker could potentially elevate their privileges by tampering with these inputs. To fix this, enforce role-based checks using server-side session data or an external authentication service. Avoid relying on any user-controlled data for role validation. Learn more about authorization vulnerabilities from OWASP[https://owasp.org/Top10/A01_2021-Broken_Access_Control/]. |
||
""" | ||
Parse input data for inference. | ||
""" | ||
if content_type == "application/json": | ||
logger.info(f"Inference request: {request_body}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html |
||
data = json.loads(request_body) | ||
else: | ||
logger.warn(f"Wrong content type: {content_type}") | ||
raise ValueError(f"Unsupported content type: {content_type}") | ||
|
||
return data | ||
|
||
def predict_fn(self, data, model): | ||
""" | ||
Perform inference based on the operation type. | ||
""" | ||
operation = data.get("operation", "encode") | ||
logger.info(f"Prediction input: {data}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html |
||
if operation == "encode": | ||
return self._encode_operation(data, model) | ||
elif operation == "similarity": | ||
return self._similarity_operation(data, model) | ||
else: | ||
raise ValueError(f"Unsupported operation: {operation}") | ||
|
||
def _encode_operation(self, data, model): | ||
""" | ||
Encode text inputs to embeddings. | ||
""" | ||
inputs = data.get("inputs", []) | ||
if not inputs: | ||
raise ValueError("No inputs provided for encoding") | ||
|
||
# Get target dimension (default to 512 for Matryoshka) | ||
target_dim = data.get("dimension", 512) | ||
|
||
# Encode inputs | ||
embeddings = model.encode( | ||
inputs, | ||
batch_size=data.get("batch_size", 32), | ||
show_progress_bar=False, | ||
normalize_embeddings=data.get("normalize", True), | ||
) | ||
|
||
# Truncate to target dimension if specified | ||
if target_dim and target_dim < embeddings.shape[1]: | ||
embeddings = embeddings[:, :target_dim] | ||
|
||
return { | ||
"embeddings": embeddings.tolist(), | ||
"dimension": embeddings.shape[1], | ||
"num_texts": len(inputs), | ||
} | ||
|
||
def _similarity_operation(self, data, model): | ||
""" | ||
Calculate similarity between text pairs. | ||
""" | ||
text1 = data.get("text1") | ||
text2 = data.get("text2") | ||
|
||
if not text1 or not text2: | ||
raise ValueError("Both text1 and text2 required for similarity") | ||
|
||
# Get target dimension | ||
target_dim = data.get("dimension", 512) | ||
|
||
# Encode both texts | ||
embeddings = model.encode([text1, text2], normalize_embeddings=True) | ||
|
||
# Truncate if needed | ||
if target_dim and target_dim < embeddings.shape[1]: | ||
embeddings = embeddings[:, :target_dim] | ||
|
||
# Calculate cosine similarity | ||
similarity = cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))[ | ||
0 | ||
][0] | ||
|
||
return { | ||
"similarity": float(similarity), | ||
"dimension": embeddings.shape[1], | ||
"text1_embedding": embeddings[0].tolist(), | ||
"text2_embedding": embeddings[1].tolist(), | ||
} | ||
|
||
def output_fn(self, prediction, accept="application/json"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. This code relies on a client-controlled input (e.g., cookies, URL parameters, or headers) to determine user roles, which is vulnerable to manipulation. An attacker could potentially elevate their privileges by tampering with these inputs. To fix this, enforce role-based checks using server-side session data or an external authentication service. Avoid relying on any user-controlled data for role validation. Learn more about authorization vulnerabilities from OWASP[https://owasp.org/Top10/A01_2021-Broken_Access_Control/]. |
||
""" | ||
Format the prediction output. | ||
""" | ||
if accept == "application/json": | ||
return json.dumps(prediction), "application/json" | ||
else: | ||
raise ValueError(f"Unsupported accept type: {accept}") | ||
|
||
|
||
# Global handler instance | ||
handler = EmbeddingInferenceHandler() | ||
|
||
|
||
# SageMaker inference functions | ||
def model_fn(model_dir): | ||
return handler.model_fn(model_dir) | ||
|
||
|
||
def input_fn(request_body, content_type): | ||
return handler.input_fn(request_body, content_type) | ||
|
||
|
||
def predict_fn(data, model): | ||
logger.info(f"predict_fn data: {data}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html |
||
return handler.predict_fn(data, model) | ||
|
||
|
||
def output_fn(prediction, accept): | ||
return handler.output_fn(prediction, accept) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
sentence-transformers>=4.1.0 | ||
transformers>=4.51.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Create code embedding training dataset | ||
from datasets import Dataset | ||
|
||
|
||
def create_code_embedding_dataset(): | ||
""" | ||
Create a comprehensive dataset with code-description pairs for embedding fine-tuning. | ||
This dataset follows contrastive learning principles for code embeddings. | ||
""" | ||
code_pairs = [ | ||
{ | ||
"text1": "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)", | ||
"text2": "recursive function to calculate fibonacci numbers in Python", | ||
}, | ||
{ | ||
"text1": "SELECT * FROM users WHERE age > 18 AND status = 'active'", | ||
"text2": "SQL query to find all active adult users from database", | ||
}, | ||
{ | ||
"text1": "class Stack: def __init__(self): self.items = []", | ||
"text2": "stack data structure implementation with initialization method", | ||
}, | ||
{ | ||
"text1": "for i in range(len(arr)): for j in range(len(arr)-1-i): if arr[j] > arr[j+1]: arr[j], arr[j+1] = arr[j+1], arr[j]", | ||
"text2": "bubble sort algorithm implementation using nested loops", | ||
}, | ||
{ | ||
"text1": "import pandas as pd; df = pd.read_csv('data.csv')", | ||
"text2": "load CSV file into pandas DataFrame for data analysis", | ||
}, | ||
{ | ||
"text1": "def quicksort(arr): return [] if not arr else quicksort([x for x in arr[1:] if x <= arr[0]]) + [arr[0]] + quicksort([x for x in arr[1:] if x > arr[0]])", | ||
"text2": "quicksort algorithm implementation using list comprehension", | ||
}, | ||
{ | ||
"text1": "try: result = func() except Exception as e: print(f'Error: {e}')", | ||
"text2": "error handling with try-except block and formatted output", | ||
}, | ||
{"text1": "lambda x: x ** 2", "text2": "lambda function to calculate square of a number"}, | ||
{ | ||
"text1": "def binary_search(arr, target): left, right = 0, len(arr) - 1", | ||
"text2": "binary search algorithm initialization with left and right pointers", | ||
}, | ||
{ | ||
"text1": "class LinkedList: def __init__(self): self.head = None", | ||
"text2": "linked list data structure class definition with head pointer", | ||
}, | ||
] | ||
|
||
# Expand dataset with more programming language examples | ||
additional_pairs = [ | ||
{ | ||
"text1": "function addNumbers(a, b) { return a + b; }", | ||
"text2": "JavaScript function to add two numbers and return result", | ||
}, | ||
{ | ||
"text1": "public class Calculator { private int result; }", | ||
"text2": "Java class definition for calculator with private result field", | ||
}, | ||
{ | ||
"text1": "def __str__(self): return f'{self.name}: {self.value}'", | ||
"text2": "Python string representation method for object display", | ||
}, | ||
{ | ||
"text1": "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100))", | ||
"text2": "SQL table creation statement with primary key and name field", | ||
}, | ||
{ | ||
"text1": "const result = await fetch('/api/data').then(res => res.json())", | ||
"text2": "JavaScript async API call with JSON response parsing", | ||
}, | ||
] | ||
|
||
# Combine both lists | ||
code_pairs.extend(additional_pairs) | ||
|
||
# Transform combined list to the format for Dataset.from_dict() | ||
dataset_dict = { | ||
"text1": [pair["text1"] for pair in code_pairs], | ||
"text2": [pair["text2"] for pair in code_pairs], | ||
} | ||
return Dataset.from_dict(dataset_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.
The code uses broad library imports instead of importing specific required class, which consumes unnecessary memory and makes code maintenance harder by obscuring actual library usage. To optimize performance and improve code clarity, use targeted imports with 'from library import specific_class' syntax. Learn More https://docs.python.org/3/tutorial/modules.html.