forked from giuse/DNE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathobservation_compressor.rb
132 lines (115 loc) · 4.88 KB
/
observation_compressor.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
require 'forwardable'
require 'machine_learning_workbench'
module DNE
# Wrap a WB compressor for usage on observations in a UL-ERL + PyCall context
class ObservationCompressor
extend Forwardable
def_delegators :@compr, :ncentrs, :centrs, :ntrains, :ntrains_skip, :encoding, :code_size
attr_reader :downsample, :downsampled_size, :compr, :train_set, :obs_range
def initialize type:, orig_size:, obs_range:, downsample: nil, **compr_args
@obs_range = obs_range
@downsample = downsample
raise ArgumentError, "Only downward scaling" if downsample.any? { |v| v < 1 }
@downsampled_size = orig_size.zip(downsample).map { |s,r| s/r }
centr_size = downsampled_size.reduce(:*)
compr_class = begin
WB::Compressor.const_get(type)
rescue NameError => err
raise ArgumentError, "Unrecognized compressor type `#{type}`"
end
@compr = compr_class.new **compr_args.merge({dims: centr_size})
@train_set = []
end
# Reset the centroids to something else than random noise
# TODO: refactor
def reset_centrs img, proport: nil
img = normalize img if img.kind_of? NImage
compr.init_centrs base: img, proport: proport
end
# Normalize an observation into a form preferable for the WB compressor
# @param observation [NImage] an observation coming from the environment
# through `AtariWrapper` (already resampled and converted to NArray)
# @return [NArray] the normalized observation ready for processing
def normalize observation
WB::Tools::Normalization.feature_scaling observation,
from: obs_range, to: compr.vrange
# from: [observation.class::MIN, observation.class::MAX], to: compr.vrange
end
# Encodes an observation using the underlying compressor
# @param observation [NArray]
# @return [NArray] encoded observation
def encode obs
obs = normalize(obs) if obs.kind_of? NImage
compr.encode obs
end
# Compute the novelty of an observation as aggregated reconstruction error
# @param observation [NArray]
# @param code [Array]
# @return [Float] novelty score
def novelty obs, code
compr.reconstr_error(obs, code: code).abs.mean
end
# Train the compressor on the observations collected so far
def train
# TODO: I can save the most similar centroid and corresponding simil in the
# training set, then the training should be just some sums! Super fast!
# NOTE: if I go back to using centroid training at all...
# NOTE: careful if you're using normalized centroids from here on
compr.train train_set.map &method(:normalize)
@train_set = []
end
# TODO: move most of this to VQ?
# Show a centroid using ImageMagick
def show_centroid idx, disp_size: [300,300]
WB::Tools::Imaging.display centrs[idx, true], shape: downsampled_size.reverse, disp_size: disp_size
end
# Show centroids using ImageMagick
def show_centroids to_disp=ncentrs, disp_size: [300, 300], wait: true
to_disp = to_disp.times unless to_disp.kind_of? Enumerable
to_disp.each &method(:show_centroid)
puts "#{to_disp.size}/#{centrs.shape.first} centroids displayed"
if wait
puts "Hit return to close them"
gets
end
nil
ensure
WB::Tools::Execution.kill_forks
end
# Save centroids to files using ImageMagick
def save_centroids to_save=ncentrs, disp_size: [700, 800]
require 'rmagick'
to_save = to_save.times unless to_save.kind_of? Enumerable
to_save.each do |idx|
img = WB::Tools::Imaging.narr_to_img centrs[idx, true], shape: downsampled_size.reverse
img.resize!(*disp_size, Magick::TriangleFilter,0.51) if disp_size
img.write "centr_#{idx}.pdf"
end
puts "#{to_save.size}/#{centrs.shape.first} centroids saved"
end
# Returns a hash of values to maintain from parallel execution
# See `AtariUlerlExperiment#gen_fit_fn`
def parall_info
{ tset: train_set,
utility: compr.utility,
ncodes: compr.ncodes
}
end
# Loads data from a hash generated by forks using `parall_info`
# See `#parall_info` and `AtariUlerlExperiment#gen_fit_fn`
def add_from_parall_info tset:, utility:, ncodes:
# NOTE: here tset contains `NImage`s, will convert into `NArray`s in `#train`
@train_set += tset # this already works regardless of the size of tset
compr.utility = case compr.encoding_type
when :ensemble, :norm_ensemble, :sparse_coding # cumulative moving average
((compr.ncodes * compr.utility) + (ncodes * utility)) / (compr.ncodes + ncodes)
when :most_similar_ary # counts occurrencies in array
compr.utility + utility
when :most_similar # only counts occurrencies
compr.utility[util] += 1
else raise ArgumentError "how did you even get here?"
end
compr.ncodes += ncodes
end
end
end