-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy patheval.go
157 lines (143 loc) · 3.07 KB
/
eval.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package eval
import (
"encoding/json"
"fmt"
"io"
"sort"
)
// KeyValue is a generic set of key-value pairs
//
// expect values to be string, float64,
// or bool (or recursively another KeyValue)
type KeyValue map[string]interface{}
func (kv KeyValue) Validate() error {
for key, v := range kv {
switch v := v.(type) {
case string, float64, bool:
// allowed, fine
continue
case KeyValue:
err := v.Validate()
if err != nil {
return err
}
default:
return fmt.Errorf("key %v is of type %T and not "+
"string, float64, or bool", key, v)
}
}
return nil
}
type KeyValuePair struct {
Key string
Val interface{}
}
func (kv KeyValue) Flatten() KeyValue {
flat := make(KeyValue)
for key, val := range kv {
switch val := val.(type) {
case KeyValue:
for _, pair := range val.Pairs() {
flat[key+"."+pair.Key] = pair.Val
}
default:
flat[key] = val
}
}
return flat
}
// Pairs returns the key-value pairs in kv, sorted by key
func (kv KeyValue) Pairs() []KeyValuePair {
var pairs []KeyValuePair
for key, val := range kv {
pairs = append(pairs, KeyValuePair{key, val})
}
sort.Slice(pairs, func(i int, j int) bool {
return pairs[i].Key < pairs[j].Key
})
return pairs
}
// Delete returns a new KeyValue with key removed
func (kv KeyValue) Delete(key string) KeyValue {
filtered := kv.Clone()
delete(filtered, key)
return filtered
}
func (kv KeyValue) Clone() KeyValue {
kv2 := make(KeyValue, len(kv))
for k, v := range kv {
switch v := v.(type) {
case KeyValue:
kv2[k] = v.Clone()
default:
kv2[k] = v
}
}
return kv2
}
// Extend adds all key-value pairs from kv2 to kv
//
// modifies kv in-place and returns kv (for chaining)
func (kv KeyValue) Extend(kv2 KeyValue) KeyValue {
for k, v := range kv2 {
switch v := v.(type) {
case KeyValue:
kv[k] = v.Clone()
default:
kv[k] = v
}
}
return kv
}
// Product takes the cross product of any fields that are slices
func (kv KeyValue) Product() []KeyValue {
variants := []KeyValue{{}}
for k, v := range kv {
switch v := v.(type) {
case []interface{}:
var newVariants []KeyValue
for _, kvs := range variants {
for _, vv := range v {
newVariants = append(newVariants,
KeyValue{k: vv}.Extend(kvs))
}
}
variants = newVariants
default:
for _, kvs := range variants {
kvs[k] = v
}
}
}
return variants
}
type Observation struct {
Values KeyValue `json:"values"`
Config KeyValue `json:"config"`
}
// Write appends the serialized observation to w
func (o Observation) Write(w io.Writer) error {
p, err := json.Marshal(o)
if err != nil {
return err
}
_, err = w.Write(p)
return err
}
// ReadObservation gets the next observation in r
func ReadObservation(r io.Reader) (o Observation, err error) {
d := json.NewDecoder(r)
err = d.Decode(&o)
return
}
func WriteObservation(w io.Writer, o Observation) {
o.Config = o.Config.Flatten()
err := o.Write(w)
if err != nil {
panic(fmt.Errorf("could not write output: %v", err))
}
_, err = w.Write([]byte{'\n'})
if err != nil {
panic(fmt.Errorf("could not write output: %v", err))
}
}