Skip to content

use position in vardata to mark internal hooks #4549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 6, 2025
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
{% import 'web/pages/utils.js.jinja2' as utils %}

{% set all_hooks = component._get_all_hooks().items() %}

export function {{tag_name}} () {
{% for hook in component._get_all_hooks_internal() %}
{% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.INTERNAL %}
{{ hook }}
{% endfor %}

{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
{% for hook, data in all_hooks if not data or (not data.position or data.position == const.hook_position.PRE_TRIGGER) %}
{{ hook }}
{% endfor %}

{% for hook in memo_trigger_hooks %}
{{ hook }}
{% endfor %}

{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
{% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }}
{% endfor %}

Expand Down
5 changes: 3 additions & 2 deletions reflex/components/base/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData


class Bare(Component):
Expand All @@ -32,7 +33,7 @@ def create(cls, contents: Any) -> Component:
contents = str(contents) if contents is not None else ""
return cls(contents=contents) # type: ignore

def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Include the hooks for the component.

Returns:
Expand All @@ -43,7 +44,7 @@ def _get_all_hooks_internal(self) -> dict[str, None]:
hooks |= self.contents._var_value._get_all_hooks_internal()
return hooks

def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Include the hooks for the component.

Returns:
Expand Down
51 changes: 33 additions & 18 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ def render(self) -> dict:
"""

@abstractmethod
def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.

Returns:
The code that should appear just before user-defined hooks.
"""

@abstractmethod
def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component.

Returns:
Expand Down Expand Up @@ -1338,7 +1338,7 @@ def _get_hooks_imports(self) -> ParsedImportDict:
"""
_imports = {}

if self._get_ref_hook():
if self._get_ref_hook() is not None:
# Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
_imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
Expand Down Expand Up @@ -1454,19 +1454,20 @@ def _get_mount_lifecycle_hook(self) -> str | None:
}}
}}, []);"""

def _get_ref_hook(self) -> str | None:
def _get_ref_hook(self) -> Var | None:
"""Generate the ref hook for the component.

Returns:
The useRef hook for managing refs.
"""
ref = self.get_ref()
if ref is not None:
return (
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};"
return Var(
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
_var_data=VarData(position=Hooks.HookPosition.INTERNAL),
)

def _get_vars_hooks(self) -> dict[str, None]:
def _get_vars_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by vars referenced in this component.

Returns:
Expand All @@ -1479,27 +1480,38 @@ def _get_vars_hooks(self) -> dict[str, None]:
vars_hooks.update(
var_data.hooks
if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks}
else {
k: VarData(position=Hooks.HookPosition.INTERNAL)
for k in var_data.hooks
}
)
return vars_hooks

def _get_events_hooks(self) -> dict[str, None]:
def _get_events_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by events referenced in this component.

Returns:
The hooks for the events.
"""
return {Hooks.EVENTS: None} if self.event_triggers else {}
return (
{Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.event_triggers
else {}
)

def _get_special_hooks(self) -> dict[str, None]:
def _get_special_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by special actions referenced in this component.

Returns:
The hooks for special actions.
"""
return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
return (
{Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.autofocus
else {}
)

def _get_hooks_internal(self) -> dict[str, None]:
def _get_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component managed by the framework.

Downstream components should NOT override this method to avoid breaking
Expand All @@ -1510,7 +1522,7 @@ def _get_hooks_internal(self) -> dict[str, None]:
"""
return {
**{
hook: None
str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
if hook is not None
},
Expand Down Expand Up @@ -1559,7 +1571,7 @@ def _get_hooks(self) -> str | None:
"""
return

def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.

Returns:
Expand All @@ -1574,14 +1586,17 @@ def _get_all_hooks_internal(self) -> dict[str, None]:

return code

def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component and its children.

Returns:
The code that should appear just before returning the rendered component.
"""
code = {}

# Add the internal hooks for this component.
code.update(self._get_all_hooks_internal())

# Add the hook code for this component.
hooks = self._get_hooks()
if hooks is not None:
Expand Down Expand Up @@ -2277,15 +2292,15 @@ def _get_memoized_event_triggers(
)
return trigger_memo

def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.

Returns:
The code that should appear just before user-defined hooks.
"""
return {}

def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component.

Returns:
Expand Down
1 change: 1 addition & 0 deletions reflex/constants/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
class HookPosition(enum.Enum):
"""The position of the hook in the component."""

INTERNAL = "internal"
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"

Expand Down
2 changes: 1 addition & 1 deletion reflex/experimental/client_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create(
else:
default_var = default
setter_name = f"set{var_name.capitalize()}"
hooks = {
hooks: dict[str, VarData | None] = {
f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
}
imports = {
Expand Down
6 changes: 4 additions & 2 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
hooks: dict[str, VarData | None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
Expand Down Expand Up @@ -194,7 +194,9 @@ def merge(*all: VarData | None) -> VarData | None:
(var_data.state for var_data in all_var_datas if var_data.state), ""
)

hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks}
hooks: dict[str, VarData | None] = {
hook: None for var_data in all_var_datas for hook in var_data.hooks
}

_imports = imports.merge_imports(
*(var_data.imports for var_data in all_var_datas)
Expand Down
Loading