diff --git a/benchmark/query_benchmark.mojo b/benchmark/query_benchmark.mojo index 1209ca4e..d06e735f 100644 --- a/benchmark/query_benchmark.mojo +++ b/benchmark/query_benchmark.mojo @@ -4,6 +4,9 @@ from larecs.world import World from larecs.entity import Entity from larecs.component import ComponentType from larecs.test_utils import * +from larecs import MutableEntityAccessor +from sys.info import simdwidthof +from algorithm import vectorize fn benchmark_add_entity_1_000_000(mut bencher: Bencher) raises capturing: @@ -19,18 +22,279 @@ fn benchmark_add_entity_1_000_000(mut bencher: Bencher) raises capturing: fn benchmark_query_1_comp_1_000_000( mut bencher: Bencher, +) raises capturing: + @always_inline + @parameter + fn bench_fn() capturing raises: + world = SmallWorld() + _ = world.add_entities(Position(1.0, 2.0), count=1000) + for _ in range(1000): + for entity in world.query[Position](): + keep(entity.get[Position]().x) + + bencher.iter[bench_fn]() + + +fn benchmark_vel_pos_add_1_000_000( + mut bencher: Bencher, +) raises capturing: + @always_inline + @parameter + fn bench_fn() capturing raises: + world = SmallWorld() + _ = world.add_entities( + Position(1.0, 2.0), Velocity(0.1, 0.2), count=1000 + ) + for _ in range(1000): + for entity in world.query[Position](): + ref pos = entity.get[Position]() + ref vel = entity.get[Velocity]() + pos.x += vel.dx + pos.y += vel.dy + + bencher.iter[bench_fn]() + + +fn benchmark_vel_pos_add_aos_1_000_000( + mut bencher: Bencher, ) raises capturing: pos = Position(1.0, 2.0) + vel = Velocity(0.1, 0.2) + + @always_inline + @parameter + fn bench_fn() capturing raises: + l1 = List[Position](length=1000, fill=pos) + l2 = List[Velocity](length=1000, fill=vel) + for _ in range(1000): + for i in range(len(l1)): + ref pos = l1[i] + ref vel = l2[i] + pos.x += vel.dx + pos.y += vel.dy + + bencher.iter[bench_fn]() + + +# fn benchmark_vel_pos_add_aos_vec_1_000_000( +# mut bencher: Bencher, +# ) raises capturing: +# pos2 = Position(1.0, 2.0) +# vel2 = Velocity(0.1, 0.2) +# alias stride = 2 + + +# alias simd_width = simdwidthof[Float64]() + +# @always_inline +# @parameter +# fn bench_fn() capturing raises: +# l1 = List[Position](length=1000, fill=pos2) +# l2 = List[Velocity](length=1000, fill=vel2) + +# @parameter +# fn move[simd_width: Int](i: Int): +# try: +# pos = Pointer(to=l1[i]) +# vel = Pointer(to=l2[i]) +# except: +# return + +# pos_x_ptr = UnsafePointer(to=pos[].x) +# pos_y_ptr = UnsafePointer(to=pos[].y) +# vel_x_ptr = UnsafePointer(to=vel[].dx) +# vel_y_ptr = UnsafePointer(to=vel[].dy) + +# pos_x = pos_x_ptr.strided_load[width=simd_width](stride) +# pos_y = pos_y_ptr.strided_load[width=simd_width](stride) +# vel_x = vel_x_ptr.strided_load[width=simd_width](stride) +# vel_y = vel_y_ptr.strided_load[width=simd_width](stride) + +# pos_x += vel_x +# pos_y += vel_y +# pos_x_ptr.strided_store[width=simd_width](pos_x, stride) +# pos_y_ptr.strided_store[width=simd_width](pos_y, stride) + +# for _ in range(1000): +# vectorize[move, simd_width](len(l1)) + +# bencher.iter[bench_fn]() + + +fn benchmark_vel_pos_add_aos_vec_1_000_000( + mut bencher: Bencher, +) raises capturing: + pos2 = Position(1.0, 2.0) + vel2 = Velocity(0.1, 0.2) + alias stride = 2 + + alias simd_width = simdwidthof[Float64]() + + @always_inline + @parameter + fn bench_fn() capturing raises: + l1 = List[Position](length=1000, fill=pos2) + l2 = List[Velocity](length=1000, fill=vel2) + + @parameter + fn move[simd_width: Int](i: Int): + # var pos_ptr = l1.unsafe_ptr().offset(i).bitcast[Float64]() + var pos_ptr = UnsafePointer(to=l1[i]).bitcast[Float64]() + var pos = pos_ptr.load[width = simd_width * 2]() + var vel = ( + l2.unsafe_ptr() + .offset(i) + .bitcast[Float64]() + .load[width = simd_width * 2]() + ) + + pos_ptr.store(pos + vel) + + for _ in range(1000): + vectorize[move, simd_width // 2](len(l1)) + + bencher.iter[bench_fn]() + + +@fieldwise_init +struct PosX(Copyable & Movable): + var value: Float64 + + +@fieldwise_init +struct PosY(Copyable & Movable): + var value: Float64 + + +@fieldwise_init +struct VelX(Copyable & Movable): + var value: Float64 + + +@fieldwise_init +struct VelY(Copyable & Movable): + var value: Float64 + + +fn benchmark_vel_pos_add_vec_optimized_1_000_000( + mut bencher: Bencher, +) raises capturing: + @parameter + fn move[simd_width: Int](entity: MutableEntityAccessor): + try: + var pos_ptr = UnsafePointer(to=entity.get[Position]()).bitcast[ + Float64 + ]() + var vel = ( + UnsafePointer(to=entity.get[Velocity]()) + .bitcast[Float64]() + .load[width = simd_width * 2]() + ) + var pos = pos_ptr.load[width = simd_width * 2]() + pos_ptr.store(pos + vel) + except: + return + + alias simd_width = simdwidthof[Float64]() @always_inline @parameter fn bench_fn() capturing raises: world = SmallWorld() + _ = world.add_entities( + Position(1.0, 2.0), Velocity(0.1, 0.2), count=1000 + ) for _ in range(1000): - _ = world.add_entity(pos) + world.apply[move, simd_width = simd_width // 2]( + world.query[Position, Velocity]() + ) + + bencher.iter[bench_fn]() + + +fn benchmark_vel_pos_add_vec_1_000_000( + mut bencher: Bencher, +) raises capturing: + @parameter + fn move[simd_width: Int](entity: MutableEntityAccessor): + try: + var posX_ptr = UnsafePointer(to=entity.get[PosX]().value) + var posX = posX_ptr.load[width=simd_width]() + var velX = UnsafePointer(to=entity.get[VelX]().value).load[ + width=simd_width + ]() + posX_ptr.store(posX + velX) + + var posY_ptr = UnsafePointer(to=entity.get[PosY]().value) + var posY = posY_ptr.load[width=simd_width]() + var velY = UnsafePointer(to=entity.get[VelY]().value).load[ + width=simd_width + ]() + posY_ptr.store(posY + velY) + + except: + return + + alias simd_width = simdwidthof[Float64]() + + @always_inline + @parameter + fn bench_fn() capturing raises: + world = World[PosX, VelX, PosY, VelY]() + _ = world.add_entities( + PosX(1.0), VelX(0.1), PosY(2.0), VelY(0.2), count=1000 + ) for _ in range(1000): - for entity in world.query[Position](): - keep(entity.get[Position]().x) + world.apply[move, simd_width=simd_width]( + world.query[PosX, VelX, PosY, VelY]() + ) + + bencher.iter[bench_fn]() + + +fn benchmark_vel_pos_add_vec_split_1_000_000( + mut bencher: Bencher, +) raises capturing: + @parameter + fn move_x[simd_width: Int](entity: MutableEntityAccessor): + try: + var pos_ptr = UnsafePointer(to=entity.get[PosX]().value) + var pos = pos_ptr.load[width=simd_width]() + var vel = UnsafePointer(to=entity.get[VelX]().value).load[ + width=simd_width + ]() + pos_ptr.store(pos + vel) + except: + return + + @parameter + fn move_y[simd_width: Int](entity: MutableEntityAccessor): + try: + var pos_ptr = UnsafePointer(to=entity.get[PosY]().value) + var pos = pos_ptr.load[width=simd_width]() + var vel = UnsafePointer(to=entity.get[VelY]().value).load[ + width=simd_width + ]() + pos_ptr.store(pos + vel) + except: + return + + alias simd_width = simdwidthof[Float64]() + + @always_inline + @parameter + fn bench_fn() capturing raises: + world = World[PosX, VelX, PosY, VelY]() + _ = world.add_entities( + PosX(1.0), VelX(0.1), PosY(2.0), VelY(0.2), count=1000 + ) + for _ in range(1000): + world.apply[move_x, simd_width=simd_width]( + world.query[PosX, VelX, PosY, VelY]() + ) + world.apply[move_y, simd_width=simd_width]( + world.query[PosX, VelX, PosY, VelY]() + ) bencher.iter[bench_fn]() @@ -150,6 +414,24 @@ fn run_all_query_benchmarks(mut bench: Bench) raises: bench.bench_function[benchmark_query_get_iter_1_000_000]( BenchId("10^6 * get query iter") ) + bench.bench_function[benchmark_vel_pos_add_aos_1_000_000]( + BenchId("10^3 * 10^3 * pos vel add aos") + ) + bench.bench_function[benchmark_vel_pos_add_1_000_000]( + BenchId("10^3 * 10^3 * pos vel add") + ) + bench.bench_function[benchmark_vel_pos_add_aos_vec_1_000_000]( + BenchId("10^3 * 10^3 * pos vel add aos vec optimized") + ) + bench.bench_function[benchmark_vel_pos_add_vec_1_000_000]( + BenchId("10^3 * 10^3 * pos vel add vec") + ) + bench.bench_function[benchmark_vel_pos_add_vec_split_1_000_000]( + BenchId("10^3 * 10^3 * pos vel add vec split") + ) + bench.bench_function[benchmark_vel_pos_add_vec_optimized_1_000_000]( + BenchId("10^3 * 10^3 * pos vel add vec optimized") + ) def main():