From 097b08d0e52614a1e945d467142e8d566416e9ee Mon Sep 17 00:00:00 2001 From: Asghar Ghorbani Date: Sun, 5 Jan 2025 20:47:19 +0100 Subject: [PATCH] [Feat] File integrity check for downloaded models (#159) --- __mocks__/stores/hfStore.ts | 2 +- .../ModelsScreen/ModelCard/ModelCard.tsx | 58 ++++++++++++---- .../ModelCard/__tests__/ModelCard.test.tsx | 4 +- src/screens/ModelsScreen/ModelCard/styles.ts | 4 ++ src/store/HFStore.ts | 9 +-- src/store/ModelStore.ts | 55 ++++++++++++++- src/store/__tests__/HFStore.test.ts | 6 +- src/store/defaultModels.ts | 49 ++++++++++++- src/utils/index.ts | 69 +++++++++++++++++++ src/utils/types.ts | 6 ++ 10 files changed, 237 insertions(+), 25 deletions(-) diff --git a/__mocks__/stores/hfStore.ts b/__mocks__/stores/hfStore.ts index cf529d2..9280262 100644 --- a/__mocks__/stores/hfStore.ts +++ b/__mocks__/stores/hfStore.ts @@ -13,7 +13,7 @@ export const mockHFStore = { // Methods setSearchQuery: jest.fn(), fetchAndSetGGUFSpecs: jest.fn().mockResolvedValue(undefined), - fetchModelFileSizes: jest.fn().mockResolvedValue(undefined), + fetchModelFileDetails: jest.fn().mockResolvedValue(undefined), getModelById: jest.fn(id => mockHFStore.models.find(model => model.id === id), ), diff --git a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx index ffd2fff..2d220bd 100644 --- a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx +++ b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx @@ -27,12 +27,16 @@ import {ModelSettings} from '../ModelSettings'; import {uiStore, modelStore} from '../../../store'; import {chatTemplates} from '../../../utils/chat'; -import {getModelDescription, L10nContext} from '../../../utils'; +import {Model, ModelOrigin, RootDrawerParamList} from '../../../utils/types'; +import { + getModelDescription, + L10nContext, + checkModelFileIntegrity, +} from '../../../utils'; import { COMPLETION_PARAMS_METADATA, validateCompletionSettings, } from '../../../utils/modelSettings'; -import {Model, ModelOrigin, RootDrawerParamList} from '../../../utils/types'; type ChatScreenNavigationProp = DrawerNavigationProp; @@ -52,6 +56,7 @@ export const ModelCard: React.FC = observer( const [snackbarVisible, setSnackbarVisible] = useState(false); // Snackbar visibility const [settingsModalVisible, setSettingsModalVisible] = useState(false); + const [integrityError, setIntegrityError] = useState(null); const {memoryWarning, shortMemoryWarning} = useMemoryCheck(model); const {isOk: storageOk, message: storageNOkMessage} = @@ -76,6 +81,17 @@ export const ModelCard: React.FC = observer( setTempCompletionSettings(model.completionSettings); }, [model]); + // Check integrity when model is downloaded + useEffect(() => { + if (isDownloaded) { + checkModelFileIntegrity(model, modelStore).then(({errorMessage}) => { + setIntegrityError(errorMessage); + }); + } else { + setIntegrityError(null); + } + }, [isDownloaded, model]); + const handleSettingsUpdate = useCallback((name: string, value: any) => { setTempChatTemplate(prev => { const newTemplate = @@ -286,20 +302,17 @@ export const ModelCard: React.FC = observer( ); } - const handlePress = () => { + const handlePress = async () => { if (isActiveModel) { modelStore.manualReleaseContext(); } else { - modelStore - .initContext(model) - .then(() => { - console.log('initialized'); - }) - .catch(e => { - console.log(`Error: ${e}`); - }); - if (uiStore.autoNavigatetoChat) { - navigation.navigate('Chat'); + try { + await modelStore.initContext(model); + if (uiStore.autoNavigatetoChat) { + navigation.navigate('Chat'); + } + } catch (e) { + console.log(`Error: ${e}`); } } }; @@ -310,6 +323,7 @@ export const ModelCard: React.FC = observer( icon={isActiveModel ? 'eject' : 'play-circle-outline'} mode="text" onPress={handlePress} + // disabled={!!integrityError} // for now integrity check is experimental. So won't disable the button style={styles.actionButton}> {isActiveModel ? l10n.offload : l10n.load} @@ -390,6 +404,24 @@ export const ModelCard: React.FC = observer( )} + {/* Display integrity warning if check fails */} + {integrityError && ( + + + + {integrityError} + + + )} + {isDownloading && ( <> { act(() => { fireEvent.press(getByTestId('load-button')); }); - expect(mockNavigate).toHaveBeenCalledWith('Chat'); + await waitFor(() => { + expect(mockNavigate).toHaveBeenCalledWith('Chat'); + }); }); it('handles model offload', async () => { diff --git a/src/screens/ModelsScreen/ModelCard/styles.ts b/src/screens/ModelsScreen/ModelCard/styles.ts index ff30cfd..0f8ecc3 100644 --- a/src/screens/ModelsScreen/ModelCard/styles.ts +++ b/src/screens/ModelsScreen/ModelCard/styles.ts @@ -115,8 +115,10 @@ export const createStyles = (theme: Theme) => flexDirection: 'row', alignItems: 'center', margin: 0, + marginTop: 8, }, warningContent: { + flex: 1, flexDirection: 'row', alignItems: 'center', }, @@ -127,6 +129,8 @@ export const createStyles = (theme: Theme) => warningText: { color: theme.colors.error, fontSize: 12, + flex: 1, + flexWrap: 'wrap', }, overlayButtons: { flex: 1, diff --git a/src/store/HFStore.ts b/src/store/HFStore.ts index 2d1e2e0..4941baf 100644 --- a/src/store/HFStore.ts +++ b/src/store/HFStore.ts @@ -62,6 +62,7 @@ class HFStore { ...file, size: details.size, oid: details.oid, + lfs: details.lfs, }; return { @@ -72,10 +73,10 @@ class HFStore { ); } - // Fetch the sizes of the model files - async fetchModelFileSizes(modelId: string) { + // Fetch the details (sizes, oid, lfs, ...) of the model files + async fetchModelFileDetails(modelId: string) { try { - console.log('Fetching model file sizes for', modelId); + console.log('Fetching model file details for', modelId); const fileDetails = await fetchModelFilesDetails(modelId); const model = this.models.find(m => m.id === modelId); @@ -103,7 +104,7 @@ class HFStore { async fetchModelData(modelId: string) { try { await this.fetchAndSetGGUFSpecs(modelId); - await this.fetchModelFileSizes(modelId); + await this.fetchModelFileDetails(modelId); } catch (error) { console.error('Error fetching model data:', error); } diff --git a/src/store/ModelStore.ts b/src/store/ModelStore.ts index 96747fb..6d05a80 100644 --- a/src/store/ModelStore.ts +++ b/src/store/ModelStore.ts @@ -8,10 +8,18 @@ import AsyncStorage from '@react-native-async-storage/async-storage'; import {computed, makeAutoObservable, ObservableMap, runInAction} from 'mobx'; import {CompletionParams, LlamaContext, initLlama} from '@pocketpalai/llama.rn'; +import {fetchModelFilesDetails} from '../api/hf'; + import {uiStore} from './UIStore'; import {chatSessionStore} from './ChatSessionStore'; import {defaultModels, MODEL_LIST_VERSION} from './defaultModels'; -import {deepMerge, formatBytes, hasEnoughSpace, hfAsModel} from '../utils'; +import { + deepMerge, + formatBytes, + getSHA256Hash, + hasEnoughSpace, + hfAsModel, +} from '../utils'; import { getHFDefaultSettings, @@ -410,7 +418,17 @@ class ModelStore { }; async checkFileExists(model: Model) { - const exists = await RNFS.exists(await this.getModelFullPath(model)); + const filePath = await this.getModelFullPath(model); + const exists = await RNFS.exists(filePath); + if (exists) { + // Only calculate hash if it's not already stored + if (!model.hash) { + const hash = await getSHA256Hash(filePath); + runInAction(() => { + model.hash = hash; + }); + } + } runInAction(() => { model.isDownloaded = exists; }); @@ -536,8 +554,12 @@ class ModelStore { const result = await ret.promise; if (result.statusCode === 200) { + // Calculate hash after successful download + const hash = await getSHA256Hash(downloadDest); + runInAction(() => { model.progress = 100; // Ensure progress is set to 100 upon completion + model.hash = hash; this.refreshDownloadStatuses(); }); @@ -1014,6 +1036,35 @@ class ModelStore { setIsStreaming(value: boolean) { this.isStreaming = value; } + + /** + * Fetches and updates model file details from HuggingFace. + * This is used when we need to get the lfs.oid for integrity checks. + * @param model - The model to update + * @returns Promise + */ + async fetchAndUpdateModelFileDetails(model: Model): Promise { + if (!model.hfModel?.id) { + return; + } + + try { + const fileDetails = await fetchModelFilesDetails(model.hfModel.id); + const matchingFile = fileDetails.find( + file => file.path === model.hfModelFile?.rfilename, + ); + + if (matchingFile && matchingFile.lfs) { + runInAction(() => { + if (model.hfModelFile) { + model.hfModelFile.lfs = matchingFile.lfs; + } + }); + } + } catch (error) { + console.error('Failed to fetch model file details:', error); + } + } } export const modelStore = new ModelStore(); diff --git a/src/store/__tests__/HFStore.test.ts b/src/store/__tests__/HFStore.test.ts index c2160d1..2deccb6 100644 --- a/src/store/__tests__/HFStore.test.ts +++ b/src/store/__tests__/HFStore.test.ts @@ -152,7 +152,7 @@ describe('HFStore', () => { }); }); - describe('fetchModelFileSizes', () => { + describe('fetchModelFileDetails', () => { it('should update model siblings with file sizes', async () => { hfStore.models = [mockHFModel1]; const fileDetails = [ @@ -161,7 +161,7 @@ describe('HFStore', () => { (fetchModelFilesDetails as jest.Mock).mockResolvedValueOnce(fileDetails); - await hfStore.fetchModelFileSizes(mockHFModel1.id); + await hfStore.fetchModelFileDetails(mockHFModel1.id); expect(hfStore.models[0].siblings[0].size).toBe(1111); expect(hfStore.models[0].siblings[0].oid).toBe('abc123'); @@ -171,7 +171,7 @@ describe('HFStore', () => { hfStore.models = []; (fetchModelFilesDetails as jest.Mock).mockResolvedValueOnce([]); - await hfStore.fetchModelFileSizes('non-existent-id'); + await hfStore.fetchModelFileDetails('non-existent-id'); expect(fetchModelFilesDetails).toHaveBeenCalled(); // Should not throw error diff --git a/src/store/defaultModels.ts b/src/store/defaultModels.ts index 32e8677..44f3e62 100644 --- a/src/store/defaultModels.ts +++ b/src/store/defaultModels.ts @@ -2,7 +2,7 @@ import {Model, ModelOrigin} from '../utils/types'; import {chatTemplates, defaultCompletionParams} from '../utils/chat'; import {Platform} from 'react-native'; -export const MODEL_LIST_VERSION = 10; +export const MODEL_LIST_VERSION = 11; const iosOnlyModels: Model[] = []; @@ -48,6 +48,11 @@ const crossPlatformModels: Model[] = [ url: 'https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q6_K.gguf', size: 2151393120, oid: '72f2510b5868d1141617aa16cfc4c4a61ec77262', + lfs: { + oid: 'f82c5c2230a8b452221706461eb93203443373625d96a05912d4f96c845c2775', + size: 2151393120, + pointerSize: 135, + }, canFitInStorage: true, }, }, @@ -88,6 +93,11 @@ const crossPlatformModels: Model[] = [ url: 'https://huggingface.co/TheDrummer/Gemmasutra-Mini-2B-v1-GGUF/resolve/main/Gemmasutra-Mini-2B-v1-Q6_K.gguf', size: 2151393152, oid: '05521bb238e46ebd8fb5dacf044ba14f7c15f73e', + lfs: { + oid: '34bdca7d62ae0b15366a6f3d7f457d6d8ef96343e72c5e4555b6475c4a78e839', + size: 2151393152, + pointerSize: 135, + }, canFitInStorage: true, }, }, @@ -127,6 +137,11 @@ const crossPlatformModels: Model[] = [ url: 'https://huggingface.co/MaziyarPanahi/Phi-3.5-mini-instruct-GGUF/resolve/main/Phi-3.5-mini-instruct.Q4_K_M.gguf', size: 2393232608, oid: 'a2b0f35b7504ba395e886fadd5ebc61236b9f5ec', + lfs: { + oid: '3f68916e850b107d8641d18bcd5548f0d66beef9e0a9077fe84ef28943eb7e88', + size: 2393232608, + pointerSize: 135, + }, canFitInStorage: true, }, }, @@ -161,6 +176,18 @@ const crossPlatformModels: Model[] = [ temperature: 0.5, stop: ['<|im_end|>'], }, + hfModelFile: { + rfilename: 'qwen2.5-1.5b-instruct-q8_0.gguf', + url: 'https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF/resolve/main/qwen2.5-1.5b-instruct-q8_0.gguf', + size: 1894532128, + oid: '1ec6832f8c80d58e2efa88832420ec7856e8e7c6', + lfs: { + oid: 'd7efb072e7724d25048a4fda0a3e10b04bdef5d06b1403a1c93bd9f1240a63c8', + size: 1894532128, + pointerSize: 135, + }, + canFitInStorage: true, + }, }, { id: 'Qwen/Qwen2.5-3B-Instruct-GGUF/qwen2.5-3b-instruct-q5_k_m.gguf', @@ -197,6 +224,11 @@ const crossPlatformModels: Model[] = [ url: 'https://huggingface.co/Qwen/Qwen2.5-3B-Instruct-GGUF/resolve/main/qwen2.5-3b-instruct-q5_k_m.gguf', size: 2438740384, oid: 'ffee048cd9cd76e7e4848d17fb96892023e8eca1', + lfs: { + oid: '2c63dde5f2c9ab1fd64d47dee2d34dade6ba9ff62442d1d20b5342310c982081', + size: 2438740384, + pointerSize: 135, + }, canFitInStorage: true, }, }, @@ -237,6 +269,11 @@ const crossPlatformModels: Model[] = [ url: 'https://huggingface.co/hugging-quants/Llama-3.2-1B-Instruct-Q8_0-GGUF/resolve/main/llama-3.2-1b-instruct-q8_0.gguf', size: 1321079200, oid: '4d5402369568f0bd157d8454270821341e833722', + lfs: { + oid: 'ba345c83bf5cc679c653b853c46517eea5a34f03ed2205449db77184d9ae62a9', + size: 1321079200, + pointerSize: 135, + }, canFitInStorage: true, }, }, @@ -275,6 +312,11 @@ const crossPlatformModels: Model[] = [ url: 'https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf', size: 2643853856, oid: '47d12cf8883aaa6a6cd0b47975cc026980a3af9d', + lfs: { + oid: '1771887c15fc3d327cfee6fd593553b2126e88834bf48eae50e709d3f70dd998', + size: 2643853856, + pointerSize: 135, + }, canFitInStorage: true, }, }, @@ -314,6 +356,11 @@ const crossPlatformModels: Model[] = [ url: 'https://huggingface.co/bartowski/SmolLM2-1.7B-Instruct-GGUF/resolve/main/SmolLM2-1.7B-Instruct-Q8_0.gguf', size: 1820414944, oid: 'c06316819523138df0346459118248997dac5089', + lfs: { + oid: '0c6e8955788b1253f418c354a4bdc4cf507b8cfe49c48bb10c7c58ae713cfa2a', + size: 1820414944, + pointerSize: 135, + }, canFitInStorage: true, }, }, diff --git a/src/utils/index.ts b/src/utils/index.ts index 22839d7..f49f9c7 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -6,6 +6,7 @@ import dayjs from 'dayjs'; import {MD3Theme} from 'react-native-paper'; import DeviceInfo from 'react-native-device-info'; import Blob from 'react-native/Libraries/Blob/Blob'; +import * as RNFS from '@dr.pogodin/react-native-fs'; import {l10n} from './l10n'; import {getHFDefaultSettings} from './chat'; @@ -493,3 +494,71 @@ export function hfAsModel( return _model; } export const randId = () => Math.random().toString(36).substring(2, 11); + +export const getSHA256Hash = async (filePath: string): Promise => { + try { + const hash = await RNFS.hash(filePath, 'sha256'); + return hash; + } catch (error) { + console.error('Error generating SHA256 hash:', error); + throw error; + } +}; + +/** + * Checks if a model's file integrity is valid by comparing its hash with the expected hash from HuggingFace. + * For HF models, it will automatically fetch missing file details if needed. + * We assume lfs.oid is the hash of the file. + * @param model - The model to check integrity for + * @param modelStore - The model store instance for updating model details + * @returns An object containing the integrity check result and any error message + */ +export const checkModelFileIntegrity = async ( + model: Model, + modelStore: any, +): Promise<{ + isValid: boolean; + errorMessage: string | null; +}> => { + if (!model.hash) { + // Unsure if this is needed. As modelstore will fetch the details if needed. + return { + isValid: true, + errorMessage: null, + }; + } + + // For HF models, if we don't have lfs.oid, fetch it + if (model.origin === ModelOrigin.HF && !model.hfModelFile?.lfs?.oid) { + await modelStore.fetchAndUpdateModelFileDetails(model); + } + + if (model.hash && model.hfModelFile?.lfs?.oid) { + if (model.hash !== model.hfModelFile.lfs.oid) { + try { + const filePath = await modelStore.getModelFullPath(model); + const fileStats = await RNFS.stat(filePath); + const actualSize = formatBytes(fileStats.size, 2); + const expectedSize = formatBytes(model.hfModelFile.lfs.size, 2); + return { + isValid: false, + errorMessage: + `Model file corrupted (${actualSize} vs ${expectedSize}). ` + + 'Please delete and redownload.', + }; + } catch (error) { + console.error('Error getting file size:', error); + return { + isValid: false, + errorMessage: + 'Model file corrupted. Please delete and redownload the model.', + }; + } + } + } + + return { + isValid: true, + errorMessage: null, + }; +}; diff --git a/src/utils/types.ts b/src/utils/types.ts index fbeb40a..0c9e955 100644 --- a/src/utils/types.ts +++ b/src/utils/types.ts @@ -309,6 +309,7 @@ export interface Model { completionSettings: CompletionParams; hfModelFile?: ModelFile; hfModel?: HuggingFaceModel; + hash?: string; } export type RootDrawerParamList = { @@ -327,6 +328,11 @@ export interface ModelFile { size?: number; url?: string; oid?: string; + lfs?: { + oid: string; + size: number; + pointerSize: number; + }; canFitInStorage?: boolean; }