Skip to content

Commit

Permalink
use position in vardata to mark internal hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Lendemor committed Dec 17, 2024
1 parent d7956c1 commit 95222e4
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
{% import 'web/pages/utils.js.jinja2' as utils %}

{% set all_hooks = component._get_all_hooks().items() %}

export function {{tag_name}} () {
{% for hook in component._get_all_hooks_internal() %}
{% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.INTERNAL %}
{{ hook }}
{% endfor %}

{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
{% for hook, data in all_hooks if not data or (not data.position or data.position == const.hook_position.PRE_TRIGGER) %}
{{ hook }}
{% endfor %}

{% for hook in memo_trigger_hooks %}
{{ hook }}
{% endfor %}

{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
{% for hook, data in all_hooks if data and data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }}
{% endfor %}

Expand Down
5 changes: 3 additions & 2 deletions reflex/components/base/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData


class Bare(Component):
Expand All @@ -32,7 +33,7 @@ def create(cls, contents: Any) -> Component:
contents = str(contents) if contents is not None else ""
return cls(contents=contents) # type: ignore

def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Include the hooks for the component.
Returns:
Expand All @@ -43,7 +44,7 @@ def _get_all_hooks_internal(self) -> dict[str, None]:
hooks |= self.contents._var_value._get_all_hooks_internal()
return hooks

def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Include the hooks for the component.
Returns:
Expand Down
51 changes: 33 additions & 18 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ def render(self) -> dict:
"""

@abstractmethod
def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
The code that should appear just before user-defined hooks.
"""

@abstractmethod
def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component.
Returns:
Expand Down Expand Up @@ -1338,7 +1338,7 @@ def _get_hooks_imports(self) -> ParsedImportDict:
"""
_imports = {}

if self._get_ref_hook():
if self._get_ref_hook() is not None:
# Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
_imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
Expand Down Expand Up @@ -1454,19 +1454,20 @@ def _get_mount_lifecycle_hook(self) -> str | None:
}}
}}, []);"""

def _get_ref_hook(self) -> str | None:
def _get_ref_hook(self) -> Var | None:
"""Generate the ref hook for the component.
Returns:
The useRef hook for managing refs.
"""
ref = self.get_ref()
if ref is not None:
return (
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};"
return Var(
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
_var_data=VarData(position=Hooks.HookPosition.INTERNAL),
)

def _get_vars_hooks(self) -> dict[str, None]:
def _get_vars_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by vars referenced in this component.
Returns:
Expand All @@ -1479,27 +1480,38 @@ def _get_vars_hooks(self) -> dict[str, None]:
vars_hooks.update(
var_data.hooks
if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks}
else {
k: VarData(position=Hooks.HookPosition.INTERNAL)
for k in var_data.hooks
}
)
return vars_hooks

def _get_events_hooks(self) -> dict[str, None]:
def _get_events_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by events referenced in this component.
Returns:
The hooks for the events.
"""
return {Hooks.EVENTS: None} if self.event_triggers else {}
return (
{Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.event_triggers
else {}
)

def _get_special_hooks(self) -> dict[str, None]:
def _get_special_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by special actions referenced in this component.
Returns:
The hooks for special actions.
"""
return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
return (
{Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.autofocus
else {}
)

def _get_hooks_internal(self) -> dict[str, None]:
def _get_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component managed by the framework.
Downstream components should NOT override this method to avoid breaking
Expand All @@ -1510,7 +1522,7 @@ def _get_hooks_internal(self) -> dict[str, None]:
"""
return {
**{
hook: None
str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
if hook is not None
},
Expand Down Expand Up @@ -1559,7 +1571,7 @@ def _get_hooks(self) -> str | None:
"""
return

def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
Expand All @@ -1574,14 +1586,17 @@ def _get_all_hooks_internal(self) -> dict[str, None]:

return code

def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component and its children.
Returns:
The code that should appear just before returning the rendered component.
"""
code = {}

# Add the internal hooks for this component.
code.update(self._get_all_hooks_internal())

# Add the hook code for this component.
hooks = self._get_hooks()
if hooks is not None:
Expand Down Expand Up @@ -2277,15 +2292,15 @@ def _get_memoized_event_triggers(
)
return trigger_memo

def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
The code that should appear just before user-defined hooks.
"""
return {}

def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component.
Returns:
Expand Down
1 change: 1 addition & 0 deletions reflex/constants/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
class HookPosition(enum.Enum):
"""The position of the hook in the component."""

INTERNAL = "internal"
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"

Expand Down
2 changes: 1 addition & 1 deletion reflex/experimental/client_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create(
else:
default_var = default
setter_name = f"set{var_name.capitalize()}"
hooks = {
hooks: dict[str, VarData | None] = {
f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
}
imports = {
Expand Down
6 changes: 4 additions & 2 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
hooks: dict[str, VarData | None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
Expand Down Expand Up @@ -194,7 +194,9 @@ def merge(*all: VarData | None) -> VarData | None:
(var_data.state for var_data in all_var_datas if var_data.state), ""
)

hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks}
hooks: dict[str, VarData | None] = {
hook: None for var_data in all_var_datas for hook in var_data.hooks
}

_imports = imports.merge_imports(
*(var_data.imports for var_data in all_var_datas)
Expand Down

0 comments on commit 95222e4

Please sign in to comment.