|
21 | 21 |
|
22 | 22 | const std = @import("std"); |
23 | 23 | const seeds_mod = @import("seeds.zig"); |
| 24 | +const pda_mod = @import("pda.zig"); |
24 | 25 | const sol = @import("solana_program_sdk"); |
25 | 26 |
|
26 | 27 | // Import from parent SDK |
@@ -56,6 +57,28 @@ fn validateSeedAccountRef(comptime Accounts: type, seed: SeedSpec) void { |
56 | 57 | } |
57 | 58 | } |
58 | 59 |
|
| 60 | +/// Validate seedField references against account DataType |
| 61 | +fn validateSeedFieldRef(comptime DataType: type, comptime field_name: []const u8) void { |
| 62 | + if (!@hasField(DataType, field_name)) { |
| 63 | + @compileError("seedField references unknown data field: " ++ field_name ++ " in " ++ @typeName(DataType)); |
| 64 | + } |
| 65 | + |
| 66 | + // Verify the field type is valid for seeds (PublicKey or byte array) |
| 67 | + const field_type = @TypeOf(@field(@as(DataType, undefined), field_name)); |
| 68 | + const is_valid_type = comptime blk: { |
| 69 | + if (field_type == PublicKey) break :blk true; |
| 70 | + const info = @typeInfo(field_type); |
| 71 | + if (info == .array) { |
| 72 | + if (info.array.child == u8) break :blk true; |
| 73 | + } |
| 74 | + break :blk false; |
| 75 | + }; |
| 76 | + |
| 77 | + if (!is_valid_type) { |
| 78 | + @compileError("seedField '" ++ field_name ++ "' must be PublicKey or [N]u8, found " ++ @typeName(field_type)); |
| 79 | + } |
| 80 | +} |
| 81 | + |
59 | 82 | fn validateAccountRefs(comptime Accounts: type) void { |
60 | 83 | const fields = @typeInfo(Accounts).@"struct".fields; |
61 | 84 |
|
@@ -93,6 +116,15 @@ fn validateAccountRefs(comptime Accounts: type) void { |
93 | 116 | if (FieldType.SEEDS) |seeds| { |
94 | 117 | inline for (seeds) |seed| { |
95 | 118 | validateSeedAccountRef(Accounts, seed); |
| 119 | + // Validate seedField references against account DataType |
| 120 | + switch (seed) { |
| 121 | + .field => |field_name| { |
| 122 | + if (@hasDecl(FieldType, "DataType")) { |
| 123 | + validateSeedFieldRef(FieldType.DataType, field_name); |
| 124 | + } |
| 125 | + }, |
| 126 | + else => {}, |
| 127 | + } |
96 | 128 | } |
97 | 129 | } |
98 | 130 | if (FieldType.SEEDS_PROGRAM) |seed| { |
@@ -471,6 +503,176 @@ pub fn loadAccountsWithPda( |
471 | 503 | return .{ .accounts = accounts, .bumps = bumps }; |
472 | 504 | } |
473 | 505 |
|
| 506 | +/// Load accounts with automatic PDA seed resolution for all seed types |
| 507 | +/// |
| 508 | +/// Unlike `loadAccountsWithPda` which only handles literal-only seeds, |
| 509 | +/// this function automatically resolves seedAccount and seedField references |
| 510 | +/// by first loading all accounts, then resolving seeds and validating PDAs. |
| 511 | +/// |
| 512 | +/// Seed resolution order: |
| 513 | +/// 1. Load all accounts (without PDA validation) |
| 514 | +/// 2. For each account with seeds: |
| 515 | +/// - Resolve literal seeds directly |
| 516 | +/// - Resolve seedAccount by getting the referenced account's public key |
| 517 | +/// - Resolve seedField by reading the field from the account's data |
| 518 | +/// - Resolve seedBump from previously validated PDAs |
| 519 | +/// 3. Validate PDA addresses and store bumps |
| 520 | +/// 4. Run Phase 3 constraint validation |
| 521 | +/// |
| 522 | +/// Note: seedField resolution requires the referenced data to already be deserialized, |
| 523 | +/// which means the account must be loaded before the field can be accessed. |
| 524 | +/// |
| 525 | +/// Example: |
| 526 | +/// ```zig |
| 527 | +/// const result = try loadAccountsWithDependencies(MyAccounts, &program_id, account_infos); |
| 528 | +/// const accounts = result.accounts; |
| 529 | +/// const bumps = result.bumps; |
| 530 | +/// ``` |
| 531 | +pub fn loadAccountsWithDependencies( |
| 532 | + comptime Accounts: type, |
| 533 | + program_id: *const PublicKey, |
| 534 | + infos: []const AccountInfo, |
| 535 | +) !struct { accounts: Accounts, bumps: Bumps } { |
| 536 | + const fields = @typeInfo(Accounts).@"struct".fields; |
| 537 | + |
| 538 | + if (infos.len < fields.len) { |
| 539 | + return error.AccountNotEnoughAccountKeys; |
| 540 | + } |
| 541 | + |
| 542 | + // Phase 1: Load all accounts without PDA validation |
| 543 | + var accounts: Accounts = undefined; |
| 544 | + |
| 545 | + inline for (fields, 0..) |field, i| { |
| 546 | + const FieldType = field.type; |
| 547 | + const info = &infos[i]; |
| 548 | + |
| 549 | + if (@hasDecl(FieldType, "load")) { |
| 550 | + @field(accounts, field.name) = try FieldType.load(info); |
| 551 | + } else { |
| 552 | + @field(accounts, field.name) = info; |
| 553 | + } |
| 554 | + } |
| 555 | + |
| 556 | + // Phase 2: Resolve seeds and validate PDAs |
| 557 | + var bumps = Bumps{}; |
| 558 | + |
| 559 | + inline for (fields, 0..) |field, i| { |
| 560 | + const FieldType = field.type; |
| 561 | + const info = &infos[i]; |
| 562 | + |
| 563 | + // Skip if no PDA seeds |
| 564 | + if (!@hasDecl(FieldType, "HAS_SEEDS") or !FieldType.HAS_SEEDS) continue; |
| 565 | + if (!@hasDecl(FieldType, "SEEDS")) continue; |
| 566 | + |
| 567 | + const seed_specs = FieldType.SEEDS orelse continue; |
| 568 | + |
| 569 | + // If all seeds are literals, we can resolve at comptime |
| 570 | + if (seeds_mod.areAllLiteralSeeds(seed_specs)) { |
| 571 | + const resolved_seeds = seeds_mod.resolveComptimeSeeds(seed_specs); |
| 572 | + |
| 573 | + if (@hasDecl(FieldType, "loadWithPda")) { |
| 574 | + const result = try FieldType.loadWithPda(info, resolved_seeds, program_id); |
| 575 | + @field(accounts, field.name) = result.account; |
| 576 | + bumps.set(field.name, result.bump); |
| 577 | + } |
| 578 | + } else { |
| 579 | + // Runtime seed resolution required |
| 580 | + var seed_buffer = seeds_mod.SeedBuffer{}; |
| 581 | + |
| 582 | + inline for (seed_specs) |spec| { |
| 583 | + switch (spec) { |
| 584 | + .literal => |lit| { |
| 585 | + try seeds_mod.appendSeed(&seed_buffer, lit); |
| 586 | + }, |
| 587 | + .account => |account_name| { |
| 588 | + // Get public key from referenced account |
| 589 | + const ref_account = @field(accounts, account_name); |
| 590 | + const RefType = @TypeOf(ref_account); |
| 591 | + |
| 592 | + // Check if it's an Account wrapper with key() method |
| 593 | + if (@hasDecl(RefType, "key")) { |
| 594 | + const key_ptr = ref_account.key(); |
| 595 | + try seeds_mod.appendSeed(&seed_buffer, &key_ptr.*.bytes); |
| 596 | + } else { |
| 597 | + // Handle AccountInfo or *AccountInfo |
| 598 | + const ActualType = if (@typeInfo(RefType) == .pointer) |
| 599 | + @typeInfo(RefType).pointer.child |
| 600 | + else |
| 601 | + RefType; |
| 602 | + |
| 603 | + if (@hasField(ActualType, "id")) { |
| 604 | + // AccountInfo has id field (which is *PublicKey) |
| 605 | + const info_ptr = if (@typeInfo(RefType) == .pointer) |
| 606 | + ref_account |
| 607 | + else |
| 608 | + &ref_account; |
| 609 | + try seeds_mod.appendSeed(&seed_buffer, &info_ptr.id.*.bytes); |
| 610 | + } else { |
| 611 | + return error.AccountNotFound; |
| 612 | + } |
| 613 | + } |
| 614 | + }, |
| 615 | + .field => |field_name| { |
| 616 | + // Get field value from account data |
| 617 | + const account = @field(accounts, field.name); |
| 618 | + if (@hasDecl(@TypeOf(account), "data")) { |
| 619 | + const data = account.data; |
| 620 | + const DataType = @TypeOf(data); |
| 621 | + const ActualDataType = if (@typeInfo(DataType) == .pointer) |
| 622 | + @typeInfo(DataType).pointer.child |
| 623 | + else |
| 624 | + DataType; |
| 625 | + |
| 626 | + if (@hasField(ActualDataType, field_name)) { |
| 627 | + const data_ptr = if (@typeInfo(DataType) == .pointer) &data.* else &data; |
| 628 | + const field_ptr = &@field(data_ptr.*, field_name); |
| 629 | + const SeedFieldType = @TypeOf(field_ptr.*); |
| 630 | + |
| 631 | + if (SeedFieldType == PublicKey) { |
| 632 | + try seeds_mod.appendSeed(&seed_buffer, &field_ptr.*.bytes); |
| 633 | + } else { |
| 634 | + const seed_field_info = @typeInfo(SeedFieldType); |
| 635 | + if (seed_field_info == .array and seed_field_info.array.child == u8) { |
| 636 | + try seeds_mod.appendSeed(&seed_buffer, field_ptr.*); |
| 637 | + } else { |
| 638 | + return error.FieldNotFound; |
| 639 | + } |
| 640 | + } |
| 641 | + } else { |
| 642 | + return error.FieldNotFound; |
| 643 | + } |
| 644 | + } else { |
| 645 | + return error.FieldNotFound; |
| 646 | + } |
| 647 | + }, |
| 648 | + .bump => |bump_name| { |
| 649 | + // Get bump from previously validated PDA |
| 650 | + const bump_value = bumps.get(bump_name) orelse return error.BumpNotFound; |
| 651 | + try seeds_mod.appendBumpSeed(&seed_buffer, bump_value); |
| 652 | + }, |
| 653 | + } |
| 654 | + } |
| 655 | + |
| 656 | + // Validate PDA with resolved seeds using pda module |
| 657 | + const bump_value = pda_mod.validatePdaRuntime( |
| 658 | + info.id, |
| 659 | + seed_buffer.asSlice(), |
| 660 | + program_id, |
| 661 | + ) catch { |
| 662 | + return error.ConstraintSeeds; |
| 663 | + }; |
| 664 | + bumps.set(field.name, bump_value); |
| 665 | + } |
| 666 | + } |
| 667 | + |
| 668 | + try validateDuplicateMutableAccounts(Accounts, &accounts); |
| 669 | + |
| 670 | + // Phase 3: Validate constraints after all accounts are loaded |
| 671 | + try validatePhase3Constraints(Accounts, &accounts); |
| 672 | + |
| 673 | + return .{ .accounts = accounts, .bumps = bumps }; |
| 674 | +} |
| 675 | + |
474 | 676 | fn validateDuplicateMutableAccounts(comptime Accounts: type, accounts: *const Accounts) !void { |
475 | 677 | const fields = @typeInfo(Accounts).@"struct".fields; |
476 | 678 |
|
|
0 commit comments