-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsmooth.py
More file actions
50 lines (44 loc) · 1.97 KB
/
smooth.py
File metadata and controls
50 lines (44 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import hashlib
import inspect
import numpy as np
import os
from learn import train
# take a function and return the hash value of its compiled bytecode
def sha224_hex(func):
code = compile(inspect.getsource(func), '<string>', 'exec')
return hashlib.sha224(code.co_code).hexdigest()
def smoothen(output_size, cache_dir='cached-nets', hash_func=sha224_hex, **ranges):
""" Decorator that replaces a function with a "smoothed" neural network.
This function will call into train.py to train a neural network on the
functions's declared input ranges which will take some time if a cached
network does not exist. All of the function's inputs and outputs should be
numeric.
Args:
output_size: number of outputs of the function.
cache_dir: location to look for and save trained neural nets.
hash_func: hash function to use to hash function's bytecode. We use
this to identify a trained neural net corresponding to a function.
**ranges: Infinite iterators corresponding to each input of the function
that generates values in the range of each input (see
utils.iterators for examples).
"""
def smoothed(func):
# check if network for func has been cached
fname = func.__name__ + '_' + hash_func(func)
fname = os.path.join(cache_dir, fname)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# if not, compute
args, _, _, _ = inspect.getargspec(func)
# will raise KeyError if there's an arg with an unspecified range.
# this is intended.
feature_iters = [ranges[arg] for arg in args]
trainer = train.Trainer(
feature_iters=feature_iters,
output_size=output_size,
label=func,
chkpt=fname)
if not os.path.isfile(fname):
trainer.train()
return lambda *args: trainer.forward_pass(np.array([args],dtype='float32'))
return smoothed