Skip to content

Commit

Permalink
fix(seed): generate random seed per-request if -1 is set (#1952)
Browse files Browse the repository at this point in the history
* fix(seed): generate random seed per-request if -1 is set

Also update ci with new workflows and allow the aio tests to run with an
api key

Signed-off-by: Ettore Di Giacinto <[email protected]>

* docs(openvino): Add OpenVINO example

Signed-off-by: Ettore Di Giacinto <[email protected]>

---------

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Apr 3, 2024
1 parent 93cfec3 commit ff77d3b
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 7 deletions.
19 changes: 19 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
enhancements:
- head-branch: ['^feature', 'feature']

kind/documentation:
- any:
- changed-files:
- any-glob-to-any-file: 'docs/*'
- changed-files:
- any-glob-to-any-file: '*.md'

examples:
- any:
- changed-files:
- any-glob-to-any-file: 'examples/*'

ci:
- any:
- changed-files:
- any-glob-to-any-file: '.github/*'
12 changes: 12 additions & 0 deletions .github/workflows/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: "Pull Request Labeler"
on:
- pull_request_target

jobs:
labeler:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5
27 changes: 27 additions & 0 deletions .github/workflows/secscan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: "Security Scan"

# Run workflow each time code is pushed to your repository and on a schedule.
# The scheduled workflow runs every at 00:00 on Sunday UTC time.
on:
push:
schedule:
- cron: '0 0 * * 0'

jobs:
tests:
runs-on: ubuntu-latest
env:
GO111MODULE: on
steps:
- name: Checkout Source
uses: actions/checkout@v3
- name: Run Gosec Security Scanner
uses: securego/gosec@master
with:
# we let the report trigger content trigger a failure using the GitHub Security features.
args: '-no-fail -fmt sarif -out results.sarif ./...'
- name: Upload SARIF file
uses: github/codeql-action/upload-sarif@v2
with:
# Path to SARIF file relative to the root of the repository
sarif_file: results.sarif
15 changes: 12 additions & 3 deletions core/backend/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package backend

import (
"math/rand"
"os"
"path/filepath"

Expand Down Expand Up @@ -33,12 +34,20 @@ func modelOpts(c config.BackendConfig, so *config.ApplicationConfig, opts []mode
return opts
}

func getSeed(c config.BackendConfig) int32 {
seed := int32(*c.Seed)
if seed == config.RAND_SEED {
seed = rand.Int31()
}

return seed
}

func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}

return &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
Expand All @@ -54,7 +63,7 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(*c.ContextSize),
Seed: int32(*c.Seed),
Seed: getSeed(c),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
DraftModel: c.DraftModel,
Expand Down Expand Up @@ -129,7 +138,7 @@ func gRPCPredictOpts(c config.BackendConfig, modelPath string) *pb.PredictOption
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: int32(*c.Seed),
Seed: getSeed(c),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: *c.MMlock,
MMap: *c.MMap,
Expand Down
7 changes: 5 additions & 2 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"io/fs"
"math/rand"
"os"
"path/filepath"
"sort"
Expand All @@ -20,6 +19,10 @@ import (
"github.com/charmbracelet/glamour"
)

const (
RAND_SEED = -1
)

type BackendConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
Expand Down Expand Up @@ -218,7 +221,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {

if cfg.Seed == nil {
// random number generator seed
defaultSeed := int(rand.Int31())
defaultSeed := RAND_SEED
cfg.Seed = &defaultSeed
}

Expand Down
32 changes: 32 additions & 0 deletions docs/content/docs/features/text-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ The backend will automatically download the required files in order to run the m
| Type | Description |
| --- | --- |
| `AutoModelForCausalLM` | `AutoModelForCausalLM` is a model that can be used to generate sequences. |
| `OVModelForCausalLM` | for OpenVINO models |
| N/A | Defaults to `AutoModel` |


Expand All @@ -324,4 +325,35 @@ curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d
"prompt": "Hello, my name is",
"temperature": 0.1, "top_p": 0.1
}'
```

#### Examples

##### OpenVINO

A model configuration file for openvion and starling model:

```yaml
name: starling-openvino
backend: transformers
parameters:
model: fakezeta/Starling-LM-7B-beta-openvino-int8
context_size: 8192
threads: 6
f16: true
type: OVModelForCausalLM
stopwords:
- <|end_of_turn|>
- <|endoftext|>
prompt_cache_path: "cache"
prompt_cache_all: true
template:
chat_message: |
{{if eq .RoleName "system"}}{{.Content}}<|end_of_turn|>{{end}}{{if eq .RoleName "assistant"}}<|end_of_turn|>GPT4 Correct Assistant: {{.Content}}<|end_of_turn|>{{end}}{{if eq .RoleName "user"}}GPT4 Correct User: {{.Content}}{{end}}
chat: |
{{.Input}}<|end_of_turn|>GPT4 Correct Assistant:
completion: |
{{.Input}}
```
5 changes: 3 additions & 2 deletions tests/e2e-aio/e2e_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var containerImageTag = os.Getenv("LOCALAI_IMAGE_TAG")
var modelsDir = os.Getenv("LOCALAI_MODELS_DIR")
var apiPort = os.Getenv("LOCALAI_API_PORT")
var apiEndpoint = os.Getenv("LOCALAI_API_ENDPOINT")
var apiKey = os.Getenv("LOCALAI_API_KEY")

func TestLocalAI(t *testing.T) {
RegisterFailHandler(Fail)
Expand All @@ -38,11 +39,11 @@ var _ = BeforeSuite(func() {
var defaultConfig openai.ClientConfig
if apiEndpoint == "" {
startDockerImage()
defaultConfig = openai.DefaultConfig("")
defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://localhost:" + apiPort + "/v1"
} else {
fmt.Println("Default ", apiEndpoint)
defaultConfig = openai.DefaultConfig("")
defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = apiEndpoint
}

Expand Down

0 comments on commit ff77d3b

Please sign in to comment.