diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d16c98fa..0eed0fc7 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -11,7 +11,7 @@ permissions: pages: write env: - MODO_VERSION: v0.11.10 + MODO_VERSION: v0.11.12 HUGO_VERSION: 0.148.2 jobs: diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..a5098da5 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,31 @@ +# Agent Guidelines for Larecs + +## Build/Test Commands +- Run all tests: `pixi run mojo test -I src/ test/` +- Run single test: `pixi run mojo test -I src/ test/.mojo` +- Format code: `pixi run mojo format src test benchmark` +- Generate docs: `pixi run mojo doc -o docs/src/larecs.json src/larecs` + +## Code Style +- Use snake_case for functions/variables, PascalCase for types +- Use `fn` for static functions, `def` for dynamic functions +- Prefer `var` for mutable, immutable by default +- Use `inout` parameters for mutation, not return modified values +- Include comprehensive type hints with Mojo's progressive typing +- Use `@parameter` for compile-time constants, `@always_inline` for critical paths +- Leverage SIMD types `SIMD[type, width]` for vectorization +- Apply traits: `Copyable`, `Movable`, `Stringable` appropriately +- Use manual memory management with `UnsafePointer` when needed +- Include docstrings for public APIs +- Reference Mojo docs for LLMs: https://docs.modular.com/llms-mojo.txt + +## Error Handling & Safety +- Follow borrow checker principles for memory safety +- Prefer stack allocation and RAII patterns +- Use `debug_warn()` utility for debug messages + +## Performance Focus +- This is a performance-critical ECS library +- Memory layout and cache efficiency are crucial +- Always consider vectorization opportunities +- Update benchmarks when making performance changes \ No newline at end of file diff --git a/benchmark/bitmask_benchmark.mojo b/benchmark/bitmask_benchmark.mojo index d6383c6b..dd2ff0c9 100644 --- a/benchmark/bitmask_benchmark.mojo +++ b/benchmark/bitmask_benchmark.mojo @@ -43,7 +43,7 @@ fn benchmark_bitmask_flip_1_000_000(mut bencher: Bencher) capturing: @parameter fn bench_fn() capturing: for _ in range(1_000_000): - mask.flip(val) + mask.flip_mut(val) keep(mask._bytes) bencher.iter[bench_fn]() @@ -75,19 +75,6 @@ fn benchmark_bitmask_contains_any_1_000_000(mut bencher: Bencher) capturing: bencher.iter[bench_fn]() -fn benchmark_mask_filter_1_000_000(mut bencher: Bencher) capturing: - mask = BitMask(0, 1, 2).without() - bits = BitMask(0, 1, 2) - - @always_inline - @parameter - fn bench_fn() capturing: - for _ in range(1_000_000): - keep(mask.matches(bits)) - - bencher.iter[bench_fn]() - - fn benchmark_bitmask_eq_1_000_000(mut bencher: Bencher) capturing: mask1 = get_random_bitmask() mask2 = mask1 @@ -202,9 +189,6 @@ fn run_all_bitmask_benchmarks(mut bench: Bench) raises: bench.bench_function[benchmark_bitmask_eq_1_000_000]( BenchId("10^6 * bitmask_eq") ) - bench.bench_function[benchmark_mask_filter_1_000_000]( - BenchId("10^6 * mask_filter") - ) # bench.bench_function[benchmark_bitmask_get_indices_1_000_000]( # BenchId("10^6 * get_indices") # ) diff --git a/benchmark/custom_benchmark.mojo b/benchmark/custom_benchmark.mojo index a5da627b..8d6d88bd 100644 --- a/benchmark/custom_benchmark.mojo +++ b/benchmark/custom_benchmark.mojo @@ -3,16 +3,202 @@ from benchmark import ( BenchId, BenchConfig as BenchConfig_, Bench as Bench_, + Format, ) +from pathlib import Path from time import perf_counter_ns from collections import Dict +from larecs.bitmask import BitMask fn DefaultConfig() raises -> BenchConfig_: """Returns the default configuration for benchmarking.""" config = BenchConfig_(min_runtime_secs=2, max_batch_size=50) config.verbose_timing = True - return config + return config^ + + +struct ArgTypes: + alias path = "--path" + alias format = "--format" + + +struct FormatStrings: + alias csv = "csv" + alias table = "table" + alias tabular = "tabular" + + @staticmethod + fn contains(str: StringSlice[StaticConstantOrigin]) -> Bool: + return ( + str == FormatStrings.csv + or str == FormatStrings.table + or str == FormatStrings.tabular + ) + + +@fieldwise_init +struct Arg(Copyable, Movable): + var type: StringSlice[StaticConstantOrigin] + var value: List[StringSlice[StaticConstantOrigin]] + + fn __eq__(self: Self, other: Arg) -> Bool: + return self.type == other.type + + fn __eq__(self: Self, other: StringSlice[StaticConstantOrigin]) -> Bool: + return self.type == other + + +@register_passable("trivial") +struct ParserError(EqualityComparable): + var type: BitMask.IndexType + alias unknown_arg = ParserError(0) + alias format_missing = ParserError(1) + alias path_missing = ParserError(2) + + fn __init__(out self, type: BitMask.IndexType): + self.type = type + + fn __eq__(self: Self, other: ParserError) -> Bool: + return self.type == other.type + + +@fieldwise_init +struct ParserErrors(ImplicitlyCopyable, Movable): + var error_mask: BitMask + + fn __init__(out self): + self.error_mask = BitMask() + + fn has_errors(self: Self) -> Bool: + return not self.error_mask.is_zero() + + fn has_error(self: Self, error: ParserError) -> Bool: + return self.error_mask.get(error.type) + + fn add_error(mut self: Self, error: ParserError): + self.error_mask.set[True](error.type) + + fn clear_error(mut self: Self, error: ParserError): + self.error_mask.set[False](error.type) + + fn get_errors(self: Self) -> List[ParserError]: + errors = List[ParserError]() + for error_bit in self.error_mask.get_indices(): + errors.append(ParserError(error_bit)) + return errors^ + + +struct Parser: + var args: List[StringSlice[StaticConstantOrigin]] + var index: Int + var errors: ParserErrors + + fn __init__(out self, var args: List[StringSlice[StaticConstantOrigin]]): + self.args = args^ + self.index = 0 + self.errors = ParserErrors() + + fn has_next(self: Self) -> Bool: + return self.index < len(self.args) + + fn parse_next(mut self: Self) raises -> Arg: + type = self.args[self.index] + self.index += 1 + + value = List[StringSlice[StaticConstantOrigin]]() + if type == ArgTypes.path: + if not self.has_next(): + raise Error("Expected a value after --path") + + value.append(self.args[self.index]) + self.index += 1 + + if self.errors.has_error(ParserError.format_missing): + raise Error( + "--path specified without --format. Please provide --format" + " first." + ) + elif self.errors.has_error(ParserError.path_missing): + self.errors.clear_error(ParserError.path_missing) + else: + self.errors.add_error(ParserError.format_missing) + + elif type == ArgTypes.format: + if not self.has_next(): + raise Error("Expected a value after --format") + + format_str = self.args[self.index] + if not FormatStrings.contains(format_str): + raise Error("Unknown format: " + format_str) + + value.append(format_str) + self.index += 1 + + # but this arg should only appear if the --path arg is also given + if self.errors.has_error(ParserError.format_missing): + self.errors.clear_error(ParserError.format_missing) + elif self.errors.has_error(ParserError.path_missing): + raise Error( + "--format specified without --path. Please provide --path" + " first." + ) + else: + self.errors.add_error(ParserError.path_missing) + + else: + raise Error("Unknown argument: " + type) + + return Arg(type, value^) + + fn parse_all(mut self: Self) raises -> List[Arg]: + parsed_args = List[Arg]() + while self.has_next(): + parsed_args.append(self.parse_next()) + return parsed_args^ + + +fn config_from_args( + args: VariadicList[StringSlice[StaticConstantOrigin]], +) raises -> BenchConfig_: + """Parses command line arguments to create a BenchConfig. + + Currently supports: + --json : Outputs results in JSON format. + + Args: + args: The command line arguments. + + Returns: + A BenchConfig with the parsed settings. + """ + config = DefaultConfig() + + args_list = List[StringSlice[StaticConstantOrigin]](capacity=len(args)) + for arg in args: + args_list.append(arg) + + parser = Parser(args_list[1:]) + parsed_args = parser.parse_all() + + if parser.errors.has_errors(): + error_msgs = List[String]() + for error in parser.errors.get_errors(): + if error == ParserError.format_missing: + error_msgs.append("--format specified without --path") + elif error == ParserError.path_missing: + error_msgs.append("--path specified without --format") + elif error == ParserError.unknown_arg: + error_msgs.append("Unknown argument") + raise Error("Argument parsing errors:\n" + "\n".join(error_msgs^)) + + for arg in parsed_args: + if arg == ArgTypes.format: + config.out_file_format = Format(arg.value[0]) + elif arg == ArgTypes.path: + config.out_file = Path(arg.value[0]) + + return config^ fn DefaultBench() raises -> Bench_: diff --git a/benchmark/plots/pixi.lock b/benchmark/plots/pixi.lock index 79275ee0..59068727 100644 --- a/benchmark/plots/pixi.lock +++ b/benchmark/plots/pixi.lock @@ -94,9 +94,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.10.5-py312h7900ff3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.10.5-py312he3d6523_0.conda - - conda: https://repo.prefix.dev/max/noarch/mblack-25.5.0-release.conda - - conda: https://repo.prefix.dev/max/linux-64/mojo-25.5.0-release.conda - - conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-25.5.0-release.conda + - conda: https://repo.prefix.dev/max/noarch/mblack-25.6.0-release.conda + - conda: https://repo.prefix.dev/max/linux-64/mojo-0.25.6.0-release.conda + - conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-0.25.6.0-release.conda + - conda: https://repo.prefix.dev/max/noarch/mojo-python-0.25.6.0-release.conda - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda @@ -1182,9 +1183,9 @@ packages: license_family: PSF size: 8071030 timestamp: 1754005868258 -- conda: https://repo.prefix.dev/max/noarch/mblack-25.5.0-release.conda +- conda: https://repo.prefix.dev/max/noarch/mblack-25.6.0-release.conda noarch: python - sha256: 85c6dd4ec8d3b08d54ba2c36872ac7c5bff8048b1a169d980772a3cd9ba4f599 + sha256: 51ef6f2dd19de5154cde4181e99380277549c60ff56172caba866411c26c5ee4 depends: - python >=3.9 - click >=8.0.0 @@ -1196,23 +1197,33 @@ packages: - typing_extensions >=v4.12.2 - python license: MIT - size: 131408 - timestamp: 1754336918677 -- conda: https://repo.prefix.dev/max/linux-64/mojo-25.5.0-release.conda - sha256: 2754518a61f2fd63c9a93f692d7f50a4bf8885f368813ea29033efea35ebd448 + size: 131727 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/linux-64/mojo-0.25.6.0-release.conda + sha256: a0a487de8d470e85e925c9ad76b493b01958ccc7affe0b92250db50a1960709a depends: - python >=3.9 - - mojo-compiler ==25.5.0 release - - mblack ==25.5.0 release + - mojo-compiler ==0.25.6.0 release + - mblack ==25.6.0 release - jupyter_client >=8.6.2,<8.7 license: LicenseRef-Modular-Proprietary - size: 86888525 - timestamp: 1754336826465 -- conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-25.5.0-release.conda - sha256: 22536e6258ff5739b12f5f3ba96ad1b09af1828db024894831fff02713b20603 + size: 90583156 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-0.25.6.0-release.conda + sha256: 41edf48721d11e186a4e6fbde89ae51ccd0aaf33909e4eab9763e00158a73119 + depends: + - mojo-python ==0.25.6.0 release + license: LicenseRef-Modular-Proprietary + size: 85409127 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/noarch/mojo-python-0.25.6.0-release.conda + noarch: python + sha256: 0638dcd3e79c5cd0ff2b385afe590eb81a68093d58e7532c3a0e5158b8eeeadb + depends: + - python license: LicenseRef-Modular-Proprietary - size: 78323265 - timestamp: 1754336826465 + size: 17887 + timestamp: 1758417230559 - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda sha256: d09c47c2cf456de5c09fa66d2c3c5035aa1fa228a1983a433c47b876aa16ce90 md5: 37293a85a0f4f77bbd9cf7aaefc62609 diff --git a/benchmark/plots/pixi.toml b/benchmark/plots/pixi.toml index 10e79666..c3f03a79 100644 --- a/benchmark/plots/pixi.toml +++ b/benchmark/plots/pixi.toml @@ -9,7 +9,7 @@ version = "0.4.0" [tasks] [dependencies] -mojo = ">=25.5,<26" +mojo = "==0.25.6" python = "==3.12" matplotlib = ">=3.10.0,<4" pandas = ">=2.2" diff --git a/benchmark/plots/src/aos.mojo b/benchmark/plots/src/aos.mojo index 2931601f..dcc303d5 100644 --- a/benchmark/plots/src/aos.mojo +++ b/benchmark/plots/src/aos.mojo @@ -15,7 +15,7 @@ alias TARGET_ITERATIONS = 10**9 @fieldwise_init -struct BenchResult(Copyable, Movable): +struct BenchResult(ImplicitlyCopyable, Movable): var components: Int var entities: Int var nanos_ecs: Float64 @@ -23,7 +23,7 @@ struct BenchResult(Copyable, Movable): @fieldwise_init -struct BenchConfig[max_comp_exp: Int](Copyable, Movable): +struct BenchConfig[max_comp_exp: Int](ImplicitlyCopyable, Movable): var max_entity_exp: Int var target_iters: Int @@ -210,7 +210,7 @@ fn run_benchmarks(config: BenchConfig) raises -> List[BenchResult]: result = benchmark[compExp](rounds, entities) results.append(result) - return results + return results^ fn benchmark[ @@ -271,7 +271,7 @@ struct AosWorld[components_exp: Int](Copyable, Movable): @fieldwise_init -struct AosEntity[components_exp: Int](Copyable, Movable): +struct AosEntity[components_exp: Int](ImplicitlyCopyable, Movable): var comps: InlineArray[Position, 2**components_exp] fn __init__(out self): diff --git a/benchmark/resources_benchmark.mojo b/benchmark/resources_benchmark.mojo index 4bf6f4ae..fa14b946 100644 --- a/benchmark/resources_benchmark.mojo +++ b/benchmark/resources_benchmark.mojo @@ -20,7 +20,7 @@ fn benchmark_add_remove_resource_1_000(mut bencher: Bencher) raises capturing: test_resource = TestResource() for _ in range(1_000): try: - resources.add(test_resource) + resources.add(test_resource.copy()) resources.remove[TestResource]() except: pass diff --git a/benchmark/run_benchmarks.mojo b/benchmark/run_benchmarks.mojo index a4dee88f..d2507e59 100644 --- a/benchmark/run_benchmarks.mojo +++ b/benchmark/run_benchmarks.mojo @@ -3,11 +3,13 @@ import world_benchmark import component_benchmark import query_benchmark import resources_benchmark -from custom_benchmark import DefaultBench +from benchmark import Bench +from custom_benchmark import config_from_args +from sys import argv def main(): - bench = DefaultBench() + bench = Bench(config_from_args(argv())) world_benchmark.run_all_world_benchmarks(bench) query_benchmark.run_all_query_benchmarks(bench) bitmask_benchmark.run_all_bitmask_benchmarks(bench) diff --git a/benchmark/world_benchmark.mojo b/benchmark/world_benchmark.mojo index 31f51c87..03eeb339 100644 --- a/benchmark/world_benchmark.mojo +++ b/benchmark/world_benchmark.mojo @@ -463,7 +463,7 @@ fn benchmark_add_remove_1_comp_1_000_000( bencher.iter[bench_fn]() -fn prevent_inlining_add_remove_1_comp() raises: +fn prevent_inlining_add_remove_1_comp_1_000_000() raises: pos = Position(1.0, 2.0) vel = Velocity(0.1, 0.2) world = SmallWorld() @@ -472,6 +472,88 @@ fn prevent_inlining_add_remove_1_comp() raises: world.remove[Velocity](entity) +fn benchmark_add_remove_1_comp_batch_1_000_000( + mut bencher: Bencher, +) raises capturing: + pos = Position(1.0, 2.0) + comp = FlexibleComponent[1](1, 42.0) + + @always_inline + @parameter + fn bench_fn() capturing raises: + world = SmallWorld() + + # create 1_000_000 entities that initially do not have FlexibleComponent[1] + _ = world.add_entities(pos, count=1_000_000) + + _ = world.add( + world.query[Position]().without[FlexibleComponent[1]](), comp + ) + _ = world.remove[FlexibleComponent[1]]( + world.query[Position, FlexibleComponent[1]]() + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_add_remove_1_comp_batch_1_000_000() raises: + pos = Position(1.0, 2.0) + comp = FlexibleComponent[1](1, 42.0) + world = SmallWorld() + + # create 1_000_000 entities that initially do not have FlexibleComponent[1] + _ = world.add_entities(pos, count=1_000_000) + + _ = world.add(world.query[Position]().without[FlexibleComponent[1]](), comp) + _ = world.remove[FlexibleComponent[1]]( + world.query[Position, FlexibleComponent[1]]() + ) + + +fn benchmark_add_remove_1_comp_1_000_batch_1_000( + mut bencher: Bencher, +) raises capturing: + pos = Position(1.0, 2.0) + comp1 = FlexibleComponent[1](1.0, 42.0) + + @always_inline + @parameter + fn bench_fn() capturing raises: + world = SmallWorld() + + # create 1_000 entities that initially do not have FlexibleComponent[1] + _ = world.add_entities(pos, count=1_000) + # then 1_000 x add component and remove it afterwards + for _ in range(1000): + _ = world.add( + world.query[Position]().without[FlexibleComponent[1]](), + comp1, + ) + _ = world.remove[FlexibleComponent[1]]( + world.query[Position, FlexibleComponent[1]]() + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_add_remove_1_comp_1_000_batch_1_000() raises: + pos = Position(1.0, 2.0) + comp1 = FlexibleComponent[1](1.0, 42.0) + world = SmallWorld() + + # create 1_000 entities that initially do not have FlexibleComponent[1] + _ = world.add_entities(pos, count=1_000) + # then 1_000 x add component and remove it afterwards + for _ in range(1000): + _ = world.add( + world.query[Position]().without[FlexibleComponent[1]](), + comp1, + ) + _ = world.remove[FlexibleComponent[1]]( + world.query[Position, FlexibleComponent[1]]() + ) + + fn benchmark_add_remove_5_comp_1_000_000( mut bencher: Bencher, ) raises capturing: @@ -500,7 +582,7 @@ fn benchmark_add_remove_5_comp_1_000_000( bencher.iter[bench_fn]() -fn prevent_inlining_add_remove_5_comp() raises: +fn prevent_inlining_add_remove_5_comp_1_000_000() raises: c1 = FlexibleComponent[1](1.0, 2.0) c2 = FlexibleComponent[2](1.0, 2.0) c3 = FlexibleComponent[3](1.0, 2.0) @@ -520,11 +602,15 @@ fn prevent_inlining_add_remove_5_comp() raises: ](entity) -fn benchmark_batch_add_1_comp_1_000_000( +fn benchmark_add_remove_5_comp_batch_1_000_000( mut bencher: Bencher, ) raises capturing: pos = Position(1.0, 2.0) - comp = FlexibleComponent[1](1.0, 42.0) + comp1 = FlexibleComponent[1](1.0, 42.0) + comp2 = FlexibleComponent[2](2.0, 42.0) + comp3 = FlexibleComponent[3](3.0, 42.0) + comp4 = FlexibleComponent[4](4.0, 42.0) + comp5 = FlexibleComponent[5](5.0, 42.0) @always_inline @parameter @@ -535,30 +621,149 @@ fn benchmark_batch_add_1_comp_1_000_000( _ = world.add_entities(pos, count=1_000_000) _ = world.add( - world.query[Position]().without[FlexibleComponent[1]](), comp + world.query[Position]().without[ + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ](), + comp1, + comp2, + comp3, + comp4, + comp5, + ) + _ = world.remove[ + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]( + world.query[ + Position, + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]() ) bencher.iter[bench_fn]() -fn benchmark_batch_add_5_comp_1_000_000( - mut bencher: Bencher, -) raises capturing: +fn prevent_inlining_add_remove_5_comp_batch_1_000_000() raises: pos = Position(1.0, 2.0) comp1 = FlexibleComponent[1](1.0, 42.0) comp2 = FlexibleComponent[2](2.0, 42.0) comp3 = FlexibleComponent[3](3.0, 42.0) comp4 = FlexibleComponent[4](4.0, 42.0) comp5 = FlexibleComponent[5](5.0, 42.0) + world = SmallWorld() + + # create 1_000_000 entities that initially do not have FlexibleComponent[1] + _ = world.add_entities(pos, count=1_000_000) + + _ = world.add( + world.query[Position]().without[ + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ](), + comp1, + comp2, + comp3, + comp4, + comp5, + ) + _ = world.remove[ + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]( + world.query[ + Position, + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]() + ) + + +fn benchmark_add_remove_5_comp_1_000_batch_1_000( + mut bencher: Bencher, +) raises capturing: + pos = Position(1.0, 2.0) + comp1 = FlexibleComponent[1](1.0, 42.0) + comp2 = FlexibleComponent[2](1.0, 42.0) + comp3 = FlexibleComponent[3](1.0, 42.0) + comp4 = FlexibleComponent[4](1.0, 42.0) + comp5 = FlexibleComponent[5](1.0, 42.0) @always_inline @parameter fn bench_fn() capturing raises: world = SmallWorld() - # create 1_000_000 entities that initially do not have FlexibleComponent[1] - _ = world.add_entities(pos, count=1_000_000) + # create 1_000 entities that initially do not have FlexibleComponent[1...5] + _ = world.add_entities(pos, count=1_000) + # then 1_000 x add components and remove them afterwards + for _ in range(1000): + _ = world.add( + world.query[Position]().without[ + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ](), + comp1, + comp2, + comp3, + comp4, + comp5, + ) + _ = world.remove[ + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]( + world.query[ + Position, + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]() + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_add_remove_5_comp_1_000_batch_1_000() raises: + pos = Position(1.0, 2.0) + comp1 = FlexibleComponent[1](1.0, 42.0) + comp2 = FlexibleComponent[2](1.0, 42.0) + comp3 = FlexibleComponent[3](1.0, 42.0) + comp4 = FlexibleComponent[4](1.0, 42.0) + comp5 = FlexibleComponent[5](1.0, 42.0) + world = SmallWorld() + # create 1_000 entities that initially do not have FlexibleComponent[1...5] + _ = world.add_entities(pos, count=1_000) + # then 1_000 x add components and remove them afterwards + for _ in range(1000): _ = world.add( world.query[Position]().without[ FlexibleComponent[1], @@ -573,8 +778,22 @@ fn benchmark_batch_add_5_comp_1_000_000( comp4, comp5, ) - - bencher.iter[bench_fn]() + _ = world.remove[ + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]( + world.query[ + Position, + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + FlexibleComponent[5], + ]() + ) fn benchmark_replace_1_comp_1_000_000( @@ -599,6 +818,345 @@ fn benchmark_replace_1_comp_1_000_000( bencher.iter[bench_fn]() +fn benchmark_replace_1_comp_batch_1_000_000( + mut bencher: Bencher, +) raises capturing: + @always_inline + @parameter + fn bench_fn() capturing raises: + world = FullWorld() + _ = world.add_entities(FlexibleComponent[0](1.0, 2.0), count=1_000_000) + + _ = world.replace[FlexibleComponent[0]]().by( + world.query[FlexibleComponent[0]](), + FlexibleComponent[1](3.0, 4.0), + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_batch_replace() raises: + world = FullWorld() + _ = world.add_entities(FlexibleComponent[0](1.0, 2.0), count=1_000_000) + + _ = world.replace[FlexibleComponent[0]]().by( + world.query[FlexibleComponent[0]](), + FlexibleComponent[1](3.0, 4.0), + ) + + +fn benchmark_replace_1_comp_1_000_batch_1_000( + mut bencher: Bencher, +) raises capturing: + @always_inline + @parameter + fn bench_fn() capturing raises: + world = FullWorld() + _ = world.add_entities(FlexibleComponent[0](1.0, 2.0), count=1_000) + + for i in range(500): + _ = world.replace[FlexibleComponent[0]]().by( + world.query[FlexibleComponent[0]](), + FlexibleComponent[1](3.0, 4.0), + ) + _ = world.replace[FlexibleComponent[1]]().by( + world.query[FlexibleComponent[1]](), + FlexibleComponent[0](1.0, 2.0), + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_1_batch_1_000_replace() raises: + world = FullWorld() + _ = world.add_entities(FlexibleComponent[0](1.0, 2.0), count=1_000) + + for i in range(500): + _ = world.replace[FlexibleComponent[0]]().by( + world.query[FlexibleComponent[0]](), + FlexibleComponent[1](3.0, 4.0), + ) + _ = world.replace[FlexibleComponent[1]]().by( + world.query[FlexibleComponent[1]](), + FlexibleComponent[0](1.0, 2.0), + ) + + +fn benchmark_replace_5_comp_1_000_000( + mut bencher: Bencher, +) raises capturing: + @always_inline + @parameter + fn bench_fn() capturing raises: + for _ in range(50): + world = FullWorld() + entities = List[Entity]() + for _ in range(1000): + entities.append( + world.add_entity( + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 10.0), + ) + ) + + @parameter + for i in range(20): + alias base = i * 5 + for entity in entities: + world.replace[ + FlexibleComponent[base + 0], + FlexibleComponent[base + 1], + FlexibleComponent[base + 2], + FlexibleComponent[base + 3], + FlexibleComponent[base + 4], + ]().by( + entity, + FlexibleComponent[base + 5]( + i + 11.0, Float32(i + 12.0) + ), + FlexibleComponent[base + 6]( + i + 13.0, Float32(i + 14.0) + ), + FlexibleComponent[base + 7]( + i + 15.0, Float32(i + 16.0) + ), + FlexibleComponent[base + 8]( + i + 17.0, Float32(i + 18.0) + ), + FlexibleComponent[base + 9]( + i + 19.0, Float32(i + 20.0) + ), + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_5_replace() raises: + for _ in range(50): + world = FullWorld() + entities = List[Entity]() + for _ in range(1000): + entities.append( + world.add_entity( + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 10.0), + ) + ) + + @parameter + for i in range(20): + alias base = i * 5 + for entity in entities: + world.replace[ + FlexibleComponent[base + 0], + FlexibleComponent[base + 1], + FlexibleComponent[base + 2], + FlexibleComponent[base + 3], + FlexibleComponent[base + 4], + ]().by( + entity, + FlexibleComponent[base + 5](i + 11.0, Float32(i + 12.0)), + FlexibleComponent[base + 6](i + 13.0, Float32(i + 14.0)), + FlexibleComponent[base + 7](i + 15.0, Float32(i + 16.0)), + FlexibleComponent[base + 8](i + 17.0, Float32(i + 18.0)), + FlexibleComponent[base + 9](i + 19.0, Float32(i + 20.0)), + ) + + +fn benchmark_replace_5_comp_batch_1_000_000( + mut bencher: Bencher, +) raises capturing: + @always_inline + @parameter + fn bench_fn() capturing raises: + world = FullWorld() + _ = world.add_entities( + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 10.0), + count=1_000_000, + ) + + _ = world.replace[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ]().by( + world.query[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ](), + FlexibleComponent[5](11.0, 12.0), + FlexibleComponent[6](13.0, 14.0), + FlexibleComponent[7](15.0, 16.0), + FlexibleComponent[8](17.0, 18.0), + FlexibleComponent[9](19.0, 20.0), + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_5_batch_replace() raises: + world = FullWorld() + _ = world.add_entities( + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 10.0), + count=1_000_000, + ) + + _ = world.replace[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ]().by( + world.query[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ](), + FlexibleComponent[5](11.0, 12.0), + FlexibleComponent[6](13.0, 14.0), + FlexibleComponent[7](15.0, 16.0), + FlexibleComponent[8](17.0, 18.0), + FlexibleComponent[9](19.0, 20.0), + ) + + +fn benchmark_replace_5_comp_1_000_batch_1_000( + mut bencher: Bencher, +) raises capturing: + @always_inline + @parameter + fn bench_fn() capturing raises: + world = FullWorld() + _ = world.add_entities( + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 10.0), + count=1_000, + ) + + for _ in range(500): + _ = world.replace[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ]().by( + world.query[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ](), + FlexibleComponent[5](11.0, 12.0), + FlexibleComponent[6](13.0, 14.0), + FlexibleComponent[7](15.0, 16.0), + FlexibleComponent[8](17.0, 18.0), + FlexibleComponent[9](19.0, 20.0), + ) + _ = world.replace[ + FlexibleComponent[5], + FlexibleComponent[6], + FlexibleComponent[7], + FlexibleComponent[8], + FlexibleComponent[9], + ]().by( + world.query[ + FlexibleComponent[5], + FlexibleComponent[6], + FlexibleComponent[7], + FlexibleComponent[8], + FlexibleComponent[9], + ](), + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 0.0), + ) + + bencher.iter[bench_fn]() + + +fn prevent_inlining_5_batch_1_000_replace() raises: + world = FullWorld() + _ = world.add_entities( + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 10.0), + count=1_000, + ) + + for _ in range(500): + _ = world.replace[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ]().by( + world.query[ + FlexibleComponent[0], + FlexibleComponent[1], + FlexibleComponent[2], + FlexibleComponent[3], + FlexibleComponent[4], + ](), + FlexibleComponent[5](11.0, 12.0), + FlexibleComponent[6](13.0, 14.0), + FlexibleComponent[7](15.0, 16.0), + FlexibleComponent[8](17.0, 18.0), + FlexibleComponent[9](19.0, 20.0), + ) + _ = world.replace[ + FlexibleComponent[5], + FlexibleComponent[6], + FlexibleComponent[7], + FlexibleComponent[8], + FlexibleComponent[9], + ]().by( + world.query[ + FlexibleComponent[5], + FlexibleComponent[6], + FlexibleComponent[7], + FlexibleComponent[8], + FlexibleComponent[9], + ](), + FlexibleComponent[0](1.0, 2.0), + FlexibleComponent[1](3.0, 4.0), + FlexibleComponent[2](5.0, 6.0), + FlexibleComponent[3](7.0, 8.0), + FlexibleComponent[4](9.0, 0.0), + ) + + fn benchmark_replace_1_comp_1_000_000_extra( mut bencher: Bencher, ) raises capturing: @@ -679,34 +1237,66 @@ fn run_all_world_benchmarks(mut bench: Bench) raises: bench.bench_function[benchmark_is_alive_1_000_000]( BenchId("10^6 * is_alive") ) + bench.bench_function[benchmark_add_remove_1_comp_1_000_000]( BenchId("10^6 * add & remove 1 component") ) + bench.bench_function[benchmark_add_remove_1_comp_batch_1_000_000]( + BenchId("10^0 * add & remove 1 component 10^6 batch") + ) + bench.bench_function[benchmark_add_remove_1_comp_1_000_batch_1_000]( + BenchId("10^3 * add & remove 1 component 10^3 batch") + ) + bench.bench_function[benchmark_add_remove_5_comp_1_000_000]( BenchId("10^6 * add & remove 5 components") ) - bench.bench_function[benchmark_batch_add_1_comp_1_000_000]( - BenchId("10^6 * batch add 1 component") + bench.bench_function[benchmark_add_remove_5_comp_batch_1_000_000]( + BenchId("10^0 * add & remove 5 components 10^6 batch") ) - bench.bench_function[benchmark_batch_add_5_comp_1_000_000]( - BenchId("10^6 * batch add 5 component") + bench.bench_function[benchmark_add_remove_5_comp_1_000_batch_1_000]( + BenchId("10^3 * add & remove 5 components 10^3 batch") ) + bench.bench_function[benchmark_replace_1_comp_1_000_000]( BenchId("10^6 * replace 1 component") ) + bench.bench_function[benchmark_replace_1_comp_batch_1_000_000]( + BenchId("10^0 * replace 1 component 10^6 batch") + ) + bench.bench_function[benchmark_replace_1_comp_1_000_batch_1_000]( + BenchId("10^3 * replace 1 component 10^3 batch") + ) + bench.bench_function[benchmark_replace_5_comp_1_000_000]( + BenchId("10^6 * replace 5 components") + ) + bench.bench_function[benchmark_replace_5_comp_batch_1_000_000]( + BenchId("10^0 * replace 5 components 10^6 batch") + ) + bench.bench_function[benchmark_replace_5_comp_1_000_batch_1_000]( + BenchId("10^3 * replace 5 components 10^3 batch") + ) # Functions to prevent inlining prevent_inlining_add_remove_entity_1_comp() prevent_inlining_add_remove_entity_5_comp() - prevent_inlining_add_remove_1_comp() - prevent_inlining_add_remove_5_comp() + prevent_inlining_add_remove_1_comp_1_000_000() + prevent_inlining_add_remove_1_comp_batch_1_000_000() + prevent_inlining_add_remove_1_comp_1_000_batch_1_000() + prevent_inlining_add_remove_5_comp_1_000_000() + prevent_inlining_add_remove_5_comp_batch_1_000_000() + prevent_inlining_add_remove_5_comp_1_000_batch_1_000() prevent_inlining_add_entity_1_comp() prevent_inlining_add_entity_5_comp() prevent_inlining_get() prevent_inlining_set_1_comp() prevent_inlining_set_5_comp() prevent_inlining_replace() - prevent_inlining_add_remove_5_comp() + prevent_inlining_batch_replace() + prevent_inlining_1_batch_1_000_replace() + prevent_inlining_5_replace() + prevent_inlining_5_batch_replace() + prevent_inlining_5_batch_1_000_replace() def main(): diff --git a/changelog.md b/changelog.md index 25127cbe..62e71503 100644 --- a/changelog.md +++ b/changelog.md @@ -3,12 +3,54 @@ ## [Unreleased](https://github.com/samufi/larecs/compare/v0.4.0...main) ### Breaking changes -- ... +- Update the utilized Mojo version to 25.6 and adjust the code accordingly. +- Revisit which structs are only `Copyable` and which can be also `ImplicitlyCopyable` + #### Copyable + - _ArchetypeByListIterator + - _ArchetypeByMaskIterator + - Archetype + - BitMaskGraph + - BitPool + - _EntityIterator + - EntityPool + - LockMask -> LockManager + + - Resources + - StaticOptional + - StaticVariant + - UnsafeBox + - World + + #### ImplicitlyCopyable + - BitMask + - Component + - Entity + - EntityIndex + - LockedContext + - Node + - Query + - QueryInfo + +- Rename `_ArchetypeIterator` to `_ArchetypeByMaskIterator` +- Rename `LockMask` to `LockManager` +- Remove all `hint_trivial_type` and `run_destructors` parameters from containers that leaked them from their underlying + List attribute +- Remove `BitMask.without` method +- Add `BitMask.set` overloads that work with multiple component IDs at once ### Other changes -- ... +- Implement batch component addition as overload of `world.add` +- Implement batch component removal as overload of `world.remove` +- Implement batch component replacing as overload of `world.Replacer.by` +- Remove `unsafe_take` from the `_utils` module +- Add `StaticVariant` +- Add `_ArchetypeByListIterator` to iterate over a given list of archetypes +- Optimize `archetype.reserve` to reduce frequent reallocations +- Add function `_utils.next_pow2` to calculate next power of 2 fast +- Add helper `QueryInfo.matches` to encapsulate query matching logic +- Add bit-wise operations for `BitMask` -## [Unreleased](https://github.com/samufi/larecs/compare/v0.3.0...v0.4.0) +## [v0.4.0 (2025-08-06)](https://github.com/samufi/larecs/compare/v0.3.0...v0.4.0) ### Breaking changes - Update the utilized Mojo version to 25.5 and adjust the code accordingly. diff --git a/docs/src/guide/changing_entities.md b/docs/src/guide/changing_entities.md index c48e4aeb..e7900255 100644 --- a/docs/src/guide/changing_entities.md +++ b/docs/src/guide/changing_entities.md @@ -173,10 +173,51 @@ for entity in entities: world.add(entity, Velocity(1.0, 0.5)) # Individual operations ``` -> [!Note] -> Currently, only batch adding of components is supported. -> Batch removal and replacement operations are planned for future releases. -> See the [roadmap](../../../README.md#next-steps) for more information. +#### Batch removing components + +You can remove components from multiple entities that match a query using the +{{< api World.remove remove >}} method with a query: + +```mojo {doctest="guide_change_entities"} +# Add 10 entities with Position and Velocity components +_ = world.add_entities(Position(0, 0), Velocity(1.0, 1.0), count=10) + +# Remove Velocity component from all entities that have both Position and Velocity +world.remove[Velocity]( + world.query[Position, Velocity]() +) + +# You can also remove multiple components at once from multiple entities +world.remove[Position, Velocity]( + world.query[Position, Velocity]() +) +``` + +The query must ensure that all matching entities have the components you want to remove, +otherwise an error will be raised. + +#### Batch replacing components + +You can replace components on multiple entities that match a query using the +{{< api World.replace replace >}} method in combination with {{< api Replacer.by by >}}: + +```mojo {doctest="guide_change_entities"} +# Add 10 entities with Position components +_ = world.add_entities(Position(0, 0), count=10) + +# Replace Position with Velocity for all entities that have Position +world.replace[Position]().by( + world.query[Position](), + Velocity(2.0, 2.0) +) + +# You can also replace multiple components with multiple other components +world.replace[Position, Velocity]().by( + world.query[Position, Velocity](), + Direction(0.0, 5.0), + Acceleration(0.2, 0.4) +) +``` > [!Tip] > Batch operations are significantly more efficient than individual operations diff --git a/examples/satellites/components.mojo b/examples/satellites/components.mojo index 2fa9d952..20d14a67 100644 --- a/examples/satellites/components.mojo +++ b/examples/satellites/components.mojo @@ -1,12 +1,12 @@ @fieldwise_init @register_passable("trivial") -struct Position(Copyable, Movable): +struct Position(ImplicitlyCopyable, Movable): var x: Float64 var y: Float64 @fieldwise_init @register_passable("trivial") -struct Velocity(Copyable, Movable): +struct Velocity(ImplicitlyCopyable, Movable): var x: Float64 var y: Float64 diff --git a/examples/satellites/parameters.mojo b/examples/satellites/parameters.mojo index 423d26a0..eb5f28c0 100644 --- a/examples/satellites/parameters.mojo +++ b/examples/satellites/parameters.mojo @@ -1,5 +1,5 @@ @fieldwise_init -struct Parameters(Copyable & Movable): +struct Parameters(ImplicitlyCopyable & Movable): var dt: Float64 var mass: Float64 diff --git a/examples/satellites/pixi.lock b/examples/satellites/pixi.lock index a0877e29..74134866 100644 --- a/examples/satellites/pixi.lock +++ b/examples/satellites/pixi.lock @@ -94,9 +94,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.10.5-py312h7900ff3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.10.5-py312he3d6523_0.conda - - conda: https://repo.prefix.dev/max/noarch/mblack-25.5.0-release.conda - - conda: https://repo.prefix.dev/max/linux-64/mojo-25.5.0-release.conda - - conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-25.5.0-release.conda + - conda: https://repo.prefix.dev/max/noarch/mblack-25.6.0-release.conda + - conda: https://repo.prefix.dev/max/linux-64/mojo-0.25.6.0-release.conda + - conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-0.25.6.0-release.conda + - conda: https://repo.prefix.dev/max/noarch/mojo-python-0.25.6.0-release.conda - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda @@ -1179,9 +1180,9 @@ packages: license_family: PSF size: 8071030 timestamp: 1754005868258 -- conda: https://repo.prefix.dev/max/noarch/mblack-25.5.0-release.conda +- conda: https://repo.prefix.dev/max/noarch/mblack-25.6.0-release.conda noarch: python - sha256: 85c6dd4ec8d3b08d54ba2c36872ac7c5bff8048b1a169d980772a3cd9ba4f599 + sha256: 51ef6f2dd19de5154cde4181e99380277549c60ff56172caba866411c26c5ee4 depends: - python >=3.9 - click >=8.0.0 @@ -1193,23 +1194,33 @@ packages: - typing_extensions >=v4.12.2 - python license: MIT - size: 131408 - timestamp: 1754336918677 -- conda: https://repo.prefix.dev/max/linux-64/mojo-25.5.0-release.conda - sha256: 2754518a61f2fd63c9a93f692d7f50a4bf8885f368813ea29033efea35ebd448 + size: 131727 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/linux-64/mojo-0.25.6.0-release.conda + sha256: a0a487de8d470e85e925c9ad76b493b01958ccc7affe0b92250db50a1960709a depends: - python >=3.9 - - mojo-compiler ==25.5.0 release - - mblack ==25.5.0 release + - mojo-compiler ==0.25.6.0 release + - mblack ==25.6.0 release - jupyter_client >=8.6.2,<8.7 license: LicenseRef-Modular-Proprietary - size: 86888525 - timestamp: 1754336826465 -- conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-25.5.0-release.conda - sha256: 22536e6258ff5739b12f5f3ba96ad1b09af1828db024894831fff02713b20603 + size: 90583156 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-0.25.6.0-release.conda + sha256: 41edf48721d11e186a4e6fbde89ae51ccd0aaf33909e4eab9763e00158a73119 + depends: + - mojo-python ==0.25.6.0 release + license: LicenseRef-Modular-Proprietary + size: 85409127 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/noarch/mojo-python-0.25.6.0-release.conda + noarch: python + sha256: 0638dcd3e79c5cd0ff2b385afe590eb81a68093d58e7532c3a0e5158b8eeeadb + depends: + - python license: LicenseRef-Modular-Proprietary - size: 78323265 - timestamp: 1754336826465 + size: 17887 + timestamp: 1758417230559 - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda sha256: d09c47c2cf456de5c09fa66d2c3c5035aa1fa228a1983a433c47b876aa16ce90 md5: 37293a85a0f4f77bbd9cf7aaefc62609 diff --git a/examples/satellites/pixi.toml b/examples/satellites/pixi.toml index a17f5c6d..4b9f6ceb 100644 --- a/examples/satellites/pixi.toml +++ b/examples/satellites/pixi.toml @@ -9,7 +9,7 @@ version = "0.4.0" [tasks] [dependencies] -mojo = ">=25.5,<26" +mojo = "==0.25.6" python = "==3.12" matplotlib = ">=3.10.0,<4" numpy = ">=1.26.4,<2" diff --git a/pixi.lock b/pixi.lock index cff3073b..fa2a7bd5 100644 --- a/pixi.lock +++ b/pixi.lock @@ -31,9 +31,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-15.1.0-h4852527_4.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - - conda: https://repo.prefix.dev/max/noarch/mblack-25.5.0-release.conda - - conda: https://repo.prefix.dev/max/linux-64/mojo-25.5.0-release.conda - - conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-25.5.0-release.conda + - conda: https://repo.prefix.dev/max/noarch/mblack-25.6.0-release.conda + - conda: https://repo.prefix.dev/max/linux-64/mojo-0.25.6.0-release.conda + - conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-0.25.6.0-release.conda + - conda: https://repo.prefix.dev/max/noarch/mojo-python-0.25.6.0-release.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.5.1-h7b32b05_0.conda @@ -317,9 +318,9 @@ packages: license_family: Other size: 60963 timestamp: 1727963148474 -- conda: https://repo.prefix.dev/max/noarch/mblack-25.5.0-release.conda +- conda: https://repo.prefix.dev/max/noarch/mblack-25.6.0-release.conda noarch: python - sha256: 85c6dd4ec8d3b08d54ba2c36872ac7c5bff8048b1a169d980772a3cd9ba4f599 + sha256: 51ef6f2dd19de5154cde4181e99380277549c60ff56172caba866411c26c5ee4 depends: - python >=3.9 - click >=8.0.0 @@ -331,23 +332,33 @@ packages: - typing_extensions >=v4.12.2 - python license: MIT - size: 131408 - timestamp: 1754336918677 -- conda: https://repo.prefix.dev/max/linux-64/mojo-25.5.0-release.conda - sha256: 2754518a61f2fd63c9a93f692d7f50a4bf8885f368813ea29033efea35ebd448 + size: 131727 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/linux-64/mojo-0.25.6.0-release.conda + sha256: a0a487de8d470e85e925c9ad76b493b01958ccc7affe0b92250db50a1960709a depends: - python >=3.9 - - mojo-compiler ==25.5.0 release - - mblack ==25.5.0 release + - mojo-compiler ==0.25.6.0 release + - mblack ==25.6.0 release - jupyter_client >=8.6.2,<8.7 license: LicenseRef-Modular-Proprietary - size: 86888525 - timestamp: 1754336826465 -- conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-25.5.0-release.conda - sha256: 22536e6258ff5739b12f5f3ba96ad1b09af1828db024894831fff02713b20603 + size: 90583156 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/linux-64/mojo-compiler-0.25.6.0-release.conda + sha256: 41edf48721d11e186a4e6fbde89ae51ccd0aaf33909e4eab9763e00158a73119 + depends: + - mojo-python ==0.25.6.0 release + license: LicenseRef-Modular-Proprietary + size: 85409127 + timestamp: 1758417230559 +- conda: https://repo.prefix.dev/max/noarch/mojo-python-0.25.6.0-release.conda + noarch: python + sha256: 0638dcd3e79c5cd0ff2b385afe590eb81a68093d58e7532c3a0e5158b8eeeadb + depends: + - python license: LicenseRef-Modular-Proprietary - size: 78323265 - timestamp: 1754336826465 + size: 17887 + timestamp: 1758417230559 - conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.1.0-pyha770c72_0.conda sha256: 6ed158e4e5dd8f6a10ad9e525631e35cee8557718f83de7a4e3966b1f772c4b1 md5: e9c622e0d00fa24a6292279af3ab6d06 diff --git a/pixi.toml b/pixi.toml index ec6039e8..04fa1d9f 100644 --- a/pixi.toml +++ b/pixi.toml @@ -9,4 +9,4 @@ version = "0.4.0" [tasks] [dependencies] -mojo = ">=25.5,<26" +mojo = "==0.25.6" diff --git a/src/larecs/_utils.mojo b/src/larecs/_utils.mojo index 337a5a24..f21b2c77 100644 --- a/src/larecs/_utils.mojo +++ b/src/larecs/_utils.mojo @@ -1,28 +1,5 @@ from memory import UnsafePointer -import math - - -@always_inline -fn unsafe_take[T: Movable](mut arg: T, out result: T): - """ - Takes a value and moves it to a different location in memory. - - [!Caution] - This function leaves the original value in an invalid state. - The value passed to this function should not be used after the call! - Also, you need to prevent calling the destructors of the elements. - You may use `__disable_del` for that. - - Parameters: - T: The type of the value to be moved. - - Args: - arg: The value to be moved. - - Returns: - Result: The moved value. - """ - result = UnsafePointer.take_pointee(UnsafePointer(to=arg)) +from math import log2 # Implementing a function generically over all integral types is not currently possible in Mojo. @@ -40,7 +17,7 @@ fn next_pow2(var value: UInt) -> UInt: Returns: The next power of two greater than or equal to the given value. """ - return UInt(next_pow2[DType.index](value)) + return UInt(next_pow2[DType.uint](value)) @always_inline @@ -100,7 +77,7 @@ fn next_pow2[dtype: DType](var value: Scalar[dtype]) -> Scalar[dtype]: return 1 @parameter - for i in range(Scalar[dtype](math.log2(Float32(dtype.bitwidth())))): + for i in range(Scalar[dtype](log2(Float32(dtype.bit_width())))): value |= value >> (2**i) return value + 1 diff --git a/src/larecs/archetype.mojo b/src/larecs/archetype.mojo index e7b91837..3ba233dd 100644 --- a/src/larecs/archetype.mojo +++ b/src/larecs/archetype.mojo @@ -88,7 +88,7 @@ struct EntityAccessor[ *Ts: ComponentType ]( mut self: EntityAccessor[archetype_mutability=True], - owned *components: *Ts, + var *components: *Ts, ) raises: """ Overwrites components for an [..entity.Entity], using the given content. @@ -106,9 +106,9 @@ struct EntityAccessor[ @parameter for i in range(components.__len__()): - self._archetype[].get_component[T = Ts[i.value]]( + self._archetype[].get_component[T = Ts[i]]( self._index_in_archetype - ) = components[i] + ) = components[i].copy() @always_inline fn has[T: ComponentType](self) -> Bool: @@ -127,7 +127,7 @@ struct EntityAccessor[ struct Archetype[ *Ts: ComponentType, component_manager: ComponentManager[*Ts], -](Boolable, Copyable, ExplicitlyCopyable, Movable, Sized): +](Boolable, Copyable, Movable, Sized): """ Archetype represents an ECS archetype. @@ -154,9 +154,7 @@ struct Archetype[ """The type of the entity accessors generated by the archetype.""" # Pointers to the component data. - var _data: InlineArray[ - UnsafePointer[UInt8], Self.max_size, run_destructors=True - ] + var _data: InlineArray[UnsafePointer[UInt8], Self.max_size] # Current number of entities. var _size: UInt @@ -168,7 +166,7 @@ struct Archetype[ var _component_count: UInt # Sizes of the component types by column - var _item_sizes: InlineArray[UInt32, Self.max_size, run_destructors=True] + var _item_sizes: InlineArray[UInt32, Self.max_size] # The indices of the present components var _ids: SIMD[Self.dType, Self.max_size] @@ -226,12 +224,10 @@ struct Archetype[ self._component_count = 0 self._capacity = capacity self._ids = SIMD[Self.dType, Self.max_size]() - self._data = InlineArray[ - UnsafePointer[UInt8], Self.max_size, run_destructors=True - ](fill=UnsafePointer[UInt8]()) - self._item_sizes = InlineArray[ - UInt32, Self.max_size, run_destructors=True - ](fill=0) + self._data = InlineArray[UnsafePointer[UInt8], Self.max_size]( + fill=UnsafePointer[UInt8]() + ) + self._item_sizes = InlineArray[UInt32, Self.max_size](fill=0) self._entities = List[Entity]() self._node_index = node_index @@ -346,14 +342,6 @@ struct Archetype[ ) self._component_count += 1 - fn copy(self, out other: Self): - """Returns a copy of the archetype. - - Returns: - A copy of the current archetype. - """ - other = self - fn __copyinit__(out self, existing: Self): """Copies the data from an existing archetype to a new one. @@ -366,15 +354,15 @@ struct Archetype[ self._capacity = existing._capacity self._component_count = existing._component_count self._item_sizes = existing._item_sizes - self._entities = existing._entities + self._entities = existing._entities.copy() self._ids = existing._ids self._node_index = existing._node_index self._mask = existing._mask # Copy the data - self._data = InlineArray[ - UnsafePointer[UInt8], Self.max_size, run_destructors=True - ](fill=UnsafePointer[UInt8]()) + self._data = InlineArray[UnsafePointer[UInt8], Self.max_size]( + fill=UnsafePointer[UInt8]() + ) for i in range(existing._component_count): id = existing._ids[i] @@ -386,7 +374,7 @@ struct Archetype[ size, ) - fn __del__(owned self): + fn __del__(deinit self): """Frees the memory of the archetype.""" for i in range(self._component_count): self._data[self._ids[i]].free() @@ -703,7 +691,7 @@ struct Archetype[ self._size -= 1 - var swapped = Int(idx) != self._size + var swapped = index(idx) != self._size if swapped: self._entities[idx] = self._entities.pop() diff --git a/src/larecs/bitmask.mojo b/src/larecs/bitmask.mojo index e194efde..4c9c613b 100644 --- a/src/larecs/bitmask.mojo +++ b/src/larecs/bitmask.mojo @@ -4,7 +4,7 @@ from hashlib import Hasher @fieldwise_init -struct _BitMaskIndexIter(Copyable, ExplicitlyCopyable, Movable, Sized): +struct _BitMaskIndexIter(ImplicitlyCopyable, Movable, Sized): """Iterator for BitMask indices.""" alias DataContainerType = SIMD[DType.uint8, BitMask.total_bytes] @@ -17,7 +17,7 @@ struct _BitMaskIndexIter(Copyable, ExplicitlyCopyable, Movable, Sized): var _index: UInt8 var _size: Int - fn __init__(out self, owned bytes: Self.DataContainerType): + fn __init__(out self, var bytes: Self.DataContainerType): self._bytes = bytes self._mask = Self.DataContainerType(1) self._compare = self._bytes & self._mask @@ -54,7 +54,9 @@ struct _BitMaskIndexIter(Copyable, ExplicitlyCopyable, Movable, Sized): @register_passable -struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): +struct BitMask( + EqualityComparable, ImplicitlyCopyable, KeyElement, Movable, Stringable +): """BitMask is a 256 bit bitmask.""" alias IndexDType = DType.uint8 @@ -69,6 +71,13 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): """Initializes the mask with the given bytes.""" self._bytes = bytes + @always_inline + fn __init__(out self, *bits: Self.IndexType): + """Initializes the mask with the bits at the given indices set to True. + """ + self = Self(bits) + + @implicit @always_inline fn __init__(out self, bits: VariadicList[BitMask.IndexType]): """Initializes the mask with the bits at the given indices set to True. @@ -77,6 +86,7 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): for bit in bits: self.set[True](bit) + @implicit @always_inline fn __init__[ size: Int @@ -89,12 +99,6 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): for i in range(size): self.set[True](bits[i]) - @always_inline - fn __init__(out self, *bits: Self.IndexType): - """Initializes the mask with the bits at the given indices set to True. - """ - self = Self(bits) - @always_inline fn __hash__[H: Hasher](self, mut hasher: H): """Hashes the mask.""" @@ -103,38 +107,13 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): @always_inline fn __eq__(self, other: Self) -> Bool: """Compares two masks for equality.""" - return (self._bytes == other._bytes).reduce_and() + return self._bytes == other._bytes @always_inline fn __ne__(self, other: Self) -> Bool: """Compares two masks for inequality.""" return not self.__eq__(other) - @always_inline - fn matches(self, bits: Self) -> Bool: - """Matches the mask as filter against another mask.""" - return bits.contains(self) - - @always_inline - fn without(self, *comps: Self.IndexType) -> MaskFilter: - """Creates a [..filter.MaskFilter] which filters for including the mask's components - and excludes the components given as arguments. - """ - return MaskFilter( - include=self, - exclude=BitMask(comps), - ) - - @always_inline - fn exclusive(self) -> MaskFilter: - """Creates a [..filter.MaskFilter] which filters for exactly the mask's components. - matches only entities that have exactly the given components, and no other. - """ - return MaskFilter( - include=self, - exclude=self.invert(), - ) - @always_inline fn get(self, bit: Self.IndexType) -> Bool: """Reports whether the bit at the given index is set. @@ -167,17 +146,111 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): self._bytes[index(idx)] &= ~(1 << offset) @always_inline - fn flip(mut self, bit: Self.IndexType): + fn set(self, *comps: Self.IndexType, value: Bool) -> Self: + """Returns a BitMask where the bits given as indices are set/unset according to the given value. + + Arguments: + comps: The InlineArray containing the bits to be set or unset. + value: If True, the bits in `comps` will be set in the resulting BitMask; if False, they will be unset. + + Returns: + A new BitMask with the specified bits set or unset. + """ + return self.set(BitMask(comps), value) + + @always_inline + fn set[value: Bool](self, *comps: Self.IndexType) -> Self: + """Returns a BitMask where the bits given as indices are set/unset according to the given value. + + Parameters: + value: If True, the bits in `comps` will be set in the resulting BitMask; if False, they will be unset. + + Arguments: + comps: The InlineArray containing the bits to be set or unset. + + Returns: + A new BitMask with the specified bits set or unset. + """ + return self.set[value](BitMask(comps)) + + @always_inline + fn set(self, comps: InlineArray[Self.IndexType], value: Bool) -> Self: + """Returns a BitMask where the bits set in the given InlineArray are set/unset according to the given value. + + Arguments: + comps: The InlineArray containing the bits to be set or unset. + value: If True, the bits in `comps` will be set in the resulting BitMask; if False, they will be unset. + + Returns: + A new BitMask with the specified bits set or unset. + """ + return self.set(BitMask(comps), value) + + @always_inline + fn set[value: Bool](self, comps: InlineArray[Self.IndexType]) -> Self: + """Returns a BitMask where the bits set in the given InlineArray are set/unset according to the given value. + + Parameters: + value: If True, the bits in `comps` will be set in the resulting BitMask; if False, they will be unset. + + Arguments: + comps: The InlineArray containing the bits to be set or unset. + + Returns: + A new BitMask with the specified bits set or unset. + """ + return self.set[value](BitMask(comps)) + + @always_inline + fn set(self, other: BitMask, value: Bool) -> Self: + """Returns a BitMask where the bits set in the other BitMask are set/unset according to the given value. + + Arguments: + other: The BitMask containing the bits to be set or unset. + value: If True, the bits in 'other' will be set in the resulting BitMask; if False, they will be unset. + + Returns: + A new BitMask with the specified bits set or unset. + """ + if value: + return self.set[True](other) + else: + return self.set[False](other) + + @always_inline + fn set[value: Bool](self, other: BitMask) -> Self: + """Returns a BitMask where the bits set in the other BitMask are set/unset according to the given value. + + Parameters: + value: If True, the bits in 'other' will be set in the resulting BitMask; if False, they will be unset. + + Arguments: + other: The BitMask containing the bits to be set or unset. + + Returns: + A new BitMask with the specified bits set or unset. + """ + + @parameter + if value: + return self | other + else: + return self & ~other + + @always_inline + fn flip(self, bit: Self.IndexType) -> Self: + """Flips the state of bit at the given index.""" + copy = self.copy() + copy.flip_mut(bit) + return copy + + @always_inline + fn flip_mut(mut self, bit: Self.IndexType): """Flips the state of bit at the given index.""" var idx: Self.IndexType = bit >> 3 # equivalent to bit // 8 var offset: Self.IndexType = bit & 7 # equivalent to bit - (8 * idx) self._bytes[index(idx)] ^= 1 << offset - @always_inline - fn invert(self) -> BitMask: - """Returns the inversion of this mask.""" - return BitMask(bytes=bit_not(self._bytes)) - @always_inline fn is_zero(self) -> Bool: """Returns whether no bits are set in the mask.""" @@ -191,12 +264,12 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): @always_inline fn contains(self, other: Self) -> Bool: """Reports if the other mask is a subset of this mask.""" - return ((self._bytes & other._bytes) == other._bytes).reduce_and() + return (self._bytes & other._bytes) == other._bytes @always_inline fn contains_any(self, other: Self) -> Bool: """Reports if any bit of the other mask is in this mask.""" - return ((self._bytes & other._bytes) != 0).reduce_or() + return (self._bytes & other._bytes) != 0 @always_inline fn total_bits_set(self) -> Int: @@ -208,6 +281,11 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): """Returns the indices of the bits that are set.""" result = _BitMaskIndexIter(self._bytes) + @always_inline + fn __invert__(self) -> BitMask: + """Returns the inversion of this mask.""" + return BitMask(bytes=~self._bytes) + @always_inline fn __or__(self, other: Self) -> BitMask: """Returns the bitwise OR of this mask and another mask. @@ -225,7 +303,102 @@ struct BitMask(Copyable, EqualityComparable, KeyElement, Movable, Stringable): This operation is highly optimized using SIMD instructions for fast parallel bitwise operations across all 256 bits simultaneously. """ - return BitMask(bytes=self._bytes | other._bytes) + copy = self.copy() + copy |= other + return copy + + @always_inline + fn __ior__(mut self, other: Self): + """Performs in-place bitwise OR with another BitMask. + + This method modifies the current BitMask by performing a bitwise OR operation + with another BitMask. Each bit in the resulting BitMask is set if it is set + in either the current BitMask or the provided BitMask. + + Args: + other: The BitMask to perform the bitwise OR operation with. + + **Performance Note:** + This operation is optimized using SIMD instructions for efficient parallel + processing of all 256 bits. + """ + self._bytes |= other._bytes + + @always_inline + fn __and__(self, other: Self) -> BitMask: + """Returns the bitwise AND of this mask and another mask. + + Performs element-wise bitwise AND operation between this mask and another mask, + creating a new mask where a bit is set if it's set in both operands. + + Args: + other: The other BitMask to AND with this mask. + + Returns: + A new BitMask containing the bitwise AND of both masks. + + **Performance Note:** + This operation is highly optimized using SIMD instructions for fast parallel + bitwise operations across all 256 bits simultaneously. + """ + copy = self.copy() + copy &= other + return copy + + @always_inline + fn __iand__(mut self, other: Self): + """Performs in-place bitwise AND with another BitMask. + + This method modifies the current BitMask by performing a bitwise AND operation + with another BitMask. Each bit in the resulting BitMask is set if it is set + in both the current BitMask and the provided BitMask. + + Args: + other: The BitMask to perform the bitwise AND operation with. + + **Performance Note:** + This operation is optimized using SIMD instructions for efficient parallel + processing of all 256 bits. + """ + self._bytes &= other._bytes + + @always_inline + fn __xor__(self, other: Self) -> BitMask: + """Returns the bitwise XOR of this mask and another mask. + + Performs element-wise bitwise XOR operation between this mask and another mask, + creating a new mask where a bit is set if it's set in one operand but not both. + + Args: + other: The other BitMask to XOR with this mask. + + Returns: + A new BitMask containing the bitwise XOR of both masks. + + **Performance Note:** + This operation is highly optimized using SIMD instructions for fast parallel + bitwise operations across all 256 bits simultaneously. + """ + copy = self.copy() + copy ^= other + return copy + + @always_inline + fn __ixor__(mut self, other: Self): + """Performs in-place bitwise XOR with another BitMask. + + This method modifies the current BitMask by performing a bitwise XOR operation + with another BitMask. Each bit in the resulting BitMask is set if it is set + in one operand but not both. + + Args: + other: The BitMask to perform the bitwise XOR operation with. + + **Performance Note:** + This operation is optimized using SIMD instructions for efficient parallel + processing of all 256 bits. + """ + self._bytes ^= other._bytes fn __str__(self) -> String: """Implements String(...).""" diff --git a/src/larecs/component.mojo b/src/larecs/component.mojo index 636b4ed0..1b7b4b3d 100644 --- a/src/larecs/component.mojo +++ b/src/larecs/component.mojo @@ -1,4 +1,4 @@ -from sys.info import sizeof +from sys import size_of from sys.intrinsics import _type_is_eq # from collections import Dict @@ -39,7 +39,7 @@ fn get_sizes[ @parameter for i in range(len(VariadicList(Ts))): - sizes[i] = sizeof[Ts[i]]() + sizes[i] = size_of[Ts[i]]() return sizes @@ -190,7 +190,7 @@ struct ComponentManager[ Returns: The size of the component type. """ - return sizeof[ComponentTypes[i]]() + return size_of[ComponentTypes[i]]() @staticmethod @always_inline diff --git a/src/larecs/entity.mojo b/src/larecs/entity.mojo index 8ae7afd9..1e64d513 100644 --- a/src/larecs/entity.mojo +++ b/src/larecs/entity.mojo @@ -15,7 +15,14 @@ from .archetype import Archetype, EntityAccessor @register_passable("trivial") -struct Entity(Boolable, EqualityComparable, Hashable, KeyElement, Stringable): +struct Entity( + Boolable, + EqualityComparable, + Hashable, + ImplicitlyCopyable, + KeyElement, + Stringable, +): """Entity identifier. Holds an entity ID and it's generation for recycling. @@ -108,7 +115,7 @@ struct Entity(Boolable, EqualityComparable, Hashable, KeyElement, Stringable): @fieldwise_init @register_passable("trivial") -struct EntityIndex: +struct EntityIndex(ImplicitlyCopyable, Movable): """Indicates where an entity is currently stored.""" # Entity's current index in the archetype diff --git a/src/larecs/filter.mojo b/src/larecs/filter.mojo index edd31b1f..c60c068e 100644 --- a/src/larecs/filter.mojo +++ b/src/larecs/filter.mojo @@ -1,4 +1,5 @@ from .bitmask import BitMask +from .static_optional import StaticOptional # Filter is the interface for logic filters. @@ -13,21 +14,75 @@ from .bitmask import BitMask # matches(bits BitMask): Bool -@fieldwise_init -struct MaskFilter: +struct MaskFilter[has_exclude: Bool = False](ImplicitlyCopyable, Movable): """MaskFilter is a filter for including and excluding certain components. - See [..bitmask.BitMask.without] and [..bitmask.BitMask.exclusive]. + This struct can be constructed implicitly from a [.Query] instance. + Therefore, [.Query] instances can be used instead of MaskFilter in function + arguments. + + Parameters: + has_exclude: If True, the filter excludes components given in the exclude mask. """ var include: BitMask # Components to include. - var exclude: BitMask # Components to exclude. + var exclude: StaticOptional[BitMask, has_exclude] # Components to exclude. + + fn __init__( + out self, + include: BitMask, + exclude: StaticOptional[BitMask, has_exclude] = None, + ): + self.include = include + self.exclude = exclude.copy() + + @implicit + fn __init__( + out self, + query: Query[has_exclude=has_exclude], + ): + """ + Takes the filter from an existing query. + + Args: + query: The query the filter information should be taken from. + """ + self = query._mask_filter + + fn __copyinit__(out self, other: MaskFilter[has_exclude]): + self.include = other.include + self.exclude = other.exclude.copy() fn matches(self, bits: BitMask) -> Bool: """Matches the filter against a mask.""" - return bits.contains(self.include) and ( - self.exclude.is_zero() or not bits.contains_any(self.exclude) - ) + + is_matching = bits.contains(self.include) + + @parameter + if has_exclude: + is_matching &= self.exclude[].is_zero() or not bits.contains_any( + self.exclude[] + ) + + return is_matching + + fn without(self, exclude: BitMask) -> MaskFilter[has_exclude=True]: + """Returns a new MaskFilter that excludes the given components in addition to the existing ones. + """ + + new_exclude = exclude.copy() + + @parameter + if has_exclude: + new_exclude |= self.exclude[] + + return MaskFilter[has_exclude=True](self.include, new_exclude) + + fn exclusive(self) -> MaskFilter[has_exclude=True]: + """Returns a new MaskFilter that includes only the currently set components and excludes all others. + """ + + return MaskFilter[has_exclude=True](self.include, ~self.include) # # RelationFilter is a [Filter] for a [Relation] target, in addition to components. diff --git a/src/larecs/graph.mojo b/src/larecs/graph.mojo index 1e1f460d..6fc70800 100644 --- a/src/larecs/graph.mojo +++ b/src/larecs/graph.mojo @@ -3,7 +3,7 @@ from .bitmask import BitMask @fieldwise_init -struct Node[DataType: KeyElement](Copyable, ExplicitlyCopyable, Movable): +struct Node[DataType: KeyElement](ImplicitlyCopyable, Movable): """Node in a BitMaskGraph. Parameters: @@ -24,7 +24,7 @@ struct Node[DataType: KeyElement](Copyable, ExplicitlyCopyable, Movable): # The mask of the node. var bit_mask: BitMask - fn __init__(out self, bit_mask: BitMask, owned value: DataType): + fn __init__(out self, bit_mask: BitMask, var value: DataType): """Initializes the node with the given mask and value. Args: @@ -35,14 +35,13 @@ struct Node[DataType: KeyElement](Copyable, ExplicitlyCopyable, Movable): self.neighbours = InlineArray[Int, 256](fill=Self.null_index) self.bit_mask = bit_mask - fn copy(self, out other: Self): - other = Self(self.bit_mask, self.value) + fn __copyinit__(out self, other: Self): + self = Self(other.bit_mask, other.value.copy()) struct BitMaskGraph[ DataType: KeyElement, //, null_value: DataType, - hint_trivial_type: Bool = False, ](Copyable, Movable): """A graph where each node is identified by a BitMask. @@ -56,30 +55,26 @@ struct BitMaskGraph[ Parameters: DataType: The type of the value stored in the nodes. null_value: The place holder stored in nodes by default. - hint_trivial_type: Hint to the compiler whether the type - is trivially copyable. """ # The node index indicating a non-established link. alias null_index = Node[DataType].null_index # The list of nodes in the graph. - var _nodes: List[Node[DataType], hint_trivial_type=hint_trivial_type] + var _nodes: List[Node[DataType]] # A mapping for random lookup of nodes by their mask. # Used for slow lookup of nodes. var _map: Dict[BitMask, Int] - fn __init__(out self, owned first_value: DataType = Self.null_value): + fn __init__(out self, var first_value: DataType = Self.null_value): """Initializes the graph. Args: first_value: The value stored in the first node, corresponding to an empty bitmask. """ - self._nodes = List[ - Node[DataType], hint_trivial_type=hint_trivial_type - ]() + self._nodes = List[Node[DataType]]() self._map = Dict[BitMask, Int]() _ = self.add_node(BitMask(), first_value^) @@ -87,7 +82,7 @@ struct BitMaskGraph[ fn add_node( mut self, node_mask: BitMask, - owned value: DataType = Self.null_value, + var value: DataType = Self.null_value, ) -> Int: """Adds a node to the graph. @@ -120,8 +115,7 @@ struct BitMaskGraph[ Returns: The index of the node to which the link is created. """ - new_mask = self._nodes[from_node_index].bit_mask - new_mask.flip(changed_bit) + new_mask = self._nodes[from_node_index].bit_mask.flip(changed_bit) optional_to_index = self._map.get(new_mask) if optional_to_index: to_node_index = optional_to_index.value() @@ -205,4 +199,4 @@ struct BitMaskGraph[ Args: node_index: The index of the node. """ - return self[node_index] != Self.null_value + return self[node_index] != materialize[Self.null_value]() diff --git a/src/larecs/lock.mojo b/src/larecs/lock.mojo index ef614b78..4421b1c0 100644 --- a/src/larecs/lock.mojo +++ b/src/larecs/lock.mojo @@ -3,7 +3,7 @@ from .pool import BitPool @fieldwise_init -struct LockMask(Copyable, ExplicitlyCopyable, Movable): +struct LockManager(Copyable, Movable): """ Manages locks by mask bits. @@ -71,26 +71,24 @@ struct LockMask(Copyable, ExplicitlyCopyable, Movable): @fieldwise_init -struct LockedContext[origin: MutableOrigin]( - Copyable, ExplicitlyCopyable, Movable -): +struct LockedContext[origin: MutableOrigin](ImplicitlyCopyable, Movable): """ A context manager for locking and unlocking the world. Parameters: - origin: The origin of the LockMask to handle. + origin: The origin of the LockManager to handle. """ - var _locks: Pointer[LockMask, origin] + var _locks: Pointer[LockManager, origin] var _lock: UInt8 @always_inline - fn __init__(out self, locks: Pointer[LockMask, origin]): + fn __init__(out self, locks: Pointer[LockManager, origin]): """ Initializes the LockedContext. Args: - locks: The LockMask to handle. + locks: The LockManager to handle. """ self._locks = locks self._lock = 0 diff --git a/src/larecs/pool.mojo b/src/larecs/pool.mojo index e7bcb1b4..5f4bbcb5 100644 --- a/src/larecs/pool.mojo +++ b/src/larecs/pool.mojo @@ -165,14 +165,14 @@ struct IntPool: Implements https:#skypjack.github.io/2019-05-06-ecs-baf-part-3/ """ - var _pool: List[Int, True] + var _pool: List[Int] var _next: Int var _available: UInt32 @always_inline fn __init__(out self): """Creates a new, initialized entity pool.""" - self._pool = List[Int, True]() + self._pool = List[Int]() self._next = 0 self._available = 0 diff --git a/src/larecs/query.mojo b/src/larecs/query.mojo index 2c96df95..beb3bab5 100644 --- a/src/larecs/query.mojo +++ b/src/larecs/query.mojo @@ -3,8 +3,9 @@ from .bitmask import BitMask from .component import ComponentType, ComponentManager from .archetype import Archetype as _Archetype from .world import World -from .lock import LockMask +from .lock import LockManager from .debug_utils import debug_warn +from .filter import MaskFilter from .static_optional import StaticOptional from .static_variant import StaticVariant @@ -12,8 +13,8 @@ from .static_variant import StaticVariant struct Query[ world_origin: MutableOrigin, *ComponentTypes: ComponentType, - has_without_mask: Bool = False, -](Copyable, ExplicitlyCopyable, Movable, SizedRaising): + has_exclude: Bool = False, +](ImplicitlyCopyable, Movable, SizedRaising): """Query builder for entities with and without specific components. This type should not be used directly, but through the [..world.World.query] method: @@ -35,7 +36,7 @@ struct Query[ Parameters: world_origin: The origin of the world. ComponentTypes: The types of the components to include in the query. - has_without_mask: Whether the query has excluded components. + has_exclude: Whether the query has excluded components. """ alias World = World[*ComponentTypes] @@ -43,19 +44,18 @@ struct Query[ alias QueryWithWithout = Query[ world_origin, *ComponentTypes, - has_without_mask=True, + has_exclude=True, ] + alias MaskFilter = MaskFilter[has_exclude=has_exclude] var _world: Pointer[Self.World, world_origin] - var _mask: BitMask - var _without_mask: StaticOptional[BitMask, has_without_mask] + var _mask_filter: Self.MaskFilter @doc_private fn __init__( out self, world: Pointer[Self.World, world_origin], - owned mask: BitMask, - owned without_mask: StaticOptional[BitMask, has_without_mask] = None, + mask_filter: Self.MaskFilter, ): """ Creates a new query. @@ -64,12 +64,20 @@ struct Query[ Args: world: A pointer to the world. - mask: The mask of the components to iterate over. - without_mask: The mask for components to exclude. + mask_filter: The mask filter to use. """ self._world = world - self._mask = mask^ - self._without_mask = without_mask^ + self._mask_filter = mask_filter + + fn __copyinit__(out self, other: Self): + """ + Copy constructor. + + Args: + other: The query to copy. + """ + self._world = other._world + self._mask_filter = other._mask_filter fn __len__(self) raises -> Int: """ @@ -89,7 +97,7 @@ struct Query[ __origin_of(self._world[]._locks), arch_iter_variant_idx=_ArchetypeByMaskIteratorIdx, has_start_indices=False, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ], ) raises: """ @@ -101,14 +109,10 @@ struct Query[ Raises: Error: If the lock cannot be acquired (more than 256 locks exist). """ - iterator = self._world[]._get_entity_iterator( - self._mask, self._without_mask - ) + iterator = self._world[]._get_entity_iterator(self._mask_filter) @always_inline - fn without[ - *Ts: ComponentType - ](owned self, out query: Self.QueryWithWithout): + fn without[*Ts: ComponentType](var self, out query: Self.QueryWithWithout): """ Excludes the given components from the query. @@ -134,12 +138,13 @@ struct Query[ """ query = Self.QueryWithWithout( self._world, - self._mask, - BitMask(Self.World.component_manager.get_id_arr[*Ts]()), + self._mask_filter.without( + Self.World.component_manager.get_id_arr[*Ts]() + ), ) @always_inline - fn exclusive(owned self, out query: Self.QueryWithWithout): + fn exclusive(var self, out query: Self.QueryWithWithout): """ Makes the query only match entities with exactly the query's components. @@ -162,69 +167,17 @@ struct Query[ """ query = Self.QueryWithWithout( self._world, - self._mask, - self._mask.invert(), + self._mask_filter.exclusive(), ) -@fieldwise_init -struct QueryInfo[ - has_without_mask: Bool = False, -](Copyable, ExplicitlyCopyable, Movable): - """ - Class that holds the same information as a query but no reference to the world. - - This struct can be constructed implicitly from a [.Query] instance. - Therefore, [.Query] instances can be used instead of QueryInfo in function - arguments. - - Parameters: - has_without_mask: Whether the query has excluded components. - """ - - var mask: BitMask - var without_mask: StaticOptional[BitMask, has_without_mask] - - @implicit - fn __init__( - out self, - query: Query[has_without_mask=has_without_mask], - ): - """ - Takes the query info from an existing query. - - Args: - query: The query the information should be taken from. - """ - self.mask = query._mask - self.without_mask = query._without_mask - - fn matches(self, archetype_mask: BitMask) -> Bool: - """ - Checks whether the given archetype mask matches the query. - - Args: - archetype_mask: The mask of the archetype to check. - - Returns: - Whether the archetype matches the query. - """ - is_valid = archetype_mask.contains(self.mask) - - @parameter - if has_without_mask: - is_valid &= not archetype_mask.contains_any(self.without_mask[]) - - return is_valid - - struct _ArchetypeByMaskIterator[ archetype_mutability: Bool, //, archetype_origin: Origin[archetype_mutability], *ComponentTypes: ComponentType, component_manager: ComponentManager[*ComponentTypes], - has_without_mask: Bool = False, -](Boolable, Copyable, ExplicitlyCopyable, Iterator, Movable, Sized): + has_exclude: Bool = False, +](Boolable, Copyable, Iterator, Movable, Sized): """ Iterator over non-empty archetypes corresponding to given include and exclude masks. @@ -235,7 +188,7 @@ struct _ArchetypeByMaskIterator[ archetype_origin: The origin of the archetypes. ComponentTypes: The types of the components. component_manager: The component manager. - has_without_mask: Whether the iterator has excluded components. + has_exclude: Whether the iterator has excluded components. """ alias buffer_size = 8 @@ -243,11 +196,10 @@ struct _ArchetypeByMaskIterator[ *ComponentTypes, component_manager=component_manager ] alias Element = Pointer[Self.Archetype, archetype_origin] - alias QueryInfo = QueryInfo[has_without_mask=has_without_mask] + alias MaskFilter = MaskFilter[has_exclude=has_exclude] var _archetypes: Pointer[List[Self.Archetype], archetype_origin] var _archetype_index_buffer: SIMD[DType.int32, Self.buffer_size] - var _mask: BitMask - var _without_mask: StaticOptional[BitMask, has_without_mask] + var _mask_filter: Self.MaskFilter var _archetype_count: Int var _buffer_index: Int var _max_buffer_index: Int @@ -255,22 +207,19 @@ struct _ArchetypeByMaskIterator[ fn __init__( out self, archetypes: Pointer[List[Self.Archetype], archetype_origin], - owned mask: BitMask, - owned without_mask: StaticOptional[BitMask, has_without_mask] = None, + mask_filter: Self.MaskFilter, ): """ Creates an archetype by mask iterator. Args: archetypes: a pointer to the world's archetypes. - mask: The mask of the archetypes to iterate over. - without_mask: An optional mask for archetypes to exclude. + mask_filter: The mask filter to use. """ self._archetypes = archetypes self._archetype_count = len(self._archetypes[]) - self._mask = mask^ - self._without_mask = without_mask^ + self._mask_filter = mask_filter self._buffer_index = 0 self._max_buffer_index = Self.buffer_size @@ -289,8 +238,7 @@ struct _ArchetypeByMaskIterator[ out self, archetypes: Pointer[List[Self.Archetype], archetype_origin], archetype_index_buffer: SIMD[DType.int32, Self.buffer_size], - owned mask: BitMask, - owned without_mask: StaticOptional[BitMask, has_without_mask], + mask_filter: Self.MaskFilter, archetype_count: Int, buffer_index: Int, max_buffer_index: Int, @@ -301,16 +249,14 @@ struct _ArchetypeByMaskIterator[ Args: archetypes: A pointer to the world's archetypes. archetype_index_buffer: The buffer of valid archetypes indices. - mask: The mask of the archetypes to iterate over. - without_mask: An optional mask for archetypes to exclude. + mask_filter: The mask filter to use. archetype_count: The number of archetypes in the world. buffer_index: Current index in the archetype buffer. max_buffer_index: Maximal valid index in the archetype buffer. """ self._archetypes = archetypes self._archetype_index_buffer = archetype_index_buffer - self._mask = mask^ - self._without_mask = without_mask^ + self._mask_filter = mask_filter self._archetype_count = archetype_count self._buffer_index = buffer_index self._max_buffer_index = max_buffer_index @@ -322,17 +268,15 @@ struct _ArchetypeByMaskIterator[ Fills the _archetype_index_buffer with the archetypes' indices. """ - query_info = Self.QueryInfo( - mask=self._mask, - without_mask=self._without_mask, - ) buffer_index = 0 for i in range( self._archetype_index_buffer[self._buffer_index] + 1, self._archetype_count, ): - is_valid = self._archetypes[].unsafe_get(i) and query_info.matches( + is_valid = self._archetypes[].unsafe_get( + i + ) and self._mask_filter.matches( self._archetypes[].unsafe_get(i).get_mask() ) @@ -347,7 +291,7 @@ struct _ArchetypeByMaskIterator[ self._max_buffer_index = buffer_index - 1 @always_inline - fn __iter__(owned self, out iterator: Self): + fn __iter__(var self, out iterator: Self): """ Returns self as an iterator usable in for loops. @@ -387,17 +331,15 @@ struct _ArchetypeByMaskIterator[ size = Self.buffer_size - query_info = Self.QueryInfo( - mask=self._mask, - without_mask=self._without_mask, - ) # If there are more archetypes than the buffer size, we # need to iterate over the remaining archetypes. for i in range( self._archetype_index_buffer[Self.buffer_size - 1] + 1, len(self._archetypes[]), ): - is_valid = self._archetypes[].unsafe_get(i) and query_info.matches( + is_valid = self._archetypes[].unsafe_get( + i + ) and self._mask_filter.matches( self._archetypes[].unsafe_get(i).get_mask() ) @@ -431,7 +373,7 @@ struct _ArchetypeByListIterator[ archetype_origin: Origin[archetype_mutability], *ComponentTypes: ComponentType, component_manager: ComponentManager[*ComponentTypes], -](Boolable, Copyable, ExplicitlyCopyable, Iterator, Movable, Sized): +](Boolable, Copyable, Iterator, Movable, Sized): """ Iterator over non-empty archetypes corresponding to given list of Archetype IDs. @@ -450,13 +392,13 @@ struct _ArchetypeByListIterator[ ] alias Element = Pointer[Self.Archetype, archetype_origin] var _archetypes: Pointer[List[Self.Archetype], archetype_origin] - var _archetype_indices: List[Int, hint_trivial_type=True] + var _archetype_indices: List[Int] var _index: Int fn __init__( out self, archetypes: Pointer[List[Self.Archetype], archetype_origin], - archetype_indices: List[Int, hint_trivial_type=True], + var archetype_indices: List[Int], ): """ Creates an archetype by list iterator. @@ -467,11 +409,11 @@ struct _ArchetypeByListIterator[ """ self._archetypes = archetypes - self._archetype_indices = archetype_indices + self._archetype_indices = archetype_indices^ self._index = 0 @always_inline - fn __iter__(owned self, out iterator: Self): + fn __iter__(var self, out iterator: Self): """ Returns self as an iterator usable in for loops. @@ -529,14 +471,14 @@ alias _ArchetypeIterator[ *ComponentTypes: ComponentType, component_manager: ComponentManager[*ComponentTypes], arch_iter_variant_idx: Int, - has_without_mask: Bool = False, + has_exclude: Bool = False, ] = StaticVariant[ arch_iter_variant_idx, _ArchetypeByMaskIterator[ archetype_origin, *ComponentTypes, component_manager=component_manager, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ], _ArchetypeByListIterator[ archetype_origin, @@ -558,7 +500,7 @@ struct _EntityIterator[ component_manager: ComponentManager[*ComponentTypes], arch_iter_variant_idx: Int = _ArchetypeByMaskIteratorIdx, has_start_indices: Bool = False, - has_without_mask: Bool = False, + has_exclude: Bool = False, ](Boolable, Movable, Sized): """Iterator over all entities corresponding to a mask. @@ -567,36 +509,34 @@ struct _EntityIterator[ Parameters: archetype_mutability: Whether the reference to the archetypes is mutable. archetype_origin: The origin of the archetypes. - lock_origin: The origin of the LockMask. + lock_origin: The origin of the LockManager. ComponentTypes: The types of the components. component_manager: The component manager. arch_iter_variant_idx: The index of the variant that holds the archetype iterator. has_start_indices: Whether the iterator starts iterating the archetypes at given indices. - has_without_mask: Whether the iterator has excluded components. + has_exclude: Whether the iterator has excluded components. """ alias buffer_size = 8 alias Archetype = _Archetype[ *ComponentTypes, component_manager=component_manager ] - alias StartIndices = StaticOptional[ - List[UInt, hint_trivial_type=True], has_start_indices - ] + alias StartIndices = StaticOptional[List[UInt], has_start_indices] alias ArchetypeIterator = _ArchetypeIterator[ archetype_origin, *ComponentTypes, component_manager=component_manager, arch_iter_variant_idx=arch_iter_variant_idx, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ] alias ArchetypeByMaskIterator = _ArchetypeByMaskIterator[ archetype_origin, *ComponentTypes, component_manager=component_manager, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ] alias ArchetypeByListIterator = _ArchetypeByListIterator[ @@ -606,7 +546,7 @@ struct _EntityIterator[ ] var _current_archetype: Pointer[Self.Archetype, archetype_origin] - var _lock_ptr: Pointer[LockMask, lock_origin] + var _lock_ptr: Pointer[LockManager, lock_origin] var _lock: UInt8 var _entity_index: Int var _last_entity_index: Int @@ -617,9 +557,9 @@ struct _EntityIterator[ fn __init__( out self, - lock_ptr: Pointer[LockMask, lock_origin], - owned archetype_iterator: Self.ArchetypeIterator, - owned start_indices: Self.StartIndices = None, + lock_ptr: Pointer[LockManager, lock_origin], + var archetype_iterator: Self.ArchetypeIterator, + var start_indices: Self.StartIndices = None, ) raises: """ Creates an entity iterator with or without excluded components. @@ -688,7 +628,7 @@ struct _EntityIterator[ # first call to __next__ will increment it. self._entity_index -= 1 - fn __del__(owned self): + fn __del__(deinit self): """ Releases the lock. """ @@ -698,7 +638,7 @@ struct _EntityIterator[ debug_warn("Failed to unlock the lock. This should not happen.") @always_inline - fn __iter__(owned self, out iterator: Self): + fn __iter__(var self, out iterator: Self): """ Returns self as an iterator usable in for loops. diff --git a/src/larecs/resource.mojo b/src/larecs/resource.mojo index bf7275de..29d0f3a6 100644 --- a/src/larecs/resource.mojo +++ b/src/larecs/resource.mojo @@ -8,14 +8,13 @@ from .component import ( contains_type, ) from .unsafe_box import UnsafeBox -from ._utils import unsafe_take alias ResourceType = Copyable & Movable """The trait that resources must conform to.""" @fieldwise_init -struct Resources(ExplicitlyCopyable, Movable, Sized): +struct Resources(Copyable, Movable, Sized): """Manages resources.""" alias IdType = StringSlice[StaticConstantOrigin] @@ -39,15 +38,7 @@ struct Resources(ExplicitlyCopyable, Movable, Sized): """ return len(self._storage) - fn copy(self, out resources: Self): - """Creates a copy of the resources. - - Returns: - A copy of the resources. - """ - resources = Resources(self._storage.copy()) - - fn add[*Ts: ResourceType](mut self, owned *resources: *Ts) raises: + fn add[*Ts: ResourceType](mut self, var *resources: *Ts) raises: """Adds resources. Parameters: @@ -57,57 +48,42 @@ struct Resources(ExplicitlyCopyable, Movable, Sized): resources: The resources to add. Raises: - Error: If the resource already exists. + Error: If some resource already exists. """ - @parameter - for i in range(resources.__len__()): - self._add(get_type_name[Ts[i]](), unsafe_take(resources[i])) - __disable_del resources + conflicting_ids = List[StringSlice[StaticConstantOrigin]]() - @always_inline - fn _add[ - T: Copyable & Movable - ](mut self, id: Self.IdType, owned resource: Pointer[T]) raises: - """Adds a resource by ID. + @parameter + for idx in range(resources.__len__()): + alias id = get_type_name[Ts[idx]]() + if id in self._storage: + conflicting_ids.append(id) - Parameters: - T: The type of the resource to add. + if conflicting_ids: + raise Error("Duplicate resource: " + ", ".join(conflicting_ids)) - Args: - id: The ID of the resource to add. - resource: The resource to add. + @parameter + fn take_resource[idx: Int](var resource: Ts[idx]) -> None: + self._add(get_type_name[Ts[idx]](), resource^) - Raises: - Error: If the resource already exists. - """ - if id in self._storage: - raise Error("Resource already exists.") - self._storage[id] = UnsafeBox(resource[]) + resources^.consume_elements[take_resource]() @always_inline - fn _add[ - T: Copyable & Movable - ](mut self, id: Self.IdType, owned resource: T) raises: + fn _add[T: Copyable & Movable](mut self, id: Self.IdType, var resource: T): """Adds a resource by ID. Parameters: T: The type of the resource to add. Args: - id: The ID of the resource to add. + id: The ID of the resource to add. It has to be not used already. resource: The resource to add. - - Raises: - Error: If the resource already exists. """ - if id in self._storage: - raise Error("Resource already exists.") self._storage[id] = UnsafeBox(resource^) fn set[ *Ts: ResourceType, add_if_not_found: Bool = False - ](mut self: Resources, owned *resources: *Ts) raises: + ](mut self: Resources, var *resources: *Ts) raises: """Sets the values of resources. Parameters: @@ -122,17 +98,31 @@ struct Resources(ExplicitlyCopyable, Movable, Sized): """ @parameter - for i in range(resources.__len__()): + if not add_if_not_found: + conflicting_ids = List[StringSlice[StaticConstantOrigin]]() + + @parameter + for idx in range(resources.__len__()): + alias id = get_type_name[Ts[idx]]() + if id not in self._storage: + conflicting_ids.append(id) + + if len(conflicting_ids) > 0: + raise Error("Unknown resource: " + ", ".join(conflicting_ids)) + + @parameter + fn take_resource[idx: Int](var resource: Ts[idx]) -> None: self._set[add_if_not_found=add_if_not_found]( - get_type_name[Ts[i]](), - unsafe_take(resources[i]), + get_type_name[Ts[idx]](), + resource^, ) - __disable_del resources + + resources^.consume_elements[take_resource]() @always_inline fn _set[ - T: Copyable & Movable, add_if_not_found: Bool - ](mut self, id: Self.IdType, owned resource: T) raises: + T: Copyable & Movable, //, add_if_not_found: Bool + ](mut self, id: Self.IdType, var resource: T): """Sets the values of the resources Parameters: @@ -140,22 +130,17 @@ struct Resources(ExplicitlyCopyable, Movable, Sized): add_if_not_found: If true, adds resources that do not exist. Args: - id: The ID of the resource to set. + id: The ID of the resource to set. If add_if_not_found is false, the resource ID must be already known. resource: The resource to set. - - Raises: - Error: If one of the resources does not exist. """ try: - self._storage._find_ref(id).unsafe_get[T]() = resource^ + self._storage[id].unsafe_get[T]() = resource^ except: @parameter if add_if_not_found: self._add(id, resource^) - else: - raise Error("Resource " + String(id) + " not found.") fn remove[*Ts: ResourceType](mut self: Resources) raises: """Removes resources. @@ -190,7 +175,7 @@ struct Resources(ExplicitlyCopyable, Movable, Sized): fn get[ T: ResourceType ](mut self) raises -> ref [ - __origin_of(self._storage._find_ref("").unsafe_get[T]()) + __origin_of(self._storage[""].unsafe_get[T]()) ] T: """Gets a resource. @@ -200,7 +185,7 @@ struct Resources(ExplicitlyCopyable, Movable, Sized): Returns: A reference to the resource. """ - return self._storage._find_ref(get_type_name[T]()).unsafe_get[T]() + return self._storage[get_type_name[T]()].unsafe_get[T]() @always_inline fn has[T: ResourceType](mut self) -> Bool: diff --git a/src/larecs/scheduler.mojo b/src/larecs/scheduler.mojo index 79c74bbd..6060eae8 100644 --- a/src/larecs/scheduler.mojo +++ b/src/larecs/scheduler.mojo @@ -146,7 +146,7 @@ struct Scheduler[*ComponentTypes: ComponentType](Movable): ]() self.world = Self.World() - fn __init__(out self, owned world: Self.World): + fn __init__(out self, var world: Self.World): """ Initializes the scheduler with a given world. @@ -163,7 +163,7 @@ struct Scheduler[*ComponentTypes: ComponentType](Movable): ]() self.world = world^ - fn add_system[S: System](mut self, owned system: S): + fn add_system[S: System](mut self, var system: S): """Adds a system to the scheduler. Args: diff --git a/src/larecs/static_optional.mojo b/src/larecs/static_optional.mojo index 868e3059..66b27345 100644 --- a/src/larecs/static_optional.mojo +++ b/src/larecs/static_optional.mojo @@ -2,7 +2,7 @@ struct StaticOptional[ ElementType: Copyable & Movable, has_value: Bool = True, -](Boolable, Copyable, ExplicitlyCopyable, Movable): +](Boolable, Copyable, Movable): """An optional type that can potentially hold a value of ElementType. In contrast to the built-in optional, it is decided at @@ -15,7 +15,7 @@ struct StaticOptional[ """ # Fields - var _value: InlineArray[ElementType, Int(has_value), run_destructors=True] + var _value: InlineArray[ElementType, Int(has_value)] """The underlying storage for the optional.""" # ===------------------------------------------------------------------===# @@ -36,7 +36,7 @@ struct StaticOptional[ @always_inline @implicit - fn __init__(out self, owned value: Self.ElementType): + fn __init__(out self, var value: Self.ElementType): """Constructs an optional type holding the provided value. Args: diff --git a/src/larecs/static_variant.mojo b/src/larecs/static_variant.mojo index d89a84bd..61deeee1 100644 --- a/src/larecs/static_variant.mojo +++ b/src/larecs/static_variant.mojo @@ -3,7 +3,7 @@ from sys.intrinsics import _type_is_eq from .static_optional import StaticOptional -alias StaticVariantType = Movable +alias StaticVariantType = Movable & Copyable """ A trait that defines the requirements for types that can be used in StaticVariant. @@ -12,7 +12,9 @@ and ownership transfer semantics. """ -struct StaticVariant[variant_idx: Int, *Ts: StaticVariantType](Movable): +struct StaticVariant[variant_idx: Int, *Ts: StaticVariantType]( + Copyable, Movable +): """ A compile-time variant type that can hold exactly one of the provided types. @@ -93,7 +95,7 @@ struct StaticVariant[variant_idx: Int, *Ts: StaticVariantType](Movable): alias ElementType = Ts[variant_idx] var _data: Self.ElementType - fn __init__(out self, owned value: Self.ElementType) raises: + fn __init__(out self, var value: Self.ElementType) raises: """ Initializes the variant with a value of the specified type. @@ -183,28 +185,6 @@ struct StaticVariant[variant_idx: Int, *Ts: StaticVariantType](Movable): # BUG: Mojo crashes with these methods (see https://github.com/modular/modular/issues/5172). When fixed, we can use # these for better ergonomics when working with StaticVariant. # This may also be fixable when conditional conformance with `requires` is released. -# -# fn __copyinit__[ -# T: Copyable & StaticVariantType, // -# ](out self: Self[variant_idx, T], read other: Self[variant_idx, T]): -# """ -# Initializes the variant by copying the value from another variant. - -# Args: -# other: The variant to copy from. -# """ -# self._data = other._data - -# fn copy[ -# T: ExplicitlyCopyable & StaticVariantType, // -# ](read self: Self[variant_idx, T], out copy: Self[variant_idx, T]): -# """ -# Initializes the variant by copying the value from another variant. - -# Args: -# other: The variant to copy from. -# """ -# copy = Self(self._data.copy()) # fn __bool__[ # T: Boolable & StaticVariantType, // diff --git a/src/larecs/test_utils.mojo b/src/larecs/test_utils.mojo index b8b1e024..52b97a21 100644 --- a/src/larecs/test_utils.mojo +++ b/src/larecs/test_utils.mojo @@ -8,15 +8,27 @@ from .world import World from .resource import Resources +# TODO: Revisit the function parameters of `load` and `store` when crash report: https://github.com/modular/modular/issues/5361 is resolved. @always_inline fn load[ - dType: DType, //, simd_width: Int, stride: Int = 1 -](ref val: SIMD[dType, 1], out simd: SIMD[dType, simd_width]): + dType: DType, + is_mut: Bool, + origin: Origin[is_mut], + address_space: AddressSpace, //, + simd_width: Int, + stride: Int = 1, +]( + ref [origin, address_space]val: SIMD[dType, 1], + out simd: SIMD[dType, simd_width], +): """ Load multiple values from a SIMD. Parameters: dType: The data type of the SIMD. + is_mut: Whether the value is mutable. + origin: The origin of the value. + address_space: The address space of the value. simd_width: The number of values to load. stride: The stride between the values. @@ -28,13 +40,24 @@ fn load[ @always_inline fn store[ - dType: DType, //, simd_width: Int, stride: Int = 1 -](ref val: SIMD[dType, 1], simd: SIMD[dType, simd_width]): + dType: DType, + is_mut: Bool, + origin: Origin[is_mut], + address_space: AddressSpace, //, + simd_width: Int, + stride: Int = 1, +]( + ref [origin, address_space]val: SIMD[dType, 1], + simd: SIMD[dType, simd_width], +): """ Store the values of a SIMD into memory with a given start SIMD value. Parameters: dType: The data type of the SIMD. + is_mut: Whether the value is mutable. + origin: The origin of the value. + address_space: The address space of the value. simd_width: The number of values to load. stride: The stride between the values. @@ -101,30 +124,27 @@ fn assert_equal_lists[ assert_equal(a[i], b[i], msg) -alias ExplicitlyCopyableComponentType = ComponentType & ExplicitlyCopyable - - @fieldwise_init -struct Position(ExplicitlyCopyableComponentType): +struct Position(ComponentType & ImplicitlyCopyable): var x: Float64 var y: Float64 @fieldwise_init -struct Velocity(ExplicitlyCopyableComponentType): +struct Velocity(ComponentType & ImplicitlyCopyable): var dx: Float64 var dy: Float64 @fieldwise_init -struct LargerComponent(ExplicitlyCopyableComponentType): +struct LargerComponent(ComponentType & ImplicitlyCopyable): var x: Float64 var y: Float64 var z: Float64 @fieldwise_init -struct FlexibleComponent[i: Int](ExplicitlyCopyableComponentType): +struct FlexibleComponent[i: Int](ComponentType & ImplicitlyCopyable): var x: Float64 var y: Float32 @@ -412,7 +432,7 @@ struct MemTestStruct(Copyable, Movable): var move_counter: UnsafePointer[Int] var del_counter: UnsafePointer[Int] - fn __moveinit__(out self, owned other: Self): + fn __moveinit__(out self, deinit other: Self): self.move_counter = other.move_counter self.del_counter = other.del_counter self.copy_counter = other.copy_counter @@ -424,13 +444,13 @@ struct MemTestStruct(Copyable, Movable): self.copy_counter = other.copy_counter self.copy_counter[] += 1 - fn __del__(owned self): + fn __del__(deinit self): self.del_counter[] += 1 fn test_copy_move_del[ Container: Copyable & Movable, //, - container_factory: fn (owned val: MemTestStruct) -> Container, + container_factory: fn (var val: MemTestStruct) -> Container, ](*, init_moves: Int = 0, copy_moves: Int = 0, move_moves: Int = 0) raises: """Test the copy, move, and delete operations of a container. @@ -474,7 +494,7 @@ fn test_copy_move_del[ assert_equal(copy_counter, test_copy_counter) # Copy - container2 = container + container2 = container.copy() test_copy_counter += 1 test_move_counter += copy_moves assert_equal(del_counter, test_del_counter) diff --git a/src/larecs/unsafe_box.mojo b/src/larecs/unsafe_box.mojo index 963d3bb3..1652cdc5 100644 --- a/src/larecs/unsafe_box.mojo +++ b/src/larecs/unsafe_box.mojo @@ -1,5 +1,5 @@ from memory import UnsafePointer -from sys.info import sizeof +from sys import size_of fn _destructor[T: Copyable & Movable](box_storage: UnsafeBox.data_type): @@ -49,7 +49,7 @@ fn _copy_initializer[ """ @parameter - if sizeof[T]() == 0: + if size_of[T]() == 0: ptr = UnsafePointer[T]() else: ptr = UnsafePointer[T].alloc(1) @@ -115,7 +115,7 @@ struct UnsafeBox(Copyable, Movable): self._destructor = _dummy_destructor self._copy_initializer = _dummy_copy_initializer - fn __init__[T: Copyable & Movable](out self, owned data: T): + fn __init__[T: Copyable & Movable](out self, var data: T): """ Constructor for the UnsafeBox. @@ -127,7 +127,7 @@ struct UnsafeBox(Copyable, Movable): """ @parameter - if sizeof[T]() == 0: + if size_of[T]() == 0: ptr = UnsafePointer[T]() else: ptr = UnsafePointer[T].alloc(1) @@ -150,7 +150,7 @@ struct UnsafeBox(Copyable, Movable): self._copy_initializer = other._copy_initializer @always_inline - fn __del__(owned self): + fn __del__(deinit self): """ Destructor for the UnsafeBox. diff --git a/src/larecs/world.mojo b/src/larecs/world.mojo index f22c3915..e9c4aef9 100644 --- a/src/larecs/world.mojo +++ b/src/larecs/world.mojo @@ -1,6 +1,6 @@ -from memory import UnsafePointer, Span +from memory import UnsafePointer, Span, memcpy from algorithm import vectorize -from sys.info import sizeof +from sys import size_of from .pool import EntityPool from .entity import Entity, EntityIndex @@ -14,11 +14,11 @@ from .component import ( constrain_components_unique, ) from .bitmask import BitMask +from .filter import MaskFilter from .static_optional import StaticOptional from .static_variant import StaticVariant from .query import ( Query, - QueryInfo, _ArchetypeIterator, _EntityIterator, _ArchetypeByMaskIterator, @@ -26,7 +26,7 @@ from .query import ( _ArchetypeByMaskIteratorIdx, _ArchetypeByListIteratorIdx, ) -from .lock import LockMask, LockedContext +from .lock import LockManager, LockedContext from .resource import Resources @@ -101,10 +101,115 @@ struct Replacer[ self._remove_ids, ) + fn by[ + *AddTs: ComponentType, + has_exclude: Bool = False, + ]( + self, + filter: MaskFilter[has_exclude=has_exclude], + *components: *AddTs, + out iterator: World[*component_types].Iterator[ + __origin_of(self._world[]._archetypes), + __origin_of(self._world[]._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ], + ) raises: + """ + Removes and adds the components to a multiple [..entity.Entity] specified by a [..filter.MaskFilter]. + + Parameters: + AddTs: The types of the components to add. + has_exclude: Whether the filter has an exclude mask. + + Args: + filter: The filter to determine which entities to modify. + components: The components to add. + + Raises: + Error: when called with components that can't be added because they are already present. + Error: when called with components that can't be removed because they are not present. + Error: when called on a locked world. Do not use during [.World.query] iteration. + """ + return self._by( + components, + filter=filter, + ) + + fn by[ + *AddTs: ComponentType, + has_exclude: Bool = False, + ]( + self, + *components: *AddTs, + filter: MaskFilter[has_exclude=has_exclude], + out iterator: World[*component_types].Iterator[ + __origin_of(self._world[]._archetypes), + __origin_of(self._world[]._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ], + ) raises: + """ + Removes and adds the components to a multiple [..entity.Entity] specified by a [..filter.MaskFilter]. + + Parameters: + AddTs: The types of the components to add. + has_exclude: Whether the filter has an exclude mask. + + Args: + components: The components to add. + filter: The filter to determine which entities to modify. + + Raises: + Error: when called with components that can't be added because they are already present. + Error: when called with components that can't be removed because they are not present. + Error: when called on a locked world. Do not use during [.World.query] iteration. + """ + return self._by( + components, + filter=filter, + ) + + fn _by[ + *AddTs: ComponentType, + has_exclude: Bool = False, + ]( + self, + components: VariadicPack[_, _, ComponentType, *AddTs], + filter: MaskFilter[has_exclude=has_exclude], + out iterator: World[*component_types].Iterator[ + __origin_of(self._world[]._archetypes), + __origin_of(self._world[]._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ], + ) raises: + """ + Private helper to remove and add components to multiple [..entity.Entity] specified by a [..filter.MaskFilter]. + + Parameters: + AddTs: The types of the components to add. + has_exclude: Whether the filter has an exclude mask. + + Args: + components: The components to add. + filter: The filter to determine which entities to modify. -struct World[*component_types: ComponentType]( - ExplicitlyCopyable, Movable, Sized -): + Raises: + Error: when called with components that can't be added because they are already present. + Error: when called with components that can't be removed because they are not present. + Error: when called on a locked world. Do not use during [.World.query] iteration. + """ + + return self._world[]._batch_remove_and_add( + filter, + components, + self._remove_ids, + ) + + +struct World[*component_types: ComponentType](Copyable, Movable, Sized): """ World is the central type holding entity and component data, as well as resources. @@ -120,7 +225,7 @@ struct World[*component_types: ComponentType]( alias Query = Query[ _, *component_types, - has_without_mask=_, + has_exclude=_, ] alias Iterator[ @@ -130,7 +235,7 @@ struct World[*component_types: ComponentType]( *, arch_iter_variant_idx: Int = _ArchetypeByMaskIteratorIdx, has_start_indices: Bool = False, - has_without_mask: Bool = False, + has_exclude: Bool = False, ] = _EntityIterator[ archetype_origin, lock_origin, @@ -138,7 +243,7 @@ struct World[*component_types: ComponentType]( component_manager = Self.component_manager, arch_iter_variant_idx=arch_iter_variant_idx, has_start_indices=has_start_indices, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ] """ Primary entity iterator type alias for the World. @@ -153,7 +258,7 @@ struct World[*component_types: ComponentType]( Parameters: arch_iter_variant_idx: Selects iteration strategy (ByMask=0, ByList=1) has_start_indices: Enables iteration from specific entity ranges (batch ops) - has_without_mask: Includes exclusion filtering capabilities for complex queries + has_exclude: Includes exclusion filtering capabilities for complex queries **Performance Considerations:** The iterator variant significantly affects performance - choose ByMask for general @@ -163,12 +268,12 @@ struct World[*component_types: ComponentType]( alias ArchetypeByMaskIterator[ archetype_mutability: Bool, //, archetype_origin: Origin[archetype_mutability], - has_without_mask: Bool = False, + has_exclude: Bool = False, ] = _ArchetypeByMaskIterator[ archetype_origin, *component_types, component_manager = Self.component_manager, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ] """ Archetype iterator optimized for component mask-based queries. @@ -180,7 +285,7 @@ struct World[*component_types: ComponentType]( **Optimizations:** - Uses SIMD-optimized bitmask operations for fast archetype matching - Skips empty archetypes automatically to reduce iteration overhead - - Supports exclusion masks via `has_without_mask` for complex filtering + - Supports exclusion masks via `has_exclude` for complex filtering **Best Use Cases:** - Standard component-based entity queries (e.g., entities with Position + Velocity) @@ -219,13 +324,13 @@ struct World[*component_types: ComponentType]( archetype_mutability: Bool, //, archetype_origin: Origin[archetype_mutability], arch_iter_variant_idx: Int = _ArchetypeByMaskIteratorIdx, - has_without_mask: Bool = False, + has_exclude: Bool = False, ] = _ArchetypeIterator[ archetype_origin, *component_types, component_manager = Self.component_manager, arch_iter_variant_idx=arch_iter_variant_idx, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ] # _listener Listener # EntityEvent _listener. @@ -239,12 +344,12 @@ struct World[*component_types: ComponentType]( # _stats _stats.World # Cached world statistics. var _entity_pool: EntityPool # Pool for entities. var _entities: List[ - EntityIndex, hint_trivial_type=True + EntityIndex ] # Mapping from entities to archetype and index. var _archetype_map: BitMaskGraph[ - -1, hint_trivial_type=True + -1 ] # Mapping from component masks to archetypes. - var _locks: LockMask # World _locks. + var _locks: LockManager # World _locks. var _archetypes: List[ Self.Archetype @@ -256,13 +361,11 @@ struct World[*component_types: ComponentType]( """ Creates a new [.World]. """ - self._archetype_map = BitMaskGraph[-1, hint_trivial_type=True](0) + self._archetype_map = BitMaskGraph[-1](0) self._archetypes = List[Self.Archetype](Self.Archetype()) - self._entities = List[EntityIndex, hint_trivial_type=True]( - EntityIndex(0, 0) - ) + self._entities = List[EntityIndex](EntityIndex(0, 0)) self._entity_pool = EntityPool() - self._locks = LockMask() + self._locks = LockManager() self.resources = Resources() # TODO @@ -276,41 +379,6 @@ struct World[*component_types: ComponentType]( # var node = self.createArchetypeNode(Mask, ID, false) - @always_inline - fn __init__(out self, other: Self): - """ - Initializes a [.World] by copying another instance. - - Args: - other: The other instance to copy. - """ - self._archetype_map = other._archetype_map - self._archetypes = other._archetypes - self._entities = other._entities - self._entity_pool = other._entity_pool - self._locks = other._locks - self.resources = other.resources.copy() - - fn __moveinit__(out self, owned other: Self): - """ - Moves the contents of another [.World] into a new one. - - Args: - other: The instance to move. - """ - self._archetype_map = other._archetype_map^ - self._archetypes = other._archetypes^ - self._entities = other._entities^ - self._entity_pool = other._entity_pool^ - self._locks = other._locks^ - self.resources = other.resources^ - - fn copy(self, out other: Self): - """ - Copies the contents of another [.World] into a new one. - """ - other = Self(self) - fn __len__(self) -> Int: """ Returns the number of entities in the world. @@ -466,8 +534,8 @@ struct World[*component_types: ComponentType]( @parameter for i in range(size): archetype[].get_component[ - T = Ts[i.value], assert_has_component=False - ](index_in_archetype) = components[i] + T = Ts[i], assert_has_component=False + ](index_in_archetype) = components[i].copy() # TODO # if self._listener != nil: @@ -564,7 +632,7 @@ struct World[*component_types: ComponentType]( Span( UnsafePointer( to=archetype[].get_component[ - T = Ts[i.value], assert_has_component=False + T = Ts[i], assert_has_component=False ](first_index_in_archetype) ), count, @@ -580,7 +648,7 @@ struct World[*component_types: ComponentType]( Pointer(to=self._archetypes), [archetype_index] ), ), - StaticOptional(List[UInt, True](UInt(first_index_in_archetype))), + StaticOptional(List[UInt](UInt(first_index_in_archetype))), ) @always_inline @@ -672,7 +740,7 @@ struct World[*component_types: ComponentType]( swap_entity = old_archetype[].get_entity(idx.index) self._entities[swap_entity.get_id()].index = idx.index - fn remove_entities(mut self, query: QueryInfo) raises: + fn remove_entities(mut self, filter: MaskFilter) raises: """ Removes multiple entities based on the provided query, making them eligible for recycling. @@ -694,17 +762,16 @@ struct World[*component_types: ComponentType]( ``` Args: - query: The query to determine which entities to remove. Note, you can - either use [..query.Query] or [..query.QueryInfo]. + filter: The filter to determine which entities to remove. Note, you can + either use [..filter.MaskFilter] or [..query.Query] because [..filter.MaskFilter] + can be implicitly constructed from [..query.Query]. Raises: Error: If the world is locked. """ self._assert_unlocked() - for archetype in self._get_archetype_iterator( - query.mask, query.without_mask - ): + for archetype in self._get_archetype_iterator(filter): for entity in archetype[].get_entities(): self._entity_pool.recycle(entity) archetype[].clear() @@ -774,9 +841,7 @@ struct World[*component_types: ComponentType]( ).get_component[T=T](entity_index.index) @always_inline - fn set[ - T: ComponentType - ](mut self, entity: Entity, owned component: T) raises: + fn set[T: ComponentType](mut self, entity: Entity, var component: T) raises: """ Overwrites a component for an [..entity.Entity], using the given content. @@ -823,9 +888,9 @@ struct World[*component_types: ComponentType]( @parameter for i in range(components.__len__()): - archetype[].get_component[T = Ts[i.value]]( + archetype[].get_component[T = Ts[i]]( entity_index.index - ) = components[i] + ) = components[i].copy() fn add[ *Ts: ComponentType @@ -868,12 +933,12 @@ struct World[*component_types: ComponentType]( self._remove_and_add(entity, add_components) fn add[ - has_without_mask: Bool, //, + has_exclude: Bool, //, *Ts: ComponentType, ]( mut self, - query: QueryInfo[has_without_mask=has_without_mask], - owned *add_components: *Ts, + filter: MaskFilter[has_exclude=has_exclude], + var *add_components: *Ts, out iterator: Self.Iterator[ __origin_of(self._archetypes), __origin_of(self._locks), @@ -882,8 +947,8 @@ struct World[*component_types: ComponentType]( ], ) raises: """ - Adds components to multiple entities at once that are specified by a [..query.Query]. - The provided query must ensure that matching entities do not already have one or more of the + Adds components to multiple entities at once that are specified by a [..filter.MaskFilter]. + The provided filter must ensure that matching entities do not already have one or more of the components to add. **Example:** @@ -915,12 +980,12 @@ struct World[*component_types: ComponentType]( ``` Parameters: - has_without_mask: Whether the query has a without mask. + has_exclude: Whether the filter has an exclude mask. Ts: The types of the components to add. Args: - query: The query specifying which entities to modify. The query must explicitly exclude existing entities - that already have some of the components to add. + filter: The [..filter.MaskFilter] specifying which entities to modify. The filter must explicitly exclude + existing entities that already have some of the components to add. add_components: The components to add. Raises: @@ -929,158 +994,90 @@ struct World[*component_types: ComponentType]( components to add. """ - # Note: - # This operation can never map multiple archetypes onto one, due to the requirement that components to add - # must be excluded in the query. Therefore, we can apply the transformation to each matching archetype - # individually without checking for edge cases where multiple archetypes get merged into one. - # This also enables potential parallelization optimizations. - - self._assert_unlocked() - - alias component_ids = Self.component_manager.get_id_arr[*Ts]() - - # If query could match archetypes that already have at least one of the components, raise an error - # FIXME: When https://github.com/modular/modular/issues/5347 is fixed, we can use short-circuiting here. - - var strict_check_needed: Bool - - @parameter - if has_without_mask: - strict_check_needed = not query.without_mask[].contains( - BitMask(component_ids) - ) - else: - strict_check_needed = True - - if strict_check_needed: - for archetype in self._get_archetype_iterator( - query.mask, query.without_mask - ): - if archetype[] and archetype[].get_mask().contains_any( - BitMask(component_ids) - ): - raise Error( - "Query matches entities that already have at least" - " one of the components to add. Use" - " `Query.without[Component, ...]()` to exclude" - " those components." - ) - - alias _2kb_of_UInt_or_Int = (1024 * 2) // sizeof[UInt]() - arch_start_idcs = List[UInt, True]( - min(len(self._archetypes), _2kb_of_UInt_or_Int) - ) - changed_archetype_idcs = List[Int, True]( - min(len(self._archetypes), _2kb_of_UInt_or_Int) + return self._batch_remove_and_add( + filter, + add_components, ) - # Search for the archetype that matches the query mask - with self._locked(): - for old_archetype in self._get_archetype_iterator( - query.mask, query.without_mask - ): - # Two cases per matching archetype A: - # 1. If an archetype B with the new component combination exists, move entities from A to B - # and insert new component data for moved entities. - # 2. If an archetype with the new component combination does not exist yet, - # create new archetype B = A + component_ids and move entities and component data from A to B. - new_archetype_idx = self._get_archetype_index( - component_ids, old_archetype[].get_node_index() - ) - - # We need to update the pointer to the old archetype, because the `self._archetypes` list may have been - # resized during the call to `_get_archetype_index`. - old_archetype_idx = self._archetype_map[ - old_archetype[].get_node_index() - ] - old_archetype = Pointer( - to=self._archetypes.unsafe_get(index(old_archetype_idx)) - ) - - new_archetype = Pointer( - to=self._archetypes.unsafe_get(new_archetype_idx) - ) + fn remove[*Ts: ComponentType](mut self, entity: Entity) raises: + """ + Removes components from an [..entity.Entity]. - arch_start_idx = len(new_archetype[]) - new_archetype[].reserve(arch_start_idx + len(old_archetype[])) + Parameters: + Ts: The types of the components to remove. - # Save arch_start_idx for the iterator. - arch_start_idcs.append(arch_start_idx) - changed_archetype_idcs.append(new_archetype_idx) + Args: + entity: The entity to modify. - # Move component data from old archetype to new archetype. - for i in range(old_archetype[]._component_count): - id = old_archetype[]._ids[i] + Raises: + Error: when called for a removed (and potentially recycled) entity. + Error: when called with components that can't be removed because they are not present. + Error: when called on a locked world. Do not use during [.World.query] iteration. + """ + self._remove_and_add( + entity, remove_ids=Self.component_manager.get_id_arr[*Ts]() + ) - new_archetype[].unsafe_set( - arch_start_idx, - id, - old_archetype[]._data[id], - len(old_archetype[]), - ) + fn remove[ + *Ts: ComponentType, has_exclude: Bool = False + ]( + mut self, + filter: MaskFilter[has_exclude=has_exclude], + out iterator: Self.Iterator[ + __origin_of(self._archetypes), + __origin_of(self._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ], + ) raises: + """ + Removes components from multiple entities at once, specified by a [..filter.MaskFilter]. + The provided filter must ensure that matching entities have all of the components that should get removed. - # Move entities from old archetype to new archetype. - for idx in range(len(old_archetype[])): - new_idx = new_archetype[].add( - old_archetype[].get_entity(idx) - ) - entity = new_archetype[].get_entity(new_idx) - self._entities[entity.get_id()] = EntityIndex( - new_idx, - new_archetype_idx, - ) + Example: - # Set new component data - @parameter - for comp_idx in range(add_components.__len__()): - alias comp_id = component_ids[comp_idx] + ```mojo {doctest="remove_query_comps" global=true} + from larecs import World - new_archetype[].unsafe_set( - new_idx, - comp_id, - UnsafePointer(to=add_components[comp_idx]).bitcast[ - UInt8 - ](), - ) + @fieldwise_init + struct Position(Copyable, Movable): + var x: Float64 + var y: Float64 - old_archetype[].clear() + @fieldwise_init + struct Velocity(Copyable, Movable): + var x: Float64 + var y: Float64 - # Return iterator to iterate over the changed entities. - iterator = Self.Iterator[ - __origin_of(self._archetypes), - __origin_of(self._locks), - arch_iter_variant_idx=_ArchetypeByListIteratorIdx, - has_start_indices=True, - ]( - Pointer(to=self._locks), - Self.ArchetypeIterator[ - __origin_of(self._archetypes), - arch_iter_variant_idx=_ArchetypeByListIteratorIdx, - ]( - Self.ArchetypeByListIterator[__origin_of(self._archetypes)]( - Pointer(to=self._archetypes), changed_archetype_idcs - ), - ), - StaticOptional(arch_start_idcs), - ) + world = World[Position, Velocity]() + _ = world.add_entities(Position(0, 0), Velocity(1, 0), 100) - fn remove[*Ts: ComponentType](mut self, entity: Entity) raises: - """ - Removes components from an [..entity.Entity]. + for entity in world.remove[Velocity]( + world.query[Position, Velocity]() + ): + position = entity.get[Position]() + ``` Parameters: Ts: The types of the components to remove. + has_exclude: Whether the query has a without mask. Args: - entity: The entity to modify. + filter: The [..filter.MaskFilter] to determine which entities to modify. Raises: - Error: when called for a removed (and potentially recycled) entity. - Error: when called with components that can't be removed because they are not present. Error: when called on a locked world. Do not use during [.World.query] iteration. + Error: when called with a filter that could match entities that don't have all of the components to remove. """ - self._remove_and_add( - entity, remove_ids=Self.component_manager.get_id_arr[*Ts]() + + # Note: + # This operation can never map multiple archetypes onto one, due to the requirement that components to remove + # must be already present on archetypes matched by the filter. Therefore, we can apply the transformation to + # each matching archetype individually, without checking for edge cases where multiple archetypes get merged + # into one. This also enables potential parallelization optimizations. + + return self._batch_remove_and_add( + filter, remove_ids=Self.component_manager.get_id_arr[*Ts]() ) @always_inline @@ -1159,99 +1156,87 @@ struct World[*component_types: ComponentType]( See documentation of overloaded function for details. """ alias add_size = add_components.__len__() - alias ComponentIdsType = StaticOptional[ - __type_of(Self.component_manager.get_id_arr[*Ts]()), add_size - ] + alias add_ids = Self.component_manager.get_id_arr[*Ts]() self._assert_unlocked() self._assert_alive(entity) - @parameter - if not add_size and not rem_size: - return - # Reserve space for the possibility that a new archetype gets created # This ensure that no further allocations can happen in this function and # therefore all pointers to the current memory space stay valid! self._archetypes.reserve(len(self._archetypes) + 1) - idx = self._entities[entity.get_id()] + entity_index = self._entities[entity.get_id()] - old_archetype_index = idx.archetype_index + old_archetype_idx = entity_index.archetype_index old_archetype = Pointer( - to=self._archetypes.unsafe_get(index(old_archetype_index)) + to=self._archetypes.unsafe_get(index(old_archetype_idx)) ) - - index_in_old_archetype = idx.index - - var component_ids: ComponentIdsType - - @parameter - if add_size: - component_ids = ComponentIdsType( - Self.component_manager.get_id_arr[*Ts]() - ) - else: - component_ids = None - - start_node_index = old_archetype[].get_node_index() - - var archetype_index: Int = -1 - compare_mask = old_archetype[].get_mask() - - alias add_error_msg = "Entity already has one of the components to add." - alias remove_error_msg = "Entity does not have one of the components to remove." + old_archetype_mask = old_archetype[].get_mask() @parameter if rem_size: @parameter - if add_size: - start_node_index = self._archetype_map.get_node_index( - remove_ids[], start_node_index - ) - if not compare_mask.contains( - self._archetype_map.get_node_mask(start_node_index) - ): - raise Error(remove_error_msg) - - compare_mask = self._archetype_map.get_node_mask( - start_node_index - ) - else: - archetype_index = self._get_archetype_index( - remove_ids[], start_node_index - ) - # No need for Pointer revalidation due to previous memory reservation! + if remove_some: + if not old_archetype_mask.contains(BitMask(remove_ids[])): + raise Error( + "Entity does not have one of the components to remove." + ) @parameter if add_size: - archetype_index = self._get_archetype_index( - component_ids[], start_node_index - ) - # No need for Pointer revalidation due to previous memory reservation! + compare_mask = old_archetype_mask - archetype = Pointer(to=self._archetypes.unsafe_get(archetype_index)) - index_in_archetype = archetype[].add(entity) + @parameter + if remove_some: + compare_mask = compare_mask.set(remove_ids[], False) + if compare_mask.contains(BitMask(add_ids)): + raise Error("Entity already has one of the components to add.") + + alias ComponentIdsType = InlineArray[Self.Id, add_size + rem_size] + var component_ids: ComponentIdsType @parameter - if add_size: - if not archetype[].get_mask().contains(compare_mask): - raise Error(add_error_msg) + if add_size and rem_size: + component_ids = ComponentIdsType(uninitialized=True) + memcpy( + component_ids.unsafe_ptr(), + remove_ids[].unsafe_ptr(), + rem_size, + ) + memcpy( + component_ids.unsafe_ptr() + rem_size * size_of[Self.Id](), + add_ids.unsafe_ptr(), + add_size, + ) + elif Bool(add_size) and not rem_size: + component_ids = rebind[ComponentIdsType](add_ids) + elif not add_size and Bool(rem_size): + component_ids = rebind[ComponentIdsType](remove_ids[]) else: - if not compare_mask.contains(archetype[].get_mask()): - raise Error(remove_error_msg) + return + + index_in_old_archetype = entity_index.index + new_archetype_idx = self._get_archetype_index( + component_ids, old_archetype[].get_node_index() + ) + new_archetype = Pointer( + to=self._archetypes.unsafe_get(new_archetype_idx) + ) + index_in_new_archetype = new_archetype[].add(entity) + # Move component data from old archetype to new archetype. for i in range(old_archetype[]._component_count): id = old_archetype[]._ids[i] @parameter if rem_size: - if not archetype[].has_component(id): + if not new_archetype[].has_component(id): continue - archetype[].unsafe_set( - index_in_archetype, + new_archetype[].unsafe_set( + index_in_new_archetype, id, old_archetype[]._get_component_ptr( index(index_in_old_archetype), id @@ -1260,19 +1245,305 @@ struct World[*component_types: ComponentType]( @parameter for i in range(add_size): - archetype[].unsafe_set( - index_in_archetype, - component_ids[][i], + new_archetype[].unsafe_set( + index_in_new_archetype, + add_ids[i], UnsafePointer(to=add_components[i]).bitcast[UInt8](), ) swapped = old_archetype[].remove(index_in_old_archetype) if swapped: - var swapEntity = old_archetype[].get_entity(idx.index) - self._entities[swapEntity.get_id()].index = idx.index + var swapEntity = old_archetype[].get_entity(entity_index.index) + self._entities[swapEntity.get_id()].index = entity_index.index self._entities[entity.get_id()] = EntityIndex( - index_in_archetype, archetype_index + index_in_new_archetype, new_archetype_idx + ) + + @always_inline + fn _batch_remove_and_add[ + *Ts: ComponentType, + rem_size: Int = 0, + remove_some: Bool = False, + has_exclude: Bool = False, + ]( + mut self, + filter: MaskFilter[has_exclude=has_exclude], + *add_components: *Ts, + remove_ids: StaticOptional[ + InlineArray[Self.Id, rem_size], remove_some + ] = None, + out iterator: Self.Iterator[ + __origin_of(self._archetypes), + __origin_of(self._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ], + ) raises: + """ + Adds and removes components to a multiple [..entity.Entity] specified by a [..filter.MaskFilter]. + + Parameters: + Ts: The types of the components to add. + rem_size: The number of components to remove. + remove_some: Whether to remove some components. + has_exclude: Whether the filter has an exclude mask. + + Args: + filter: The filter to determine which entities to modify. + add_components: The components to add. + remove_ids: The IDs of the components to remove. + + Returns: + An iterator over the modified entities. + + Raises: + Error: when called with nothing to do (i.e. no components to add or remove). + Error: when called with a filter that could match existing entities that already have at least one of the + components to add. + Error: when called with a filter that could match entities that don't have all of the components to remove. + Error: when called on a locked world. Do not use during [.World.query] iteration. + """ + return self._batch_remove_and_add(filter, add_components, remove_ids) + + @always_inline + fn _batch_remove_and_add[ + *Ts: ComponentType, + rem_size: Int = 0, + remove_some: Bool = False, + has_exclude: Bool = False, + ]( + mut self, + filter: MaskFilter[has_exclude=has_exclude], + add_components: VariadicPack[_, _, ComponentType, *Ts], + remove_ids: StaticOptional[ + InlineArray[Self.Id, rem_size], remove_some + ] = None, + out iterator: Self.Iterator[ + __origin_of(self._archetypes), + __origin_of(self._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ], + ) raises: + """ + Adds and removes components to a multiple [..entity.Entity] specified by a [..filter.MaskFilter]. + + Parameters: + Ts: The types of the components to add. + rem_size: The number of components to remove. + remove_some: Whether to remove some components. + has_exclude: Whether the filter has an exclude mask. + + Args: + filter: The filter to determine which entities to modify. + add_components: The components to add. + remove_ids: The IDs of the components to remove. + + Returns: + An iterator over the modified entities. + + Raises: + Error: when called with nothing to do (i.e. no components to add or remove). + Error: when called on a locked world. Do not use during [.World.query] iteration. + """ + alias add_size = add_components.__len__() + alias add_ids = Self.component_manager.get_id_arr[*Ts]() + + alias ComponentIdsType = InlineArray[Self.Id, add_size + rem_size] + + var component_ids: ComponentIdsType + + # Note: + # This operation can never map multiple archetypes onto one, due to the requirement that components to add + # must be excluded in the query. Therefore, we can apply the transformation to each matching archetype + # individually without checking for edge cases where multiple archetypes get merged into one. + # This also enables potential parallelization optimizations. + + @parameter + if add_size: + # If query could match archetypes that already have at least one of the components, raise an error + # FIXME: When https://github.com/modular/modular/issues/5347 is fixed, we can use short-circuiting here. + + var strict_check_needed: Bool + + @parameter + if has_exclude: + strict_check_needed = not filter.exclude[].contains( + BitMask(add_ids) + ) + else: + strict_check_needed = True + + if strict_check_needed: + for archetype in self._get_archetype_iterator(filter): + archetype_mask = archetype[].get_mask() + + @parameter + if remove_some: + archetype_mask = archetype_mask.set(remove_ids[], False) + + if archetype[] and archetype_mask.contains_any( + BitMask(add_ids) + ): + raise Error( + "Filter matches entities that already have at least" + " one of the components to add. Use" + " `Filter.without[Component, ...]()` to exclude" + " those components." + ) + + @parameter + if rem_size: + # If filter could match archetypes that don't have all of the components, raise an error + if not filter.include.contains(BitMask(remove_ids[])): + raise Error( + "Filter matches entities that don't have all of the" + " components to remove. Use `Filter(Component, ...)` to" + " include those components." + ) + + @parameter + if has_exclude: + if filter.exclude[].contains_any(BitMask(remove_ids[])): + raise Error( + "Filter excludes entities that have a component which" + " should be removed in the without mask. Remove all" + " components that get removed from" + " `Filter.without(...)`." + ) + + @parameter + if add_size and rem_size: + component_ids = ComponentIdsType(uninitialized=True) + memcpy( + component_ids.unsafe_ptr(), + remove_ids[].unsafe_ptr(), + rem_size, + ) + memcpy( + component_ids.unsafe_ptr() + rem_size * size_of[Self.Id](), + add_ids.unsafe_ptr(), + add_size, + ) + elif Bool(add_size) and not rem_size: + component_ids = rebind[ComponentIdsType](add_ids) + elif not add_size and Bool(rem_size): + component_ids = rebind[ComponentIdsType](remove_ids[]) + else: + return Self.Iterator[ + __origin_of(self._archetypes), + __origin_of(self._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ]( + Pointer(to=self._locks), + Self.ArchetypeIterator[ + __origin_of(self._archetypes), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + ]( + Self.ArchetypeByListIterator[__origin_of(self._archetypes)]( + Pointer(to=self._archetypes), List[Int]() + ), + ), + StaticOptional(List[UInt]()), + ) + + self._assert_unlocked() + + alias _2kb_of_UInt_or_Int = (1024 * 2) // size_of[UInt]() + arch_start_idcs = List[UInt]( + capacity=min(len(self._archetypes), _2kb_of_UInt_or_Int) + ) + changed_archetype_idcs = List[Int]( + capacity=min(len(self._archetypes), _2kb_of_UInt_or_Int) + ) + + # Search for the archetype that matches the query mask + with self._locked(): + for old_archetype in self._get_archetype_iterator(filter): + # Two cases per matching archetype A: + # 1. If an archetype B with the new component combination exists, move entities from A to B + # and insert new component data for moved entities. + # 2. If an archetype with the new component combination does not exist yet, + # create new archetype B = A.different_by(component_ids) and move entities and component data from A to B. + new_archetype_idx = self._get_archetype_index( + component_ids, old_archetype[].get_node_index() + ) + + # We need to update the pointer to the old archetype, because the `self._archetypes` list may have been + # resized during the call to `_get_archetype_index`. + old_archetype_idx = self._archetype_map[ + old_archetype[].get_node_index() + ] + old_archetype = Pointer( + to=self._archetypes.unsafe_get(index(old_archetype_idx)) + ) + + new_archetype = Pointer( + to=self._archetypes.unsafe_get(new_archetype_idx) + ) + + new_archetype[].reserve( + len(new_archetype[]) + len(old_archetype[]) + ) + + # Save arch_start_idx for the iterator. + arch_start_idx = len(new_archetype[]) + arch_start_idcs.append(arch_start_idx) + changed_archetype_idcs.append(new_archetype_idx) + + # Move component data from old archetype to new archetype. + for i in range(old_archetype[]._component_count): + id = old_archetype[]._ids[i] + + new_archetype[].unsafe_set( + arch_start_idx, + id, + old_archetype[]._data[id], + len(old_archetype[]), + ) + + # Move entities to the new archetype and update entity index mappings + for i in range(len(old_archetype[])): + entity = old_archetype[].get_entity(i) + new_index = new_archetype[].add(entity) + self._entities[entity.get_id()] = EntityIndex( + new_index, new_archetype_idx + ) + + # Set new component data + @parameter + for add_comp_idx in range(add_components.__len__()): + alias comp_id = add_ids[add_comp_idx] + + new_archetype[].unsafe_set( + new_index, + comp_id, + UnsafePointer( + to=add_components[add_comp_idx] + ).bitcast[UInt8](), + ) + + old_archetype[].clear() + + # Return iterator to iterate over the changed entities. + iterator = Self.Iterator[ + __origin_of(self._archetypes), + __origin_of(self._locks), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + has_start_indices=True, + ]( + Pointer(to=self._locks), + Self.ArchetypeIterator[ + __origin_of(self._archetypes), + arch_iter_variant_idx=_ArchetypeByListIteratorIdx, + ]( + Self.ArchetypeByListIterator[__origin_of(self._archetypes)]( + Pointer(to=self._archetypes), changed_archetype_idcs^ + ), + ), + StaticOptional(arch_start_idcs^), ) @always_inline @@ -1305,7 +1576,7 @@ struct World[*component_types: ComponentType]( operation: fn (accessor: MutableEntityAccessor) capturing -> None, *, unroll_factor: Int = 1, - ](mut self, query: QueryInfo) raises: + ](mut self, filter: MaskFilter) raises: """ Applies an operation to all entities with the given components. @@ -1315,7 +1586,7 @@ struct World[*component_types: ComponentType]( (see [vectorize doc](https://docs.modular.com/mojo/stdlib/algorithm/functional/vectorize)). Args: - query: The query to determine which entities to apply the operation to. + filter: The [..filter.MaskFilter] to determine which entities to apply the operation to. Raises: Error: If the world is locked. @@ -1326,7 +1597,7 @@ struct World[*component_types: ComponentType]( fn operation_wrapper[simd_width: Int](accessor: MutableEntityAccessor): operation(accessor) - self.apply[operation_wrapper, unroll_factor=unroll_factor](query) + self.apply[operation_wrapper, unroll_factor=unroll_factor](filter) fn apply[ operation: fn[simd_width: Int] ( @@ -1335,7 +1606,7 @@ struct World[*component_types: ComponentType]( *, simd_width: Int = 1, unroll_factor: Int = 1, - ](mut self, query: QueryInfo) raises: + ](mut self, filter: MaskFilter) raises: """ Applies an operation to all entities with the given components. @@ -1407,7 +1678,7 @@ struct World[*component_types: ComponentType]( (see [vectorize doc](https://docs.modular.com/mojo/stdlib/algorithm/functional/vectorize)). Args: - query: The query to determine which entities to apply the operation to. + filter: The [..filter.MaskFilter] to determine which entities to apply the operation to. Constraints: The simd_width must be a power of 2. @@ -1419,9 +1690,7 @@ struct World[*component_types: ComponentType]( with self._locked(): for archetype in _ArchetypeByMaskIterator( - Pointer(to=self._archetypes), - query.mask, - query.without_mask, + Pointer(to=self._archetypes), filter ): @always_inline @@ -1510,7 +1779,7 @@ struct World[*component_types: ComponentType]( *Ts: ComponentType ]( mut self, - out iterator: Self.Query[__origin_of(self), has_without_mask=False], + out iterator: Self.Query[__origin_of(self), has_exclude=False], ): """ Returns an [..query.Query] for all [..entity.Entity Entities] with the given components. @@ -1521,27 +1790,40 @@ struct World[*component_types: ComponentType]( Returns: A [..query.Query] for all entities with the given components. """ + + iterator = Self.Query(Pointer(to=self), self.filter[*Ts]()) + + @always_inline + fn filter[ + *Ts: ComponentType + ](mut self, out mask_filter: MaskFilter[has_exclude=False]): + """ + Returns a [..filter.MaskFilter] for all [..entity.Entity Entities] with the given components. + + Parameters: + Ts: The types of the components. + + Returns: + A [..filter.MaskFilter] for all entities with the given components. + """ alias size = VariadicPack[ True, MutableAnyOrigin, ComponentType, *Ts ].__len__() - var bitmask: BitMask - @parameter if not size: bitmask = BitMask() else: bitmask = BitMask(Self.component_manager.get_id_arr[*Ts]()) - iterator = Self.Query(Pointer(to=self), bitmask) + mask_filter = MaskFilter(bitmask) fn _get_entity_iterator[ - has_without_mask: Bool = False, has_start_indices: Bool = False + has_exclude: Bool = False, has_start_indices: Bool = False ]( mut self, - owned mask: BitMask, - owned without_mask: StaticOptional[BitMask, has_without_mask], - owned start_indices: _EntityIterator[ + mask_filter: MaskFilter[has_exclude=has_exclude], + var start_indices: _EntityIterator[ __origin_of(self._archetypes), __origin_of(self._locks), *component_types, @@ -1554,19 +1836,18 @@ struct World[*component_types: ComponentType]( __origin_of(self._locks), arch_iter_variant_idx=_ArchetypeByMaskIteratorIdx, has_start_indices=has_start_indices, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ], ) raises: """ Creates an iterator over all entities that have / do not have the components in the provided masks. Parameters: - has_without_mask: Whether a without_mask is provided. + has_exclude: Whether a without_mask is provided. has_start_indices: Whether start_indices are provided. Args: - mask: The mask of components to include. - without_mask: The mask of components to exclude. + mask_filter: The mask filter to use for selecting archetypes. start_indices: The start indices of the iterator. See [..query._EntityIterator]. """ iterator = Self.Iterator[ @@ -1574,35 +1855,33 @@ struct World[*component_types: ComponentType]( __origin_of(self._locks), arch_iter_variant_idx=_ArchetypeByMaskIteratorIdx, has_start_indices=has_start_indices, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ]( Pointer(to=self._locks), Self.ArchetypeIterator[ __origin_of(self._archetypes), arch_iter_variant_idx=_ArchetypeByMaskIteratorIdx, - has_without_mask=has_without_mask, + has_exclude=has_exclude, ]( Self.ArchetypeByMaskIterator[ __origin_of(self._archetypes), - has_without_mask=has_without_mask, + has_exclude=has_exclude, ]( Pointer(to=self._archetypes), - mask, - without_mask, + mask_filter, ) ), - start_indices, + start_indices^, ) @always_inline fn _get_archetype_iterator[ - has_without_mask: Bool = False + has_exclude: Bool = False ]( ref self, - mask: BitMask, - without_mask: StaticOptional[BitMask, has_without_mask] = None, + mask_filter: MaskFilter[has_exclude=has_exclude], out iterator: Self.ArchetypeByMaskIterator[ - __origin_of(self._archetypes), has_without_mask=has_without_mask + __origin_of(self._archetypes), has_exclude=has_exclude ], ): """ @@ -1613,8 +1892,7 @@ struct World[*component_types: ComponentType]( """ iterator = _ArchetypeByMaskIterator( Pointer(to=self._archetypes), - mask, - without_mask, + mask_filter, ) @always_inline diff --git a/test/bitmask_test.mojo b/test/bitmask_test.mojo index 2894078d..eea066f6 100644 --- a/test/bitmask_test.mojo +++ b/test/bitmask_test.mojo @@ -51,7 +51,6 @@ fn get_random_1_true_bitmasks(size: Int, out vals: List[BitMask]): fn run_all_bitmask_tests() raises: print("Running all bitmask tests...") test_bit_mask() - test_bit_mask_without_exclusive() test_bit_mask_256() test_bit_mask_eq() test_bitmask_get_indices() @@ -82,14 +81,14 @@ fn test_bit_mask() raises: assert_true(mask.get(0)) assert_false(mask.get(1)) - mask.flip(UInt8(0)) - mask.flip(UInt8(1)) + mask.flip_mut(UInt8(0)) + mask.flip_mut(UInt8(1)) assert_false(mask.get(0)) assert_true(mask.get(1)) - mask.flip(UInt8(0)) - mask.flip(UInt8(1)) + mask.flip_mut(UInt8(0)) + mask.flip_mut(UInt8(1)) var other1 = BitMask(UInt8(1), UInt8(2), UInt8(32)) var other2 = BitMask(UInt8(0), UInt8(2)) @@ -108,41 +107,13 @@ fn test_bit_mask() raises: assert_false(mask.contains_any(other2)) -fn test_bit_mask_without_exclusive() raises: - mask = BitMask(UInt8(1), UInt8(2), UInt8(13)) - assert_true(mask.matches(BitMask(UInt8(1), UInt8(2), UInt8(13)))) - assert_true(mask.matches(BitMask(UInt8(1), UInt8(2), UInt8(13), UInt8(27)))) - - assert_false(mask.matches(BitMask(UInt8(1), UInt8(2)))) - - without = mask.without(UInt8(3)) - - assert_true(without.matches(BitMask(UInt8(1), UInt8(2), UInt8(13)))) - assert_true( - without.matches(BitMask(UInt8(1), UInt8(2), UInt8(13), UInt8(27))) - ) - - assert_false( - without.matches(BitMask(UInt8(1), UInt8(2), UInt8(3), UInt8(13))) - ) - assert_false(without.matches(BitMask(UInt8(1), UInt8(2)))) - - excl = mask.exclusive() - - assert_true(excl.matches(BitMask(UInt8(1), UInt8(2), UInt8(13)))) - assert_false( - excl.matches(BitMask(UInt8(1), UInt8(2), UInt8(13), UInt8(27))) - ) - assert_false(excl.matches(BitMask(UInt8(1), UInt8(2), UInt8(3), UInt8(13)))) - - fn test_bit_mask_eq() raises: mask1 = get_random_bitmask() mask2 = mask1 assert_true(mask1 == mask2) - mask2.flip(3) + mask2.flip_mut(3) assert_false(mask1 == mask2) diff --git a/test/graph_test.mojo b/test/graph_test.mojo index 87a7f4b2..a80622d5 100644 --- a/test/graph_test.mojo +++ b/test/graph_test.mojo @@ -84,7 +84,7 @@ struct S: self.l = List[Node[Int]]() self.add(BitMask(), -1) - fn add(mut self, owned node_mask: BitMask, owned value: Int): + fn add(mut self, var node_mask: BitMask, var value: Int): self.l.append(Node(node_mask, value)) diff --git a/test/query_test.mojo b/test/query_test.mojo index 1bb43300..f4edf6c7 100644 --- a/test/query_test.mojo +++ b/test/query_test.mojo @@ -4,6 +4,7 @@ from larecs import Entity, Query from larecs.archetype import Archetype as _Archetype from larecs.component import ComponentManager from larecs.query import _ArchetypeByMaskIterator +from larecs.filter import MaskFilter def test_query_length(): @@ -393,14 +394,14 @@ def test_query_archetype_iterator(): a = Archetype(0, BitMask(0)) _ = a.add(Entity(0, 0)) - l = List[Archetype](a, a, a) + l = List[Archetype](a.copy(), a.copy(), a.copy()) var count = 0 for _ in _ArchetypeByMaskIterator[ __origin_of(l), FlexibleComponent[0], component_manager = ComponentManager[FlexibleComponent[0]](), - ](Pointer(to=l), BitMask(0)): + ](Pointer(to=l), MaskFilter(BitMask(0))): count += 1 assert_equal(count, 3) diff --git a/test/static_optional_test.mojo b/test/static_optional_test.mojo index 92c0b859..10832fa9 100644 --- a/test/static_optional_test.mojo +++ b/test/static_optional_test.mojo @@ -1,4 +1,5 @@ from testing import * +from sys import size_of from larecs.test_utils import * from larecs.static_optional import StaticOptional @@ -9,7 +10,7 @@ def test_comptime_optional_init(): assert_false(opt.has_value) _ = opt._value l = List[Int](42) - opt_with_value = StaticOptional(l) + opt_with_value = StaticOptional(l^) assert_true(opt_with_value.has_value) assert_equal(opt_with_value[][0], 42) @@ -22,13 +23,11 @@ def test_comptime_optional_copy(): opt_without_value = StaticOptional[Int, False]() opt_copy_without = opt_without_value.copy() _ = opt_copy_without._value - opt_copy_without = opt_without_value - _ = opt_copy_without def test_comptime_optional_move_del(): fn factory( - owned val: MemTestStruct, + var val: MemTestStruct, out result: StaticOptional[MemTestStruct, True], ): result = __type_of(result)(val^) @@ -42,8 +41,8 @@ def test_comptime_optional_value(): def test_comptime_optional_size(): - assert_equal(sizeof[StaticOptional[UInt16, True]](), 2) - assert_equal(sizeof[StaticOptional[UInt16, False]](), 0) + assert_equal(size_of[StaticOptional[UInt16, True]](), 2) + assert_equal(size_of[StaticOptional[UInt16, False]](), 0) fn optional_argument_application[ diff --git a/test/static_variant_test.mojo b/test/static_variant_test.mojo index d8d30f90..c50bb6a7 100644 --- a/test/static_variant_test.mojo +++ b/test/static_variant_test.mojo @@ -32,7 +32,7 @@ def test_comptime_variant_init(): # def test_comptime_variant_move_del(): # fn factory( -# owned val: MemTestStruct, +# var val: MemTestStruct, # out result: StaticVariant[MemTestStruct, True], # ): # result = __type_of(result)(val^) diff --git a/test/unsafe_box_test.mojo b/test/unsafe_box_test.mojo index 72b82fbc..e7630ec3 100644 --- a/test/unsafe_box_test.mojo +++ b/test/unsafe_box_test.mojo @@ -12,7 +12,7 @@ struct TestStruct: def test_unsafe_box_copy_move_del(): fn factory( - owned val: MemTestStruct, + var val: MemTestStruct, out result: UnsafeBox, ): result = __type_of(result)(val^) diff --git a/test/world_test.mojo b/test/world_test.mojo index f8677349..dc4100b0 100644 --- a/test/world_test.mojo +++ b/test/world_test.mojo @@ -232,7 +232,7 @@ def test_world_add(): def test_world_batch_add(): world = SmallWorld() - n = 10 + n = 100 _ = world.add_entities(Position(1.0, 2.0), count=n) assert_equal(len(world.query[Position]().without[Velocity]()), n) @@ -250,7 +250,7 @@ def test_world_batch_add(): with assert_raises( contains=( - "Query matches entities that already have at least one of the" + "Filter matches entities that already have at least one of the" " components to add." ) ): @@ -266,6 +266,9 @@ def test_world_batch_add(): LargerComponent(0.3, 0.4, 0.5), ) + assert_equal(len(world.query[Position]().without[Velocity]()), 0) + assert_equal(len(world.query[Position, Velocity]()), n) + def test_world_remove(): world = SmallWorld() @@ -300,6 +303,37 @@ def test_world_remove(): assert_equal(index1, world._entities[entity2._id].index) +def test_world_batch_remove(): + world = SmallWorld() + n = 100 + _ = world.add_entities( + Position(1.0, 2.0), Velocity(0.1, 0.2), count=n + ) + + assert_equal(len(world.query[Position, Velocity]()), n) + assert_equal(len(world.query[Position]().without[Velocity]()), 0) + + for entity in world.remove[Velocity]( + world.query[Position, Velocity]()) + : + assert_false(entity.has[Velocity]()) + assert_equal(entity.get[Position]().x, 1.0) + assert_equal(entity.get[Position]().y, 2.0) + + assert_equal(len(world.query[Position, Velocity]()), 0) + assert_equal(len(world.query[Position]().without[Velocity]()), n) + + with assert_raises( + contains=( + "Filter matches entities that don't have all of the" + " components to remove." + ) + ): + _ = world.remove[Velocity]( + world.query[Position](), + ) + + def test_remove_and_add(): world = SmallWorld() pos = Position(1.0, 2.0) @@ -315,13 +349,53 @@ def test_remove_and_add(): world.replace[Position]().by(entity, vel) assert_false(world.has[Position](entity)) assert_true(world.has[Velocity](entity)) + assert_equal(world.get[Velocity](entity).dx, vel.dx) + assert_equal(world.get[Velocity](entity).dy, vel.dy) with assert_raises(): world.replace[Position]().by(entity, vel) + assert_false(world.has[Position](entity)) + assert_true(world.has[Velocity](entity)) assert_equal(world.get[Velocity](entity).dx, vel.dx) assert_equal(world.get[Velocity](entity).dy, vel.dy) +def test_batch_remove_and_add(): + world = SmallWorld() + n = 100 + _ = world.add_entities( + Position(1.0, 2.0), Velocity(0.1, 0.2), count=n + ) + + assert_equal(len(world.query[Position, Velocity]()), n) + assert_equal(len(world.query[Position, FlexibleComponent[1]]().without[Velocity]()), 0) + + for entity in world.replace[Velocity]().by(FlexibleComponent[1](3.0, 4.0), filter= + world.query[Position, Velocity]()): + assert_false(entity.has[Velocity]()) + assert_true(entity.has[Position]()) + assert_true(entity.has[FlexibleComponent[1]]()) + assert_equal(entity.get[Position]().x, 1.0) + assert_equal(entity.get[Position]().y, 2.0) + assert_equal(entity.get[FlexibleComponent[1]]().x, 3.0) + assert_equal(entity.get[FlexibleComponent[1]]().y, 4.0) + + assert_equal(len(world.query[Position, Velocity]()), 0) + assert_equal(len(world.query[Position, FlexibleComponent[1]]().without[Velocity]()), n) + + with assert_raises( + contains=( + "Filter matches entities that already have at least" + " one of the components to add." + ) + ): + _ = world.replace[Velocity]().by(Position(5.0, 6.0), filter=world.query[Position]()) + + for entity in world.replace[Position]().by(Position(42.0, 6.0), filter=world.query[Position]()): + assert_true(entity.has[Position]()) + assert_equal(entity.get[Position]().x, 42.0) + assert_equal(entity.get[Position]().y, 6.0) + @fieldwise_init struct Resource1(ResourceType): @@ -371,7 +445,7 @@ def test_world_apply(): except: pass - world.apply[operation, unroll_factor=3](world.query[Position, Velocity]()) + world.apply[operation, unroll_factor=3](world.filter[Position, Velocity]()) for entity in world.query[Position, Velocity](): assert_equal(entity.get[Position]().x, new_pos.x)