Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 56 additions & 78 deletions providers/luadns/luadnsProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,84 +134,58 @@ func (l *luadnsProvider) GetZoneRecordsCorrections(dc *models.DomainConfig, reco
return nil, 0, err
}

var corrs []*models.Correction

changes, actualChangeCount, err := diff2.ByRecord(records, dc, nil)
changes, actualChangeCount, err := diff2.ByRecordSet(records, dc, nil)
if err != nil {
return nil, 0, err
}

for _, change := range changes {
msg := change.Msgs[0]
switch change.Type {
case diff2.REPORT:
corrs = []*models.Correction{{Msg: change.MsgsJoined}}
corrections = append(corrections, &models.Correction{Msg: change.MsgsJoined})
case diff2.CREATE:
corrs = l.makeCreateCorrection(change.New[0], zone, msg)
req := recordsToNative(change.New)
corrections = append(corrections, &models.Correction{
F: func() error {
if err := l.rateLimiter.Wait(l.ctx); err != nil {
return err
}
_, err := l.provider.CreateManyRecords(l.ctx, zone, req)
return err
},
Msg: change.MsgsJoined,
})
case diff2.CHANGE:
corrs = l.makeChangeCorrection(change.Old[0], change.New[0], zone, msg)
req := recordsToNative(change.New)
corrections = append(corrections, &models.Correction{
F: func() error {
if err := l.rateLimiter.Wait(l.ctx); err != nil {
return err
}
_, err := l.provider.UpdateManyRecords(l.ctx, zone, req)
return err
},
Msg: change.MsgsJoined,
})
case diff2.DELETE:
corrs = l.makeDeleteCorrection(change.Old[0], zone, msg)
req := recordsToNative(change.Old)
corrections = append(corrections, &models.Correction{
F: func() error {
if err := l.rateLimiter.Wait(l.ctx); err != nil {
return err
}
_, err := l.provider.DeleteManyRecords(l.ctx, zone, req)
return err
},
Msg: change.MsgsJoined,
})
default:
panic(fmt.Sprintf("unhandled inst.Type %s", change.Type))
}
corrections = append(corrections, corrs...)
}
return corrections, actualChangeCount, nil
}

func (l *luadnsProvider) makeCreateCorrection(newrec *models.RecordConfig, zone *api.Zone, msg string) []*models.Correction {
req := recordsToNative(newrec)
return []*models.Correction{{
Msg: msg,
F: func() error {
if err := l.rateLimiter.Wait(l.ctx); err != nil {
return err
}
_, err := l.provider.CreateRecord(l.ctx, zone, req)
if err != nil {
return err
}
return nil
},
}}
}

func (l *luadnsProvider) makeChangeCorrection(oldrec *models.RecordConfig, newrec *models.RecordConfig, zone *api.Zone, msg string) []*models.Correction {
recordID := oldrec.Original.(*api.Record).ID
req := recordsToNative(newrec)
return []*models.Correction{{
Msg: fmt.Sprintf("%s, LuaDNS ID: %d", msg, recordID),
F: func() error {
if err := l.rateLimiter.Wait(l.ctx); err != nil {
return err
}
_, err := l.provider.UpdateRecord(l.ctx, zone, recordID, req)
if err != nil {
return err
}
return nil
},
}}
}

func (l *luadnsProvider) makeDeleteCorrection(deleterec *models.RecordConfig, zone *api.Zone, msg string) []*models.Correction {
recordID := deleterec.Original.(*api.Record).ID
return []*models.Correction{{
Msg: fmt.Sprintf("%s, LuaDNS ID: %d", msg, recordID),
F: func() error {
if err := l.rateLimiter.Wait(l.ctx); err != nil {
return err
}
_, err := l.provider.DeleteRecord(l.ctx, zone, recordID)
if err != nil {
return err
}
return nil
},
}}
}

// EnsureZoneExists creates a zone if it does not exist.
func (l *luadnsProvider) EnsureZoneExists(domain string, metadata map[string]string) error {
if l.zones == nil {
Expand Down Expand Up @@ -291,25 +265,29 @@ func nativeToRecord(domain string, r *api.Record) (*models.RecordConfig, error)
return rc, err
}

func recordsToNative(rc *models.RecordConfig) *api.Record {
r := &api.Record{
Name: rc.GetLabelFQDN() + ".",
Type: rc.Type,
TTL: rc.TTL,
}
switch rtype := rc.Type; rtype {
case "TXT":
r.Content = rc.GetTargetTXTJoined()
case "HTTPS":
content := fmt.Sprintf("%d %s %s", rc.SvcPriority, rc.GetTargetField(), rc.SvcParams)
if rc.SvcParams == "" {
content = content[:len(content)-1]
func recordsToNative(rc []*models.RecordConfig) []*api.RR {
var rrs []*api.RR
for _, rec := range rc {
r := &api.RR{
Name: rec.GetLabelFQDN() + ".",
Type: rec.Type,
TTL: rec.TTL,
}
r.Content = content
default:
r.Content = rc.GetTargetCombined()
switch rtype := rec.Type; rtype {
case "TXT":
r.Content = rec.GetTargetTXTJoined()
case "HTTPS":
content := fmt.Sprintf("%d %s %s", rec.SvcPriority, rec.GetTargetField(), rec.SvcParams)
if rec.SvcParams == "" {
content = content[:len(content)-1]
}
r.Content = content
default:
r.Content = rec.GetTargetCombined()
}
rrs = append(rrs, r)
}
return r
return rrs
}

func checkNS(dc *models.DomainConfig) {
Expand Down
Loading