From 055d724e8c1c93a357c35f1e12e2ef84e67e3945 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Wed, 31 Jul 2024 13:11:00 +0200 Subject: [PATCH 01/19] feat: LCS-based movement implementation This new algorithm implements generation of move actions by utilizing a longest common substring algorithm (LCS). First, based on the position type, we generate a list describing the expected order of the elements. LCS algorithm is then used to find the longest sequence of items that is shared between the expected list, and an actual list (e.g. a list of entries from the server). Once longest sequence is known, we figure out the least amount of moves to translate existing list into its expected form, and those movements are returned back. --- assets/pango/movement/movement.go | 380 +++++++++++++++++++ assets/pango/movement/movement_suite_test.go | 18 + assets/pango/movement/movement_test.go | 147 +++++++ 3 files changed, 545 insertions(+) create mode 100644 assets/pango/movement/movement.go create mode 100644 assets/pango/movement/movement_suite_test.go create mode 100644 assets/pango/movement/movement_test.go diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go new file mode 100644 index 00000000..64036bcc --- /dev/null +++ b/assets/pango/movement/movement.go @@ -0,0 +1,380 @@ +package movement + +import ( + "fmt" + "log/slog" + "slices" +) + +var _ = slog.LevelDebug + +type Movable interface { + EntryName() string +} + +type MoveAction struct { + EntryName string + Where string + Destination string +} + +type Position interface { + Move(entries []Movable, existing []Movable) ([]MoveAction, error) +} + +type PositionTop struct{} + +type PositionBottom struct{} + +type PositionBefore struct { + Directly bool + Pivot Movable +} + +type PositionAfter struct { + Directly bool + Pivot Movable +} + +func removeEntriesFromExisting(entries []Movable, filterFn func(entry Movable) bool) []Movable { + entryNames := make(map[string]bool, len(entries)) + for _, elt := range entries { + entryNames[elt.EntryName()] = true + } + + filtered := make([]Movable, len(entries)) + copy(filtered, entries) + + filtered = slices.DeleteFunc(filtered, filterFn) + + return filtered +} + +func findPivotIdx(entries []Movable, pivot Movable) int { + return slices.IndexFunc(entries, func(entry Movable) bool { + if entry.EntryName() == pivot.EntryName() { + return true + } + + return false + }) + +} + +type movementType int + +const ( + movementBefore movementType = iota + movementAfter +) + +func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, direct bool, movement movementType) ([]MoveAction, error) { + existingLen := len(existing) + existingIdxMap := make(map[Movable]int, existingLen) + + for idx, elt := range existing { + existingIdxMap[elt] = idx + } + + pivotIdx := findPivotIdx(existing, pivot) + if pivotIdx == -1 { + return nil, fmt.Errorf("pivot point not found in the list of existing items") + } + + if !direct { + movementRequired := false + entriesLen := len(entries) + loop: + for i := 0; i < entriesLen; i++ { + + // For any given entry in the list of entries to move check if the entry + // index is at or after pivot point index, which will require movement + // set to be generated. + existingEntryIdx := existingIdxMap[entries[i]] + switch movement { + case movementBefore: + if existingEntryIdx >= pivotIdx { + movementRequired = true + break + } + case movementAfter: + if existingEntryIdx <= pivotIdx { + movementRequired = true + break + } + } + + if i == 0 { + continue + } + + // Check if the entries to be moved have the same order in the existing + // slice, and if not require a movement set to be generated. + switch movement { + case movementBefore: + if existingIdxMap[entries[i-1]] >= existingEntryIdx { + movementRequired = true + break loop + + } + case movementAfter: + if existingIdxMap[entries[i-1]] <= existingEntryIdx { + movementRequired = true + break loop + + } + + } + } + + if !movementRequired { + return nil, nil + } + } + + expected := make([]Movable, len(existing)) + + entriesIdxMap := make(map[Movable]int, len(entries)) + for idx, elt := range entries { + entriesIdxMap[elt] = idx + } + + filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { + _, ok := entriesIdxMap[entry] + return ok + }) + + filteredPivotIdx := findPivotIdx(filtered, pivot) + + switch movement { + case movementBefore: + expectedIdx := 0 + for ; expectedIdx < filteredPivotIdx; expectedIdx++ { + expected[expectedIdx] = filtered[expectedIdx] + } + + for _, elt := range entries { + expected[expectedIdx] = elt + expectedIdx++ + } + + expected[expectedIdx] = pivot + expectedIdx++ + + filteredLen := len(filtered) + for i := filteredPivotIdx + 1; i < filteredLen; i++ { + expected[expectedIdx] = filtered[i] + expectedIdx++ + } + } + + return GenerateMovements(existing, expected, entries) +} + +func (o PositionAfter) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + return processPivotMovement(entries, existing, o.Pivot, o.Directly, movementAfter) +} + +func (o PositionBefore) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + return processPivotMovement(entries, existing, o.Pivot, o.Directly, movementBefore) +} + +type Entry struct { + Element Movable + Expected int + Existing int +} + +type sequencePosition struct { + Start int + End int +} + +func longestCommonSubsequence(S []Movable, T []Movable) [][]Movable { + + r := len(S) + n := len(T) + + L := make([][]int, r) + for idx := range len(T) { + L[idx] = make([]int, n) + } + z := 0 + + var results [][]Movable + + for i := 0; i < r; i++ { + for j := 0; j < n; j++ { + if S[i].EntryName() == T[j].EntryName() { + if i == 0 || j == 0 { + L[i][j] = 1 + } else { + L[i][j] = L[i-1][j-1] + 1 + } + + if L[i][j] > z { + slog.Debug("L[i][j] > z", "L[i][j]", L[i][j], "z", z, "i-z", i-z, "i", i) + results = nil + results = append(results, S[i-z:i+1]) + z = L[i][j] + slog.Debug("L[i][j] > z", "results", results) + } else if L[i][j] == z { + results = append(results, S[i-z:i+1]) + slog.Debug("L[i][j] == z", "i-z", i, "i", i+1) + } + slog.Debug("Still", "results", results) + } else { + L[i][j] = 0 + } + } + } + + slog.Debug("commonSubsequence", "results", results) + + return results +} + +func GenerateMovements(existing []Movable, expected []Movable, entries []Movable) ([]MoveAction, error) { + if len(existing) != len(expected) { + return nil, fmt.Errorf("existing length != expected length: %d != %d", len(existing), len(expected)) + } + + common := longestCommonSubsequence(existing, expected) + + entriesIdxMap := make(map[Movable]int, len(entries)) + for idx, elt := range entries { + entriesIdxMap[elt] = idx + } + + var commonSequence []Movable + for _, elt := range common { + filtered := removeEntriesFromExisting(elt, func(elt Movable) bool { + _, ok := entriesIdxMap[elt] + return ok + }) + + if len(filtered) > len(commonSequence) { + commonSequence = filtered + } + + } + + existingIdxMap := make(map[Movable]int, len(existing)) + for idx, elt := range existing { + existingIdxMap[elt] = idx + } + + expectedIdxMap := make(map[Movable]int, len(expected)) + for idx, elt := range expected { + expectedIdxMap[elt] = idx + } + + commonLen := len(commonSequence) + commonIdxMap := make(map[Movable]int, len(commonSequence)) + for idx, elt := range commonSequence { + commonIdxMap[elt] = idx + } + + var movements []MoveAction + + var previous Movable + for _, elt := range entries { + slog.Debug("GenerateMovements", "elt", elt.EntryName(), "existingIdx", existingIdxMap[elt], "expectedIdx", expectedIdxMap[elt]) + if existingIdxMap[elt] == expectedIdxMap[elt] { + continue + } + + if expectedIdxMap[elt] == 0 { + movements = append(movements, MoveAction{ + EntryName: elt.EntryName(), + Destination: "top", + Where: "top", + }) + previous = elt + } else if len(commonSequence) > 0 { + if expectedIdxMap[elt] < expectedIdxMap[commonSequence[0]] { + if previous == nil { + previous = expected[0] + } + movements = append(movements, MoveAction{ + EntryName: elt.EntryName(), + Destination: previous.EntryName(), + Where: "after", + }) + previous = elt + } else if expectedIdxMap[elt] > expectedIdxMap[commonSequence[commonLen-1]] { + if previous == nil { + previous = commonSequence[commonLen-1] + } + movements = append(movements, MoveAction{ + EntryName: elt.EntryName(), + Destination: previous.EntryName(), + Where: "after", + }) + previous = elt + + } else if expectedIdxMap[elt] > expectedIdxMap[commonSequence[0]] { + if previous == nil { + previous = commonSequence[0] + } + movements = append(movements, MoveAction{ + EntryName: elt.EntryName(), + Destination: previous.EntryName(), + Where: "after", + }) + previous = elt + } + } else { + movements = append(movements, MoveAction{ + EntryName: elt.EntryName(), + Destination: previous.EntryName(), + Where: "after", + }) + previous = elt + } + + slog.Debug("GenerateMovements()", "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt]) + } + + _ = previous + + slog.Debug("GenerateMovements()", "movements", movements) + + return movements, nil +} + +func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + entriesIdxMap := make(map[Movable]int, len(entries)) + for idx, elt := range entries { + entriesIdxMap[elt] = idx + } + + filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { + _, ok := entriesIdxMap[entry] + return ok + }) + + expected := append(entries, filtered...) + + return GenerateMovements(existing, expected, entries) +} + +func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + entriesIdxMap := make(map[Movable]int, len(entries)) + for idx, elt := range entries { + entriesIdxMap[elt] = idx + } + + filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { + _, ok := entriesIdxMap[entry] + return ok + }) + + expected := append(filtered, entries...) + + return GenerateMovements(existing, expected, entries) +} + +func MoveGroup(position Position, entries []Movable, existing []Movable) ([]MoveAction, error) { + return position.Move(entries, existing) +} diff --git a/assets/pango/movement/movement_suite_test.go b/assets/pango/movement/movement_suite_test.go new file mode 100644 index 00000000..b750b000 --- /dev/null +++ b/assets/pango/movement/movement_suite_test.go @@ -0,0 +1,18 @@ +package movement_test + +import ( + "log/slog" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestMovement(t *testing.T) { + handler := slog.NewTextHandler(GinkgoWriter, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + slog.SetDefault(slog.New(handler)) + RegisterFailHandler(Fail) + RunSpecs(t, "Movement Suite") +} diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go new file mode 100644 index 00000000..a4322591 --- /dev/null +++ b/assets/pango/movement/movement_test.go @@ -0,0 +1,147 @@ +package movement_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "movements/movement" +) + +type Mock struct { + Name string +} + +func (o Mock) EntryName() string { + return o.Name +} + +func asMovable(mocks []string) []movement.Movable { + var movables []movement.Movable + + for _, elt := range mocks { + movables = append(movables, Mock{elt}) + } + + return movables +} + +var _ = Describe("Movement", func() { + Context("With PositionTop used as position", func() { + Context("when existing positions matches expected", func() { + It("should generate no movements", func() { + expected := asMovable([]string{"A", "B", "C"}) + moves, err := movement.MoveGroup(movement.PositionTop{}, expected, expected) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + }) + Context("when it has to move two elements", func() { + It("should generate three move actions", func() { + entries := asMovable([]string{"A", "B", "C"}) + existing := asMovable([]string{"D", "E", "A", "B", "C"}) + + moves, err := movement.MoveGroup(movement.PositionTop{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(3)) + + Expect(moves[0].EntryName).To(Equal("A")) + Expect(moves[0].Where).To(Equal("top")) + Expect(moves[0].Destination).To(Equal("top")) + + Expect(moves[1].EntryName).To(Equal("B")) + Expect(moves[1].Where).To(Equal("after")) + Expect(moves[1].Destination).To(Equal("A")) + + Expect(moves[2].EntryName).To(Equal("C")) + Expect(moves[2].Where).To(Equal("after")) + Expect(moves[2].Destination).To(Equal("B")) + }) + }) + Context("when expected order is reversed", func() { + It("should generate required move actions to converge lists", func() { + entries := asMovable([]string{"E", "D", "C", "B", "A"}) + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + moves, err := movement.MoveGroup(movement.PositionTop{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + + Expect(moves).To(HaveLen(4)) + }) + }) + }) + Context("With PositionBottom used as position", func() { + Context("when it needs to move one element", func() { + It("should generate a single move action", func() { + entries := asMovable([]string{"E"}) + existing := asMovable([]string{"A", "E", "B", "C", "D"}) + + moves, err := movement.MoveGroup(movement.PositionBottom{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].EntryName).To(Equal("E")) + Expect(moves[0].Where).To(Equal("after")) + Expect(moves[0].Destination).To(Equal("D")) + }) + }) + }) + + Context("With PositionBefore used as position", func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + + Context("when direct position relative to the pivot is not required", func() { + Context("and moved entries are already before pivot point", func() { + It("should not generate any move actions", func() { + entries := asMovable([]string{"A", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + }) + Context("and moved entries are out of order", func() { + It("should generate only move commands to sort entries", func() { + // A B C D E -> A C B D E + entries := asMovable([]string{"C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + Expect(moves[0].EntryName).To(Equal("C")) + Expect(moves[0].Where).To(Equal("after")) + Expect(moves[0].Destination).To(Equal("A")) + + Expect(moves[1].EntryName).To(Equal("B")) + Expect(moves[1].Where).To(Equal("after")) + Expect(moves[1].Destination).To(Equal("C")) + }) + }) + }) + Context("when direct position relative to the pivot is required", func() { + It("should generate required move actions", func() { + // A B C D E -> C A B D E + entries := asMovable([]string{"A", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: Mock{"D"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].EntryName).To(Equal("A")) + Expect(moves[0].Where).To(Equal("after")) + Expect(moves[0].Destination).To(Equal("C")) + + Expect(moves[1].EntryName).To(Equal("B")) + Expect(moves[1].Where).To(Equal("after")) + Expect(moves[1].Destination).To(Equal("A")) + }) + }) + }) +}) From 90963b83bee9debf8a22f6ef591476ede80f3dd6 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Thu, 1 Aug 2024 09:13:23 +0200 Subject: [PATCH 02/19] Updated LCS implementation and added OptimizeMovements step Fixed some bugs in LCS implementation, but it needs to be be optimized - it takes ~25 seconds to find a common subsequence in two equal sequences of 50k elements. Added an OptimizeMovements step that is done after GenerateMovements returns a list of all move actions. It removes redundant actions that, after previous actions were applied, no longer have any effect. --- assets/pango/movement/movement.go | 380 +++++++++++++++++-------- assets/pango/movement/movement_test.go | 84 ++++-- 2 files changed, 325 insertions(+), 139 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 64036bcc..596636da 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -1,6 +1,7 @@ package movement import ( + "errors" "fmt" "log/slog" "slices" @@ -8,14 +9,23 @@ import ( var _ = slog.LevelDebug +type ActionWhereType string + +const ( + ActionWhereTop ActionWhereType = "top" + ActionWhereBottom ActionWhereType = "bottom" + ActionWhereBefore ActionWhereType = "before" + ActionWhereAfter ActionWhereType = "after" +) + type Movable interface { EntryName() string } type MoveAction struct { - EntryName string - Where string - Destination string + Movable Movable + Where ActionWhereType + Destination Movable } type Position interface { @@ -36,6 +46,14 @@ type PositionAfter struct { Pivot Movable } +func createIdxMapFor(entries []Movable) map[Movable]int { + entriesIdxMap := make(map[Movable]int, len(entries)) + for idx, elt := range entries { + entriesIdxMap[elt] = idx + } + return entriesIdxMap +} + func removeEntriesFromExisting(entries []Movable, filterFn func(entry Movable) bool) []Movable { entryNames := make(map[string]bool, len(entries)) for _, elt := range entries { @@ -68,17 +86,24 @@ const ( movementAfter ) -func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, direct bool, movement movementType) ([]MoveAction, error) { - existingLen := len(existing) - existingIdxMap := make(map[Movable]int, existingLen) +var ( + ErrSlicesNotEqualLength = errors.New("existing and expected slices length mismatch") + ErrPivotInEntries = errors.New("pivot element found in the entries slice") + ErrPivotNotInExisting = errors.New("pivot element not foudn in the existing slice") + ErrInvalidMovementPlan = errors.New("created movement plan is invalid") +) + +func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, direct bool, movement movementType) ([]Movable, []MoveAction, error) { + existingIdxMap := createIdxMapFor(existing) - for idx, elt := range existing { - existingIdxMap[elt] = idx + entriesPivotIdx := findPivotIdx(entries, pivot) + if entriesPivotIdx != -1 { + return nil, nil, ErrPivotInEntries } - pivotIdx := findPivotIdx(existing, pivot) - if pivotIdx == -1 { - return nil, fmt.Errorf("pivot point not found in the list of existing items") + existingPivotIdx := findPivotIdx(existing, pivot) + if existingPivotIdx == -1 { + return nil, nil, ErrPivotNotInExisting } if !direct { @@ -93,12 +118,12 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, existingEntryIdx := existingIdxMap[entries[i]] switch movement { case movementBefore: - if existingEntryIdx >= pivotIdx { + if existingEntryIdx >= existingPivotIdx { movementRequired = true break } case movementAfter: - if existingEntryIdx <= pivotIdx { + if existingEntryIdx <= existingPivotIdx { movementRequired = true break } @@ -128,16 +153,13 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, } if !movementRequired { - return nil, nil + return nil, nil, nil } } expected := make([]Movable, len(existing)) - entriesIdxMap := make(map[Movable]int, len(entries)) - for idx, elt := range entries { - entriesIdxMap[elt] = idx - } + entriesIdxMap := createIdxMapFor(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { _, ok := entriesIdxMap[entry] @@ -168,15 +190,26 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, } } - return GenerateMovements(existing, expected, entries) + actions, err := GenerateMovements(existing, expected, entries) + return expected, actions, err } func (o PositionAfter) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - return processPivotMovement(entries, existing, o.Pivot, o.Directly, movementAfter) + expected, actions, err := processPivotMovement(entries, existing, o.Pivot, o.Directly, movementBefore) + if err != nil { + return nil, err + } + + return OptimizeMovements(existing, expected, actions, o), nil } func (o PositionBefore) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - return processPivotMovement(entries, existing, o.Pivot, o.Directly, movementBefore) + expected, actions, err := processPivotMovement(entries, existing, o.Pivot, o.Directly, movementBefore) + if err != nil { + return nil, err + } + + return OptimizeMovements(existing, expected, actions, o), nil } type Entry struct { @@ -190,150 +223,270 @@ type sequencePosition struct { End int } -func longestCommonSubsequence(S []Movable, T []Movable) [][]Movable { - +func printLCSMatrix(S []Movable, T []Movable, L [][]int) { r := len(S) n := len(T) - L := make([][]int, r) - for idx := range len(T) { - L[idx] = make([]int, n) + line := " " + for _, elt := range S { + line += fmt.Sprintf("%s ", elt.EntryName()) } - z := 0 + slog.Debug("LCS", "line", line) - var results [][]Movable + line = " " + for _, elt := range L[0] { + line += fmt.Sprintf("%d ", elt) + } + slog.Debug("LCS", "line", line) + + for i := 1; i < r+1; i++ { + line = fmt.Sprintf("%s ", T[i-1].EntryName()) + for j := 0; j < n+1; j++ { + line += fmt.Sprintf("%d ", L[i][j]) + } + } + +} + +func LongestCommonSubstring(S []Movable, T []Movable) [][]Movable { + r := len(S) + n := len(T) + + L := make([][]int, r+1) + for idx := range r + 1 { + L[idx] = make([]int, n+1) + } - for i := 0; i < r; i++ { - for j := 0; j < n; j++ { - if S[i].EntryName() == T[j].EntryName() { - if i == 0 || j == 0 { - L[i][j] = 1 + for i := 1; i < r+1; i++ { + for j := 1; j < n+1; j++ { + if S[i-1].EntryName() == T[j-1].EntryName() { + if i == 1 { + L[j][i] = 1 + } else if j == 1 { + L[j][i] = 1 } else { - L[i][j] = L[i-1][j-1] + 1 + L[j][i] = L[j-1][i-1] + 1 } + } + } + } - if L[i][j] > z { - slog.Debug("L[i][j] > z", "L[i][j]", L[i][j], "z", z, "i-z", i-z, "i", i) - results = nil - results = append(results, S[i-z:i+1]) - z = L[i][j] - slog.Debug("L[i][j] > z", "results", results) - } else if L[i][j] == z { - results = append(results, S[i-z:i+1]) - slog.Debug("L[i][j] == z", "i-z", i, "i", i+1) + var results [][]Movable + var lcsList [][]Movable + + var entry []Movable + var index int + for i := r; i > 0; i-- { + for j := n; j > 0; j-- { + if S[i-1].EntryName() == T[j-1].EntryName() { + if L[j][i] >= index { + if len(entry) > 0 { + var entries []string + for _, elt := range entry { + entries = append(entries, elt.EntryName()) + } + + lcsList = append(lcsList, entry) + } + index = L[j][i] + entry = []Movable{S[i-1]} + } else if L[j][i] < index { + index = L[j][i] + entry = append(entry, S[i-1]) + } else { + entry = []Movable{} } - slog.Debug("Still", "results", results) - } else { - L[i][j] = 0 } } } - slog.Debug("commonSubsequence", "results", results) + if len(entry) > 0 { + lcsList = append(lcsList, entry) + } + + lcsLen := len(lcsList) + for idx := range lcsList { + elt := lcsList[lcsLen-idx-1] + if len(elt) > 1 { + slices.Reverse(elt) + results = append(results, elt) + } + } return results } +func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx int, targetIdx int) { + slog.Debug("updateSimulatedIdxMap", "entries", idxMap) + for entry, idx := range *idxMap { + if entry == moved { + continue + } + + slog.Debug("updateSimulatedIdxMap", "entry", entry, "idx", idx, "startingIdx", startingIdx, "targetIdx", targetIdx) + if startingIdx > targetIdx && idx >= targetIdx { + (*idxMap)[entry] = idx + 1 + } else if startingIdx < targetIdx && idx >= startingIdx && idx <= targetIdx { + (*idxMap)[entry] = idx - 1 + } + } + slog.Debug("updateSimulatedIdxMap", "entries", idxMap) +} + +func OptimizeMovements(existing []Movable, expected []Movable, actions []MoveAction, position any) []MoveAction { + simulated := make([]Movable, len(existing)) + copy(simulated, existing) + + simulatedIdxMap := createIdxMapFor(simulated) + expectedIdxMap := createIdxMapFor(expected) + + optimized := make([]MoveAction, len(actions)) + + switch position.(type) { + case PositionBefore: + slog.Debug("OptimizeMovements()", "position", position, "type", fmt.Sprintf("%T", position)) + slices.Reverse(actions) + case PositionAfter: + slog.Debug("OptimizeMovements()", "position", position, "type", fmt.Sprintf("%T", position)) + default: + return actions + } + + optimizedIdx := 0 + for _, action := range actions { + currentIdx := simulatedIdxMap[action.Movable] + if currentIdx == expectedIdxMap[action.Movable] { + continue + } + + var targetIdx int + switch action.Where { + case ActionWhereTop: + targetIdx = 0 + case ActionWhereBottom: + targetIdx = len(simulated) - 1 + case ActionWhereBefore: + targetIdx = simulatedIdxMap[action.Destination] - 1 + case ActionWhereAfter: + targetIdx = simulatedIdxMap[action.Destination] + 1 + } + + slog.Debug("OptimizeMovements()", "action", action, "currentIdx", currentIdx, "targetIdx", targetIdx) + if targetIdx != currentIdx { + optimized[optimizedIdx] = action + optimizedIdx++ + simulatedIdxMap[action.Movable] = targetIdx + updateSimulatedIdxMap(&simulatedIdxMap, action.Movable, currentIdx, targetIdx) + } + } + + return optimized[:optimizedIdx] +} + func GenerateMovements(existing []Movable, expected []Movable, entries []Movable) ([]MoveAction, error) { if len(existing) != len(expected) { - return nil, fmt.Errorf("existing length != expected length: %d != %d", len(existing), len(expected)) + return nil, ErrSlicesNotEqualLength } - common := longestCommonSubsequence(existing, expected) + commonSequences := LongestCommonSubstring(existing, expected) - entriesIdxMap := make(map[Movable]int, len(entries)) - for idx, elt := range entries { - entriesIdxMap[elt] = idx - } + entriesIdxMap := createIdxMapFor(entries) - var commonSequence []Movable - for _, elt := range common { - filtered := removeEntriesFromExisting(elt, func(elt Movable) bool { + var common []Movable + for _, sequence := range commonSequences { + filtered := removeEntriesFromExisting(sequence, func(elt Movable) bool { _, ok := entriesIdxMap[elt] return ok }) - if len(filtered) > len(commonSequence) { - commonSequence = filtered + if len(filtered) > len(common) { + common = filtered } } + commonLen := len(common) - existingIdxMap := make(map[Movable]int, len(existing)) - for idx, elt := range existing { - existingIdxMap[elt] = idx - } + existingIdxMap := createIdxMapFor(existing) + expectedIdxMap := createIdxMapFor(expected) - expectedIdxMap := make(map[Movable]int, len(expected)) - for idx, elt := range expected { - expectedIdxMap[elt] = idx - } + var movements []MoveAction - commonLen := len(commonSequence) - commonIdxMap := make(map[Movable]int, len(commonSequence)) - for idx, elt := range commonSequence { - commonIdxMap[elt] = idx + var commonStartIdx, commonEndIdx int + if commonLen > 0 { + commonStartIdx = expectedIdxMap[common[0]] + commonEndIdx = expectedIdxMap[common[commonLen-1]] } - var movements []MoveAction - + slog.Debug("GenerateMovements()", "common", common, "commonStartIdx", commonStartIdx, "commonEndIdx", commonEndIdx) var previous Movable for _, elt := range entries { - slog.Debug("GenerateMovements", "elt", elt.EntryName(), "existingIdx", existingIdxMap[elt], "expectedIdx", expectedIdxMap[elt]) + slog.Debug("GenerateMovements()", "elt", elt, "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt]) + // If existing index for the element matches the expected one, skip it over if existingIdxMap[elt] == expectedIdxMap[elt] { continue } + // Else, if expected index is 0, generate move action to move it to the top if expectedIdxMap[elt] == 0 { movements = append(movements, MoveAction{ - EntryName: elt.EntryName(), - Destination: "top", - Where: "top", + Movable: elt, + Destination: nil, + Where: ActionWhereTop, + }) + previous = elt + } else if len(common) == 0 { + // If, after filtering out all elements that cannot be moved, common sequence + // is empty we need to move everything element by element. + movements = append(movements, MoveAction{ + Movable: elt, + Destination: previous, + Where: ActionWhereAfter, }) previous = elt - } else if len(commonSequence) > 0 { - if expectedIdxMap[elt] < expectedIdxMap[commonSequence[0]] { + } else { + // Otherwise if there is some common sequence of elements between existing and expected + if expectedIdxMap[elt] <= commonStartIdx { + slog.Debug("GenerateMovements() HELP1") + // And the expected index of the element is lower than start of the common sequence if previous == nil { - previous = expected[0] + previous = common[0] } + + // Generate a movement action for the element to move it directly before the first + // element of the common sequence. movements = append(movements, MoveAction{ - EntryName: elt.EntryName(), - Destination: previous.EntryName(), - Where: "after", + Movable: elt, + Destination: previous, + Where: ActionWhereBefore, }) - previous = elt - } else if expectedIdxMap[elt] > expectedIdxMap[commonSequence[commonLen-1]] { + } else if expectedIdxMap[elt] > commonEndIdx { + slog.Debug("GenerateMovements() HELP2") + // If expected index of the element is larger than index of the last element of the common + // sequence if previous == nil { - previous = commonSequence[commonLen-1] + previous = common[commonLen-1] } + // Generate a move to move this element directly behind it. movements = append(movements, MoveAction{ - EntryName: elt.EntryName(), - Destination: previous.EntryName(), - Where: "after", + Movable: elt, + Destination: previous, + Where: ActionWhereAfter, }) previous = elt - } else if expectedIdxMap[elt] > expectedIdxMap[commonSequence[0]] { + } else if expectedIdxMap[elt] > expectedIdxMap[common[0]] { + slog.Debug("GenerateMovements() HELP2") if previous == nil { - previous = commonSequence[0] + previous = common[0] } movements = append(movements, MoveAction{ - EntryName: elt.EntryName(), - Destination: previous.EntryName(), - Where: "after", + Movable: elt, + Destination: previous, + Where: ActionWhereAfter, }) previous = elt } - } else { - movements = append(movements, MoveAction{ - EntryName: elt.EntryName(), - Destination: previous.EntryName(), - Where: "after", - }) - previous = elt } - - slog.Debug("GenerateMovements()", "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt]) } _ = previous @@ -344,10 +497,7 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable } func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - entriesIdxMap := make(map[Movable]int, len(entries)) - for idx, elt := range entries { - entriesIdxMap[elt] = idx - } + entriesIdxMap := createIdxMapFor(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { _, ok := entriesIdxMap[entry] @@ -356,14 +506,16 @@ func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, expected := append(entries, filtered...) - return GenerateMovements(existing, expected, entries) + actions, err := GenerateMovements(existing, expected, entries) + if err != nil { + return nil, err + } + + return OptimizeMovements(existing, expected, actions, o), nil } func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - entriesIdxMap := make(map[Movable]int, len(entries)) - for idx, elt := range entries { - entriesIdxMap[elt] = idx - } + entriesIdxMap := createIdxMapFor(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { _, ok := entriesIdxMap[entry] @@ -372,7 +524,11 @@ func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveActio expected := append(filtered, entries...) - return GenerateMovements(existing, expected, entries) + actions, err := GenerateMovements(existing, expected, entries) + if err != nil { + return nil, err + } + return OptimizeMovements(existing, expected, actions, o), nil } func MoveGroup(position Position, entries []Movable, existing []Movable) ([]MoveAction, error) { diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index a4322591..8e54d7a1 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -1,12 +1,16 @@ package movement_test import ( + "fmt" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "movements/movement" ) +var _ = fmt.Printf + type Mock struct { Name string } @@ -25,6 +29,35 @@ func asMovable(mocks []string) []movement.Movable { return movables } +var _ = Describe("LCS", func() { + Context("with two common substrings", func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + expected := asMovable([]string{"C", "A", "B", "D", "E"}) + It("should return two sequences of two elements", func() { + options := movement.LongestCommonSubstring(existing, expected) + Expect(options).To(HaveLen(2)) + + Expect(options[0]).To(HaveExactElements(asMovable([]string{"A", "B"}))) + Expect(options[1]).To(HaveExactElements(asMovable([]string{"D", "E"}))) + }) + }) + // Context("with one very large common substring", func() { + // It("should return one sequence of elements in a reasonable time", Label("benchmark"), func() { + // var elts []string + // elements := 50000 + // for idx := range elements { + // elts = append(elts, fmt.Sprintf("%d", idx)) + // } + // existing := asMovable(elts) + // expected := existing + + // options := movement.LongestCommonSubstring(existing, expected) + // Expect(options).To(HaveLen(1)) + // Expect(options[0]).To(HaveLen(elements)) + // }) + // }) +}) + var _ = Describe("Movement", func() { Context("With PositionTop used as position", func() { Context("when existing positions matches expected", func() { @@ -36,7 +69,7 @@ var _ = Describe("Movement", func() { }) }) Context("when it has to move two elements", func() { - It("should generate three move actions", func() { + FIt("should generate three move actions", func() { entries := asMovable([]string{"A", "B", "C"}) existing := asMovable([]string{"D", "E", "A", "B", "C"}) @@ -44,17 +77,17 @@ var _ = Describe("Movement", func() { Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(3)) - Expect(moves[0].EntryName).To(Equal("A")) - Expect(moves[0].Where).To(Equal("top")) - Expect(moves[0].Destination).To(Equal("top")) + Expect(moves[0].Movable.EntryName()).To(Equal("A")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereTop)) + Expect(moves[0].Destination).To(BeNil()) - Expect(moves[1].EntryName).To(Equal("B")) - Expect(moves[1].Where).To(Equal("after")) - Expect(moves[1].Destination).To(Equal("A")) + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("A")) - Expect(moves[2].EntryName).To(Equal("C")) - Expect(moves[2].Where).To(Equal("after")) - Expect(moves[2].Destination).To(Equal("B")) + Expect(moves[2].Movable.EntryName()).To(Equal("C")) + Expect(moves[2].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[2].Destination.EntryName()).To(Equal("B")) }) }) Context("when expected order is reversed", func() { @@ -78,9 +111,9 @@ var _ = Describe("Movement", func() { Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(1)) - Expect(moves[0].EntryName).To(Equal("E")) - Expect(moves[0].Where).To(Equal("after")) - Expect(moves[0].Destination).To(Equal("D")) + Expect(moves[0].Movable.EntryName()).To(Equal("E")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) }) }) }) @@ -111,14 +144,11 @@ var _ = Describe("Movement", func() { ) Expect(err).ToNot(HaveOccurred()) - Expect(moves).To(HaveLen(2)) - Expect(moves[0].EntryName).To(Equal("C")) - Expect(moves[0].Where).To(Equal("after")) - Expect(moves[0].Destination).To(Equal("A")) - - Expect(moves[1].EntryName).To(Equal("B")) - Expect(moves[1].Where).To(Equal("after")) - Expect(moves[1].Destination).To(Equal("C")) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("B")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) }) }) }) @@ -134,13 +164,13 @@ var _ = Describe("Movement", func() { Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(2)) - Expect(moves[0].EntryName).To(Equal("A")) - Expect(moves[0].Where).To(Equal("after")) - Expect(moves[0].Destination).To(Equal("C")) + Expect(moves[0].Movable.EntryName()).To(Equal("A")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("C")) - Expect(moves[1].EntryName).To(Equal("B")) - Expect(moves[1].Where).To(Equal("after")) - Expect(moves[1].Destination).To(Equal("A")) + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("A")) }) }) }) From 86a73c1fd4941b9e8d4709ef4ca168371fccaeea Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Thu, 1 Aug 2024 10:05:57 +0200 Subject: [PATCH 03/19] Some fixes around generation and optimization of move actions --- assets/pango/movement/movement.go | 113 ++++++++++++------------- assets/pango/movement/movement_test.go | 14 +-- 2 files changed, 60 insertions(+), 67 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 596636da..92113440 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -18,6 +18,13 @@ const ( ActionWhereAfter ActionWhereType = "after" ) +type entryPositionType int + +const ( + entryPositionBefore entryPositionType = iota + entryPositionAfter +) + type Movable interface { EntryName() string } @@ -190,17 +197,17 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, } } - actions, err := GenerateMovements(existing, expected, entries) + actions, err := GenerateMovements(existing, expected, entries, movement) return expected, actions, err } func (o PositionAfter) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - expected, actions, err := processPivotMovement(entries, existing, o.Pivot, o.Directly, movementBefore) + expected, actions, err := processPivotMovement(entries, existing, o.Pivot, o.Directly, movementAfter) if err != nil { return nil, err } - return OptimizeMovements(existing, expected, actions, o), nil + return OptimizeMovements(existing, expected, entries, actions, o), nil } func (o PositionBefore) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { @@ -209,7 +216,7 @@ func (o PositionBefore) Move(entries []Movable, existing []Movable) ([]MoveActio return nil, err } - return OptimizeMovements(existing, expected, actions, o), nil + return OptimizeMovements(existing, expected, entries, actions, o), nil } type Entry struct { @@ -333,26 +340,24 @@ func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx i slog.Debug("updateSimulatedIdxMap", "entries", idxMap) } -func OptimizeMovements(existing []Movable, expected []Movable, actions []MoveAction, position any) []MoveAction { +func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable, actions []MoveAction, position any) []MoveAction { simulated := make([]Movable, len(existing)) copy(simulated, existing) simulatedIdxMap := createIdxMapFor(simulated) expectedIdxMap := createIdxMapFor(expected) - optimized := make([]MoveAction, len(actions)) + var optimized []MoveAction switch position.(type) { case PositionBefore: slog.Debug("OptimizeMovements()", "position", position, "type", fmt.Sprintf("%T", position)) - slices.Reverse(actions) case PositionAfter: slog.Debug("OptimizeMovements()", "position", position, "type", fmt.Sprintf("%T", position)) default: return actions } - optimizedIdx := 0 for _, action := range actions { currentIdx := simulatedIdxMap[action.Movable] if currentIdx == expectedIdxMap[action.Movable] { @@ -366,6 +371,7 @@ func OptimizeMovements(existing []Movable, expected []Movable, actions []MoveAct case ActionWhereBottom: targetIdx = len(simulated) - 1 case ActionWhereBefore: + slog.Debug("OptimizeMovements()", "dest", action.Destination, "destIdx", simulatedIdxMap[action.Destination]) targetIdx = simulatedIdxMap[action.Destination] - 1 case ActionWhereAfter: targetIdx = simulatedIdxMap[action.Destination] + 1 @@ -373,17 +379,18 @@ func OptimizeMovements(existing []Movable, expected []Movable, actions []MoveAct slog.Debug("OptimizeMovements()", "action", action, "currentIdx", currentIdx, "targetIdx", targetIdx) if targetIdx != currentIdx { - optimized[optimizedIdx] = action - optimizedIdx++ + optimized = append(optimized, action) simulatedIdxMap[action.Movable] = targetIdx updateSimulatedIdxMap(&simulatedIdxMap, action.Movable, currentIdx, targetIdx) } } - return optimized[:optimizedIdx] + slog.Debug("OptimizeMovements()", "optimized", optimized) + + return optimized } -func GenerateMovements(existing []Movable, expected []Movable, entries []Movable) ([]MoveAction, error) { +func GenerateMovements(existing []Movable, expected []Movable, entries []Movable, movement movementType) ([]MoveAction, error) { if len(existing) != len(expected) { return nil, ErrSlicesNotEqualLength } @@ -417,6 +424,8 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable commonEndIdx = expectedIdxMap[common[commonLen-1]] } + slog.Debug("GenerateMovements()", "expected", expected) + slog.Debug("GenerateMovements()", "existing", existing) slog.Debug("GenerateMovements()", "common", common, "commonStartIdx", commonStartIdx, "commonEndIdx", commonEndIdx) var previous Movable for _, elt := range entries { @@ -426,17 +435,25 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable continue } - // Else, if expected index is 0, generate move action to move it to the top if expectedIdxMap[elt] == 0 { + slog.Debug("HELP1") movements = append(movements, MoveAction{ Movable: elt, Destination: nil, Where: ActionWhereTop, }) previous = elt - } else if len(common) == 0 { - // If, after filtering out all elements that cannot be moved, common sequence - // is empty we need to move everything element by element. + } else if expectedIdxMap[elt] == len(expectedIdxMap) { + slog.Debug("HELP2") + movements = append(movements, MoveAction{ + Movable: elt, + Destination: nil, + Where: ActionWhereBottom, + }) + previous = elt + } else if previous != nil { + slog.Debug("HELP3") + movements = append(movements, MoveAction{ Movable: elt, Destination: previous, @@ -444,48 +461,24 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable }) previous = elt } else { - // Otherwise if there is some common sequence of elements between existing and expected - if expectedIdxMap[elt] <= commonStartIdx { - slog.Debug("GenerateMovements() HELP1") - // And the expected index of the element is lower than start of the common sequence - if previous == nil { - previous = common[0] - } + slog.Debug("HELP4") + var where ActionWhereType - // Generate a movement action for the element to move it directly before the first - // element of the common sequence. - movements = append(movements, MoveAction{ - Movable: elt, - Destination: previous, - Where: ActionWhereBefore, - }) - } else if expectedIdxMap[elt] > commonEndIdx { - slog.Debug("GenerateMovements() HELP2") - // If expected index of the element is larger than index of the last element of the common - // sequence - if previous == nil { - previous = common[commonLen-1] - } - // Generate a move to move this element directly behind it. - movements = append(movements, MoveAction{ - Movable: elt, - Destination: previous, - Where: ActionWhereAfter, - }) - previous = elt - - } else if expectedIdxMap[elt] > expectedIdxMap[common[0]] { - slog.Debug("GenerateMovements() HELP2") - if previous == nil { - previous = common[0] - } - movements = append(movements, MoveAction{ - Movable: elt, - Destination: previous, - Where: ActionWhereAfter, - }) - previous = elt + switch movement { + case movementAfter: + previous = common[commonLen-1] + where = ActionWhereAfter + case movementBefore: + previous = common[0] + where = ActionWhereBefore } + + movements = append(movements, MoveAction{ + Movable: elt, + Destination: previous, + Where: where, + }) + previous = elt } } @@ -506,12 +499,12 @@ func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, expected := append(entries, filtered...) - actions, err := GenerateMovements(existing, expected, entries) + actions, err := GenerateMovements(existing, expected, entries, movementBefore) if err != nil { return nil, err } - return OptimizeMovements(existing, expected, actions, o), nil + return OptimizeMovements(existing, expected, entries, actions, o), nil } func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { @@ -524,11 +517,11 @@ func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveActio expected := append(filtered, entries...) - actions, err := GenerateMovements(existing, expected, entries) + actions, err := GenerateMovements(existing, expected, entries, movementAfter) if err != nil { return nil, err } - return OptimizeMovements(existing, expected, actions, o), nil + return OptimizeMovements(existing, expected, entries, actions, o), nil } func MoveGroup(position Position, entries []Movable, existing []Movable) ([]MoveAction, error) { diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index 8e54d7a1..202f2a88 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -69,7 +69,7 @@ var _ = Describe("Movement", func() { }) }) Context("when it has to move two elements", func() { - FIt("should generate three move actions", func() { + It("should generate three move actions", func() { entries := asMovable([]string{"A", "B", "C"}) existing := asMovable([]string{"D", "E", "A", "B", "C"}) @@ -135,8 +135,8 @@ var _ = Describe("Movement", func() { }) }) Context("and moved entries are out of order", func() { - It("should generate only move commands to sort entries", func() { - // A B C D E -> A C B D E + It("should generate a single command to move B before D", func() { + // A B C D E -> A B C D E entries := asMovable([]string{"C", "B"}) moves, err := movement.MoveGroup( movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, @@ -147,8 +147,8 @@ var _ = Describe("Movement", func() { Expect(moves).To(HaveLen(1)) Expect(moves[0].Movable.EntryName()).To(Equal("B")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) - Expect(moves[0].Destination.EntryName()).To(Equal("D")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("C")) }) }) }) @@ -165,8 +165,8 @@ var _ = Describe("Movement", func() { Expect(moves).To(HaveLen(2)) Expect(moves[0].Movable.EntryName()).To(Equal("A")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) - Expect(moves[0].Destination.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) Expect(moves[1].Movable.EntryName()).To(Equal("B")) Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) From 06847fb5bddfb8cdae77e36ae5de7667110518e6 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Thu, 1 Aug 2024 11:02:37 +0200 Subject: [PATCH 04/19] Implement support for PositionAfter and some clean-up --- assets/pango/movement/movement.go | 247 ++++++++++++++----------- assets/pango/movement/movement_test.go | 37 +++- 2 files changed, 173 insertions(+), 111 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 92113440..0cd18660 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -100,6 +100,8 @@ var ( ErrInvalidMovementPlan = errors.New("created movement plan is invalid") ) +// PositionBefore and PositionAfter are similar enough that we can generate expected sequences +// for both using the same code and some conditionals based on the given movement. func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, direct bool, movement movementType) ([]Movable, []MoveAction, error) { existingIdxMap := createIdxMapFor(existing) @@ -118,39 +120,41 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, entriesLen := len(entries) loop: for i := 0; i < entriesLen; i++ { - + existingEntryIdx := existingIdxMap[entries[i]] + slog.Debug("generate()", "i", i, "len(entries)", len(entries), "entry", entries[i], "existingEntryIdx", existingEntryIdx, "existingPivotIdx", existingPivotIdx) // For any given entry in the list of entries to move check if the entry // index is at or after pivot point index, which will require movement // set to be generated. - existingEntryIdx := existingIdxMap[entries[i]] + + // Then check if the entries to be moved have the same order in the existing + // slice, and if not require a movement set to be generated. switch movement { case movementBefore: if existingEntryIdx >= existingPivotIdx { movementRequired = true break } - case movementAfter: - if existingEntryIdx <= existingPivotIdx { - movementRequired = true - break - } - } - if i == 0 { - continue - } + if i == 0 { + continue + } - // Check if the entries to be moved have the same order in the existing - // slice, and if not require a movement set to be generated. - switch movement { - case movementBefore: if existingIdxMap[entries[i-1]] >= existingEntryIdx { movementRequired = true break loop } case movementAfter: - if existingIdxMap[entries[i-1]] <= existingEntryIdx { + if existingEntryIdx <= existingPivotIdx { + movementRequired = true + break + } + + if i == len(entries)-1 { + continue + } + + if existingIdxMap[entries[i+1]] < existingEntryIdx { movementRequired = true break loop @@ -195,6 +199,22 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, expected[expectedIdx] = filtered[i] expectedIdx++ } + case movementAfter: + expectedIdx := 0 + for ; expectedIdx < len(filtered); expectedIdx++ { + expected[expectedIdx] = filtered[expectedIdx] + } + + for _, elt := range entries { + expected[expectedIdx] = elt + expectedIdx++ + } + + filteredLen := len(filtered) + for i := filteredPivotIdx + 1; i < filteredLen-1; i++ { + expected[expectedIdx] = filtered[i] + expectedIdx++ + } } actions, err := GenerateMovements(existing, expected, entries, movement) @@ -230,99 +250,6 @@ type sequencePosition struct { End int } -func printLCSMatrix(S []Movable, T []Movable, L [][]int) { - r := len(S) - n := len(T) - - line := " " - for _, elt := range S { - line += fmt.Sprintf("%s ", elt.EntryName()) - } - slog.Debug("LCS", "line", line) - - line = " " - for _, elt := range L[0] { - line += fmt.Sprintf("%d ", elt) - } - slog.Debug("LCS", "line", line) - - for i := 1; i < r+1; i++ { - line = fmt.Sprintf("%s ", T[i-1].EntryName()) - for j := 0; j < n+1; j++ { - line += fmt.Sprintf("%d ", L[i][j]) - } - } - -} - -func LongestCommonSubstring(S []Movable, T []Movable) [][]Movable { - r := len(S) - n := len(T) - - L := make([][]int, r+1) - for idx := range r + 1 { - L[idx] = make([]int, n+1) - } - - for i := 1; i < r+1; i++ { - for j := 1; j < n+1; j++ { - if S[i-1].EntryName() == T[j-1].EntryName() { - if i == 1 { - L[j][i] = 1 - } else if j == 1 { - L[j][i] = 1 - } else { - L[j][i] = L[j-1][i-1] + 1 - } - } - } - } - - var results [][]Movable - var lcsList [][]Movable - - var entry []Movable - var index int - for i := r; i > 0; i-- { - for j := n; j > 0; j-- { - if S[i-1].EntryName() == T[j-1].EntryName() { - if L[j][i] >= index { - if len(entry) > 0 { - var entries []string - for _, elt := range entry { - entries = append(entries, elt.EntryName()) - } - - lcsList = append(lcsList, entry) - } - index = L[j][i] - entry = []Movable{S[i-1]} - } else if L[j][i] < index { - index = L[j][i] - entry = append(entry, S[i-1]) - } else { - entry = []Movable{} - } - } - } - } - - if len(entry) > 0 { - lcsList = append(lcsList, entry) - } - - lcsLen := len(lcsList) - for idx := range lcsList { - elt := lcsList[lcsLen-idx-1] - if len(elt) > 1 { - slices.Reverse(elt) - results = append(results, elt) - } - } - - return results -} - func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx int, targetIdx int) { slog.Debug("updateSimulatedIdxMap", "entries", idxMap) for entry, idx := range *idxMap { @@ -399,6 +326,9 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable entriesIdxMap := createIdxMapFor(entries) + // LCS returns a list of longest common sequences found between existing and expected + // slices. We want to find the longest common sequence that doesn't intersect entries + // given by the user, as entries are moved in relation to the common sequence. var common []Movable for _, sequence := range commonSequences { filtered := removeEntriesFromExisting(sequence, func(elt Movable) bool { @@ -527,3 +457,102 @@ func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveActio func MoveGroup(position Position, entries []Movable, existing []Movable) ([]MoveAction, error) { return position.Move(entries, existing) } + +// Debug helper to print generated LCS matrix +func printLCSMatrix(S []Movable, T []Movable, L [][]int) { + r := len(S) + n := len(T) + + line := " " + for _, elt := range S { + line += fmt.Sprintf("%s ", elt.EntryName()) + } + slog.Debug("LCS", "line", line) + + line = " " + for _, elt := range L[0] { + line += fmt.Sprintf("%d ", elt) + } + slog.Debug("LCS", "line", line) + + for i := 1; i < r+1; i++ { + line = fmt.Sprintf("%s ", T[i-1].EntryName()) + for j := 0; j < n+1; j++ { + line += fmt.Sprintf("%d ", L[i][j]) + } + } + +} + +// LongestCommonSubstring implements dynamic programming variant of the algorithm +// +// See https://en.wikipedia.org/wiki/Longest_common_substring for the details. Our +// implementation is not optimal, as generation of the matrix can be done at the +// same time as finding LCSs, but it's easier to reason about for now. +func LongestCommonSubstring(S []Movable, T []Movable) [][]Movable { + r := len(S) + n := len(T) + + L := make([][]int, r+1) + for idx := range r + 1 { + L[idx] = make([]int, n+1) + } + + for i := 1; i < r+1; i++ { + for j := 1; j < n+1; j++ { + if S[i-1].EntryName() == T[j-1].EntryName() { + if i == 1 { + L[j][i] = 1 + } else if j == 1 { + L[j][i] = 1 + } else { + L[j][i] = L[j-1][i-1] + 1 + } + } + } + } + + var results [][]Movable + var lcsList [][]Movable + + var entry []Movable + var index int + for i := r; i > 0; i-- { + for j := n; j > 0; j-- { + if S[i-1].EntryName() == T[j-1].EntryName() { + if L[j][i] >= index { + if len(entry) > 0 { + var entries []string + for _, elt := range entry { + entries = append(entries, elt.EntryName()) + } + + lcsList = append(lcsList, entry) + } + index = L[j][i] + entry = []Movable{S[i-1]} + } else if L[j][i] < index { + index = L[j][i] + entry = append(entry, S[i-1]) + } else { + entry = []Movable{} + } + } + } + } + + if len(entry) > 0 { + lcsList = append(lcsList, entry) + } + + lcsLen := len(lcsList) + for idx := range lcsList { + elt := lcsList[lcsLen-idx-1] + if len(elt) > 1 { + slices.Reverse(elt) + results = append(results, elt) + } + } + + return results +} diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index 202f2a88..09a6dd66 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -118,6 +118,39 @@ var _ = Describe("Movement", func() { }) }) + Context("With PositionAfter used as position", func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + Context("when direct position relative to the pivot is not required", func() { + It("should not generate any move actions", func() { + entries := asMovable([]string{"D", "E"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: false, Pivot: Mock{"B"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + Context("and moved entries are out of order", func() { + FIt("should generate a single command to move B before D", func() { + // A B C D E -> A B C E D + entries := asMovable([]string{"E", "D"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: false, Pivot: Mock{"B"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("E")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("C")) + }) + }) + }) + + }) Context("With PositionBefore used as position", func() { existing := asMovable([]string{"A", "B", "C", "D", "E"}) @@ -135,8 +168,8 @@ var _ = Describe("Movement", func() { }) }) Context("and moved entries are out of order", func() { - It("should generate a single command to move B before D", func() { - // A B C D E -> A B C D E + FIt("should generate a single command to move B before D", func() { + // A B C D E -> A C B D E entries := asMovable([]string{"C", "B"}) moves, err := movement.MoveGroup( movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, From be9aa94ca4b9763dcde00fbc1ee674aca7443723 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Thu, 1 Aug 2024 11:04:54 +0200 Subject: [PATCH 05/19] Remove some debug logging --- assets/pango/movement/movement.go | 28 ++------------------------ assets/pango/movement/movement_test.go | 4 ++-- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 0cd18660..4625c84f 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -251,20 +251,17 @@ type sequencePosition struct { } func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx int, targetIdx int) { - slog.Debug("updateSimulatedIdxMap", "entries", idxMap) for entry, idx := range *idxMap { if entry == moved { continue } - slog.Debug("updateSimulatedIdxMap", "entry", entry, "idx", idx, "startingIdx", startingIdx, "targetIdx", targetIdx) if startingIdx > targetIdx && idx >= targetIdx { (*idxMap)[entry] = idx + 1 } else if startingIdx < targetIdx && idx >= startingIdx && idx <= targetIdx { (*idxMap)[entry] = idx - 1 } } - slog.Debug("updateSimulatedIdxMap", "entries", idxMap) } func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable, actions []MoveAction, position any) []MoveAction { @@ -277,10 +274,7 @@ func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable var optimized []MoveAction switch position.(type) { - case PositionBefore: - slog.Debug("OptimizeMovements()", "position", position, "type", fmt.Sprintf("%T", position)) - case PositionAfter: - slog.Debug("OptimizeMovements()", "position", position, "type", fmt.Sprintf("%T", position)) + case PositionBefore, PositionAfter: default: return actions } @@ -298,13 +292,11 @@ func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable case ActionWhereBottom: targetIdx = len(simulated) - 1 case ActionWhereBefore: - slog.Debug("OptimizeMovements()", "dest", action.Destination, "destIdx", simulatedIdxMap[action.Destination]) targetIdx = simulatedIdxMap[action.Destination] - 1 case ActionWhereAfter: targetIdx = simulatedIdxMap[action.Destination] + 1 } - slog.Debug("OptimizeMovements()", "action", action, "currentIdx", currentIdx, "targetIdx", targetIdx) if targetIdx != currentIdx { optimized = append(optimized, action) simulatedIdxMap[action.Movable] = targetIdx @@ -312,8 +304,7 @@ func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable } } - slog.Debug("OptimizeMovements()", "optimized", optimized) - + slog.Debug("OptimiveMovements()", "optimized", optimized) return optimized } @@ -348,25 +339,14 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable var movements []MoveAction - var commonStartIdx, commonEndIdx int - if commonLen > 0 { - commonStartIdx = expectedIdxMap[common[0]] - commonEndIdx = expectedIdxMap[common[commonLen-1]] - } - - slog.Debug("GenerateMovements()", "expected", expected) - slog.Debug("GenerateMovements()", "existing", existing) - slog.Debug("GenerateMovements()", "common", common, "commonStartIdx", commonStartIdx, "commonEndIdx", commonEndIdx) var previous Movable for _, elt := range entries { - slog.Debug("GenerateMovements()", "elt", elt, "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt]) // If existing index for the element matches the expected one, skip it over if existingIdxMap[elt] == expectedIdxMap[elt] { continue } if expectedIdxMap[elt] == 0 { - slog.Debug("HELP1") movements = append(movements, MoveAction{ Movable: elt, Destination: nil, @@ -374,7 +354,6 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable }) previous = elt } else if expectedIdxMap[elt] == len(expectedIdxMap) { - slog.Debug("HELP2") movements = append(movements, MoveAction{ Movable: elt, Destination: nil, @@ -382,8 +361,6 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable }) previous = elt } else if previous != nil { - slog.Debug("HELP3") - movements = append(movements, MoveAction{ Movable: elt, Destination: previous, @@ -391,7 +368,6 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable }) previous = elt } else { - slog.Debug("HELP4") var where ActionWhereType switch movement { diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index 09a6dd66..33f16ade 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -132,7 +132,7 @@ var _ = Describe("Movement", func() { Expect(moves).To(HaveLen(0)) }) Context("and moved entries are out of order", func() { - FIt("should generate a single command to move B before D", func() { + It("should generate a single command to move B before D", func() { // A B C D E -> A B C E D entries := asMovable([]string{"E", "D"}) moves, err := movement.MoveGroup( @@ -168,7 +168,7 @@ var _ = Describe("Movement", func() { }) }) Context("and moved entries are out of order", func() { - FIt("should generate a single command to move B before D", func() { + It("should generate a single command to move B before D", func() { // A B C D E -> A C B D E entries := asMovable([]string{"C", "B"}) moves, err := movement.MoveGroup( From 2ef650f200cf2202a0c2522d458ee93801e2a615 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Thu, 1 Aug 2024 11:58:57 +0200 Subject: [PATCH 06/19] Fixes to PositionAfter expected sequence generation --- assets/pango/movement/movement.go | 56 ++++++++++++++++++-------- assets/pango/movement/movement_test.go | 25 +++++++++++- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 4625c84f..633c8e9b 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -18,13 +18,6 @@ const ( ActionWhereAfter ActionWhereType = "after" ) -type entryPositionType int - -const ( - entryPositionBefore entryPositionType = iota - entryPositionAfter -) - type Movable interface { EntryName() string } @@ -179,6 +172,7 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, filteredPivotIdx := findPivotIdx(filtered, pivot) + slog.Debug("pivot()", "existing", existing, "filtered", filtered, "filteredPivotIdx", filteredPivotIdx) switch movement { case movementBefore: expectedIdx := 0 @@ -186,34 +180,61 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, expected[expectedIdx] = filtered[expectedIdx] } + slog.Debug("pivot()", "expected", expected) + for _, elt := range entries { expected[expectedIdx] = elt expectedIdx++ } + slog.Debug("pivot()", "expected", expected) + expected[expectedIdx] = pivot expectedIdx++ + slog.Debug("pivot()", "expected", expected) + filteredLen := len(filtered) for i := filteredPivotIdx + 1; i < filteredLen; i++ { expected[expectedIdx] = filtered[i] expectedIdx++ } + + slog.Debug("pivot()", "expected", expected) case movementAfter: + slog.Debug("pivot()", "filtered", filtered) expectedIdx := 0 - for ; expectedIdx < len(filtered); expectedIdx++ { + for ; expectedIdx < filteredPivotIdx+1; expectedIdx++ { expected[expectedIdx] = filtered[expectedIdx] } - for _, elt := range entries { - expected[expectedIdx] = elt - expectedIdx++ - } + if direct { + for _, elt := range entries { + expected[expectedIdx] = elt + expectedIdx++ + } - filteredLen := len(filtered) - for i := filteredPivotIdx + 1; i < filteredLen-1; i++ { - expected[expectedIdx] = filtered[i] - expectedIdx++ + slog.Debug("pivot()", "expected", expected) + + filteredLen := len(filtered) + for i := filteredPivotIdx + 1; i < filteredLen; i++ { + expected[expectedIdx] = filtered[i] + } + } else { + filteredLen := len(filtered) + for i := filteredPivotIdx + 1; i < filteredLen; i++ { + expected[expectedIdx] = filtered[i] + expectedIdx++ + } + + slog.Debug("pivot()", "expected", expected) + + for _, elt := range entries { + expected[expectedIdx] = elt + expectedIdx++ + } + + slog.Debug("pivot()", "expected", expected) } } @@ -264,7 +285,7 @@ func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx i } } -func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable, actions []MoveAction, position any) []MoveAction { +func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable, actions []MoveAction, position Position) []MoveAction { simulated := make([]Movable, len(existing)) copy(simulated, existing) @@ -309,6 +330,7 @@ func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable } func GenerateMovements(existing []Movable, expected []Movable, entries []Movable, movement movementType) ([]MoveAction, error) { + slog.Debug("GenerateMovements()", "existing", existing, "expected", expected) if len(existing) != len(expected) { return nil, ErrSlicesNotEqualLength } diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index 33f16ade..0703a54f 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -132,7 +132,7 @@ var _ = Describe("Movement", func() { Expect(moves).To(HaveLen(0)) }) Context("and moved entries are out of order", func() { - It("should generate a single command to move B before D", func() { + FIt("should generate a single command to move B before D", func() { // A B C D E -> A B C E D entries := asMovable([]string{"E", "D"}) moves, err := movement.MoveGroup( @@ -149,6 +149,27 @@ var _ = Describe("Movement", func() { }) }) }) + Context("when direct position relative to the pivot is required", func() { + It("should generate required move actions", func() { + // A B C D E -> C D A B E + entries := asMovable([]string{"A", "B"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: true, Pivot: Mock{"D"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("A")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) + + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("A")) + }) + }) }) Context("With PositionBefore used as position", func() { @@ -204,6 +225,8 @@ var _ = Describe("Movement", func() { Expect(moves[1].Movable.EntryName()).To(Equal("B")) Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) Expect(moves[1].Destination.EntryName()).To(Equal("A")) + + Expect(true).To(BeFalse()) }) }) }) From d2e3eed5d737b3274c94542c97950c821dd040f2 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Fri, 2 Aug 2024 09:48:07 +0200 Subject: [PATCH 07/19] Rewrite GenerateMovements() to work without LCS algorithm --- assets/pango/movement/movement.go | 302 +++++++++++-------------- assets/pango/movement/movement_test.go | 209 +++++++++++++---- 2 files changed, 298 insertions(+), 213 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 633c8e9b..57a15045 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -2,7 +2,6 @@ package movement import ( "errors" - "fmt" "log/slog" "slices" ) @@ -30,6 +29,7 @@ type MoveAction struct { type Position interface { Move(entries []Movable, existing []Movable) ([]MoveAction, error) + GetExpected(entries []Movable, existing []Movable) ([]Movable, error) } type PositionTop struct{} @@ -79,14 +79,8 @@ func findPivotIdx(entries []Movable, pivot Movable) int { } -type movementType int - -const ( - movementBefore movementType = iota - movementAfter -) - var ( + errNoMovements = errors.New("no movements needed") ErrSlicesNotEqualLength = errors.New("existing and expected slices length mismatch") ErrPivotInEntries = errors.New("pivot element found in the entries slice") ErrPivotNotInExisting = errors.New("pivot element not foudn in the existing slice") @@ -95,17 +89,17 @@ var ( // PositionBefore and PositionAfter are similar enough that we can generate expected sequences // for both using the same code and some conditionals based on the given movement. -func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, direct bool, movement movementType) ([]Movable, []MoveAction, error) { +func getPivotMovement(entries []Movable, existing []Movable, pivot Movable, direct bool, movement ActionWhereType) ([]Movable, error) { existingIdxMap := createIdxMapFor(existing) entriesPivotIdx := findPivotIdx(entries, pivot) if entriesPivotIdx != -1 { - return nil, nil, ErrPivotInEntries + return nil, ErrPivotInEntries } existingPivotIdx := findPivotIdx(existing, pivot) if existingPivotIdx == -1 { - return nil, nil, ErrPivotNotInExisting + return nil, ErrPivotNotInExisting } if !direct { @@ -114,7 +108,6 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, loop: for i := 0; i < entriesLen; i++ { existingEntryIdx := existingIdxMap[entries[i]] - slog.Debug("generate()", "i", i, "len(entries)", len(entries), "entry", entries[i], "existingEntryIdx", existingEntryIdx, "existingPivotIdx", existingPivotIdx) // For any given entry in the list of entries to move check if the entry // index is at or after pivot point index, which will require movement // set to be generated. @@ -122,7 +115,7 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, // Then check if the entries to be moved have the same order in the existing // slice, and if not require a movement set to be generated. switch movement { - case movementBefore: + case ActionWhereBefore: if existingEntryIdx >= existingPivotIdx { movementRequired = true break @@ -137,7 +130,7 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, break loop } - case movementAfter: + case ActionWhereAfter: if existingEntryIdx <= existingPivotIdx { movementRequired = true break @@ -157,7 +150,7 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, } if !movementRequired { - return nil, nil, nil + return nil, errNoMovements } } @@ -172,37 +165,28 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, filteredPivotIdx := findPivotIdx(filtered, pivot) - slog.Debug("pivot()", "existing", existing, "filtered", filtered, "filteredPivotIdx", filteredPivotIdx) switch movement { - case movementBefore: + case ActionWhereBefore: expectedIdx := 0 for ; expectedIdx < filteredPivotIdx; expectedIdx++ { expected[expectedIdx] = filtered[expectedIdx] } - slog.Debug("pivot()", "expected", expected) - for _, elt := range entries { expected[expectedIdx] = elt expectedIdx++ } - slog.Debug("pivot()", "expected", expected) - expected[expectedIdx] = pivot expectedIdx++ - slog.Debug("pivot()", "expected", expected) - filteredLen := len(filtered) for i := filteredPivotIdx + 1; i < filteredLen; i++ { expected[expectedIdx] = filtered[i] expectedIdx++ } - slog.Debug("pivot()", "expected", expected) - case movementAfter: - slog.Debug("pivot()", "filtered", filtered) + case ActionWhereAfter: expectedIdx := 0 for ; expectedIdx < filteredPivotIdx+1; expectedIdx++ { expected[expectedIdx] = filtered[expectedIdx] @@ -214,8 +198,6 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, expectedIdx++ } - slog.Debug("pivot()", "expected", expected) - filteredLen := len(filtered) for i := filteredPivotIdx + 1; i < filteredLen; i++ { expected[expectedIdx] = filtered[i] @@ -227,23 +209,31 @@ func processPivotMovement(entries []Movable, existing []Movable, pivot Movable, expectedIdx++ } - slog.Debug("pivot()", "expected", expected) - for _, elt := range entries { expected[expectedIdx] = elt expectedIdx++ } - slog.Debug("pivot()", "expected", expected) } } - actions, err := GenerateMovements(existing, expected, entries, movement) - return expected, actions, err + return expected, nil +} + +func (o PositionAfter) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + return getPivotMovement(entries, existing, o.Pivot, o.Directly, ActionWhereAfter) } func (o PositionAfter) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - expected, actions, err := processPivotMovement(entries, existing, o.Pivot, o.Directly, movementAfter) + expected, err := o.GetExpected(entries, existing) + if err != nil { + if errors.Is(err, errNoMovements) { + return nil, nil + } + return nil, err + } + + actions, err := GenerateMovements(existing, expected, entries, ActionWhereAfter, o.Pivot, o.Directly) if err != nil { return nil, err } @@ -251,8 +241,22 @@ func (o PositionAfter) Move(entries []Movable, existing []Movable) ([]MoveAction return OptimizeMovements(existing, expected, entries, actions, o), nil } +func (o PositionBefore) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + return getPivotMovement(entries, existing, o.Pivot, o.Directly, ActionWhereBefore) +} + func (o PositionBefore) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - expected, actions, err := processPivotMovement(entries, existing, o.Pivot, o.Directly, movementBefore) + expected, err := o.GetExpected(entries, existing) + if err != nil { + if errors.Is(err, errNoMovements) { + return nil, nil + } + return nil, err + } + + slog.Debug("PositionBefore.Move()", "existing", existing, "expected", expected, "entries", entries) + + actions, err := GenerateMovements(existing, expected, entries, ActionWhereBefore, o.Pivot, o.Directly) if err != nil { return nil, err } @@ -325,99 +329,102 @@ func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable } } - slog.Debug("OptimiveMovements()", "optimized", optimized) + slog.Debug("OptimizeMovements()", "optimized", optimized) + return optimized } -func GenerateMovements(existing []Movable, expected []Movable, entries []Movable, movement movementType) ([]MoveAction, error) { - slog.Debug("GenerateMovements()", "existing", existing, "expected", expected) +func GenerateMovements(existing []Movable, expected []Movable, entries []Movable, movement ActionWhereType, pivot Movable, directly bool) ([]MoveAction, error) { if len(existing) != len(expected) { return nil, ErrSlicesNotEqualLength } - commonSequences := LongestCommonSubstring(existing, expected) - entriesIdxMap := createIdxMapFor(entries) - - // LCS returns a list of longest common sequences found between existing and expected - // slices. We want to find the longest common sequence that doesn't intersect entries - // given by the user, as entries are moved in relation to the common sequence. - var common []Movable - for _, sequence := range commonSequences { - filtered := removeEntriesFromExisting(sequence, func(elt Movable) bool { - _, ok := entriesIdxMap[elt] - return ok - }) - - if len(filtered) > len(common) { - common = filtered - } - - } - commonLen := len(common) - existingIdxMap := createIdxMapFor(existing) expectedIdxMap := createIdxMapFor(expected) var movements []MoveAction - var previous Movable for _, elt := range entries { + slog.Debug("GenerateMovements()", "elt", elt, "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt], "len(expected)", len(expected)) // If existing index for the element matches the expected one, skip it over if existingIdxMap[elt] == expectedIdxMap[elt] { continue } - if expectedIdxMap[elt] == 0 { + if previous != nil { movements = append(movements, MoveAction{ Movable: elt, - Destination: nil, - Where: ActionWhereTop, + Destination: previous, + Where: ActionWhereAfter, }) previous = elt - } else if expectedIdxMap[elt] == len(expectedIdxMap) { + continue + } + if expectedIdxMap[elt] == 0 { movements = append(movements, MoveAction{ Movable: elt, Destination: nil, - Where: ActionWhereBottom, + Where: ActionWhereTop, }) previous = elt - } else if previous != nil { + } else if expectedIdxMap[elt] == len(expectedIdxMap)-1 { movements = append(movements, MoveAction{ Movable: elt, - Destination: previous, - Where: ActionWhereAfter, + Destination: nil, + Where: ActionWhereBottom, }) previous = elt } else { var where ActionWhereType + var pivot Movable switch movement { - case movementAfter: - previous = common[commonLen-1] + case ActionWhereBottom: + where = ActionWhereBottom + case ActionWhereAfter: + pivot = expected[expectedIdxMap[elt]-1] where = ActionWhereAfter - case movementBefore: - previous = common[0] + case ActionWhereTop: + pivot = existing[0] + where = ActionWhereBefore + case ActionWhereBefore: + eltExpectedIdx := expectedIdxMap[elt] + pivot = expected[eltExpectedIdx+1] where = ActionWhereBefore + // if previous was nil (we are processing the first element in entries set) + // and selected pivot is part of the entries set it means the order of entries + // changes between existing and expected sets. If direct move has been requested, + // we need to find the correct pivot point for the move. + if _, ok := entriesIdxMap[pivot]; ok && directly { + // The actual pivot for the move is the element that follows all elements + // from the existing set. + pivotIdx := eltExpectedIdx + len(entries) + if pivotIdx >= len(expected) { + // This should never happen as by definition there is at least + // element (pivot point) at the end of the expected slice. + return nil, ErrInvalidMovementPlan + } + pivot = expected[pivotIdx] + } } movements = append(movements, MoveAction{ Movable: elt, - Destination: previous, + Destination: pivot, Where: where, }) previous = elt } - } - _ = previous + } - slog.Debug("GenerateMovements()", "movements", movements) + slog.Debug("GeneraveMovements()", "movements", movements) return movements, nil } -func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { +func (o PositionTop) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { entriesIdxMap := createIdxMapFor(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { @@ -427,7 +434,15 @@ func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, expected := append(entries, filtered...) - actions, err := GenerateMovements(existing, expected, entries, movementBefore) + return expected, nil +} + +func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + expected, err := o.GetExpected(entries, existing) + if err != nil { + return nil, err + } + actions, err := GenerateMovements(existing, expected, entries, ActionWhereTop, nil, false) if err != nil { return nil, err } @@ -435,7 +450,7 @@ func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, return OptimizeMovements(existing, expected, entries, actions, o), nil } -func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { +func (o PositionBottom) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { entriesIdxMap := createIdxMapFor(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { @@ -445,112 +460,55 @@ func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveActio expected := append(filtered, entries...) - actions, err := GenerateMovements(existing, expected, entries, movementAfter) - if err != nil { - return nil, err - } - return OptimizeMovements(existing, expected, entries, actions, o), nil -} - -func MoveGroup(position Position, entries []Movable, existing []Movable) ([]MoveAction, error) { - return position.Move(entries, existing) + return expected, nil } -// Debug helper to print generated LCS matrix -func printLCSMatrix(S []Movable, T []Movable, L [][]int) { - r := len(S) - n := len(T) - - line := " " - for _, elt := range S { - line += fmt.Sprintf("%s ", elt.EntryName()) - } - slog.Debug("LCS", "line", line) - - line = " " - for _, elt := range L[0] { - line += fmt.Sprintf("%d ", elt) +func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { + slog.Debug("PositionBottom.Move())", "entries", entries, "existing", existing) + expected, err := o.GetExpected(entries, existing) + if err != nil { + return nil, err } - slog.Debug("LCS", "line", line) - for i := 1; i < r+1; i++ { - line = fmt.Sprintf("%s ", T[i-1].EntryName()) - for j := 0; j < n+1; j++ { - line += fmt.Sprintf("%d ", L[i][j]) - } + actions, err := GenerateMovements(existing, expected, entries, ActionWhereBottom, nil, false) + if err != nil { + return nil, err } - + return OptimizeMovements(existing, expected, entries, actions, o), nil } -// LongestCommonSubstring implements dynamic programming variant of the algorithm -// -// See https://en.wikipedia.org/wiki/Longest_common_substring for the details. Our -// implementation is not optimal, as generation of the matrix can be done at the -// same time as finding LCSs, but it's easier to reason about for now. -func LongestCommonSubstring(S []Movable, T []Movable) [][]Movable { - r := len(S) - n := len(T) - - L := make([][]int, r+1) - for idx := range r + 1 { - L[idx] = make([]int, n+1) - } - - for i := 1; i < r+1; i++ { - for j := 1; j < n+1; j++ { - if S[i-1].EntryName() == T[j-1].EntryName() { - if i == 1 { - L[j][i] = 1 - } else if j == 1 { - L[j][i] = 1 - } else { - L[j][i] = L[j-1][i-1] + 1 - } - } - } - } +type Movement struct { + Entries []Movable + Position Position +} - var results [][]Movable - var lcsList [][]Movable - - var entry []Movable - var index int - for i := r; i > 0; i-- { - for j := n; j > 0; j-- { - if S[i-1].EntryName() == T[j-1].EntryName() { - if L[j][i] >= index { - if len(entry) > 0 { - var entries []string - for _, elt := range entry { - entries = append(entries, elt.EntryName()) - } - - lcsList = append(lcsList, entry) - } - index = L[j][i] - entry = []Movable{S[i-1]} - } else if L[j][i] < index { - index = L[j][i] - entry = append(entry, S[i-1]) - } else { - entry = []Movable{} - } +func MoveGroups(existing []Movable, movements []Movement) ([]MoveAction, error) { + expected := existing + for idx := range len(movements) - 1 { + position := movements[idx].Position + entries := movements[idx].Entries + slog.Debug("MoveGroups()", "position", position, "existing", existing, "entries", entries) + result, err := position.GetExpected(entries, expected) + if err != nil { + if !errors.Is(err, errNoMovements) { + return nil, err } + continue } + expected = result } - if len(entry) > 0 { - lcsList = append(lcsList, entry) - } + entries := movements[len(movements)-1].Entries + position := movements[len(movements)-1].Position + slog.Debug("MoveGroups()", "position", position, "expected", expected, "entries", entries) + return position.Move(entries, expected) +} - lcsLen := len(lcsList) - for idx := range lcsList { - elt := lcsList[lcsLen-idx-1] - if len(elt) > 1 { - slices.Reverse(elt) - results = append(results, elt) - } - } +func MoveGroup(position Position, entries []Movable, existing []Movable) ([]MoveAction, error) { + return position.Move(entries, existing) +} - return results +type Move struct { + Position Position + Existing []Movable } diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index 0703a54f..38162e04 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -29,39 +29,11 @@ func asMovable(mocks []string) []movement.Movable { return movables } -var _ = Describe("LCS", func() { - Context("with two common substrings", func() { - existing := asMovable([]string{"A", "B", "C", "D", "E"}) - expected := asMovable([]string{"C", "A", "B", "D", "E"}) - It("should return two sequences of two elements", func() { - options := movement.LongestCommonSubstring(existing, expected) - Expect(options).To(HaveLen(2)) - - Expect(options[0]).To(HaveExactElements(asMovable([]string{"A", "B"}))) - Expect(options[1]).To(HaveExactElements(asMovable([]string{"D", "E"}))) - }) - }) - // Context("with one very large common substring", func() { - // It("should return one sequence of elements in a reasonable time", Label("benchmark"), func() { - // var elts []string - // elements := 50000 - // for idx := range elements { - // elts = append(elts, fmt.Sprintf("%d", idx)) - // } - // existing := asMovable(elts) - // expected := existing - - // options := movement.LongestCommonSubstring(existing, expected) - // Expect(options).To(HaveLen(1)) - // Expect(options[0]).To(HaveLen(elements)) - // }) - // }) -}) - -var _ = Describe("Movement", func() { +var _ = Describe("MoveGroup()", func() { Context("With PositionTop used as position", func() { Context("when existing positions matches expected", func() { It("should generate no movements", func() { + // '(A B C) -> '(A B C) expected := asMovable([]string{"A", "B", "C"}) moves, err := movement.MoveGroup(movement.PositionTop{}, expected, expected) Expect(err).ToNot(HaveOccurred()) @@ -70,6 +42,7 @@ var _ = Describe("Movement", func() { }) Context("when it has to move two elements", func() { It("should generate three move actions", func() { + // '(D E A B C) -> '(A B C D E) entries := asMovable([]string{"A", "B", "C"}) existing := asMovable([]string{"D", "E", "A", "B", "C"}) @@ -92,6 +65,7 @@ var _ = Describe("Movement", func() { }) Context("when expected order is reversed", func() { It("should generate required move actions to converge lists", func() { + // '(A B C D E) -> '(E D C B A) entries := asMovable([]string{"E", "D", "C", "B", "A"}) existing := asMovable([]string{"A", "B", "C", "D", "E"}) moves, err := movement.MoveGroup(movement.PositionTop{}, entries, existing) @@ -101,9 +75,31 @@ var _ = Describe("Movement", func() { }) }) }) + Context("With PositionBottom used as position", func() { + Context("with non-consecutive entries", func() { + It("should generate two move actions", func() { + // '(A E B C D) -> '(A B D E C) + entries := asMovable([]string{"E", "C"}) + existing := asMovable([]string{"A", "E", "B", "C", "D"}) + + moves, err := movement.MoveGroup(movement.PositionBottom{}, entries, existing) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("E")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBottom)) + Expect(moves[0].Destination).To(BeNil()) + + Expect(moves[1].Movable.EntryName()).To(Equal("C")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("E")) + }) + }) + }) Context("With PositionBottom used as position", func() { Context("when it needs to move one element", func() { It("should generate a single move action", func() { + // '(A E B C D) -> '(A B C D E) entries := asMovable([]string{"E"}) existing := asMovable([]string{"A", "E", "B", "C", "D"}) @@ -112,8 +108,8 @@ var _ = Describe("Movement", func() { Expect(moves).To(HaveLen(1)) Expect(moves[0].Movable.EntryName()).To(Equal("E")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) - Expect(moves[0].Destination.EntryName()).To(Equal("D")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBottom)) + Expect(moves[0].Destination).To(BeNil()) }) }) }) @@ -122,6 +118,7 @@ var _ = Describe("Movement", func() { existing := asMovable([]string{"A", "B", "C", "D", "E"}) Context("when direct position relative to the pivot is not required", func() { It("should not generate any move actions", func() { + // '(A B C D E) -> '(A B C D E) entries := asMovable([]string{"D", "E"}) moves, err := movement.MoveGroup( movement.PositionAfter{Directly: false, Pivot: Mock{"B"}}, @@ -132,8 +129,8 @@ var _ = Describe("Movement", func() { Expect(moves).To(HaveLen(0)) }) Context("and moved entries are out of order", func() { - FIt("should generate a single command to move B before D", func() { - // A B C D E -> A B C E D + It("should generate a single command to move B before D", func() { + // '(A B C D E) -> '(A B C E D) entries := asMovable([]string{"E", "D"}) moves, err := movement.MoveGroup( movement.PositionAfter{Directly: false, Pivot: Mock{"B"}}, @@ -151,7 +148,7 @@ var _ = Describe("Movement", func() { }) Context("when direct position relative to the pivot is required", func() { It("should generate required move actions", func() { - // A B C D E -> C D A B E + // '(A B C D E) -> '(C D A B E) entries := asMovable([]string{"A", "B"}) moves, err := movement.MoveGroup( movement.PositionAfter{Directly: true, Pivot: Mock{"D"}}, @@ -170,7 +167,27 @@ var _ = Describe("Movement", func() { Expect(moves[1].Destination.EntryName()).To(Equal("A")) }) }) + Context("when direct position relative to the pivot is required", func() { + It("should generate required move actions", func() { + // '(A B C D E) -> '(C D B A E) + entries := asMovable([]string{"B", "A"}) + moves, err := movement.MoveGroup( + movement.PositionAfter{Directly: true, Pivot: Mock{"D"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("B")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("D")) + Expect(moves[1].Movable.EntryName()).To(Equal("A")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("B")) + }) + }) }) Context("With PositionBefore used as position", func() { existing := asMovable([]string{"A", "B", "C", "D", "E"}) @@ -178,6 +195,7 @@ var _ = Describe("Movement", func() { Context("when direct position relative to the pivot is not required", func() { Context("and moved entries are already before pivot point", func() { It("should not generate any move actions", func() { + // '(A B C D E) -> '(A B C D E) entries := asMovable([]string{"A", "B"}) moves, err := movement.MoveGroup( movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, @@ -190,7 +208,7 @@ var _ = Describe("Movement", func() { }) Context("and moved entries are out of order", func() { It("should generate a single command to move B before D", func() { - // A B C D E -> A C B D E + // '(A B C D E) -> '(A C B D E) entries := asMovable([]string{"C", "B"}) moves, err := movement.MoveGroup( movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, @@ -200,15 +218,45 @@ var _ = Describe("Movement", func() { Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(1)) - Expect(moves[0].Movable.EntryName()).To(Equal("B")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) - Expect(moves[0].Destination.EntryName()).To(Equal("C")) + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("B")) + }) + }) + Context("and moved entries are out of order", func() { + It("should generate a single command to move B before D", func() { + // '(A B C D E) -> '(A B C D E) + entries := asMovable([]string{"A", "C"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(0)) + }) + }) + Context("and moved entries are out of order", func() { + It("should generate a single command to move B before D", func() { + // '(A B C D E) -> '(A C B D E) + entries := asMovable([]string{"A", "C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("B")) }) }) }) Context("when direct position relative to the pivot is required", func() { It("should generate required move actions", func() { - // A B C D E -> C A B D E + // '(A B C D E) -> '(C A B D E) entries := asMovable([]string{"A", "B"}) moves, err := movement.MoveGroup( movement.PositionBefore{Directly: true, Pivot: Mock{"D"}}, @@ -225,9 +273,88 @@ var _ = Describe("Movement", func() { Expect(moves[1].Movable.EntryName()).To(Equal("B")) Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) Expect(moves[1].Destination.EntryName()).To(Equal("A")) + }) + }) + Context("when passing single Movement to MoveGroups()", func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + It("should return a set of move actions that describe it", func() { + // '(A B C D E) -> '(A D B C E) + entries := asMovable([]string{"B", "C"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: Mock{"E"}}, + entries, existing) - Expect(true).To(BeFalse()) + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) }) }) }) }) + +var _ = Describe("MoveGroups()", Label("MoveGroups"), func() { + existing := asMovable([]string{"A", "B", "C", "D", "E"}) + Context("when passing single Movement to MoveGroups()", func() { + It("should return a set of move actions that describe it", func() { + // '(A B C D E) -> '(A D B C E) + entries := asMovable([]string{"B", "C"}) + movements := []movement.Movement{{ + Entries: entries, + Position: movement.PositionBefore{ + Directly: true, + Pivot: Mock{"E"}, + }}} + moves, err := movement.MoveGroups(existing, movements) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + }) + }) + // Context("when passing single Movement to MoveGroups()", func() { + // FIt("should return a set of move actions that describe it", func() { + // // '(A B C D E) -> '(A D B C E) -> '(D B C E A) + // movements := []movement.Movement{ + // { + // Entries: asMovable([]string{"B", "C"}), + // Position: movement.PositionBefore{ + // Directly: true, + // Pivot: Mock{"E"}}, + // }, + // { + // Entries: asMovable([]string{"A"}), + // Position: movement.PositionBottom{}, + // }, + // } + // moves, err := movement.MoveGroups(existing, movements) + + // Expect(err).ToNot(HaveOccurred()) + // Expect(moves).To(HaveLen(3)) + // }) + // }) +}) + +var _ = Describe("Movement benchmarks", func() { + BeforeEach(func() { + if !Label("benchmark").MatchesLabelFilter(GinkgoLabelFilter()) { + Skip("unless label 'benchmark' is specified.") + } + }) + Context("when moving only a few elements", func() { + It("should generate a simple sequence of actions", Label("benchmark"), func() { + var elts []string + elements := 50000 + for idx := range elements { + elts = append(elts, fmt.Sprintf("%d", idx)) + } + existing := asMovable(elts) + + entries := asMovable([]string{"90", "80", "70", "60", "50", "40"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: Mock{"100"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(6)) + }) + }) +}) From b010b2c529d707c0820196f70d309baef3197b8b Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Fri, 2 Aug 2024 11:31:31 +0200 Subject: [PATCH 08/19] Fix edge cases in GenerateMovements() and OptimizeMovements() interaction Now GenerateMovements() generate all movements blindly, and depend on the optimization done by OptimizeMovements() to remove reduntant actions. --- assets/pango/movement/movement.go | 39 ++++++------------ assets/pango/movement/movement_test.go | 55 +++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 30 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 57a15045..52b57411 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -254,8 +254,6 @@ func (o PositionBefore) Move(entries []Movable, existing []Movable) ([]MoveActio return nil, err } - slog.Debug("PositionBefore.Move()", "existing", existing, "expected", expected, "entries", entries) - actions, err := GenerateMovements(existing, expected, entries, ActionWhereBefore, o.Pivot, o.Directly) if err != nil { return nil, err @@ -292,23 +290,11 @@ func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx i func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable, actions []MoveAction, position Position) []MoveAction { simulated := make([]Movable, len(existing)) copy(simulated, existing) - simulatedIdxMap := createIdxMapFor(simulated) - expectedIdxMap := createIdxMapFor(expected) var optimized []MoveAction - - switch position.(type) { - case PositionBefore, PositionAfter: - default: - return actions - } - for _, action := range actions { currentIdx := simulatedIdxMap[action.Movable] - if currentIdx == expectedIdxMap[action.Movable] { - continue - } var targetIdx int switch action.Where { @@ -317,11 +303,13 @@ func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable case ActionWhereBottom: targetIdx = len(simulated) - 1 case ActionWhereBefore: - targetIdx = simulatedIdxMap[action.Destination] - 1 + targetIdx = simulatedIdxMap[action.Destination] case ActionWhereAfter: targetIdx = simulatedIdxMap[action.Destination] + 1 } + slog.Debug("OptimizeMovements()", "action", action, "currentIdx", currentIdx, "targetIdx", targetIdx) + if targetIdx != currentIdx { optimized = append(optimized, action) simulatedIdxMap[action.Movable] = targetIdx @@ -346,11 +334,7 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable var movements []MoveAction var previous Movable for _, elt := range entries { - slog.Debug("GenerateMovements()", "elt", elt, "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt], "len(expected)", len(expected)) - // If existing index for the element matches the expected one, skip it over - if existingIdxMap[elt] == expectedIdxMap[elt] { - continue - } + slog.Debug("GeneraveMovements()", "elt", elt, "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt]) if previous != nil { movements = append(movements, MoveAction{ @@ -392,10 +376,10 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable eltExpectedIdx := expectedIdxMap[elt] pivot = expected[eltExpectedIdx+1] where = ActionWhereBefore - // if previous was nil (we are processing the first element in entries set) - // and selected pivot is part of the entries set it means the order of entries - // changes between existing and expected sets. If direct move has been requested, - // we need to find the correct pivot point for the move. + // If previous was nil (we are processing the first element in entries set) + // and selected pivot is part of the entries set, it means the order of elements + // changes between existing adn expected sets. In this case the actual pivot + // is element from expected set that follows all moved elements. if _, ok := entriesIdxMap[pivot]; ok && directly { // The actual pivot for the move is the element that follows all elements // from the existing set. @@ -406,6 +390,7 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable return nil, ErrInvalidMovementPlan } pivot = expected[pivotIdx] + } } @@ -419,7 +404,7 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable } - slog.Debug("GeneraveMovements()", "movements", movements) + slog.Debug("GenerateMovements()", "movements", movements) return movements, nil } @@ -464,7 +449,6 @@ func (o PositionBottom) GetExpected(entries []Movable, existing []Movable) ([]Mo } func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { - slog.Debug("PositionBottom.Move())", "entries", entries, "existing", existing) expected, err := o.GetExpected(entries, existing) if err != nil { return nil, err @@ -472,6 +456,7 @@ func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveActio actions, err := GenerateMovements(existing, expected, entries, ActionWhereBottom, nil, false) if err != nil { + slog.Debug("PositionBottom()", "err", err) return nil, err } return OptimizeMovements(existing, expected, entries, actions, o), nil @@ -487,7 +472,6 @@ func MoveGroups(existing []Movable, movements []Movement) ([]MoveAction, error) for idx := range len(movements) - 1 { position := movements[idx].Position entries := movements[idx].Entries - slog.Debug("MoveGroups()", "position", position, "existing", existing, "entries", entries) result, err := position.GetExpected(entries, expected) if err != nil { if !errors.Is(err, errNoMovements) { @@ -500,7 +484,6 @@ func MoveGroups(existing []Movable, movements []Movement) ([]MoveAction, error) entries := movements[len(movements)-1].Entries position := movements[len(movements)-1].Position - slog.Debug("MoveGroups()", "position", position, "expected", expected, "entries", entries) return position.Move(entries, expected) } diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index 38162e04..bc48e6de 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -71,6 +71,8 @@ var _ = Describe("MoveGroup()", func() { moves, err := movement.MoveGroup(movement.PositionTop{}, entries, existing) Expect(err).ToNot(HaveOccurred()) + // '((E 'top nil)(B 'after E)(C 'after B)(D 'after C)) + // 'A element stays in place Expect(moves).To(HaveLen(4)) }) }) @@ -189,9 +191,54 @@ var _ = Describe("MoveGroup()", func() { }) }) }) + + // '(A E B C D) -> '(A B C D E) => '(E 'bottom nil) / '(E 'after D) + + // PositionSomewhereBefore PositionDirectlyBefore + // '(C B 'before E, directly) + // '(A B C D E) -> '(A D C B E) -> '(B 'before E) + // '(A B C D E) -> '(A C B D E) -> '(B 'after C) + Context("With PositionBefore used as position", func() { existing := asMovable([]string{"A", "B", "C", "D", "E"}) + Context("when doing a direct move with entries reordering", func() { + It("should put reordered entries directly before pivot point", func() { + // '(A B C D E) -> '(A D C B E) + entries := asMovable([]string{"C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: true, Pivot: Mock{"E"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(2)) + + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("E")) + + Expect(moves[1].Movable.EntryName()).To(Equal("B")) + Expect(moves[1].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[1].Destination.EntryName()).To(Equal("C")) + }) + }) + Context("when doing a non direct move with entries reordering", func() { + It("should reorder entries in-place without moving them around", func() { + // '(A B C D E) -> '(A C B D E) + entries := asMovable([]string{"C", "B"}) + moves, err := movement.MoveGroup( + movement.PositionBefore{Directly: false, Pivot: Mock{"E"}}, + entries, existing, + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(moves).To(HaveLen(1)) + Expect(moves[0].Movable.EntryName()).To(Equal("C")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("B")) + }) + }) Context("when direct position relative to the pivot is not required", func() { Context("and moved entries are already before pivot point", func() { It("should not generate any move actions", func() { @@ -249,8 +296,8 @@ var _ = Describe("MoveGroup()", func() { Expect(moves).To(HaveLen(1)) Expect(moves[0].Movable.EntryName()).To(Equal("C")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) - Expect(moves[0].Destination.EntryName()).To(Equal("B")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereAfter)) + Expect(moves[0].Destination.EntryName()).To(Equal("A")) }) }) }) @@ -355,6 +402,10 @@ var _ = Describe("Movement benchmarks", func() { Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(6)) + + Expect(moves[0].Movable.EntryName()).To(Equal("90")) + Expect(moves[0].Where).To(Equal(movement.ActionWhereBefore)) + Expect(moves[0].Destination.EntryName()).To(Equal("100")) }) }) }) From 583fc99c7dccbafc846b93e663540fd25b6f6b0d Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Tue, 3 Sep 2024 12:14:43 +0200 Subject: [PATCH 09/19] Update movement code to better fit terraform provider --- assets/pango/movement/movement.go | 218 +++++++++++++++++-------- assets/pango/movement/movement_test.go | 54 +++--- 2 files changed, 177 insertions(+), 95 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 52b57411..0e73c01d 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -11,8 +11,8 @@ var _ = slog.LevelDebug type ActionWhereType string const ( - ActionWhereTop ActionWhereType = "top" - ActionWhereBottom ActionWhereType = "bottom" + ActionWhereFirst ActionWhereType = "first" + ActionWhereLast ActionWhereType = "last" ActionWhereBefore ActionWhereType = "before" ActionWhereAfter ActionWhereType = "after" ) @@ -30,26 +30,85 @@ type MoveAction struct { type Position interface { Move(entries []Movable, existing []Movable) ([]MoveAction, error) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) + IsDirectly() bool + Where() ActionWhereType + PivotEntryName() string } -type PositionTop struct{} +type PositionFirst struct{} -type PositionBottom struct{} +func (o PositionFirst) IsDirectly() bool { + return false +} + +func (o PositionFirst) Where() ActionWhereType { + return ActionWhereFirst +} + +func (o PositionFirst) PivotEntryName() string { + return "" +} + +type PositionLast struct{} + +func (o PositionLast) IsDirectly() bool { + return false +} + +func (o PositionLast) Where() ActionWhereType { + return ActionWhereLast +} + +func (o PositionLast) PivotEntryName() string { + return "" +} type PositionBefore struct { Directly bool - Pivot Movable + Pivot string +} + +func (o PositionBefore) IsDirectly() bool { + return o.Directly +} + +func (o PositionBefore) Where() ActionWhereType { + return ActionWhereBefore +} + +func (o PositionBefore) PivotEntryName() string { + return o.Pivot } type PositionAfter struct { Directly bool - Pivot Movable + Pivot string +} + +func (o PositionAfter) IsDirectly() bool { + return o.Directly +} + +func (o PositionAfter) Where() ActionWhereType { + return ActionWhereAfter +} + +func (o PositionAfter) PivotEntryName() string { + return o.Pivot +} + +type entryWithIdx[E Movable] struct { + Entry E + Idx int } -func createIdxMapFor(entries []Movable) map[Movable]int { - entriesIdxMap := make(map[Movable]int, len(entries)) +func entriesByName[E Movable](entries []E) map[string]entryWithIdx[E] { + entriesIdxMap := make(map[string]entryWithIdx[E], len(entries)) for idx, elt := range entries { - entriesIdxMap[elt] = idx + entriesIdxMap[elt.EntryName()] = entryWithIdx[E]{ + Entry: elt, + Idx: idx, + } } return entriesIdxMap } @@ -68,15 +127,18 @@ func removeEntriesFromExisting(entries []Movable, filterFn func(entry Movable) b return filtered } -func findPivotIdx(entries []Movable, pivot Movable) int { - return slices.IndexFunc(entries, func(entry Movable) bool { - if entry.EntryName() == pivot.EntryName() { +func findPivotIdx(entries []Movable, pivot string) (int, Movable) { + var pivotEntry Movable + pivotIdx := slices.IndexFunc(entries, func(entry Movable) bool { + if entry.EntryName() == pivot { + pivotEntry = entry return true } return false }) + return pivotIdx, pivotEntry } var ( @@ -89,15 +151,15 @@ var ( // PositionBefore and PositionAfter are similar enough that we can generate expected sequences // for both using the same code and some conditionals based on the given movement. -func getPivotMovement(entries []Movable, existing []Movable, pivot Movable, direct bool, movement ActionWhereType) ([]Movable, error) { - existingIdxMap := createIdxMapFor(existing) +func getPivotMovement(entries []Movable, existing []Movable, pivot string, direct bool, movement ActionWhereType) ([]Movable, error) { + existingIdxMap := entriesByName(existing) - entriesPivotIdx := findPivotIdx(entries, pivot) + entriesPivotIdx, _ := findPivotIdx(entries, pivot) if entriesPivotIdx != -1 { return nil, ErrPivotInEntries } - existingPivotIdx := findPivotIdx(existing, pivot) + existingPivotIdx, _ := findPivotIdx(existing, pivot) if existingPivotIdx == -1 { return nil, ErrPivotNotInExisting } @@ -107,7 +169,7 @@ func getPivotMovement(entries []Movable, existing []Movable, pivot Movable, dire entriesLen := len(entries) loop: for i := 0; i < entriesLen; i++ { - existingEntryIdx := existingIdxMap[entries[i]] + existingEntryIdx := existingIdxMap[entries[i].EntryName()].Idx // For any given entry in the list of entries to move check if the entry // index is at or after pivot point index, which will require movement // set to be generated. @@ -125,7 +187,7 @@ func getPivotMovement(entries []Movable, existing []Movable, pivot Movable, dire continue } - if existingIdxMap[entries[i-1]] >= existingEntryIdx { + if existingIdxMap[entries[i-1].EntryName()].Idx >= existingEntryIdx { movementRequired = true break loop @@ -140,7 +202,7 @@ func getPivotMovement(entries []Movable, existing []Movable, pivot Movable, dire continue } - if existingIdxMap[entries[i+1]] < existingEntryIdx { + if existingIdxMap[entries[i+1].EntryName()].Idx < existingEntryIdx { movementRequired = true break loop @@ -156,14 +218,14 @@ func getPivotMovement(entries []Movable, existing []Movable, pivot Movable, dire expected := make([]Movable, len(existing)) - entriesIdxMap := createIdxMapFor(entries) + entriesIdxMap := entriesByName(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { - _, ok := entriesIdxMap[entry] + _, ok := entriesIdxMap[entry.EntryName()] return ok }) - filteredPivotIdx := findPivotIdx(filtered, pivot) + filteredPivotIdx, pivotEntry := findPivotIdx(filtered, pivot) switch movement { case ActionWhereBefore: @@ -177,7 +239,7 @@ func getPivotMovement(entries []Movable, existing []Movable, pivot Movable, dire expectedIdx++ } - expected[expectedIdx] = pivot + expected[expectedIdx] = pivotEntry expectedIdx++ filteredLen := len(filtered) @@ -273,16 +335,20 @@ type sequencePosition struct { End int } -func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx int, targetIdx int) { - for entry, idx := range *idxMap { - if entry == moved { +func updateSimulatedIdxMap[E Movable](idxMap *map[string]entryWithIdx[E], moved Movable, startingIdx int, targetIdx int) { + for name, entry := range *idxMap { + if name == moved.EntryName() { continue } + idx := entry.Idx + if startingIdx > targetIdx && idx >= targetIdx { - (*idxMap)[entry] = idx + 1 + entry.Idx = idx + 1 + (*idxMap)[name] = entry } else if startingIdx < targetIdx && idx >= startingIdx && idx <= targetIdx { - (*idxMap)[entry] = idx - 1 + entry.Idx = idx - 1 + (*idxMap)[name] = entry } } } @@ -290,29 +356,31 @@ func updateSimulatedIdxMap(idxMap *map[Movable]int, moved Movable, startingIdx i func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable, actions []MoveAction, position Position) []MoveAction { simulated := make([]Movable, len(existing)) copy(simulated, existing) - simulatedIdxMap := createIdxMapFor(simulated) + simulatedIdxMap := entriesByName(simulated) var optimized []MoveAction for _, action := range actions { - currentIdx := simulatedIdxMap[action.Movable] + currentIdx := simulatedIdxMap[action.Movable.EntryName()].Idx var targetIdx int switch action.Where { - case ActionWhereTop: + case ActionWhereFirst: targetIdx = 0 - case ActionWhereBottom: + case ActionWhereLast: targetIdx = len(simulated) - 1 case ActionWhereBefore: - targetIdx = simulatedIdxMap[action.Destination] + targetIdx = simulatedIdxMap[action.Destination.EntryName()].Idx case ActionWhereAfter: - targetIdx = simulatedIdxMap[action.Destination] + 1 + targetIdx = simulatedIdxMap[action.Destination.EntryName()].Idx + 1 } slog.Debug("OptimizeMovements()", "action", action, "currentIdx", currentIdx, "targetIdx", targetIdx) if targetIdx != currentIdx { optimized = append(optimized, action) - simulatedIdxMap[action.Movable] = targetIdx + entry := simulatedIdxMap[action.Movable.EntryName()] + entry.Idx = targetIdx + simulatedIdxMap[action.Movable.EntryName()] = entry updateSimulatedIdxMap(&simulatedIdxMap, action.Movable, currentIdx, targetIdx) } } @@ -322,19 +390,20 @@ func OptimizeMovements(existing []Movable, expected []Movable, entries []Movable return optimized } -func GenerateMovements(existing []Movable, expected []Movable, entries []Movable, movement ActionWhereType, pivot Movable, directly bool) ([]MoveAction, error) { +func GenerateMovements(existing []Movable, expected []Movable, entries []Movable, movement ActionWhereType, pivot string, directly bool) ([]MoveAction, error) { if len(existing) != len(expected) { + slog.Error("GenerateMovements()", "len(existing)", len(existing), "len(expected)", len(expected)) return nil, ErrSlicesNotEqualLength } - entriesIdxMap := createIdxMapFor(entries) - existingIdxMap := createIdxMapFor(existing) - expectedIdxMap := createIdxMapFor(expected) + entriesIdxMap := entriesByName(entries) + existingIdxMap := entriesByName(existing) + expectedIdxMap := entriesByName(expected) var movements []MoveAction var previous Movable for _, elt := range entries { - slog.Debug("GeneraveMovements()", "elt", elt, "existing", existingIdxMap[elt], "expected", expectedIdxMap[elt]) + slog.Debug("GeneraveMovements()", "elt", elt, "existing", existingIdxMap[elt.EntryName()], "expected", expectedIdxMap[elt.EntryName()]) if previous != nil { movements = append(movements, MoveAction{ @@ -345,18 +414,18 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable previous = elt continue } - if expectedIdxMap[elt] == 0 { + if expectedIdxMap[elt.EntryName()].Idx == 0 { movements = append(movements, MoveAction{ Movable: elt, Destination: nil, - Where: ActionWhereTop, + Where: ActionWhereFirst, }) previous = elt - } else if expectedIdxMap[elt] == len(expectedIdxMap)-1 { + } else if expectedIdxMap[elt.EntryName()].Idx == len(expectedIdxMap)-1 { movements = append(movements, MoveAction{ Movable: elt, Destination: nil, - Where: ActionWhereBottom, + Where: ActionWhereLast, }) previous = elt } else { @@ -364,23 +433,23 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable var pivot Movable switch movement { - case ActionWhereBottom: - where = ActionWhereBottom + case ActionWhereLast: + where = ActionWhereLast case ActionWhereAfter: - pivot = expected[expectedIdxMap[elt]-1] + pivot = expected[expectedIdxMap[elt.EntryName()].Idx-1] where = ActionWhereAfter - case ActionWhereTop: + case ActionWhereFirst: pivot = existing[0] where = ActionWhereBefore case ActionWhereBefore: - eltExpectedIdx := expectedIdxMap[elt] + eltExpectedIdx := expectedIdxMap[elt.EntryName()].Idx pivot = expected[eltExpectedIdx+1] where = ActionWhereBefore - // If previous was nil (we are processing the first element in entries set) - // and selected pivot is part of the entries set, it means the order of elements - // changes between existing adn expected sets. In this case the actual pivot - // is element from expected set that follows all moved elements. - if _, ok := entriesIdxMap[pivot]; ok && directly { + // When entries are to be put directly before the pivot point, if previous was nil (we + // are processing the first element in entries set) and selected pivot is part of the + // entries set, we need to find the actual pivot, i.e. element of the expected list + // that directly follows all elements from the entries set. + if _, ok := entriesIdxMap[pivot.EntryName()]; ok && directly { // The actual pivot for the move is the element that follows all elements // from the existing set. pivotIdx := eltExpectedIdx + len(entries) @@ -409,11 +478,11 @@ func GenerateMovements(existing []Movable, expected []Movable, entries []Movable return movements, nil } -func (o PositionTop) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { - entriesIdxMap := createIdxMapFor(entries) +func (o PositionFirst) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + entriesIdxMap := entriesByName(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { - _, ok := entriesIdxMap[entry] + _, ok := entriesIdxMap[entry.EntryName()] return ok }) @@ -422,12 +491,15 @@ func (o PositionTop) GetExpected(entries []Movable, existing []Movable) ([]Movab return expected, nil } -func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { +func (o PositionFirst) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { expected, err := o.GetExpected(entries, existing) if err != nil { return nil, err } - actions, err := GenerateMovements(existing, expected, entries, ActionWhereTop, nil, false) + + slog.Error("PositionFirst.Move()", "len(expected)", len(expected), "len(existing)", len(existing)) + + actions, err := GenerateMovements(existing, expected, entries, ActionWhereFirst, "", false) if err != nil { return nil, err } @@ -435,11 +507,11 @@ func (o PositionTop) Move(entries []Movable, existing []Movable) ([]MoveAction, return OptimizeMovements(existing, expected, entries, actions, o), nil } -func (o PositionBottom) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { - entriesIdxMap := createIdxMapFor(entries) +func (o PositionLast) GetExpected(entries []Movable, existing []Movable) ([]Movable, error) { + entriesIdxMap := entriesByName(entries) filtered := removeEntriesFromExisting(existing, func(entry Movable) bool { - _, ok := entriesIdxMap[entry] + _, ok := entriesIdxMap[entry.EntryName()] return ok }) @@ -448,15 +520,15 @@ func (o PositionBottom) GetExpected(entries []Movable, existing []Movable) ([]Mo return expected, nil } -func (o PositionBottom) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { +func (o PositionLast) Move(entries []Movable, existing []Movable) ([]MoveAction, error) { expected, err := o.GetExpected(entries, existing) if err != nil { return nil, err } - actions, err := GenerateMovements(existing, expected, entries, ActionWhereBottom, nil, false) + actions, err := GenerateMovements(existing, expected, entries, ActionWhereLast, "", false) if err != nil { - slog.Debug("PositionBottom()", "err", err) + slog.Debug("PositionLast()", "err", err) return nil, err } return OptimizeMovements(existing, expected, entries, actions, o), nil @@ -467,7 +539,7 @@ type Movement struct { Position Position } -func MoveGroups(existing []Movable, movements []Movement) ([]MoveAction, error) { +func MoveGroups[E Movable](existing []Movable, movements []Movement) ([]MoveAction, error) { expected := existing for idx := range len(movements) - 1 { position := movements[idx].Position @@ -487,8 +559,18 @@ func MoveGroups(existing []Movable, movements []Movement) ([]MoveAction, error) return position.Move(entries, expected) } -func MoveGroup(position Position, entries []Movable, existing []Movable) ([]MoveAction, error) { - return position.Move(entries, existing) +func MoveGroup[E Movable](position Position, entries []E, existing []E) ([]MoveAction, error) { + var movableEntries []Movable + for _, elt := range entries { + slog.Warn("MoveGroup", "entry.EntryName()", elt.EntryName()) + movableEntries = append(movableEntries, elt) + } + var movableExisting []Movable + for _, elt := range existing { + slog.Warn("MoveGroup", "existing.EntryName()", elt.EntryName()) + movableExisting = append(movableExisting, elt) + } + return position.Move(movableEntries, movableExisting) } type Move struct { diff --git a/assets/pango/movement/movement_test.go b/assets/pango/movement/movement_test.go index bc48e6de..9e5e9026 100644 --- a/assets/pango/movement/movement_test.go +++ b/assets/pango/movement/movement_test.go @@ -30,12 +30,12 @@ func asMovable(mocks []string) []movement.Movable { } var _ = Describe("MoveGroup()", func() { - Context("With PositionTop used as position", func() { + Context("With PositionFirst used as position", func() { Context("when existing positions matches expected", func() { It("should generate no movements", func() { // '(A B C) -> '(A B C) expected := asMovable([]string{"A", "B", "C"}) - moves, err := movement.MoveGroup(movement.PositionTop{}, expected, expected) + moves, err := movement.MoveGroup(movement.PositionFirst{}, expected, expected) Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(0)) }) @@ -46,12 +46,12 @@ var _ = Describe("MoveGroup()", func() { entries := asMovable([]string{"A", "B", "C"}) existing := asMovable([]string{"D", "E", "A", "B", "C"}) - moves, err := movement.MoveGroup(movement.PositionTop{}, entries, existing) + moves, err := movement.MoveGroup(movement.PositionFirst{}, entries, existing) Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(3)) Expect(moves[0].Movable.EntryName()).To(Equal("A")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereTop)) + Expect(moves[0].Where).To(Equal(movement.ActionWhereFirst)) Expect(moves[0].Destination).To(BeNil()) Expect(moves[1].Movable.EntryName()).To(Equal("B")) @@ -68,7 +68,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(E D C B A) entries := asMovable([]string{"E", "D", "C", "B", "A"}) existing := asMovable([]string{"A", "B", "C", "D", "E"}) - moves, err := movement.MoveGroup(movement.PositionTop{}, entries, existing) + moves, err := movement.MoveGroup(movement.PositionFirst{}, entries, existing) Expect(err).ToNot(HaveOccurred()) // '((E 'top nil)(B 'after E)(C 'after B)(D 'after C)) @@ -77,19 +77,19 @@ var _ = Describe("MoveGroup()", func() { }) }) }) - Context("With PositionBottom used as position", func() { + Context("With PositionLast used as position", func() { Context("with non-consecutive entries", func() { It("should generate two move actions", func() { // '(A E B C D) -> '(A B D E C) entries := asMovable([]string{"E", "C"}) existing := asMovable([]string{"A", "E", "B", "C", "D"}) - moves, err := movement.MoveGroup(movement.PositionBottom{}, entries, existing) + moves, err := movement.MoveGroup(movement.PositionLast{}, entries, existing) Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(2)) Expect(moves[0].Movable.EntryName()).To(Equal("E")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereBottom)) + Expect(moves[0].Where).To(Equal(movement.ActionWhereLast)) Expect(moves[0].Destination).To(BeNil()) Expect(moves[1].Movable.EntryName()).To(Equal("C")) @@ -98,19 +98,19 @@ var _ = Describe("MoveGroup()", func() { }) }) }) - Context("With PositionBottom used as position", func() { + Context("With PositionLast used as position", func() { Context("when it needs to move one element", func() { It("should generate a single move action", func() { // '(A E B C D) -> '(A B C D E) entries := asMovable([]string{"E"}) existing := asMovable([]string{"A", "E", "B", "C", "D"}) - moves, err := movement.MoveGroup(movement.PositionBottom{}, entries, existing) + moves, err := movement.MoveGroup(movement.PositionLast{}, entries, existing) Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(1)) Expect(moves[0].Movable.EntryName()).To(Equal("E")) - Expect(moves[0].Where).To(Equal(movement.ActionWhereBottom)) + Expect(moves[0].Where).To(Equal(movement.ActionWhereLast)) Expect(moves[0].Destination).To(BeNil()) }) }) @@ -123,7 +123,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A B C D E) entries := asMovable([]string{"D", "E"}) moves, err := movement.MoveGroup( - movement.PositionAfter{Directly: false, Pivot: Mock{"B"}}, + movement.PositionAfter{Directly: false, Pivot: "B"}, entries, existing, ) @@ -135,7 +135,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A B C E D) entries := asMovable([]string{"E", "D"}) moves, err := movement.MoveGroup( - movement.PositionAfter{Directly: false, Pivot: Mock{"B"}}, + movement.PositionAfter{Directly: false, Pivot: "B"}, entries, existing, ) @@ -153,7 +153,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(C D A B E) entries := asMovable([]string{"A", "B"}) moves, err := movement.MoveGroup( - movement.PositionAfter{Directly: true, Pivot: Mock{"D"}}, + movement.PositionAfter{Directly: true, Pivot: "D"}, entries, existing, ) @@ -174,7 +174,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(C D B A E) entries := asMovable([]string{"B", "A"}) moves, err := movement.MoveGroup( - movement.PositionAfter{Directly: true, Pivot: Mock{"D"}}, + movement.PositionAfter{Directly: true, Pivot: "D"}, entries, existing, ) @@ -206,7 +206,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A D C B E) entries := asMovable([]string{"C", "B"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: true, Pivot: Mock{"E"}}, + movement.PositionBefore{Directly: true, Pivot: "E"}, entries, existing, ) @@ -227,7 +227,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A C B D E) entries := asMovable([]string{"C", "B"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: false, Pivot: Mock{"E"}}, + movement.PositionBefore{Directly: false, Pivot: "E"}, entries, existing, ) @@ -245,7 +245,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A B C D E) entries := asMovable([]string{"A", "B"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + movement.PositionBefore{Directly: false, Pivot: "D"}, entries, existing, ) @@ -258,7 +258,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A C B D E) entries := asMovable([]string{"C", "B"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + movement.PositionBefore{Directly: false, Pivot: "D"}, entries, existing, ) @@ -275,7 +275,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A B C D E) entries := asMovable([]string{"A", "C"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + movement.PositionBefore{Directly: false, Pivot: "D"}, entries, existing, ) @@ -288,7 +288,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A C B D E) entries := asMovable([]string{"A", "C", "B"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: false, Pivot: Mock{"D"}}, + movement.PositionBefore{Directly: false, Pivot: "D"}, entries, existing, ) @@ -306,7 +306,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(C A B D E) entries := asMovable([]string{"A", "B"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: true, Pivot: Mock{"D"}}, + movement.PositionBefore{Directly: true, Pivot: "D"}, entries, existing, ) @@ -328,7 +328,7 @@ var _ = Describe("MoveGroup()", func() { // '(A B C D E) -> '(A D B C E) entries := asMovable([]string{"B", "C"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: true, Pivot: Mock{"E"}}, + movement.PositionBefore{Directly: true, Pivot: "E"}, entries, existing) Expect(err).ToNot(HaveOccurred()) @@ -348,9 +348,9 @@ var _ = Describe("MoveGroups()", Label("MoveGroups"), func() { Entries: entries, Position: movement.PositionBefore{ Directly: true, - Pivot: Mock{"E"}, + Pivot: "E", }}} - moves, err := movement.MoveGroups(existing, movements) + moves, err := movement.MoveGroups[Mock](existing, movements) Expect(err).ToNot(HaveOccurred()) Expect(moves).To(HaveLen(2)) @@ -368,7 +368,7 @@ var _ = Describe("MoveGroups()", Label("MoveGroups"), func() { // }, // { // Entries: asMovable([]string{"A"}), - // Position: movement.PositionBottom{}, + // Position: movement.PositionLast{}, // }, // } // moves, err := movement.MoveGroups(existing, movements) @@ -396,7 +396,7 @@ var _ = Describe("Movement benchmarks", func() { entries := asMovable([]string{"90", "80", "70", "60", "50", "40"}) moves, err := movement.MoveGroup( - movement.PositionBefore{Directly: true, Pivot: Mock{"100"}}, + movement.PositionBefore{Directly: true, Pivot: "100"}, entries, existing, ) From 2fdaa61214045f2b62bbd8941efa87e457443e52 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Tue, 3 Sep 2024 13:31:49 +0200 Subject: [PATCH 10/19] Update SDK and Provider code to use new movement API --- pkg/translate/imports.go | 2 + templates/sdk/service.tmpl | 422 ++++--------------------------------- 2 files changed, 46 insertions(+), 378 deletions(-) diff --git a/pkg/translate/imports.go b/pkg/translate/imports.go index bb7aa9e5..4d2c65a2 100644 --- a/pkg/translate/imports.go +++ b/pkg/translate/imports.go @@ -43,6 +43,8 @@ func RenderImports(templateTypes ...string) (string, error) { manager.AddSdkImport("github.com/PaloAltoNetworks/pango/audit", "") case "rule": manager.AddSdkImport("github.com/PaloAltoNetworks/pango/rule", "") + case "movement": + manager.AddSdkImport("github.com/PaloAltoNetworks/pango/movement", "") case "version": manager.AddSdkImport("github.com/PaloAltoNetworks/pango/version", "") case "template": diff --git a/templates/sdk/service.tmpl b/templates/sdk/service.tmpl index 8f714cb9..768d25c2 100644 --- a/templates/sdk/service.tmpl +++ b/templates/sdk/service.tmpl @@ -2,13 +2,13 @@ package {{packageName .GoSdkPath}} {{- if .Entry}} {{- if $.Imports}} {{- if $.Spec.Params.uuid}} - {{renderImports "service" "filtering" "audit" "rule" "version"}} + {{renderImports "service" "filtering" "audit" "movement"}} {{- else}} {{renderImports "service" "filtering"}} {{- end}} {{- else}} {{- if $.Spec.Params.uuid}} - {{renderImports "service" "filtering" "audit" "rule" "version"}} + {{renderImports "service" "filtering" "audit" "movement"}} {{- else}} {{renderImports "service" "filtering"}} {{- end}} @@ -747,396 +747,62 @@ func (s *Service) RemoveFromImport(ctx context.Context, loc Location, entry Entr // MoveGroup arranges the given rules in the order specified. // Any rule with a UUID specified is ignored. // Only the rule names are considered for the purposes of the rule placement. - func (s *Service) MoveGroup(ctx context.Context, loc Location, position rule.Position, entries []*Entry) error { + func (s *Service) MoveGroup(ctx context.Context, loc Location, position movement.Position, entries []*Entry) error { if len(entries) == 0 { - return nil + return nil } - listing, err := s.List(ctx, loc, "get", "", "") + existing, err := s.List(ctx, loc, "get", "", "") if err != nil { - return err - } else if len(listing) == 0 { - return fmt.Errorf("no rules present") - } - - rp := make(map[string]int) - for idx, live := range listing { - rp[live.Name] = idx + return err + } else if len(existing) == 0 { + return fmt.Errorf("no rules present") } - vn := s.client.Versioning() - updates := xmlapi.NewMultiConfig(len(entries)) - - var ok, topDown bool - var otherIndex int - baseIndex := -1 - switch { - case position.First != nil && *position.First: - topDown, baseIndex, ok, err = s.moveTop(topDown, entries, baseIndex, ok, rp, loc, vn, updates) - if err != nil { - return err - } - case position.Last != nil && *position.Last: - baseIndex, ok, err = s.moveBottom(entries, baseIndex, ok, rp, listing, loc, vn, updates) - if err != nil { - return err - } - case position.SomewhereAfter != nil && *position.SomewhereAfter != "": - topDown, baseIndex, ok, otherIndex, err = s.moveSomewhereAfter(topDown, entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - case position.SomewhereBefore != nil && *position.SomewhereBefore != "": - baseIndex, ok, otherIndex, err = s.moveSomewhereBefore(entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - case position.DirectlyAfter != nil && *position.DirectlyAfter != "": - topDown, baseIndex, ok, otherIndex, err = s.moveDirectlyAfter(topDown, entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - case position.DirectlyBefore != nil && *position.DirectlyBefore != "": - baseIndex, ok, err = s.moveDirectlyBefore(entries, baseIndex, ok, rp, otherIndex, position, loc, vn, updates) - if err != nil { - return err - } - default: - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return fmt.Errorf("could not find rule %q for first positioning", target.Name) - } - } + movements, err := movement.MoveGroup(position, entries, existing) + if err != nil { + return err + } - var prevName, where string - if topDown { - prevName = entries[0].Name - where = "after" - } else { - prevName = entries[len(entries)-1].Name - where = "before" - } + updates := xmlapi.NewMultiConfig(len(movements)) + + for _, elt := range movements { + path, err := loc.XpathWithEntryName(s.client.Versioning(), elt.Movable.EntryName()) + if err != nil { + return err + } + + switch elt.Where { + case movement.ActionWhereFirst, movement.ActionWhereLast: + updates.Add(&xmlapi.Config{ + Action: "move", + Xpath: util.AsXpath(path), + Where: string(elt.Where), + Destination: string(elt.Where), + Target: s.client.GetTarget(), + }) + case movement.ActionWhereBefore, movement.ActionWhereAfter: + updates.Add(&xmlapi.Config{ + Action: "move", + Xpath: util.AsXpath(path), + Where: string(elt.Where), + Destination: elt.Destination.EntryName(), + Target: s.client.GetTarget(), + }) + } - for i := 1; i < len(entries); i++ { - err := s.moveRestOfRules(topDown, entries, i, baseIndex, rp, loc, vn, updates, where, prevName) - if err != nil { - return err - } - } + } if len(updates.Operations) > 0 { - _, _, _, err = s.client.MultiConfig(ctx, updates, false, nil) - return err - } - - return nil + _, _, _, err = s.client.MultiConfig(ctx, updates, false, nil) + return err } - func (s *Service) moveRestOfRules(topDown bool, entries []*Entry, i int, baseIndex int, rp map[string]int, loc Location, vn version.Number, updates *xmlapi.MultiConfig, where string, prevName string) error { - var target Entry - var desiredIndex int - if topDown { - target = *entries[i] - desiredIndex = baseIndex + i - } else { - target = *entries[len(entries)-1-i] - desiredIndex = baseIndex - i - } - - idx, ok := rp[target.Name] - if !ok { - return fmt.Errorf("rule %q not present", target.Name) - } - - if idx != desiredIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return err - } - - if idx < desiredIndex { - for name, val := range rp { - if val > idx && val <= desiredIndex { - rp[name] = val - 1 - } - } - } else { - for name, val := range rp { - if val < idx && val >= desiredIndex { - rp[name] = val + 1 - } - } - } - rp[target.Name] = desiredIndex - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: where, - Destination: prevName, - Target: s.client.GetTarget(), - }) - } - - prevName = target.Name - return nil - } - - func (s *Service) moveDirectlyBefore(entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (int, bool, error) { - target := entries[len(entries)-1] - - baseIndex, ok = rp[target.Name] - if !ok { - return 0, false, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.DirectlyBefore] - if !ok { - return 0, false, fmt.Errorf("could not find referenced rule %q", *position.DirectlyBefore) - } - - if baseIndex+1 != otherIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return 0, false, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val < baseIndex && val >= otherIndex: - rp[name] = val + 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "before", - Destination: *position.DirectlyBefore, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return baseIndex, ok, nil - } - - func (s *Service) moveDirectlyAfter(topDown bool, entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (bool, int, bool, int, error) { - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.DirectlyAfter] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find referenced rule %q for initial positioning", *position.DirectlyAfter) - } - - if baseIndex != otherIndex+1 { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return false, 0, false, 0, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val > baseIndex && val <= otherIndex: - rp[name] = otherIndex - 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "after", - Destination: *position.DirectlyAfter, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return topDown, baseIndex, ok, otherIndex, nil - } - - func (s *Service) moveSomewhereBefore(entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (int, bool, int, error) { - target := entries[len(entries)-1] - - baseIndex, ok = rp[target.Name] - if !ok { - return 0, false, 0, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.SomewhereBefore] - if !ok { - return 0, false, 0, fmt.Errorf("could not find referenced rule %q", *position.SomewhereBefore) - } - - if baseIndex > otherIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return 0, false, 0, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val < baseIndex && val >= otherIndex: - rp[name] = val + 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "before", - Destination: *position.SomewhereBefore, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return baseIndex, ok, otherIndex, nil - } - - func (s *Service) moveSomewhereAfter(topDown bool, entries []*Entry, baseIndex int, ok bool, rp map[string]int, otherIndex int, position rule.Position, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (bool, int, bool, int, error) { - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find rule %q for initial positioning", target.Name) - } - - otherIndex, ok = rp[*position.SomewhereAfter] - if !ok { - return false, 0, false, 0, fmt.Errorf("could not find referenced rule %q for initial positioning", *position.SomewhereAfter) - } - - if baseIndex < otherIndex { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return false, 0, false, 0, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = otherIndex - case val > baseIndex && val <= otherIndex: - rp[name] = otherIndex - 1 - } - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "after", - Destination: *position.SomewhereAfter, - Target: s.client.GetTarget(), - }) - - baseIndex = otherIndex - } - return topDown, baseIndex, ok, otherIndex, nil - } - - func (s *Service) moveBottom(entries []*Entry, baseIndex int, ok bool, rp map[string]int, listing []*Entry, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (int, bool, error) { - target := entries[len(entries)-1] - - baseIndex, ok = rp[target.Name] - if !ok { - return 0, false, fmt.Errorf("could not find rule %q for last positioning", target.Name) - } - - if baseIndex != len(listing)-1 { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return 0, false, err - } - - for name, val := range rp { - switch { - case name == target.Name: - rp[name] = len(listing) - 1 - case val > baseIndex: - rp[name] = val - 1 - } - } - - // some versions of PAN-OS require that the destination always be set - var dst string - if !vn.Gte(util.FixedPanosVersionForMultiConfigMove) { - dst = "bottom" - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "bottom", - Destination: dst, - Target: s.client.GetTarget(), - }) - - baseIndex = len(listing) - 1 - } - return baseIndex, ok, nil - } - - func (s *Service) moveTop(topDown bool, entries []*Entry, baseIndex int, ok bool, rp map[string]int, loc Location, vn version.Number, updates *xmlapi.MultiConfig) (bool, int, bool, error) { - topDown = true - target := entries[0] - - baseIndex, ok = rp[target.Name] - if !ok { - return false, 0, false, fmt.Errorf("could not find rule %q for first positioning", target.Name) - } - - if baseIndex != 0 { - path, err := loc.XpathWithEntryName(vn, target.Name) - if err != nil { - return false, 0, false, err - } - - for name, val := range rp { - switch { - case name == entries[0].Name: - rp[name] = 0 - case val < baseIndex: - rp[name] = val + 1 - } - } - - // some versions of PAN-OS require that the destination always be set - var dst string - if !vn.Gte(util.FixedPanosVersionForMultiConfigMove) { - dst = "top" - } - - updates.Add(&xmlapi.Config{ - Action: "move", - Xpath: util.AsXpath(path), - Where: "top", - Destination: dst, - Target: s.client.GetTarget(), - }) + return nil +} - baseIndex = 0 - } - return topDown, baseIndex, ok, nil - } - // HitCount returns the hit count for the given rule. + // HITCOUNT returns the hit count for the given rule. func (s *Service) HitCount(ctx context.Context, loc Location, rules ...string) ([]util.HitCount, error) { switch { case loc.Vsys != nil: From 532d1bb101cb53e2e9ad23e82e3702ca6e286bbd Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Tue, 14 Jan 2025 15:40:40 +0100 Subject: [PATCH 11/19] Update terraform code to use new movement implementation --- assets/pango/movement/movement.go | 4 +- assets/terraform/internal/manager/uuid.go | 129 ++---------------- .../terraform/internal/provider/position.go | 35 ++--- pkg/translate/terraform_provider/template.go | 6 +- .../terraform_provider_file.go | 2 +- 5 files changed, 27 insertions(+), 149 deletions(-) diff --git a/assets/pango/movement/movement.go b/assets/pango/movement/movement.go index 0e73c01d..c097f43b 100644 --- a/assets/pango/movement/movement.go +++ b/assets/pango/movement/movement.go @@ -11,8 +11,8 @@ var _ = slog.LevelDebug type ActionWhereType string const ( - ActionWhereFirst ActionWhereType = "first" - ActionWhereLast ActionWhereType = "last" + ActionWhereFirst ActionWhereType = "top" + ActionWhereLast ActionWhereType = "bottom" ActionWhereBefore ActionWhereType = "before" ActionWhereAfter ActionWhereType = "after" ) diff --git a/assets/terraform/internal/manager/uuid.go b/assets/terraform/internal/manager/uuid.go index 07cf0ad5..289e5add 100644 --- a/assets/terraform/internal/manager/uuid.go +++ b/assets/terraform/internal/manager/uuid.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/types" sdkerrors "github.com/PaloAltoNetworks/pango/errors" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" "github.com/PaloAltoNetworks/pango/util" "github.com/PaloAltoNetworks/pango/version" "github.com/PaloAltoNetworks/pango/xmlapi" @@ -36,7 +36,7 @@ type SDKUuidService[E UuidObject, L UuidLocation] interface { Create(context.Context, L, E) (E, error) List(context.Context, L, string, string, string) ([]E, error) Delete(context.Context, L, ...string) error - MoveGroup(context.Context, L, rule.Position, []E) error + MoveGroup(context.Context, L, movement.Position, []E) error } type uuidObjectWithState[E EntryObject] struct { @@ -156,7 +156,7 @@ func (o *UuidObjectManager[E, L, S]) entriesProperlySorted(existing []E, planEnt return movementRequired, nil } -func (o *UuidObjectManager[E, L, S]) moveExhaustive(ctx context.Context, location L, entriesByName map[string]uuidObjectWithState[E], position rule.Position) error { +func (o *UuidObjectManager[E, L, S]) moveExhaustive(ctx context.Context, location L, entriesByName map[string]uuidObjectWithState[E], position movement.Position) error { existing, err := o.service.List(ctx, location, "get", "", "") if err != nil && err.Error() != "Object not found" { return &Error{err: err, message: "Failed to list existing entries"} @@ -202,84 +202,21 @@ type position struct { // When moveNonExhaustive is called, the given list is not entirely managed by the Terraform resource. // In that case a care has to be taken to only execute movement on a subset of entries, those that // are under Terraform control. -func (o *UuidObjectManager[E, L, S]) moveNonExhaustive(ctx context.Context, location L, planEntries []E, planEntriesByName map[string]uuidObjectWithState[E], sdkPosition rule.Position) error { - - existing, err := o.service.List(ctx, location, "get", "", "") - if err != nil { - return fmt.Errorf("failed to list remote entries: %w", err) - } - - movementRequired, err := o.entriesProperlySorted(existing, planEntriesByName) - - // If all entries are ordered properly, check if their position matches the requested - // position. - if !movementRequired { - existingEntriesByName := o.entriesByName(existing, entryOk) - p, err := parseSDKPosition(sdkPosition) - if err != nil { - return ErrInvalidPosition - } - - switch p.Where { - case PositionWhereFirst: - planEntryName := planEntries[0].EntryName() - movementRequired = existing[0].EntryName() != planEntryName - case PositionWhereLast: - planEntryName := planEntries[len(planEntries)-1].EntryName() - movementRequired = existing[len(existing)-1].EntryName() != planEntryName - case PositionWhereBefore: - lastPlanElementName := planEntries[len(planEntries)-1].EntryName() - if existingPivot, found := existingEntriesByName[p.PivotEntry]; !found { - return ErrMissingPivotPoint - } else if p.Directly { - if existingPivot.StateIdx == 0 { - movementRequired = true - } else if existing[existingPivot.StateIdx-1].EntryName() != lastPlanElementName { - movementRequired = true - } - } else { - if lastPlanElementInExisting, found := existingEntriesByName[lastPlanElementName]; !found { - return ErrMissingPivotPoint - } else if lastPlanElementInExisting.StateIdx >= existingPivot.StateIdx { - movementRequired = true - } - } - case PositionWhereAfter: - firstPlanElementName := planEntries[0].EntryName() - if existingPivot, found := existingEntriesByName[p.PivotEntry]; !found { - return ErrMissingPivotPoint - } else if p.Directly { - if existingPivot.StateIdx == len(existing)-1 { - movementRequired = true - } else if existing[existingPivot.StateIdx+1].EntryName() != firstPlanElementName { - movementRequired = true - } - } else { - if firstPlanElementInExisting, found := existingEntriesByName[firstPlanElementName]; !found { - return ErrMissingPivotPoint - } else if firstPlanElementInExisting.StateIdx <= existingPivot.StateIdx { - movementRequired = true - } - } - } +func (o *UuidObjectManager[E, L, S]) moveNonExhaustive(ctx context.Context, location L, planEntries []E, planEntriesByName map[string]uuidObjectWithState[E], sdkPosition movement.Position) error { + entries := make([]E, len(planEntriesByName)) + for _, elt := range planEntriesByName { + entries[elt.StateIdx] = elt.Entry } - if movementRequired { - entries := make([]E, len(planEntriesByName)) - for _, elt := range planEntriesByName { - entries[elt.StateIdx] = elt.Entry - } - - err = o.service.MoveGroup(ctx, location, sdkPosition, entries) - if err != nil { - return &Error{err: err, message: "Failed to move group of entries"} - } + err := o.service.MoveGroup(ctx, location, sdkPosition, entries) + if err != nil { + return &Error{err: err, message: "Failed to move group of entries"} } return nil } -func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, planEntries []E, exhaustive ExhaustiveType, sdkPosition rule.Position) ([]E, error) { +func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, planEntries []E, exhaustive ExhaustiveType, sdkPosition movement.Position) ([]E, error) { var diags diag.Diagnostics planEntriesByName := o.entriesByName(planEntries, entryUnknown) @@ -367,7 +304,7 @@ func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, return entries, nil } -func (o *UuidObjectManager[E, L, S]) UpdateMany(ctx context.Context, location L, stateEntries []E, planEntries []E, exhaustive ExhaustiveType, position rule.Position) ([]E, error) { +func (o *UuidObjectManager[E, L, S]) UpdateMany(ctx context.Context, location L, stateEntries []E, planEntries []E, exhaustive ExhaustiveType, position movement.Position) ([]E, error) { stateEntriesByName := o.entriesByName(stateEntries, entryUnknown) planEntriesByName := o.entriesByName(planEntries, entryUnknown) @@ -670,45 +607,3 @@ func (o *UuidObjectManager[E, L, S]) Delete(ctx context.Context, location L, ent } return nil } - -func parseSDKPosition(sdkPosition rule.Position) (position, error) { - if sdkPosition.IsValid(false) != nil { - return position{}, ErrInvalidPosition - } - - if sdkPosition.DirectlyAfter != nil { - return position{ - Directly: true, - Where: PositionWhereAfter, - PivotEntry: *sdkPosition.DirectlyAfter, - }, nil - } else if sdkPosition.DirectlyBefore != nil { - return position{ - Directly: true, - Where: PositionWhereBefore, - PivotEntry: *sdkPosition.DirectlyBefore, - }, nil - } else if sdkPosition.SomewhereAfter != nil { - return position{ - Directly: false, - Where: PositionWhereAfter, - PivotEntry: *sdkPosition.SomewhereAfter, - }, nil - } else if sdkPosition.SomewhereBefore != nil { - return position{ - Directly: false, - Where: PositionWhereBefore, - PivotEntry: *sdkPosition.SomewhereBefore, - }, nil - } else if sdkPosition.First != nil { - return position{ - Where: PositionWhereFirst, - }, nil - } else if sdkPosition.Last != nil { - return position{ - Where: PositionWhereLast, - }, nil - } - - return position{}, ErrInvalidPosition -} diff --git a/assets/terraform/internal/provider/position.go b/assets/terraform/internal/provider/position.go index c1b128c1..deafc80f 100644 --- a/assets/terraform/internal/provider/position.go +++ b/assets/terraform/internal/provider/position.go @@ -8,7 +8,7 @@ import ( rsschema "github.com/hashicorp/terraform-plugin-framework/resource/schema" "github.com/hashicorp/terraform-plugin-framework/types" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" ) type TerraformPositionObject struct { @@ -34,36 +34,21 @@ func TerraformPositionObjectSchema() rsschema.SingleNestedAttribute { } } -func (o *TerraformPositionObject) CopyToPango() rule.Position { - trueVal := true +func (o *TerraformPositionObject) CopyToPango() movement.Position { switch o.Where.ValueString() { case "first": - return rule.Position{ - First: &trueVal, - } + return movement.PositionFirst{} case "last": - return rule.Position{ - Last: &trueVal, - } + return movement.PositionLast{} case "before": - if o.Directly.ValueBool() == true { - return rule.Position{ - DirectlyBefore: o.Pivot.ValueStringPointer(), - } - } else { - return rule.Position{ - SomewhereBefore: o.Pivot.ValueStringPointer(), - } + return movement.PositionBefore{ + Pivot: o.Pivot.ValueString(), + Directly: o.Directly.ValueBool(), } case "after": - if o.Directly.ValueBool() == true { - return rule.Position{ - DirectlyAfter: o.Pivot.ValueStringPointer(), - } - } else { - return rule.Position{ - SomewhereAfter: o.Pivot.ValueStringPointer(), - } + return movement.PositionAfter{ + Pivot: o.Pivot.ValueString(), + Directly: o.Directly.ValueBool(), } default: panic("unreachable") diff --git a/pkg/translate/terraform_provider/template.go b/pkg/translate/terraform_provider/template.go index 13c5b580..5438fca2 100644 --- a/pkg/translate/terraform_provider/template.go +++ b/pkg/translate/terraform_provider/template.go @@ -494,8 +494,7 @@ if err != nil { return } {{- else if .Exhaustive }} -trueVal := true -processed, err := r.manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, rule.Position{First: &trueVal}) +processed, err := r.manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, movement.PositionFirst{}) if err != nil { resp.Diagnostics.AddError("Error during CreateMany() call", err.Error()) return @@ -1037,8 +1036,7 @@ for idx, elt := range elements { {{ $exhaustive := "sdkmanager.NonExhaustive" }} {{- if .Exhaustive }} {{ $exhaustive = "sdkmanager.Exhaustive" }} -trueValue := true -position := rule.Position{First: &trueValue} +position := movement.PositionFirst{} {{- else }} position := state.Position.CopyToPango() {{- end }} diff --git a/pkg/translate/terraform_provider/terraform_provider_file.go b/pkg/translate/terraform_provider/terraform_provider_file.go index 6e42fd65..6a7bfeb8 100644 --- a/pkg/translate/terraform_provider/terraform_provider_file.go +++ b/pkg/translate/terraform_provider/terraform_provider_file.go @@ -198,7 +198,7 @@ func (g *GenerateTerraformProvider) GenerateTerraformResource(resourceTyp proper terraformProvider.ImportManager.AddStandardImport("errors", "") switch resourceTyp { case properties.ResourceUuid: - terraformProvider.ImportManager.AddSdkImport("github.com/PaloAltoNetworks/pango/rule", "") + terraformProvider.ImportManager.AddSdkImport("github.com/PaloAltoNetworks/pango/movement", "") terraformProvider.ImportManager.AddSdkImport("github.com/PaloAltoNetworks/pango/errors", "sdkerrors") case properties.ResourceEntry: case properties.ResourceUuidPlural: From 4d6e83b54e19f9f679d6c6d0ac6bd68e7e47a419 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Tue, 14 Jan 2025 15:47:17 +0100 Subject: [PATCH 12/19] When calling CreateMany() on uuid-style resources, make sure order is preserved --- assets/terraform/internal/manager/uuid.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/assets/terraform/internal/manager/uuid.go b/assets/terraform/internal/manager/uuid.go index 289e5add..3c40cb6b 100644 --- a/assets/terraform/internal/manager/uuid.go +++ b/assets/terraform/internal/manager/uuid.go @@ -250,13 +250,13 @@ func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, } } - for name, elt := range planEntriesByName { - path, err := location.XpathWithEntryName(o.client.Versioning(), name) + for _, elt := range planEntries { + path, err := location.XpathWithEntryName(o.client.Versioning(), elt.EntryName()) if err != nil { return nil, ErrMarshaling } - xmlEntry, err := o.specifier(elt.Entry) + xmlEntry, err := o.specifier(elt) if err != nil { diags.AddError("Failed to transform entry into XML document", err.Error()) return nil, ErrMarshaling From c3a84ec1fe8043ef493a4e2c2484214dab87b437 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Tue, 14 Jan 2025 15:57:19 +0100 Subject: [PATCH 13/19] Update acceptance tests to better verify entries order on the server post creation --- assets/terraform/test/resource_security_policy_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/assets/terraform/test/resource_security_policy_test.go b/assets/terraform/test/resource_security_policy_test.go index c943219d..ab007921 100644 --- a/assets/terraform/test/resource_security_policy_test.go +++ b/assets/terraform/test/resource_security_policy_test.go @@ -426,8 +426,8 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { nameSuffix := acctest.RandStringFromCharSet(6, acctest.CharSetAlphaNum) prefix := fmt.Sprintf("test-acc-%s", nameSuffix) - rulesInitial := []string{"rule-1", "rule-2", "rule-3"} - rulesReordered := []string{"rule-2", "rule-1", "rule-3"} + rulesInitial := []string{"rule-1", "rule-2", "rule-3", "rule-4", "rule-5"} + rulesReordered := []string{"rule-2", "rule-1", "rule-3", "rule-4", "rule-5"} prefixed := func(name string) string { return fmt.Sprintf("%s-%s", prefix, name) @@ -481,6 +481,8 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { stateExpectedRuleName(0, "rule-1"), stateExpectedRuleName(1, "rule-2"), stateExpectedRuleName(2, "rule-3"), + stateExpectedRuleName(3, "rule-4"), + stateExpectedRuleName(4, "rule-5"), ExpectServerSecurityRulesCount(prefix, sdkLocation, len(rulesInitial)), ExpectServerSecurityRulesOrder(prefix, sdkLocation, rulesInitial), }, @@ -508,12 +510,16 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { planExpectedRuleName(0, "rule-2"), planExpectedRuleName(1, "rule-1"), planExpectedRuleName(2, "rule-3"), + planExpectedRuleName(3, "rule-4"), + planExpectedRuleName(4, "rule-5"), }, }, ConfigStateChecks: []statecheck.StateCheck{ stateExpectedRuleName(0, "rule-2"), stateExpectedRuleName(1, "rule-1"), stateExpectedRuleName(2, "rule-3"), + stateExpectedRuleName(3, "rule-4"), + stateExpectedRuleName(4, "rule-5"), ExpectServerSecurityRulesOrder(prefix, sdkLocation, rulesReordered), }, }, From 432dbc156c19fc1cfd096fd9a0250b23a6dc3c2d Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Mon, 27 Jan 2025 15:13:56 +0100 Subject: [PATCH 14/19] Update pango example for the new movement API --- assets/pango/example/main.go | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/assets/pango/example/main.go b/assets/pango/example/main.go index 500c6a54..bf623028 100644 --- a/assets/pango/example/main.go +++ b/assets/pango/example/main.go @@ -9,6 +9,7 @@ import ( "github.com/PaloAltoNetworks/pango" "github.com/PaloAltoNetworks/pango/device/services/dns" "github.com/PaloAltoNetworks/pango/device/services/ntp" + "github.com/PaloAltoNetworks/pango/movement" "github.com/PaloAltoNetworks/pango/network/interface/ethernet" "github.com/PaloAltoNetworks/pango/network/interface/loopback" "github.com/PaloAltoNetworks/pango/network/profiles/interface_management" @@ -23,7 +24,6 @@ import ( "github.com/PaloAltoNetworks/pango/panorama/template" "github.com/PaloAltoNetworks/pango/panorama/template_stack" "github.com/PaloAltoNetworks/pango/policies/rules/security" - "github.com/PaloAltoNetworks/pango/rule" "github.com/PaloAltoNetworks/pango/util" ) @@ -773,22 +773,11 @@ func checkSecurityPolicyRulesMove(c *pango.Client, ctx context.Context) { log.Printf("Security policy rule '%s:%s' with description '%s' created", *securityPolicyRuleItemReply.Uuid, securityPolicyRuleItemReply.Name, *securityPolicyRuleItemReply.Description) } - rulePositionBefore7 := rule.Position{ - First: nil, - Last: nil, - SomewhereBefore: nil, - DirectlyBefore: util.String("codegen_rule7"), - SomewhereAfter: nil, - DirectlyAfter: nil, - } - rulePositionBottom := rule.Position{ - First: nil, - Last: util.Bool(true), - SomewhereBefore: nil, - DirectlyBefore: nil, - SomewhereAfter: nil, - DirectlyAfter: nil, + positionBefore7 := movement.PositionBefore{ + Directly: true, + Pivot: "codegen_rule7", } + positionLast := movement.PositionLast{} var securityPolicyRulesEntriesToMove []*security.Entry securityPolicyRulesEntriesToMove = append(securityPolicyRulesEntriesToMove, securityPolicyRulesEntries[3]) @@ -797,7 +786,7 @@ func checkSecurityPolicyRulesMove(c *pango.Client, ctx context.Context) { for _, securityPolicyRuleItemToMove := range securityPolicyRulesEntriesToMove { log.Printf("Security policy rule '%s' is going to be moved", securityPolicyRuleItemToMove.Name) } - err := securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, rulePositionBefore7, securityPolicyRulesEntriesToMove) + err := securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, positionBefore7, securityPolicyRulesEntriesToMove) if err != nil { log.Printf("Failed to move security policy rules %v: %s", securityPolicyRulesEntriesToMove, err) return @@ -807,7 +796,7 @@ func checkSecurityPolicyRulesMove(c *pango.Client, ctx context.Context) { for _, securityPolicyRuleItemToMove := range securityPolicyRulesEntriesToMove { log.Printf("Security policy rule '%s' is going to be moved", securityPolicyRuleItemToMove.Name) } - err = securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, rulePositionBottom, securityPolicyRulesEntriesToMove) + err = securityPolicyRuleApi.MoveGroup(ctx, *securityPolicyRuleLocation, positionLast, securityPolicyRulesEntriesToMove) if err != nil { log.Printf("Failed to move security policy rules %v: %s", securityPolicyRulesEntriesToMove, err) return From 23ef3c98131dc7c7848f5323b0e7717a40d8d5fc Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Tue, 4 Feb 2025 16:21:50 +0100 Subject: [PATCH 15/19] Update tests for the new movement API --- .../terraform/internal/manager/uuid_test.go | 20 ++++++------ .../internal/manager/uuid_utils_test.go | 32 ++++++++----------- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/assets/terraform/internal/manager/uuid_test.go b/assets/terraform/internal/manager/uuid_test.go index 7e3d07f5..db199a4d 100644 --- a/assets/terraform/internal/manager/uuid_test.go +++ b/assets/terraform/internal/manager/uuid_test.go @@ -7,7 +7,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" sdkmanager "github.com/PaloAltoNetworks/terraform-provider-panos/internal/manager" ) @@ -32,18 +32,16 @@ var _ = Describe("Server", func() { var client *MockUuidClient[*MockUuidObject] var service sdkmanager.SDKUuidService[*MockUuidObject, MockLocation] var mockService *MockUuidService[*MockUuidObject, MockLocation] - var trueVal bool var location MockLocation var ctx context.Context - var position rule.Position + var position movement.Position var entries []*MockUuidObject var mode sdkmanager.ExhaustiveType BeforeEach(func() { location = MockLocation{} ctx = context.Background() - trueVal = true initial = []*MockUuidObject{{Name: "1", Value: "A"}, {Name: "2", Value: "B"}, {Name: "3", Value: "C"}} client = NewMockUuidClient(initial) service = NewMockUuidService[*MockUuidObject, MockLocation](client) @@ -65,7 +63,7 @@ var _ = Describe("Server", func() { It("CreateMany() should create new entries on the server, and return them with uuid set", func() { entries := []*MockUuidObject{{Name: "1", Value: "A"}} - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, rule.Position{First: &trueVal}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.Exhaustive, movement.PositionFirst{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(1)) @@ -100,7 +98,7 @@ var _ = Describe("Server", func() { Context("and all entries being created are new to the server", func() { It("should create those entries in the correct position", func() { - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, rule.Position{First: &trueVal}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(2)) @@ -117,7 +115,7 @@ var _ = Describe("Server", func() { BeforeEach(func() { entries = []*MockUuidObject{{Name: "1", Value: "A'"}, {Name: "3", Value: "C"}} mode = sdkmanager.Exhaustive - position = rule.Position{First: &trueVal} + position = movement.PositionFirst{} }) It("should not return any error and overwrite all entries on the server", func() { @@ -169,7 +167,7 @@ var _ = Describe("Server", func() { Expect(processed).To(HaveLen(3)) Expect(processed).NotTo(MatchEntries(entries)) - processed, err = manager.UpdateMany(ctx, location, entries, entries, sdkmanager.NonExhaustive, rule.Position{First: &trueVal}) + processed, err = manager.UpdateMany(ctx, location, entries, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(3)) @@ -184,7 +182,7 @@ var _ = Describe("Server", func() { It("should create new entries on the bottom of the list", func() { entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, rule.Position{Last: &trueVal}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(3)) @@ -206,7 +204,7 @@ var _ = Describe("Server", func() { It("should create new entries directly after first existing element", func() { entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} - processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, rule.Position{DirectlyAfter: &initial[0].Name}) + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionAfter{Directly: true, Pivot: initial[0].Name}) Expect(err).ToNot(HaveOccurred()) Expect(processed).To(HaveLen(3)) @@ -234,7 +232,7 @@ var _ = Describe("Server", func() { entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} pivot := initial[2].Name // "3" - position = rule.Position{DirectlyBefore: &pivot} + position = movement.PositionBefore{Directly: true, Pivot: pivot} processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, position) Expect(err).ToNot(HaveOccurred()) diff --git a/assets/terraform/internal/manager/uuid_utils_test.go b/assets/terraform/internal/manager/uuid_utils_test.go index de51b87a..df393bef 100644 --- a/assets/terraform/internal/manager/uuid_utils_test.go +++ b/assets/terraform/internal/manager/uuid_utils_test.go @@ -8,7 +8,7 @@ import ( "net/http" "net/url" - "github.com/PaloAltoNetworks/pango/rule" + "github.com/PaloAltoNetworks/pango/movement" "github.com/PaloAltoNetworks/pango/version" "github.com/PaloAltoNetworks/pango/xmlapi" @@ -180,7 +180,7 @@ func (o *MockUuidService[E, L]) removeEntriesFromCurrent(entries []*MockUuidObje return firstIdx } -func (o *MockUuidService[E, T]) MoveGroup(ctx context.Context, location MockLocation, position rule.Position, entries []*MockUuidObject) error { +func (o *MockUuidService[E, T]) MoveGroup(ctx context.Context, location MockLocation, position movement.Position, entries []*MockUuidObject) error { o.moveGroupEntries = entries firstIdx := o.removeEntriesFromCurrent(entries) @@ -190,34 +190,30 @@ func (o *MockUuidService[E, T]) MoveGroup(ctx context.Context, location MockLoca entriesList.PushBack(elt) } - if position.First != nil { + switch position.(type) { + case movement.PositionFirst: o.client.Current.PushFrontList(entriesList) return nil - } else if position.Last != nil { + case movement.PositionLast: o.client.Current.PushBackList(entriesList) return nil + case movement.PositionBefore, movement.PositionAfter: } var pivotEntry string var after bool var directly bool - if position.DirectlyBefore != nil { - pivotEntry = *position.DirectlyBefore + switch typed := position.(type) { + case movement.PositionBefore: after = false - directly = true - } else if position.DirectlyAfter != nil { - pivotEntry = *position.DirectlyAfter + directly = typed.Directly + pivotEntry = typed.Pivot + case movement.PositionAfter: after = true - directly = true - } else if position.SomewhereBefore != nil { - pivotEntry = *position.SomewhereBefore - after = false - directly = false - } else if position.SomewhereAfter != nil { - pivotEntry = *position.SomewhereAfter - after = true - directly = false + directly = typed.Directly + pivotEntry = typed.Pivot + case movement.PositionFirst, movement.PositionLast: } var pivotElt *list.Element From e4133ebafec25e45b293e80de0e1f302d0893fb3 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Tue, 4 Feb 2025 17:15:21 +0100 Subject: [PATCH 16/19] Verify that all entries from plan have unique names --- assets/terraform/internal/manager/manager.go | 1 + assets/terraform/internal/manager/uuid.go | 6 +++ .../terraform/internal/manager/uuid_test.go | 40 +++++++++++++++++-- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/assets/terraform/internal/manager/manager.go b/assets/terraform/internal/manager/manager.go index 4e18e952..e598ec1b 100644 --- a/assets/terraform/internal/manager/manager.go +++ b/assets/terraform/internal/manager/manager.go @@ -30,6 +30,7 @@ func (o *Error) Unwrap() error { } var ( + ErrPlanConflict = errors.New("multiple plan entries with shared name") ErrConflict = errors.New("entry from the plan already exists on the server") ErrMissingUuid = errors.New("entry is missing required uuid") ErrMarshaling = errors.New("failed to marshal entry to XML document") diff --git a/assets/terraform/internal/manager/uuid.go b/assets/terraform/internal/manager/uuid.go index 06e2f60b..2119f163 100644 --- a/assets/terraform/internal/manager/uuid.go +++ b/assets/terraform/internal/manager/uuid.go @@ -220,6 +220,9 @@ func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, var diags diag.Diagnostics planEntriesByName := o.entriesByName(planEntries, entryUnknown) + if len(planEntriesByName) != len(planEntries) { + return nil, ErrPlanConflict + } existing, err := o.service.List(ctx, location, "get", "", "") if err != nil && !sdkerrors.IsObjectNotFound(err) { @@ -307,6 +310,9 @@ func (o *UuidObjectManager[E, L, S]) CreateMany(ctx context.Context, location L, func (o *UuidObjectManager[E, L, S]) UpdateMany(ctx context.Context, location L, stateEntries []E, planEntries []E, exhaustive ExhaustiveType, position movement.Position) ([]E, error) { stateEntriesByName := o.entriesByName(stateEntries, entryUnknown) planEntriesByName := o.entriesByName(planEntries, entryUnknown) + if len(planEntriesByName) != len(planEntries) { + return nil, ErrPlanConflict + } findMatchingStateEntry := func(entry E) (E, bool) { var found bool diff --git a/assets/terraform/internal/manager/uuid_test.go b/assets/terraform/internal/manager/uuid_test.go index db199a4d..7fbce77c 100644 --- a/assets/terraform/internal/manager/uuid_test.go +++ b/assets/terraform/internal/manager/uuid_test.go @@ -2,6 +2,7 @@ package manager_test import ( "context" + "log" "log/slog" . "github.com/onsi/ginkgo/v2" @@ -11,6 +12,7 @@ import ( sdkmanager "github.com/PaloAltoNetworks/terraform-provider-panos/internal/manager" ) +var _ = log.Printf var _ = Expect var _ = slog.Debug @@ -178,8 +180,8 @@ var _ = Describe("Server", func() { Context("initially has some entries", func() { Context("when creating new entries with NonExhaustive type", func() { - Context("and position is set to Last", func() { - It("should create new entries on the bottom of the list", func() { + Context("and position is set to first", func() { + It("should create new entries on the top of the list", func() { entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) @@ -193,11 +195,35 @@ var _ = Describe("Server", func() { clientEntries := client.list() Expect(clientEntries).To(HaveLen(6)) + Expect(mockService.moveGroupEntries).To(Equal(entries)) + + Expect(clientEntries[0]).To(Equal(entries[0])) + Expect(clientEntries[1]).To(Equal(entries[1])) + Expect(clientEntries[2]).To(Equal(entries[2])) + + }) + }) + Context("and position is set to last", func() { + It("should create new entries on the bottom of the list", func() { + entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "5", Value: "E"}, {Name: "6", Value: "F"}} + + processed, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionLast{}) + Expect(err).ToNot(HaveOccurred()) + Expect(processed).To(HaveLen(3)) + + Expect(processed[0]).To(Equal(entries[0])) + Expect(processed[1]).To(Equal(entries[1])) + Expect(processed[2]).To(Equal(entries[2])) + + clientEntries := client.list() + Expect(clientEntries).To(HaveLen(6)) + + Expect(mockService.moveGroupEntries).To(Equal(entries)) + Expect(clientEntries[3]).To(Equal(entries[0])) Expect(clientEntries[4]).To(Equal(entries[1])) Expect(clientEntries[5]).To(Equal(entries[2])) - Expect(mockService.moveGroupEntries).To(Equal(entries)) }) }) Context("and position is set to directly after first element", func() { @@ -249,6 +275,14 @@ var _ = Describe("Server", func() { Expect(mockService.moveGroupEntries).To(Equal(entries)) }) }) + Context("and there is a duplicate entry within a list", func() { + It("should properly raise an error", func() { + entries := []*MockUuidObject{{Name: "4", Value: "D"}, {Name: "4", Value: "D"}} + _, err := manager.CreateMany(ctx, location, entries, sdkmanager.NonExhaustive, movement.PositionFirst{}) + + Expect(err).To(MatchError(sdkmanager.ErrPlanConflict)) + }) + }) }) }) }) From b572b2339e65b230c53414d383f9fdf09c68dfb1 Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Wed, 5 Feb 2025 10:13:07 +0100 Subject: [PATCH 17/19] Update tests to not use map for location --- .../test/resource_nat_policy_test.go | 96 +++++++----- .../test/resource_security_policy_test.go | 141 +++++------------- 2 files changed, 98 insertions(+), 139 deletions(-) diff --git a/assets/terraform/test/resource_nat_policy_test.go b/assets/terraform/test/resource_nat_policy_test.go index a29d61c2..b9e860b5 100644 --- a/assets/terraform/test/resource_nat_policy_test.go +++ b/assets/terraform/test/resource_nat_policy_test.go @@ -40,9 +40,12 @@ type expectServerNatRulesOrder struct { RuleNames []string } -func ExpectServerNatRulesOrder(prefix string, location nat.Location, ruleNames []string) *expectServerNatRulesOrder { +func ExpectServerNatRulesOrder(prefix string, ruleNames []string) *expectServerNatRulesOrder { + location := nat.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + return &expectServerNatRulesOrder{ - Location: location, + Location: *location, Prefix: prefix, RuleNames: ruleNames, } @@ -111,10 +114,13 @@ type expectServerNatRulesCount struct { Count int } -func ExpectServerNatRulesCount(prefix string, location nat.Location, count int) *expectServerNatRulesCount { +func ExpectServerNatRulesCount(prefix string, count int) *expectServerNatRulesCount { + location := nat.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + return &expectServerNatRulesCount{ Prefix: prefix, - Location: location, + Location: *location, Count: count, } } @@ -143,10 +149,23 @@ func (o *expectServerNatRulesCount) CheckState(ctx context.Context, req stateche const natPolicyExtendedResource1Tmpl = ` variable "prefix" { type = string } -variable "location" { type = map } + +resource "panos_template" "template" { + location = { panorama = {} } + + name = format("%s-tmpl", var.prefix) +} + + +resource "panos_device_group" "dg" { + location = { panorama = {} } + + name = format("%s-dg", var.prefix) + templates = [ resource.panos_template.template.name ] +} resource "panos_nat_policy" "policy" { - location = var.location + location = { device_group = { name = resource.panos_device_group.dg.name }} rules = [{ name = format("%s-rule1", var.prefix) @@ -331,19 +350,15 @@ func TestAccNatPolicyExtended(t *testing.T) { nameSuffix := acctest.RandStringFromCharSet(6, acctest.CharSetAlphaNum) prefix := fmt.Sprintf("test-acc-%s", nameSuffix) - device := devicePanorama - sdkLocation, cfgLocation := natPolicyLocationByDeviceType(device, "post-rulebase") - resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource1Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -433,13 +448,12 @@ func TestAccNatPolicyExtended(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource2Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -517,13 +531,12 @@ func TestAccNatPolicyExtended(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource3Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -565,13 +578,12 @@ func TestAccNatPolicyExtended(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: natPolicyExtendedResource4Tmpl, ConfigVariables: map[string]config.Variable{ - "prefix": config.StringVariable(prefix), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ statecheck.ExpectKnownValue( @@ -645,7 +657,7 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { device := devicePanorama - sdkLocation, cfgLocation := natPolicyLocationByDeviceType(device, "pre-rulebase") + sdkLocation, _ := natPolicyLocationByDeviceType(device, "pre-rulebase") stateExpectedRuleName := func(idx int, value string) statecheck.StateCheck { return statecheck.ExpectKnownValue( @@ -670,27 +682,27 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: natPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: natPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: makeNatPolicyConfig(prefix), ConfigVariables: map[string]config.Variable{ "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigStateChecks: []statecheck.StateCheck{ stateExpectedRuleName(0, "rule-1"), stateExpectedRuleName(1, "rule-2"), stateExpectedRuleName(2, "rule-3"), - ExpectServerNatRulesCount(prefix, sdkLocation, len(rulesInitial)), - ExpectServerNatRulesOrder(prefix, sdkLocation, rulesInitial), + ExpectServerNatRulesCount(prefix, len(rulesInitial)), + ExpectServerNatRulesOrder(prefix, rulesInitial), }, }, { Config: makeNatPolicyConfig(prefix), ConfigVariables: map[string]config.Variable{ "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -702,7 +714,7 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { Config: makeNatPolicyConfig(prefix), ConfigVariables: map[string]config.Variable{ "rule_names": config.ListVariable(withPrefix(rulesReordered)...), - "location": cfgLocation, + "prefix": config.StringVariable(prefix), }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -715,7 +727,7 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { stateExpectedRuleName(0, "rule-2"), stateExpectedRuleName(1, "rule-1"), stateExpectedRuleName(2, "rule-3"), - ExpectServerNatRulesOrder(prefix, sdkLocation, rulesReordered), + ExpectServerNatRulesOrder(prefix, rulesReordered), }, }, }, @@ -723,11 +735,24 @@ func TestAccPanosNatPolicyOrdering(t *testing.T) { } const configTmpl = ` +variable "prefix" { type = string } variable "rule_names" { type = list(string) } -variable "location" { type = map } + +resource "panos_template" "template" { + location = { panorama = {} } + + name = format("%s-tmpl", var.prefix) +} + +resource "panos_device_group" "dg" { + location = { panorama = {} } + + name = format("%s-dg", var.prefix) + templates = [ resource.panos_template.template.name ] +} resource "panos_nat_policy" "{{ .ResourceName }}" { - location = var.location + location = { device_group = { name = resource.panos_device_group.dg.name }} rules = [ for index, name in var.rule_names: { @@ -830,12 +855,15 @@ func natPolicyPreCheck(prefix string, location nat.Location) { } } -func natPolicyCheckDestroy(prefix string, location nat.Location) func(s *terraform.State) error { +func natPolicyCheckDestroy(prefix string) func(s *terraform.State) error { return func(s *terraform.State) error { service := nat.NewService(sdkClient) ctx := context.TODO() - rules, err := service.List(ctx, location, "get", "", "") + location := nat.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + + rules, err := service.List(ctx, *location, "get", "", "") if err != nil && !sdkerrors.IsObjectNotFound(err) { return err } @@ -849,7 +877,7 @@ func natPolicyCheckDestroy(prefix string, location nat.Location) func(s *terrafo if len(danglingNames) > 0 { err := DanglingObjectsError - delErr := service.Delete(ctx, location, danglingNames...) + delErr := service.Delete(ctx, *location, danglingNames...) if delErr != nil { err = errors.Join(err, delErr) } diff --git a/assets/terraform/test/resource_security_policy_test.go b/assets/terraform/test/resource_security_policy_test.go index a070ffa2..ce27203c 100644 --- a/assets/terraform/test/resource_security_policy_test.go +++ b/assets/terraform/test/resource_security_policy_test.go @@ -26,9 +26,12 @@ type expectServerSecurityRulesOrder struct { RuleNames []string } -func ExpectServerSecurityRulesOrder(prefix string, location security.Location, ruleNames []string) *expectServerSecurityRulesOrder { +func ExpectServerSecurityRulesOrder(prefix string, ruleNames []string) *expectServerSecurityRulesOrder { + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + return &expectServerSecurityRulesOrder{ - Location: location, + Location: *location, Prefix: prefix, RuleNames: ruleNames, } @@ -97,10 +100,12 @@ type expectServerSecurityRulesCount struct { Count int } -func ExpectServerSecurityRulesCount(prefix string, location security.Location, count int) *expectServerSecurityRulesCount { +func ExpectServerSecurityRulesCount(prefix string, count int) *expectServerSecurityRulesCount { + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) return &expectServerSecurityRulesCount{ Prefix: prefix, - Location: location, + Location: *location, Count: count, } } @@ -133,13 +138,13 @@ variable "prefix" { type = string } resource "panos_template" "template" { location = { panorama = {} } - name = format("%s-secgroup-tmpl1", var.prefix) + name = format("%s-tmpl", var.prefix) } resource "panos_device_group" "dg" { location = { panorama = {} } - name = format("%s-secgroup-dg1", var.prefix) + name = format("%s-dg", var.prefix) templates = [ resource.panos_template.template.name ] } @@ -399,11 +404,11 @@ func TestAccSecurityPolicyExtended(t *testing.T) { } const securityPolicyOrderingTmpl = ` +variable "prefix" { type = string } variable "rule_names" { type = list(string) } -variable "location" { type = map } resource "panos_security_policy" "policy" { - location = var.location + location = { device_group = { name = format("%s-dg", var.prefix) }} rules = [ for index, name in var.rule_names: { @@ -444,10 +449,6 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { return result } - device := devicePanorama - - sdkLocation, cfgLocation := securityPolicyLocationByDeviceType(device, "pre-rulebase") - stateExpectedRuleName := func(idx int, value string) statecheck.StateCheck { return statecheck.ExpectKnownValue( "panos_security_policy.policy", @@ -467,24 +468,24 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { resource.Test(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) - securityPolicyPreCheck(prefix, sdkLocation) + securityPolicyPreCheck(prefix) }, ProtoV6ProviderFactories: testAccProviders, - CheckDestroy: securityPolicyCheckDestroy(prefix, sdkLocation), + CheckDestroy: securityPolicyCheckDestroy(prefix), Steps: []resource.TestStep{ { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, PlanOnly: true, ExpectNonEmptyPlan: false, @@ -492,8 +493,8 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, }, ConfigStateChecks: []statecheck.StateCheck{ stateExpectedRuleName(0, "rule-1"), @@ -501,15 +502,15 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { stateExpectedRuleName(2, "rule-3"), stateExpectedRuleName(3, "rule-4"), stateExpectedRuleName(4, "rule-5"), - ExpectServerSecurityRulesCount(prefix, sdkLocation, len(rulesInitial)), - ExpectServerSecurityRulesOrder(prefix, sdkLocation, rulesInitial), + ExpectServerSecurityRulesCount(prefix, len(rulesInitial)), + ExpectServerSecurityRulesOrder(prefix, rulesInitial), }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable(withPrefix(rulesInitial)...), - "location": cfgLocation, }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -520,8 +521,8 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable(withPrefix(rulesReordered)...), - "location": cfgLocation, }, ConfigPlanChecks: resource.ConfigPlanChecks{ PreApply: []plancheck.PlanCheck{ @@ -538,21 +539,21 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { stateExpectedRuleName(2, "rule-3"), stateExpectedRuleName(3, "rule-4"), stateExpectedRuleName(4, "rule-5"), - ExpectServerSecurityRulesOrder(prefix, sdkLocation, rulesReordered), + ExpectServerSecurityRulesOrder(prefix, rulesReordered), }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, }, { Config: securityPolicyOrderingTmpl, ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), "rule_names": config.ListVariable([]config.Variable{}...), - "location": cfgLocation, }, PlanOnly: true, ExpectNonEmptyPlan: false, @@ -561,39 +562,7 @@ func TestAccSecurityPolicyOrdering(t *testing.T) { }) } -func securityPolicyLocationByDeviceType(typ deviceType, rulebase string) (security.Location, config.Variable) { - var sdkLocation security.Location - var cfgLocation config.Variable - switch typ { - case devicePanorama: - sdkLocation = security.Location{ - Shared: &security.SharedLocation{ - Rulebase: rulebase, - }, - } - cfgLocation = config.ObjectVariable(map[string]config.Variable{ - "shared": config.ObjectVariable(map[string]config.Variable{ - "rulebase": config.StringVariable(rulebase), - }), - }) - case deviceFirewall: - sdkLocation = security.Location{ - Vsys: &security.VsysLocation{ - NgfwDevice: "localhost.localdomain", - Vsys: "vsys1", - }, - } - cfgLocation = config.ObjectVariable(map[string]config.Variable{ - "vsys": config.ObjectVariable(map[string]config.Variable{ - "name": config.StringVariable("vsys1"), - }), - }) - } - - return sdkLocation, cfgLocation -} - -func securityPolicyPreCheck(prefix string, location security.Location) { +func securityPolicyPreCheck(prefix string) { service := security.NewService(sdkClient) ctx := context.TODO() @@ -620,8 +589,11 @@ func securityPolicyPreCheck(prefix string, location security.Location) { }, } + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + for _, elt := range rules { - _, err := service.Create(ctx, location, &elt) + _, err := service.Create(ctx, *location, &elt) if err != nil { panic(fmt.Sprintf("natPolicyPreCheck failed: %s", err)) } @@ -629,12 +601,15 @@ func securityPolicyPreCheck(prefix string, location security.Location) { } } -func securityPolicyCheckDestroy(prefix string, location security.Location) func(s *terraform.State) error { +func securityPolicyCheckDestroy(prefix string) func(s *terraform.State) error { return func(s *terraform.State) error { service := security.NewService(sdkClient) ctx := context.TODO() - rules, err := service.List(ctx, location, "get", "", "") + location := security.NewDeviceGroupLocation() + location.DeviceGroup.DeviceGroup = fmt.Sprintf("%s-dg", prefix) + + rules, err := service.List(ctx, *location, "get", "", "") if err != nil && !sdkerrors.IsObjectNotFound(err) { return err } @@ -648,7 +623,7 @@ func securityPolicyCheckDestroy(prefix string, location security.Location) func( if len(danglingNames) > 0 { err := DanglingObjectsError - delErr := service.Delete(ctx, location, danglingNames...) + delErr := service.Delete(ctx, *location, danglingNames...) if delErr != nil { err = errors.Join(err, delErr) } @@ -659,47 +634,3 @@ func securityPolicyCheckDestroy(prefix string, location security.Location) func( return nil } } - -func init() { - resource.AddTestSweepers("pango_security_policy", &resource.Sweeper{ - Name: "pango_security_policy", - F: func(typ string) error { - service := security.NewService(sdkClient) - - var deviceTyp deviceType - switch typ { - case "panorama": - deviceTyp = devicePanorama - case "firewall": - deviceTyp = deviceFirewall - default: - panic("invalid device type") - } - - for _, rulebase := range []string{"pre-rulebase", "post-rulebase"} { - location, _ := securityPolicyLocationByDeviceType(deviceTyp, rulebase) - ctx := context.TODO() - objects, err := service.List(ctx, location, "get", "", "") - if err != nil && !sdkerrors.IsObjectNotFound(err) { - return fmt.Errorf("Failed to list Security Rules during sweep: %w", err) - } - - var names []string - for _, elt := range objects { - if strings.HasPrefix(elt.Name, "test-acc") { - names = append(names, elt.Name) - } - } - - if len(names) > 0 { - err = service.Delete(ctx, location, names...) - if err != nil { - return fmt.Errorf("Failed to delete Security Rules during sweep: %w", err) - } - } - } - - return nil - }, - }) -} From 8f0796475ac8a264cf4f9d588cb59d10404d2cca Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Wed, 5 Feb 2025 10:28:17 +0100 Subject: [PATCH 18/19] Add validation of name uniqueness in uuid resources --- .../test/resource_security_policy_test.go | 63 +++++++++++++++++++ pkg/translate/terraform_provider/template.go | 26 +++++++- .../terraform_provider_file.go | 17 ++--- 3 files changed, 98 insertions(+), 8 deletions(-) diff --git a/assets/terraform/test/resource_security_policy_test.go b/assets/terraform/test/resource_security_policy_test.go index ce27203c..844d7854 100644 --- a/assets/terraform/test/resource_security_policy_test.go +++ b/assets/terraform/test/resource_security_policy_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "regexp" "strings" "testing" @@ -132,6 +133,47 @@ func (o *expectServerSecurityRulesCount) CheckState(ctx context.Context, req sta } } +const securityPolicyDuplicatedTmpl = ` +variable "prefix" { type = string } + +resource "panos_template" "template" { + location = { panorama = {} } + + name = format("%s-tmpl", var.prefix) +} + +resource "panos_device_group" "dg" { + location = { panorama = {} } + + name = format("%s-dg", var.prefix) + templates = [ resource.panos_template.template.name ] +} + + +resource "panos_security_policy" "policy" { + location = { device_group = { name = resource.panos_device_group.dg.name }} + + rules = [ + { + name = format("%s-rule", var.prefix) + source_zones = ["any"] + source_addresses = ["any"] + + destination_zones = ["any"] + destination_addresses = ["any"] + }, + { + name = format("%s-rule", var.prefix) + source_zones = ["any"] + source_addresses = ["any"] + + destination_zones = ["any"] + destination_addresses = ["any"] + } + ] +} +` + const securityPolicyExtendedResource1Tmpl = ` variable "prefix" { type = string } @@ -197,6 +239,27 @@ resource "panos_security_policy" "policy" { } ` +func TestAccSecurityPolicyDuplicatedPlan(t *testing.T) { + t.Parallel() + + nameSuffix := acctest.RandStringFromCharSet(6, acctest.CharSetAlphaNum) + prefix := fmt.Sprintf("test-acc-%s", nameSuffix) + + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + ProtoV6ProviderFactories: testAccProviders, + Steps: []resource.TestStep{ + { + Config: securityPolicyDuplicatedTmpl, + ConfigVariables: map[string]config.Variable{ + "prefix": config.StringVariable(prefix), + }, + ExpectError: regexp.MustCompile("List entries must have unique names"), + }, + }, + }) +} + func TestAccSecurityPolicyExtended(t *testing.T) { t.Parallel() diff --git a/pkg/translate/terraform_provider/template.go b/pkg/translate/terraform_provider/template.go index 9db88bfc..ad8db6e6 100644 --- a/pkg/translate/terraform_provider/template.go +++ b/pkg/translate/terraform_provider/template.go @@ -275,13 +275,37 @@ func (r *{{ resourceStructName }}) Metadata(ctx context.Context, req resource.Me func (r *{{ resourceStructName }}) ValidateConfig(ctx context.Context, req resource.ValidateConfigRequest, resp *resource.ValidateConfigResponse) { {{- if HasPosition }} + { var resource {{ resourceStructName }}Model resp.Diagnostics.Append(req.Config.Get(ctx, &resource)...) if resp.Diagnostics.HasError() { return } - resource.Position.ValidateConfig(resp) + } +{{- end }} + +{{- if IsUuid }} + { + var resource {{ resourceStructName }}Model + resp.Diagnostics.Append(req.Config.Get(ctx, &resource)...) + if resp.Diagnostics.HasError() { + return + } + {{ $resourceTFStructName := printf "%s%sObject" resourceStructName ListAttribute.CamelCase }} + entries := make(map[string]struct{}) + var elements []{{ $resourceTFStructName }} + resource.{{ ListAttribute.CamelCase }}.ElementsAs(ctx, &elements, false) + + for _, elt := range elements { + entry := elt.Name.ValueString() + if _, found := entries[entry]; found { + resp.Diagnostics.AddError("Failed to validate resource", "List entries must have unique names") + return + } + entries[entry] = struct{}{} + } + } {{- end }} } diff --git a/pkg/translate/terraform_provider/terraform_provider_file.go b/pkg/translate/terraform_provider/terraform_provider_file.go index a926101e..372c1e98 100644 --- a/pkg/translate/terraform_provider/terraform_provider_file.go +++ b/pkg/translate/terraform_provider/terraform_provider_file.go @@ -127,13 +127,16 @@ func (g *GenerateTerraformProvider) GenerateTerraformResource(resourceTyp proper } funcMap := template.FuncMap{ - "GoSDKSkipped": func() bool { return spec.GoSdkSkip }, - "IsEntry": func() bool { return spec.HasEntryName() && !spec.HasEntryUuid() }, - "HasImports": func() bool { return len(spec.Imports) > 0 }, - "IsCustom": func() bool { return spec.TerraformProviderConfig.ResourceType == properties.TerraformResourceCustom }, - "IsUuid": func() bool { return spec.HasEntryUuid() }, - "IsConfig": func() bool { return !spec.HasEntryName() && !spec.HasEntryUuid() }, - "IsImportable": func() bool { return resourceTyp == properties.ResourceEntry }, + "GoSDKSkipped": func() bool { return spec.GoSdkSkip }, + "IsEntry": func() bool { return spec.HasEntryName() && !spec.HasEntryUuid() }, + "HasImports": func() bool { return len(spec.Imports) > 0 }, + "IsCustom": func() bool { return spec.TerraformProviderConfig.ResourceType == properties.TerraformResourceCustom }, + "IsUuid": func() bool { return spec.HasEntryUuid() }, + "IsConfig": func() bool { return !spec.HasEntryName() && !spec.HasEntryUuid() }, + "IsImportable": func() bool { return resourceTyp == properties.ResourceEntry }, + "ListAttribute": func() *properties.NameVariant { + return properties.NewNameVariant(spec.TerraformProviderConfig.PluralName) + }, "resourceSDKName": func() string { return names.PackageName }, "HasPosition": func() bool { return hasPosition }, "metaName": func() string { return names.MetaName }, From 63bbb0a46876e960c00a6b8aafa7ae90d4cbe56c Mon Sep 17 00:00:00 2001 From: Krzysztof Klimonda Date: Wed, 5 Feb 2025 12:21:19 +0100 Subject: [PATCH 19/19] Enable cpu profiling --- assets/terraform/main.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/assets/terraform/main.go b/assets/terraform/main.go index a1540c1a..368bc3da 100644 --- a/assets/terraform/main.go +++ b/assets/terraform/main.go @@ -4,12 +4,16 @@ import ( "context" "flag" "log" + "os" + "runtime/pprof" "github.com/PaloAltoNetworks/terraform-provider-panos/internal/provider" "github.com/hashicorp/terraform-plugin-framework/providerserver" ) +var _ = pprof.StartCPUProfile + // Run "go generate" to format example terraform files and generate the docs for the registry/website // If you do not have terraform installed, you can remove the formatting command, but its suggested to @@ -35,6 +39,16 @@ func main() { flag.BoolVar(&debug, "debug", false, "set to true to run the provider with support for debuggers like delve") flag.Parse() + cpuprofile := os.Getenv("TF_PANOS_PROFILE") + if cpuprofile != "" { + f, err := os.Create(cpuprofile) + if err != nil { + log.Fatal(err) + } + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() + } + opts := providerserver.ServeOpts{ Address: "registry.terraform.io/paloaltonetworks/panos", Debug: debug,