Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
125 changes: 124 additions & 1 deletion vscode/src/edit/adapters/smart-apply-custom.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { describe, expect, it } from 'vitest'
import { TokenCounterUtils, ps } from '@sourcegraph/cody-shared'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import * as vscode from 'vscode'
import { CustomModelSelectionProvider } from '../prompt/smart-apply/selection/custom-model'
import type { ModelParametersInput } from './base'
import { SmartApplyCustomModelParameterProvider } from './smart-apply-custom'

Expand Down Expand Up @@ -105,3 +108,123 @@ describe('SmartApplyCustomModelParameterProvider', () => {
})
})
})

describe('CustomModelSelectionProvider token count validation', () => {
const mockChatClient = {
getCompletion: vi.fn().mockResolvedValue('ENTIRE_FILE'),
} as any

beforeEach(() => {
vi.spyOn(vscode, 'Range').mockImplementation(() => ({}) as vscode.Range)
})

afterEach(() => {
vi.restoreAllMocks()
})

it('should throw an error when token count exceeds context window input limit', async () => {
// Mock token count to exceed limit
vi.spyOn(TokenCounterUtils, 'countPromptString').mockImplementation(async () => 1500)

const mockDocumentText = 'mock document text that exceeds token limit'
const mockDocument = {
lineCount: 100,
uri: { fsPath: 'test.ts' },
getText: vi.fn().mockReturnValue(mockDocumentText),
} as unknown as vscode.TextDocument

// Mock token count to exceed context window input
const contextWindow = { input: 1000, output: 500 }

const provider = new CustomModelSelectionProvider({ shouldAlwaysUseEntireFile: false })

await expect(
provider.getSelectedText({
instruction: ps`test instruction`,
replacement: ps`test replacement`,
document: mockDocument,
model: 'gpt-4',
chatClient: {} as any,
contextWindow,
codyApiVersion: 1,
})
).rejects.toThrow("The amount of text in this document exceeds Cody's current capacity.")

// Verify the token counter was called with the correct document text
expect(TokenCounterUtils.countPromptString).toHaveBeenCalledTimes(1)
const calledWith = vi.mocked(TokenCounterUtils.countPromptString).mock.calls[0][0]
expect(calledWith).toBeDefined()
expect(calledWith.toString()).toContain(mockDocumentText)
})

it('should return ENTIRE_FILE when token count is within limits and below threshold', async () => {
// Mock token count to be within limit but below threshold
vi.spyOn(TokenCounterUtils, 'countPromptString').mockImplementation(async () => 800)

const mockDocumentText = 'mock document text within token limit'
const mockDocument = {
lineCount: 100,
uri: { fsPath: 'test.ts' },
getText: vi.fn().mockReturnValue(mockDocumentText),
} as unknown as vscode.TextDocument

const contextWindow = { input: 1000, output: 500 }

const provider = new CustomModelSelectionProvider({ shouldAlwaysUseEntireFile: false })

const result = await provider.getSelectedText({
instruction: ps`test instruction`,
replacement: ps`test replacement`,
document: mockDocument,
model: 'gpt-4',
chatClient: mockChatClient,
contextWindow,
codyApiVersion: 1,
})

// Verify result is ENTIRE_FILE when token count is below threshold
expect(result).toBe('ENTIRE_FILE')

// Verify the token counter was called with the correct document text
expect(TokenCounterUtils.countPromptString).toHaveBeenCalledTimes(1)
const calledWith = vi.mocked(TokenCounterUtils.countPromptString).mock.calls[0][0]
expect(calledWith).toBeDefined()
expect(calledWith.toString()).toContain(mockDocumentText)
})

it('should always return ENTIRE_FILE when shouldAlwaysUseEntireFile is true regardless of token count', async () => {
// Mock token count to be above threshold but within limit
vi.spyOn(TokenCounterUtils, 'countPromptString').mockImplementation(async () => 15000)

const mockDocumentText = 'mock document text above threshold'
const mockDocument = {
lineCount: 100,
uri: { fsPath: 'test.ts' },
getText: vi.fn().mockReturnValue(mockDocumentText),
} as unknown as vscode.TextDocument

const contextWindow = { input: 20000, output: 500 } // Large enough to not trigger error

// Create provider instance with shouldAlwaysUseEntireFile set to true
const provider = new CustomModelSelectionProvider({ shouldAlwaysUseEntireFile: true })

const result = await provider.getSelectedText({
instruction: ps`test instruction`,
replacement: ps`test replacement`,
document: mockDocument,
model: 'gpt-4',
chatClient: mockChatClient,
contextWindow,
codyApiVersion: 1,
})

// Verify result is ENTIRE_FILE when shouldAlwaysUseEntireFile is true
expect(result).toBe('ENTIRE_FILE')

// Verify the token counter was called with the correct document text
expect(TokenCounterUtils.countPromptString).toHaveBeenCalledTimes(1)
const calledWith = vi.mocked(TokenCounterUtils.countPromptString).mock.calls[0][0]
expect(calledWith).toBeDefined()
expect(calledWith.toString()).toContain(mockDocumentText)
})
})
77 changes: 68 additions & 9 deletions vscode/src/edit/output/response-transformer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ describe('Smart Apply Response Extraction', () => {
const text = `<${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>const x = 1;</${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>`
const task = createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault)

const result = responseTransformer(text, task, true)
// When isMessageInProgress is true, the original text is returned without processing
const resultInProgress = responseTransformer(text, task, true)
expect(resultInProgress).toBe(text)

// When isMessageInProgress is false, the text is processed
const result = responseTransformer(text, task, false)
expect(result).toBe('const x = 1;')
})

Expand Down Expand Up @@ -48,7 +53,12 @@ describe('Smart Apply Response Extraction', () => {
const text = `<${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>outer <${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>inner</${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}> content</${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>`
const task = createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault)

const result = responseTransformer(text, task, true)
// When isMessageInProgress is true, the original text is returned without processing
const resultInProgress = responseTransformer(text, task, true)
expect(resultInProgress).toBe(text)

// When isMessageInProgress is false, the text is processed
const result = responseTransformer(text, task, false)
expect(result).toBe(
`outer <${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>inner</${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}> content`
)
Expand All @@ -58,23 +68,72 @@ describe('Smart Apply Response Extraction', () => {
const text = `<${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}></${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>`
const task = createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault)

const result = responseTransformer(text, task, true)
// When isMessageInProgress is true, the original text is returned without processing
const resultInProgress = responseTransformer(text, task, true)
expect(resultInProgress).toBe(text)

// When isMessageInProgress is false, the text is processed
const result = responseTransformer(text, task, false)
expect(result).toBe('')
})

it('should add newline for smartApply without empty selection range', () => {
const text = `<${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>const x = 1;</${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>`
it('should not add newline for smartApply with empty selection range', () => {
const text = 'const x = 1;'
const task = {
...createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault),
mode: 'insert',
selectionRange: { isEmpty: false },
original: 'existing content',
selectionRange: { isEmpty: true },
original: '',
fixupFile: { uri: {} as any },
} as any

const result = responseTransformer(text, task, false)
expect(result).toBe('const x = 1;\n')
expect(result.endsWith('\n')).toBe(true)
expect(result).toBe('const x = 1;')
expect(result.endsWith('\n')).toBe(false)
})

it('should preserve newlines based on original text', () => {
const text = `<${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>\nconst x = 1;\n</${SMART_APPLY_CUSTOM_PROMPT_TOPICS.FINAL_CODE}>`

// Test 1: Original has no newlines, result should have no newlines
const task1 = {
...createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault),
original: 'const y = 2;',
} as any
const result1 = responseTransformer(text, task1, false)
expect(result1).toBe('const x = 1;')
expect(result1.startsWith('\n')).toBe(false)
expect(result1.endsWith('\n')).toBe(false)

// Test 2: Original has starting newline, result should have starting newline
const task2 = {
...createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault),
original: '\nconst y = 2;',
} as any
const result2 = responseTransformer(text, task2, false)
expect(result2).toBe('\nconst x = 1;')
expect(result2.startsWith('\n')).toBe(true)
expect(result2.endsWith('\n')).toBe(false)

// Test 3: Original has ending newline, result should have ending newline
const task3 = {
...createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault),
original: 'const y = 2;\n',
} as any
const result3 = responseTransformer(text, task3, false)
expect(result3).toBe('const x = 1;\n')
expect(result3.startsWith('\n')).toBe(false)
expect(result3.endsWith('\n')).toBe(true)

// Test 4: Original has both newlines, result should have both newlines
const task4 = {
...createTask('smartApply', SMART_APPLY_MODEL_IDENTIFIERS.FireworksQwenCodeDefault),
original: '\nconst y = 2;\n',
} as any
const result4 = responseTransformer(text, task4, false)
expect(result4).toBe('\nconst x = 1;\n')
expect(result4.startsWith('\n')).toBe(true)
expect(result4.endsWith('\n')).toBe(true)
})
})

Expand Down
63 changes: 48 additions & 15 deletions vscode/src/edit/output/response-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ const MARKDOWN_CODE_BLOCK_REGEX = new RegExp(
const LEADING_SPACES_AND_NEW_LINES = /^\s*\n/
const LEADING_SPACES = /^[ ]+/

const SMART_APPLY_MODEL_SET = new Set(Object.values(SMART_APPLY_MODEL_IDENTIFIERS))

/**
* Checks if the task is using a smart apply custom model
*/
function taskUsesSmartApplyCustomModel(task: FixupTask): boolean {
return SMART_APPLY_MODEL_SET.has(task.model)
}

/**
* Strips the text of any unnecessary content.
* This includes:
Expand All @@ -62,10 +71,7 @@ function stripText(text: string, task: FixupTask): string {
}

function extractSmartApplyCustomModelResponse(text: string, task: FixupTask): string {
if (
task.intent !== 'smartApply' ||
!Object.values(SMART_APPLY_MODEL_IDENTIFIERS).includes(task.model)
) {
if (!taskUsesSmartApplyCustomModel(task)) {
return text
}

Expand All @@ -86,15 +92,31 @@ function extractSmartApplyCustomModelResponse(text: string, task: FixupTask): st
}

/**
- * Regular expression to detect potential HTML entities.
- * Checks for named (&name;), decimal (&#digits;), or hex (&#xhex;) entities.
+ * Regular expression to detect the *few* entities we actually care about.
+ * We purposefully limit the named-entity part to the common escaping
+ * sequences that LLMs emit in source code:
+ * &lt; &gt; &amp; &quot; &apos;
+ * Everything else (e.g. &nbsp;, &curren;, &copy;, …) is ignored so that we
+ * don’t accidentally alter code like “&current_value;”.
*/
* Preserves or removes newlines at the start and end of the text based on the original text.
* If the original text doesn't start/end with a newline, the corresponding newline in the updated text is removed.
*/
function trimLLMNewlines(text: string, original: string): string {
let result = text

// Handle starting newline
if (result.match(/^\r?\n/) && !original.match(/^\r?\n/)) {
result = result.replace(/^\r?\n/, '')
}
if (result.match(/\r?\n$/) && !original.match(/\r?\n$/)) {
result = result.replace(/\r?\n$/, '')
}

return result
}

/**
* Regular expression to detect the *few* entities we actually care about.
* We purposefully limit the named-entity part to the common escaping
* sequences that LLMs emit in source code:
* &lt; &gt; &amp; &quot; &apos;
* Everything else (e.g. &nbsp;, &curren;, &copy;, …) is ignored so that we
* don't accidentally alter code like "&current_value;".
*/
const POTENTIAL_HTML_ENTITY_REGEX = /&(?:(?:lt|gt|amp|quot|apos)|#\d+|#x[0-9a-fA-F]+);/

/**
Expand All @@ -107,8 +129,19 @@ export function responseTransformer(
task: FixupTask,
isMessageInProgress: boolean
): string {
const updatedText = extractSmartApplyCustomModelResponse(text, task)
const strippedText = stripText(updatedText, task)
// Skip processing for in-progress messages from smart apply custom models
if (taskUsesSmartApplyCustomModel(task)) {
if (isMessageInProgress || task.mode === 'insert') {
return text
}

const updatedText = extractSmartApplyCustomModelResponse(text, task)

// Preserve newlines only if they were in the original text
return trimLLMNewlines(updatedText, task.original)
}

const strippedText = stripText(text, task)

// Trim leading spaces
// - For `add` insertions, the LLM will attempt to continue the code from the position of the cursor, we handle the `insertionPoint`
Expand Down
7 changes: 2 additions & 5 deletions vscode/src/edit/prompt/smart-apply/selection/custom-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,17 @@ export class CustomModelSelectionProvider implements SmartApplySelectionProvider
chatClient,
contextWindow,
}: SelectionPromptProviderArgs): Promise<string> {
if (this.shouldAlwaysUseEntireFile) {
return 'ENTIRE_FILE'
}

const documentRange = new vscode.Range(0, 0, document.lineCount - 1, 0)
const documentText = PromptString.fromDocumentText(document, documentRange)
const tokenCount = await TokenCounterUtils.countPromptString(documentText)

if (tokenCount > contextWindow.input) {
throw new Error("The amount of text in this document exceeds Cody's current capacity.")
}
if (tokenCount < FULL_FILE_REWRITE_TOKEN_TOKEN_LIMIT) {
if (tokenCount < FULL_FILE_REWRITE_TOKEN_TOKEN_LIMIT || this.shouldAlwaysUseEntireFile) {
return 'ENTIRE_FILE'
}

const { prefix, messages } = await this.getPrompt(
instruction,
replacement,
Expand Down
Loading