Skip to content

Commit 105c0df

Browse files
committedApr 5, 2023
add get-models api
1 parent fb1bb56 commit 105c0df

10 files changed

+166
-24
lines changed
 

‎chat.go

+31-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"io/ioutil"
8+
"log"
89
"net/http"
910
"strings"
1011

@@ -16,9 +17,11 @@ import (
1617
// ChatText chat reply with text format
1718
type ChatText struct {
1819
data string // event data
19-
ConversationID string // conversation context id
20-
MessageID string // current message id, can used as next chat's parent_message_id
21-
Content string // text content
20+
ConversationID string `json:"conversation_id"` // conversation context id
21+
MessageID string `json:"message_id"` // current message id, can used as next chat's parent_message_id
22+
Content string `json:"content"` // text content
23+
Model string `json:"model"` // chat model
24+
CreatedAt int64 `json:"created_at"` // message create_time
2225
}
2326

2427
// ChatStream chat reply with sream
@@ -27,9 +30,16 @@ type ChatStream struct {
2730
Err error // error message
2831
}
2932

33+
// ChatText raw data
34+
func (c *ChatText) Raw() string {
35+
return c.data
36+
}
37+
3038
// ChatText format to string
3139
func (c *ChatText) String() string {
32-
return c.data
40+
b, _ := json.Marshal(c)
41+
42+
return string(b)
3343
}
3444

3545
// GetChatText will return text message
@@ -72,6 +82,19 @@ func (c *Client) GetChatStream(message string, args ...string) (*ChatStream, err
7282
return nil, fmt.Errorf("send message failed: %v", err)
7383
}
7484

85+
contentType := resp.Header.Get("Content-Type")
86+
// not event-strem response
87+
if !strings.HasPrefix(contentType, "text/event-stream") {
88+
defer resp.Body.Close()
89+
90+
body, _ := ioutil.ReadAll(resp.Body)
91+
if c.opts.Debug {
92+
log.Printf("http response info: %s\n", body)
93+
}
94+
95+
return nil, fmt.Errorf("response failed: [%s] %s", resp.Status, body)
96+
}
97+
7598
chatStream := &ChatStream{
7699
Stream: make(chan *ChatText),
77100
Err: nil,
@@ -119,6 +142,8 @@ func (c *Client) parseChatText(text string) (*ChatText, error) {
119142
conversationID := res.Get("conversation_id").String()
120143
messageID := res.Get("message.id").String()
121144
content := res.Get("message.content.parts.0").String()
145+
model := res.Get("message.metadata.model_slug").String()
146+
createdAt := res.Get("message.create_time").Int()
122147

123148
if conversationID == "" || messageID == "" {
124149
return nil, fmt.Errorf("invalid chat text")
@@ -129,6 +154,8 @@ func (c *Client) parseChatText(text string) (*ChatText, error) {
129154
ConversationID: conversationID,
130155
MessageID: messageID,
131156
Content: content,
157+
Model: model,
158+
CreatedAt: createdAt,
132159
}, nil
133160
}
134161

‎client.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ const (
1212
BASE_URI = "https://chat.openai.com"
1313
AUTH_SESSION_URI = "https://chat.openai.com/api/auth/session"
1414
CONVERSATION_URI = "https://chat.openai.com/backend-api/conversation"
15+
GET_MODELS_URI = "https://chat.openai.com/backend-api/models"
1516
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36"
1617
EOF_TEXT = "[DONE]"
1718
)
@@ -81,9 +82,16 @@ func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
8182
resp, err := c.httpCli.Do(req)
8283

8384
if c.opts.Debug {
84-
respInfo, _ := httputil.DumpResponse(resp, true)
85+
respInfo, _ := httputil.DumpResponse(resp, false)
8586
log.Printf("http response info: \n%s\n", respInfo)
8687
}
8788

8889
return resp, err
8990
}
91+
92+
// WithModel: set chat model
93+
func (c *Client) WithModel(model string) *Client {
94+
c.opts.Model = model
95+
96+
return c
97+
}

‎examples/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
init_test.go

‎examples/chat_test.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
)
77

88
func ExampleClient_GetChatText() {
9+
cli := getClient()
10+
911
message := "say hello to me"
1012

1113
log.Printf("start get chat text")
@@ -23,6 +25,8 @@ func ExampleClient_GetChatText() {
2325
}
2426

2527
func ExampleClient_GetContinuousChatText() {
28+
cli := getClient()
29+
2630
message := "say hello to me"
2731

2832
log.Printf("start get chat text")
@@ -54,18 +58,20 @@ func ExampleClient_GetContinuousChatText() {
5458
}
5559

5660
func ExampleClient_GetChatStream() {
61+
cli := getClient()
62+
5763
message := "say hello to me"
5864

5965
log.Printf("start get chat stream")
6066

61-
stream, err := cli.GetChatStream(message)
67+
stream, err := cli.WithModel("gpt-4").GetChatStream(message)
6268
if err != nil {
6369
log.Fatalf("get chat stream failed: %v\n", err)
6470
}
6571

6672
var answer string
6773
for text := range stream.Stream {
68-
log.Printf("stream text: %s\n", text.Content)
74+
log.Printf("stream text: %s\n", text)
6975
answer = text.Content
7076
}
7177

‎examples/client_test.go

+15-11
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,25 @@ import (
88
chatgpt "github.com/chatgp/chatgpt-go"
99
)
1010

11-
// chatgpt client
12-
var cli *chatgpt.Client
11+
var (
12+
debug bool
13+
accessToken string
14+
sessionToken string
15+
cfValue string
16+
puid string
17+
)
1318

1419
func ExampleNewClient() {
15-
fmt.Printf("%T", cli)
20+
fmt.Printf("%T", getClient())
1621

1722
// Output: *chatgpt.Client
1823
}
1924

20-
func init() {
21-
token := `copy-from-cookies`
22-
cfValue := "copy-from-cookies"
23-
puid := "copy-from-cookies"
24-
25+
func getClient() *chatgpt.Client {
2526
cookies := []*http.Cookie{
2627
{
2728
Name: "__Secure-next-auth.session-token",
28-
Value: token,
29+
Value: sessionToken,
2930
},
3031
{
3132
Name: "cf_clearance",
@@ -37,9 +38,12 @@ func init() {
3738
},
3839
}
3940

40-
cli = chatgpt.NewClient(
41-
chatgpt.WithDebug(false),
41+
cli := chatgpt.NewClient(
4242
chatgpt.WithTimeout(60*time.Second),
43+
chatgpt.WithDebug(debug),
44+
chatgpt.WithAccessToken(accessToken),
4345
chatgpt.WithCookies(cookies),
4446
)
47+
48+
return cli
4549
}

‎examples/models_test.go

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package examples
2+
3+
import (
4+
"fmt"
5+
"log"
6+
)
7+
8+
func ExampleClient_GetModels() {
9+
cli := getClient()
10+
11+
res, cookies, err := cli.GetModels()
12+
13+
if err != nil {
14+
log.Fatalf("get models failed: %v\n", err)
15+
}
16+
17+
for _, v := range res.Get("models").Array() {
18+
log.Printf("model: %s\n", v.String())
19+
}
20+
21+
for _, v := range cookies {
22+
log.Printf("cookie: %s, %s, expires: %v\n", v.Name, v.Value, v.Expires)
23+
}
24+
25+
fmt.Println("xxx")
26+
// Output: xxx
27+
}

‎examples/session_test.go

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package examples
2+
3+
import (
4+
"fmt"
5+
"log"
6+
)
7+
8+
func ExampleClient_AuthSession() {
9+
cli := getClient()
10+
11+
res, err := cli.AuthSession()
12+
13+
if err != nil {
14+
log.Fatalf("auth session failed: %v\n", err)
15+
}
16+
17+
log.Printf("session: %s\n", res)
18+
19+
fmt.Println("xxx")
20+
// Output: xxx
21+
}

‎models.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package chatgpt
2+
3+
import (
4+
"fmt"
5+
"io/ioutil"
6+
"log"
7+
"net/http"
8+
9+
"github.com/tidwall/gjson"
10+
)
11+
12+
// GetModels to get all availabel model s
13+
func (c *Client) GetModels() (*gjson.Result, []*http.Cookie, error) {
14+
req, err := http.NewRequest(http.MethodGet, GET_MODELS_URI, nil)
15+
if err != nil {
16+
return nil, nil, fmt.Errorf("new request failed: %v", err)
17+
}
18+
19+
accessToken, err := c.getAccessToken()
20+
if err != nil {
21+
return nil, nil, fmt.Errorf("get accessToken failed: %v", err)
22+
}
23+
24+
bearerToken := fmt.Sprintf("Bearer %s", accessToken)
25+
req.Header.Set("Authorization", bearerToken)
26+
27+
resp, err := c.doRequest(req)
28+
29+
if err != nil {
30+
return nil, nil, fmt.Errorf("do request failed: %v", err)
31+
}
32+
defer resp.Body.Close()
33+
34+
body, err := ioutil.ReadAll(resp.Body)
35+
if err != nil {
36+
return nil, nil, fmt.Errorf("read response body failed: %v", err)
37+
}
38+
39+
if c.opts.Debug {
40+
log.Printf("http response info: %s\n", body)
41+
}
42+
43+
res := gjson.ParseBytes(body)
44+
45+
return &res, resp.Cookies(), nil
46+
}

‎session.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package chatgpt
33
import (
44
"fmt"
55
"io/ioutil"
6+
"log"
67
"net/http"
78

89
"github.com/tidwall/gjson"
910
)
1011

11-
// authSession will check if session is expired and return a new accessToken
12-
func (c *Client) authSession() (*gjson.Result, error) {
12+
// AuthSession will check if session is expired and return a new accessToken
13+
func (c *Client) AuthSession() (*gjson.Result, error) {
1314
req, err := http.NewRequest(http.MethodGet, AUTH_SESSION_URI, nil)
1415
if err != nil {
1516
return nil, fmt.Errorf("new request failed: %v", err)
@@ -27,10 +28,11 @@ func (c *Client) authSession() (*gjson.Result, error) {
2728
return nil, fmt.Errorf("read response body failed: %v", err)
2829
}
2930

30-
res := gjson.ParseBytes(body)
31-
if res.String() == "" {
32-
return nil, fmt.Errorf("parse response body failed")
31+
if c.opts.Debug {
32+
log.Printf("http response info: %s\n", body)
3333
}
3434

35+
res := gjson.ParseBytes(body)
36+
3537
return &res, nil
3638
}

‎token.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func (c *Client) getAccessToken() (string, error) {
1111
}
1212

1313
// fetch new accessToken
14-
res, err := c.authSession()
14+
res, err := c.AuthSession()
1515
if err != nil {
1616
return "", fmt.Errorf("fetch new accessToken failed: %v", err)
1717
}

0 commit comments

Comments
 (0)
Please sign in to comment.