diff --git a/cmd/cmd.go b/cmd/cmd.go index 3bb8b06ec10..dc288e43906 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -680,6 +680,17 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { return err } + // Unload the model if it's running before deletion + opts := &runOptions{ + Model: args[0], + KeepAlive: &api.Duration{Duration: 0}, + } + if err := loadOrUnloadModel(cmd, opts); err != nil { + if !strings.Contains(err.Error(), "not found") { + return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err) + } + } + for _, name := range args { req := api.DeleteRequest{Name: name} if err := client.Delete(cmd.Context(), &req); err != nil { diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 0f8863cc770..9d23f3e963e 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -2,11 +2,17 @@ package cmd import ( "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" "os" "path/filepath" + "strings" "testing" "github.com/google/go-cmp/cmp" + "github.com/spf13/cobra" "github.com/ollama/ollama/api" ) @@ -204,3 +210,63 @@ Weigh anchor! } }) } + +func TestDeleteHandler(t *testing.T) { + stopped := false + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/delete" && r.Method == http.MethodDelete { + var req api.DeleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if req.Name == "test-model" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusNotFound) + } + return + } + if r.URL.Path == "/api/generate" && r.Method == http.MethodPost { + var req api.GenerateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if req.Model == "test-model" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{ + Done: true, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + stopped = true + return + } else { + w.WriteHeader(http.StatusNotFound) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{ + Done: false, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + } + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(context.TODO()) + if err := DeleteHandler(cmd, []string{"test-model"}); err != nil { + t.Fatalf("DeleteHandler failed: %v", err) + } + if !stopped { + t.Fatal("Model was not stopped before deletion") + } + + err := DeleteHandler(cmd, []string{"test-model-not-found"}) + if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") { + t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) + } +} diff --git a/server/routes.go b/server/routes.go index 6bd3a93f5a9..23f9dbfd9b8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -693,7 +693,12 @@ func (s *Server) DeleteHandler(c *gin.Context) { m, err := ParseNamedManifest(n) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } return } diff --git a/server/routes_test.go b/server/routes_test.go index bffcea205d9..f7a7a22bead 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -15,9 +15,6 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" @@ -30,24 +27,47 @@ func createTestFile(t *testing.T, name string) string { t.Helper() f, err := os.CreateTemp(t.TempDir(), name) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } defer f.Close() err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to write to file: %v", err) + } err = binary.Write(f, binary.LittleEndian, uint32(3)) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to write to file: %v", err) + } err = binary.Write(f, binary.LittleEndian, uint64(0)) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to write to file: %v", err) + } err = binary.Write(f, binary.LittleEndian, uint64(0)) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to write to file: %v", err) + } return f.Name() } +// equalStringSlices checks if two slices of strings are equal. +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + func Test_Routes(t *testing.T) { type testCase struct { Name string @@ -64,12 +84,16 @@ func Test_Routes(t *testing.T) { r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) modelfile, err := parser.ParseFile(r) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to parse file: %v", err) + } fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to create model: %v", err) + } } testCases := []testCase{ @@ -81,10 +105,17 @@ func Test_Routes(t *testing.T) { }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json; charset=utf-8", contentType) + if contentType != "application/json; charset=utf-8" { + t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType) + } body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body)) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version) + if string(body) != expectedBody { + t.Errorf("expected body %s, got %s", expectedBody, string(body)) + } }, }, { @@ -93,17 +124,24 @@ func Test_Routes(t *testing.T) { Path: "/api/tags", Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json; charset=utf-8", contentType) + if contentType != "application/json; charset=utf-8" { + t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType) + } body, err := io.ReadAll(resp.Body) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } var modelList api.ListResponse err = json.Unmarshal(body, &modelList) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } - assert.NotNil(t, modelList.Models) - assert.Empty(t, len(modelList.Models)) + if modelList.Models == nil || len(modelList.Models) != 0 { + t.Errorf("expected empty model list, got %v", modelList.Models) + } }, }, { @@ -112,16 +150,23 @@ func Test_Routes(t *testing.T) { Path: "/v1/models", Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json", contentType) + if contentType != "application/json" { + t.Errorf("expected content type application/json, got %s", contentType) + } body, err := io.ReadAll(resp.Body) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } var modelList openai.ListCompletion err = json.Unmarshal(body, &modelList) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } - assert.Equal(t, "list", modelList.Object) - assert.Empty(t, modelList.Data) + if modelList.Object != "list" || len(modelList.Data) != 0 { + t.Errorf("expected empty model list, got %v", modelList.Data) + } }, }, { @@ -133,18 +178,92 @@ func Test_Routes(t *testing.T) { }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json; charset=utf-8", contentType) + if contentType != "application/json; charset=utf-8" { + t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType) + } body, err := io.ReadAll(resp.Body) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } - assert.NotContains(t, string(body), "expires_at") + if strings.Contains(string(body), "expires_at") { + t.Errorf("response body should not contain 'expires_at'") + } var modelList api.ListResponse err = json.Unmarshal(body, &modelList) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } + + if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" { + t.Errorf("expected model 'test-model:latest', got %v", modelList.Models) + } + }, + }, + { + Name: "Delete Model Handler", + Method: http.MethodDelete, + Path: "/api/delete", + Setup: func(t *testing.T, req *http.Request) { + createTestModel(t, "model-to-delete") - assert.Len(t, modelList.Models, 1) - assert.Equal(t, "test-model:latest", modelList.Models[0].Name) + deleteReq := api.DeleteRequest{ + Name: "model-to-delete", + } + jsonData, err := json.Marshal(deleteReq) + if err != nil { + t.Fatalf("failed to marshal delete request: %v", err) + } + + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status code 200, got %d", resp.StatusCode) + } + + // Verify the model was deleted + _, err := GetModel("model-to-delete") + if err == nil || !os.IsNotExist(err) { + t.Errorf("expected model to be deleted, got error %v", err) + } + }, + }, + { + Name: "Delete Non-existent Model", + Method: http.MethodDelete, + Path: "/api/delete", + Setup: func(t *testing.T, req *http.Request) { + deleteReq := api.DeleteRequest{ + Name: "non-existent-model", + } + jsonData, err := json.Marshal(deleteReq) + if err != nil { + t.Fatalf("failed to marshal delete request: %v", err) + } + + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + if resp.StatusCode != http.StatusNotFound { + t.Errorf("expected status code 404, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + var errorResp map[string]string + err = json.Unmarshal(body, &errorResp) + if err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } + + if !strings.Contains(errorResp["error"], "not found") { + t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"]) + } }, }, { @@ -153,17 +272,23 @@ func Test_Routes(t *testing.T) { Path: "/v1/models", Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json", contentType) + if contentType != "application/json" { + t.Errorf("expected content type application/json, got %s", contentType) + } body, err := io.ReadAll(resp.Body) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } var modelList openai.ListCompletion err = json.Unmarshal(body, &modelList) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } - assert.Len(t, modelList.Data, 1) - assert.Equal(t, "test-model:latest", modelList.Data[0].Id) - assert.Equal(t, "library", modelList.Data[0].OwnedBy) + if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" { + t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data) + } }, }, { @@ -180,20 +305,32 @@ func Test_Routes(t *testing.T) { Stream: &stream, } jsonData, err := json.Marshal(createReq) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to marshal create request: %v", err) + } req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json", contentType) + if contentType != "application/json" { + t.Errorf("expected content type application/json, got %s", contentType) + } _, err := io.ReadAll(resp.Body) - require.NoError(t, err) - assert.Equal(t, 200, resp.StatusCode) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + if resp.StatusCode != http.StatusOK { // Updated line + t.Errorf("expected status code 200, got %d", resp.StatusCode) + } model, err := GetModel("t-bone") - require.NoError(t, err) - assert.Equal(t, "t-bone:latest", model.ShortName) + if err != nil { + t.Fatalf("failed to get model: %v", err) + } + if model.ShortName != "t-bone:latest" { + t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName) + } }, }, { @@ -207,14 +344,20 @@ func Test_Routes(t *testing.T) { Destination: "beefsteak", } jsonData, err := json.Marshal(copyReq) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to marshal copy request: %v", err) + } req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { model, err := GetModel("beefsteak") - require.NoError(t, err) - assert.Equal(t, "beefsteak:latest", model.ShortName) + if err != nil { + t.Fatalf("failed to get model: %v", err) + } + if model.ShortName != "beefsteak:latest" { + t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName) + } }, }, { @@ -225,18 +368,26 @@ func Test_Routes(t *testing.T) { createTestModel(t, "show-model") showReq := api.ShowRequest{Model: "show-model"} jsonData, err := json.Marshal(showReq) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to marshal show request: %v", err) + } req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json; charset=utf-8", contentType) + if contentType != "application/json; charset=utf-8" { + t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType) + } body, err := io.ReadAll(resp.Body) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } var showResp api.ShowResponse err = json.Unmarshal(body, &showResp) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } var params []string paramsSplit := strings.Split(showResp.Parameters, "\n") @@ -250,8 +401,16 @@ func Test_Routes(t *testing.T) { "stop \"foo\"", "top_p 0.9", } - assert.Equal(t, expectedParams, params) - assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0") + if !equalStringSlices(params, expectedParams) { + t.Errorf("expected parameters %v, got %v", expectedParams, params) + } + paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64) + if !ok { + t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"]) + } + if math.Abs(paramCount) > 1e-9 { + t.Errorf("expected parameter count to be 0, got %f", paramCount) + } }, }, { @@ -260,16 +419,23 @@ func Test_Routes(t *testing.T) { Path: "/v1/models/show-model", Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") - assert.Equal(t, "application/json", contentType) + if contentType != "application/json" { + t.Errorf("expected content type application/json, got %s", contentType) + } body, err := io.ReadAll(resp.Body) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } var retrieveResp api.RetrieveModelResponse err = json.Unmarshal(body, &retrieveResp) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } - assert.Equal(t, "show-model", retrieveResp.Id) - assert.Equal(t, "library", retrieveResp.OwnedBy) + if retrieveResp.Id != "show-model" || retrieveResp.OwnedBy != "library" { + t.Errorf("expected model 'show-model' owned by 'library', got %v", retrieveResp) + } }, }, } @@ -286,14 +452,18 @@ func Test_Routes(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { u := httpSrv.URL + tc.Path req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } if tc.Setup != nil { tc.Setup(t, req) } resp, err := httpSrv.Client().Do(req) - require.NoError(t, err) + if err != nil { + t.Fatalf("failed to do request: %v", err) + } defer resp.Body.Close() if tc.Expected != nil {