diff --git a/keras/backend/mxnet_backend.py b/keras/backend/mxnet_backend.py index 206121bd7533..fb0910cfa3f5 100644 --- a/keras/backend/mxnet_backend.py +++ b/keras/backend/mxnet_backend.py @@ -203,7 +203,11 @@ def __init__(self, symbol, name=None, neighbor=None, is_var=False): def bind(self, data): - self.tensor = data + if not hasattr(self, 'tensor'): + self.tensor = data + else: + self.tensor[:] = data + if self.name in self._bind_values: assert self._bind_values[self.name].shape == data.shape, \ "Redefinition of variable %s" % self.name @@ -884,7 +888,8 @@ def random_uniform_variable(shape, low, high, dtype=None, value = mx.random.uniform(low=low, high=high, dtype='float32', shape=shape) if dtype != np.float32: value = mx.nd.Cast(value, dtype=dtype) - name = _autogen_name("uniform") + if name is None: + name = _autogen_name("uniform") ret = KerasVariable(name, value.shape, value.dtype) ret.bind(value) return ret @@ -923,7 +928,8 @@ def random_normal_variable(shape, mean, scale, dtype=None, value = mx.random.normal(loc=mean, scale=scale, dtype='float32', shape=shape) if dtype != np.float32: value = mx.nd.Cast(value, dtype=dtype) - name = _autogen_name("normal") + if name is None: + name = _autogen_name("normal") ret = KerasVariable(name, value.shape, value.dtype) ret.bind(value) return ret @@ -2956,7 +2962,8 @@ def random_normal(shape, mean=0.0, std=1.0, dtype=None, seed=None): if dtype is None: dtype = floatx() dtype = np.dtype(dtype) - name = _autogen_name('normal') + if name is None: + name = _autogen_name('normal') sym = mx.sym.normal(loc=mean, scale=std, shape=shape, dtype='float32', name=name) if dtype != np.float32: sym = mx.sym.Cast(data=sym, dtype=dtype) @@ -2983,7 +2990,8 @@ def random_uniform(shape, low=0.0, high=1.0, dtype=None, seed=None): if dtype is None: dtype = floatx() dtype = np.dtype(dtype) - name = _autogen_name('uniform') + if name is None: + name = _autogen_name('uniform') sym = mx.sym.uniform(low=low, high=high, shape=shape, dtype='float32', name=name) if dtype != np.float32: sym = mx.sym.Cast(data=sym, dtype=dtype)