Skip to content

Fine-tune Qwen3-Embedding for code embeddings using Amazon SageMaker #4846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: default
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# inference.py - Custom inference script for embedding operations
import json
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

The code uses broad library imports instead of importing specific required class, which consumes unnecessary memory and makes code maintenance harder by obscuring actual library usage. To optimize performance and improve code clarity, use targeted imports with 'from library import specific_class' syntax. Learn More https://docs.python.org/3/tutorial/modules.html.

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EmbeddingInferenceHandler:
"""
Custom inference handler that supports both encoding and similarity operations.
"""

def __init__(self):
self.model = None
self.device = None

def model_fn(self, model_dir):
"""
Load the fine-tuned model from the model directory.
"""
logger.info(f"Loading model from: {model_dir}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html


# Detect device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")

# Load the sentence transformer model
self.model = SentenceTransformer(model_dir, device=self.device)

return self.model

def input_fn(self, request_body, content_type="application/json"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

This code relies on a client-controlled input (e.g., cookies, URL parameters, or headers) to determine user roles, which is vulnerable to manipulation. An attacker could potentially elevate their privileges by tampering with these inputs. To fix this, enforce role-based checks using server-side session data or an external authentication service. Avoid relying on any user-controlled data for role validation. Learn more about authorization vulnerabilities from OWASP[https://owasp.org/Top10/A01_2021-Broken_Access_Control/].

"""
Parse input data for inference.
"""
if content_type == "application/json":
logger.info(f"Inference request: {request_body}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html

data = json.loads(request_body)
else:
logger.warn(f"Wrong content type: {content_type}")
raise ValueError(f"Unsupported content type: {content_type}")

return data

def predict_fn(self, data, model):
"""
Perform inference based on the operation type.
"""
operation = data.get("operation", "encode")
logger.info(f"Prediction input: {data}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html

if operation == "encode":
return self._encode_operation(data, model)
elif operation == "similarity":
return self._similarity_operation(data, model)
else:
raise ValueError(f"Unsupported operation: {operation}")

def _encode_operation(self, data, model):
"""
Encode text inputs to embeddings.
"""
inputs = data.get("inputs", [])
if not inputs:
raise ValueError("No inputs provided for encoding")

# Get target dimension (default to 512 for Matryoshka)
target_dim = data.get("dimension", 512)

# Encode inputs
embeddings = model.encode(
inputs,
batch_size=data.get("batch_size", 32),
show_progress_bar=False,
normalize_embeddings=data.get("normalize", True),
)

# Truncate to target dimension if specified
if target_dim and target_dim < embeddings.shape[1]:
embeddings = embeddings[:, :target_dim]

return {
"embeddings": embeddings.tolist(),
"dimension": embeddings.shape[1],
"num_texts": len(inputs),
}

def _similarity_operation(self, data, model):
"""
Calculate similarity between text pairs.
"""
text1 = data.get("text1")
text2 = data.get("text2")

if not text1 or not text2:
raise ValueError("Both text1 and text2 required for similarity")

# Get target dimension
target_dim = data.get("dimension", 512)

# Encode both texts
embeddings = model.encode([text1, text2], normalize_embeddings=True)

# Truncate if needed
if target_dim and target_dim < embeddings.shape[1]:
embeddings = embeddings[:, :target_dim]

# Calculate cosine similarity
similarity = cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))[
0
][0]

return {
"similarity": float(similarity),
"dimension": embeddings.shape[1],
"text1_embedding": embeddings[0].tolist(),
"text2_embedding": embeddings[1].tolist(),
}

def output_fn(self, prediction, accept="application/json"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

This code relies on a client-controlled input (e.g., cookies, URL parameters, or headers) to determine user roles, which is vulnerable to manipulation. An attacker could potentially elevate their privileges by tampering with these inputs. To fix this, enforce role-based checks using server-side session data or an external authentication service. Avoid relying on any user-controlled data for role validation. Learn more about authorization vulnerabilities from OWASP[https://owasp.org/Top10/A01_2021-Broken_Access_Control/].

"""
Format the prediction output.
"""
if accept == "application/json":
return json.dumps(prediction), "application/json"
else:
raise ValueError(f"Unsupported accept type: {accept}")


# Global handler instance
handler = EmbeddingInferenceHandler()


# SageMaker inference functions
def model_fn(model_dir):
return handler.model_fn(model_dir)


def input_fn(request_body, content_type):
return handler.input_fn(request_body, content_type)


def predict_fn(data, model):
logger.info(f"predict_fn data: {data}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Log Injection occurs when unsafe input is directly written to log files without proper sanitization. This can allow attackers to manipulate log entries, potentially leading to security issues like log forging or cross-site scripting. To prevent this, always sanitize user input before logging by removing or encoding newline characters, using string encoding functions, and leveraging built-in sanitization features of logging libraries when available. Learn more - https://cwe.mitre.org/data/definitions/117.html

return handler.predict_fn(data, model)


def output_fn(prediction, accept):
return handler.output_fn(prediction, accept)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sentence-transformers>=4.1.0
transformers>=4.51.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Create code embedding training dataset
from datasets import Dataset


def create_code_embedding_dataset():
"""
Create a comprehensive dataset with code-description pairs for embedding fine-tuning.
This dataset follows contrastive learning principles for code embeddings.
"""
code_pairs = [
{
"text1": "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",
"text2": "recursive function to calculate fibonacci numbers in Python",
},
{
"text1": "SELECT * FROM users WHERE age > 18 AND status = 'active'",
"text2": "SQL query to find all active adult users from database",
},
{
"text1": "class Stack: def __init__(self): self.items = []",
"text2": "stack data structure implementation with initialization method",
},
{
"text1": "for i in range(len(arr)): for j in range(len(arr)-1-i): if arr[j] > arr[j+1]: arr[j], arr[j+1] = arr[j+1], arr[j]",
"text2": "bubble sort algorithm implementation using nested loops",
},
{
"text1": "import pandas as pd; df = pd.read_csv('data.csv')",
"text2": "load CSV file into pandas DataFrame for data analysis",
},
{
"text1": "def quicksort(arr): return [] if not arr else quicksort([x for x in arr[1:] if x <= arr[0]]) + [arr[0]] + quicksort([x for x in arr[1:] if x > arr[0]])",
"text2": "quicksort algorithm implementation using list comprehension",
},
{
"text1": "try: result = func() except Exception as e: print(f'Error: {e}')",
"text2": "error handling with try-except block and formatted output",
},
{"text1": "lambda x: x ** 2", "text2": "lambda function to calculate square of a number"},
{
"text1": "def binary_search(arr, target): left, right = 0, len(arr) - 1",
"text2": "binary search algorithm initialization with left and right pointers",
},
{
"text1": "class LinkedList: def __init__(self): self.head = None",
"text2": "linked list data structure class definition with head pointer",
},
]

# Expand dataset with more programming language examples
additional_pairs = [
{
"text1": "function addNumbers(a, b) { return a + b; }",
"text2": "JavaScript function to add two numbers and return result",
},
{
"text1": "public class Calculator { private int result; }",
"text2": "Java class definition for calculator with private result field",
},
{
"text1": "def __str__(self): return f'{self.name}: {self.value}'",
"text2": "Python string representation method for object display",
},
{
"text1": "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100))",
"text2": "SQL table creation statement with primary key and name field",
},
{
"text1": "const result = await fetch('/api/data').then(res => res.json())",
"text2": "JavaScript async API call with JSON response parsing",
},
]

# Combine both lists
code_pairs.extend(additional_pairs)

# Transform combined list to the format for Dataset.from_dict()
dataset_dict = {
"text1": [pair["text1"] for pair in code_pairs],
"text2": [pair["text2"] for pair in code_pairs],
}
return Dataset.from_dict(dataset_dict)
Loading