diff --git a/econml/_cate_estimator.py b/econml/_cate_estimator.py index 3c453b291..be559ad8e 100644 --- a/econml/_cate_estimator.py +++ b/econml/_cate_estimator.py @@ -591,7 +591,9 @@ def effect(self, X=None, *, T0, T1): Note that when Y is a vector rather than a 2-dimensional array, the corresponding singleton dimension will be collapsed (so this method will return a vector) """ - X, T0, T1 = self._expand_treatments(X, T0, T1) + X, T1 = self._expand_treatments(X, T1) + is_default = ndim(T0) == 0 and T0 == 0 + _, T0 = self._expand_treatments(None, T0, suppress_warn=is_default) # TODO: what if input is sparse? - there's no equivalent to einsum, # but tensordot can't be applied to this problem because we don't sum over m eff = self.const_marginal_effect(X) @@ -599,7 +601,6 @@ def effect(self, X=None, *, T0, T1): # of rows of T was not taken into account if X is None: eff = np.repeat(eff, shape(T0)[0], axis=0) - m = shape(eff)[0] dT = T1 - T0 einsum_str = 'myt,mt->my' if ndim(dT) == 1: @@ -847,12 +848,12 @@ def _postfit(self, Y, T, *args, **kwargs): if self.transformer: self._set_transformed_treatment_names() - def _expand_treatments(self, X=None, *Ts, transform=True): + def _expand_treatments(self, X=None, *Ts, transform=True, suppress_warn=False): X, *Ts = check_input_arrays(X, *Ts) n_rows = 1 if X is None else shape(X)[0] outTs = [] for T in Ts: - if (ndim(T) == 0) and self._d_t_in and self._d_t_in[0] > 1: + if (ndim(T) == 0) and self._d_t_in and self._d_t_in[0] > 1 and not suppress_warn: warn("A scalar was specified but there are multiple treatments; " "the same value will be used for each treatment. Consider specifying" "all treatments, or using the const_marginal_effect method.")