diff --git a/modules/_all/import.go b/modules/_all/import.go index 5faf6e5..8833ca7 100644 --- a/modules/_all/import.go +++ b/modules/_all/import.go @@ -15,5 +15,6 @@ import ( _ "github.com/riking/marvin/modules/restart" _ "github.com/riking/marvin/modules/rss" _ "github.com/riking/marvin/modules/timedpin" + _ "github.com/riking/marvin/modules/usercache" _ "github.com/riking/marvin/modules/weblogin" ) diff --git a/modules/usercache/database.go b/modules/usercache/database.go new file mode 100644 index 0000000..29ab92c --- /dev/null +++ b/modules/usercache/database.go @@ -0,0 +1,111 @@ +package usercache + +import ( + "encoding/json" + + "github.com/riking/marvin/slack" + "github.com/riking/marvin/slack/rtm" +) + +const ( + sqlMigrate1 = `CREATE TABLE module_user_cache ( + user_id varchar(15) PRIMARY KEY NOT NULL, + data text + + UNIQUE(user_id) + )` + + sqlGetAllEntries = `SELECT * FROM module_user_cache` + + // $1 = slack.UserID + sqlGetEntry = `SELECT data FROM module_user_cache WHERE user_id = $1` + + // $1 = slack.UserID + // $2 = data (json encoded) + sqlUpsertEntry = `INSERT INTO module_user_cache (user_id,data) VALUES ($1, $2) + ON CONFLICT (user_id) DO UPDATE SET data = EXCLUDED.data` +) + +func (mod *UserCacheModule) GetEntry(userid slack.UserID) (slack.User, error) { + var entry slack.User + + var data string + stmt, err := mod.team.DB().Prepare(sqlGetEntry) + if err != nil { + return entry, nil + } + defer stmt.Close() + row := stmt.QueryRow(userid) + err = row.Scan(&data) + if err != nil { + return entry, nil + } + err = json.Unmarshal([]byte(userid), &entry) + if err != nil { + return entry, nil + } + return entry, nil +} + +func (mod *UserCacheModule) LoadEntries() error { + stmt, err := mod.team.DB().Query(sqlGetAllEntries) + if err != nil { + return err + } + + rtmClient := mod.team.GetRTMClient().(*rtm.Client) + + defer stmt.Close() + var arr = make([]*slack.User, 200) + for stmt.Next() { + var id string + var data string + var user *slack.User + + err = stmt.Scan(&id, &data) + if err != nil { + return err + } + err = json.Unmarshal([]byte(data), &user) + if err != nil { + continue + } + arr = append(arr, user) + if len(arr) >= 199 { + rtmClient.ReplaceManyUserObjects(arr, false) + arr = arr[:0] + } + } + if len(arr) >= 0 { + rtmClient.ReplaceManyUserObjects(arr, false) + arr = nil + } + + return stmt.Err() +} + +func (mod *UserCacheModule) UpdateEntry(userobject *slack.User) error { + return mod.UpdateEntries([]*slack.User{userobject}) +} + +func (mod *UserCacheModule) UpdateEntries(userobjects []*slack.User) error { + stmt, err := mod.team.DB().Prepare(sqlUpsertEntry) + if err != nil { + return err + } + + defer stmt.Close() + + for _, obj := range userobjects { + if obj != nil { + entrydata, err := json.Marshal(obj) + if err == nil { + _, err := stmt.Exec(obj.ID, entrydata) + if err != nil { + return err + } + } + } + } + return nil +} diff --git a/modules/usercache/usercache.go b/modules/usercache/usercache.go new file mode 100644 index 0000000..3d6d568 --- /dev/null +++ b/modules/usercache/usercache.go @@ -0,0 +1,88 @@ +package usercache + +import ( + "fmt" + "strconv" + "time" + + "github.com/riking/marvin" + "github.com/riking/marvin/slack" + "github.com/riking/marvin/slack/rtm" +) + +// interface duplicated in rtm package +type API interface { + marvin.Module + + GetEntry(userid slack.UserID) (slack.User, error) + LoadEntries() error + UpdateEntry(userobject *slack.User) error + UpdateEntries(userobjects []*slack.User) error +} + +var _ API = &UserCacheModule{} + +// --- +func init() { + marvin.RegisterModule(NewUserCacheModule) +} + +const Identifier = "usercache" + +type UserCacheModule struct { + team marvin.Team +} + +func NewUserCacheModule(t marvin.Team) marvin.Module { + mod := &UserCacheModule{ + team: t, + } + return mod +} + +func (mod *UserCacheModule) Identifier() marvin.ModuleID { + return Identifier +} + +func (mod *UserCacheModule) Load(t marvin.Team) { + t.DB().MustMigrate(Identifier, 1505192548, sqlMigrate1) + t.DB().SyntaxCheck(sqlGetAllEntries, sqlGetEntry, sqlUpsertEntry) + t.ModuleConfig(Identifier).Add("last-timestamp", "0") + t.ModuleConfig(Identifier).Add("delay", (72 * time.Hour).String()) +} + +func (mod *UserCacheModule) Enable(team marvin.Team) { + go func() { + fmt.Printf("Loading user cache entries....\n") + err := mod.LoadEntries() + if err != nil { + fmt.Printf("Error whilst updating entries: %s\n", err.Error()) + return + } + + fmt.Printf("Loaded all entries from the user cache.\n") + go mod.UpdateTask() + }() +} + +func (mod *UserCacheModule) Disable(t marvin.Team) { +} + +func (mod *UserCacheModule) UpdateTask() { + rtmClient := mod.team.GetRTMClient().(*rtm.Client) + + for { + timestr, _, _ := mod.team.ModuleConfig(Identifier).GetIsDefault("last-timestamp") + delaystr, _, _ := mod.team.ModuleConfig(Identifier).GetIsDefault("delay") + timeint, _ := strconv.ParseInt(timestr, 10, 64) + var timeres = time.Unix(timeint, 0) + delayres, err := time.ParseDuration(delaystr) + + if err != nil || timeres.Before(time.Now().Add(-delayres)) { + fmt.Printf("Repolling user list....\n") + go rtmClient.FillUsersList() + err = mod.team.ModuleConfig(Identifier).Set("last-timestamp", strconv.FormatInt(time.Now().Unix(), 10)) + } + time.Sleep(1 * time.Hour) + } +} diff --git a/slack/controller/channel_info.go b/slack/controller/channel_info.go index 241b289..71a339a 100644 --- a/slack/controller/channel_info.go +++ b/slack/controller/channel_info.go @@ -189,8 +189,15 @@ func (t *Team) cachedUserInfo(user slack.UserID) *slack.User { uID := slack.ParseUserMention(string(user)) if uID != "" { - go t.updateUserInfo(user) + // XXX HACK to get around unlocking. + t.client.MetadataLock.RUnlock() + defer t.client.MetadataLock.RLock() + user, err := t.updateUserInfo(user) + if err == nil { + return user + } } + return nil } diff --git a/slack/controller/team.go b/slack/controller/team.go index 8022a18..07ef4f3 100644 --- a/slack/controller/team.go +++ b/slack/controller/team.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "os" + "strconv" "strings" "sync" "time" @@ -318,7 +319,13 @@ func (t *Team) SlackAPIPostJSON(method string, form url.Values, result interface util.LogBadf("Slack API %s error: %s", method, err) util.LogBadf("Form for %s: %v", method, form) if slackResponse.SlackError == "ratelimited" { - time.Sleep(1*time.Second) + retryafter := resp.Header.Get("Retry-After") + intp, err := strconv.ParseInt(retryafter, 10, 64) + if err == nil { + time.Sleep(time.Duration(intp) * time.Second) + } else { + time.Sleep(1 * time.Second) + } } return errors.Wrapf(err, "Slack API %s", method) } diff --git a/slack/rtm/events.go b/slack/rtm/events.go index f3047a0..bd0f047 100644 --- a/slack/rtm/events.go +++ b/slack/rtm/events.go @@ -90,6 +90,13 @@ func (c *Client) onChannelJoin(msg slack.RTMRawMessage) { } func (c *Client) ReplaceUserObject(obj *slack.User) { + var cacheApi userCacheAPI + moduleCacheApi := c.team.GetModule("usercache") + if moduleCacheApi != nil { + cacheApi = moduleCacheApi.(userCacheAPI) + cacheApi.UpdateEntry(obj) + } + c.MetadataLock.Lock() defer c.MetadataLock.Unlock() @@ -103,7 +110,14 @@ func (c *Client) ReplaceUserObject(obj *slack.User) { c.Users = append(c.Users, obj) } -func (c *Client) ReplaceManyUserObjects(objs []*slack.User) { +func (c *Client) ReplaceManyUserObjects(objs []*slack.User, updateCache bool) { + var cacheApi userCacheAPI + moduleCacheApi := c.team.GetModule("usercache") + if moduleCacheApi != nil && updateCache { + cacheApi = moduleCacheApi.(userCacheAPI) + cacheApi.UpdateEntries(objs) + } + c.MetadataLock.Lock() defer c.MetadataLock.Unlock() diff --git a/slack/rtm/membership_info.go b/slack/rtm/membership_info.go index 542897d..08efd59 100644 --- a/slack/rtm/membership_info.go +++ b/slack/rtm/membership_info.go @@ -5,6 +5,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/riking/marvin" "github.com/riking/marvin/slack" "github.com/riking/marvin/util" ) @@ -16,6 +17,13 @@ type membershipRequest struct { C chan interface{} } +type userCacheAPI interface { + marvin.Module + + UpdateEntry(userobject *slack.User) error + UpdateEntries(userobjects []*slack.User) error +} + func (c *Client) membershipWorker() { for req := range c.membershipCh { req.C <- req.F(c.channelMembers) @@ -141,13 +149,12 @@ func (c *Client) ListIMs() []*slack.ChannelIM { func (c *Client) fetchTeamInfo() { go c.fillGroupList() - go c.fillUsersList() // TODO(kyork): list normal channels too // TODO(kyork): use the listChannels() from logger module } -func (c *Client) fillUsersList() { +func (c *Client) FillUsersList() { var response struct { slack.APIResponse Members []*slack.User @@ -165,16 +172,17 @@ func (c *Client) fillUsersList() { util.LogError(errors.Wrapf(err, "[%s] Could not retrieve users list", c.Team.Domain)) } - for response.PageInfo.NextCursor != "" { - c.ReplaceManyUserObjects(response.Members) - time.Sleep(2*time.Second) + c.ReplaceManyUserObjects(response.Members, true) + for response.PageInfo.NextCursor != "" { + time.Sleep(2 * time.Second) form.Set("cursor", response.PageInfo.NextCursor) err := c.team.SlackAPIPostJSON("users.list", form, &response) if err != nil { util.LogError(errors.Wrapf(err, "[%s] Could not retrieve users list", c.Team.Domain)) - break + continue } + c.ReplaceManyUserObjects(response.Members, true) } }