Skip to content

Commit

Permalink
Merge pull request jax-ml#2999 from gnecula/jax_outfeed_undo
Browse files Browse the repository at this point in the history
Undo the id_print/id_tap feature (PR jax-ml#2791)
  • Loading branch information
gnecula authored May 7, 2020
2 parents 0a7974e + 769d703 commit d679ccd
Show file tree
Hide file tree
Showing 11 changed files with 12 additions and 1,963 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ before_install:
- conda update -q conda
install:
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip absl-py opt_einsum numpy scipy pytest-xdist pytest-benchmark mypy=0.770
- pip install msgpack
- if [ "$JAX_ONLY_CHECK_TYPES" = true ]; then
pip install pytype ;
fi
Expand Down
16 changes: 0 additions & 16 deletions docs/jax.experimental.host_callback.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/jax.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ jax.experimental package
.. toctree::
:maxdepth: 1

jax.experimental.host_callback
jax.experimental.loops
jax.experimental.optimizers
jax.experimental.optix
Expand Down
7 changes: 0 additions & 7 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,3 @@ pytype_library(
srcs_version = "PY3",
deps = [":jax"],
)

pytype_library(
name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"],
srcs_version = "PY3",
deps = [":jax"],
)
1 change: 0 additions & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ def computation_maker(*args, **kwargs):
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals,
instantiate=instantiate_const_outputs,
stage_out=True)
jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
xla_consts = map(partial(xb.constant, c), consts)
Expand Down
7 changes: 1 addition & 6 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,12 +932,7 @@ def strip_weak_type(self):
_oct = partialmethod(_forward_to_value, oct)


class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
return self
else:
assert False, f"Cannot join {self} with {other}"
class AbstractToken(AbstractValue): pass

abstract_token = AbstractToken()

Expand Down
Loading

0 comments on commit d679ccd

Please sign in to comment.