Skip to content

Commit 8f6cdc2

Browse files
committed
Add PDA seed automation and runtime validation
1 parent aea4a21 commit 8f6cdc2

5 files changed

Lines changed: 536 additions & 0 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.vscode/
22
.direnv/
33
.envrc
4+
.claude/
45

56
solana-zig/
67
zig-cache/

anchor/src/context.zig

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
const std = @import("std");
2323
const seeds_mod = @import("seeds.zig");
24+
const pda_mod = @import("pda.zig");
2425
const sol = @import("solana_program_sdk");
2526

2627
// Import from parent SDK
@@ -56,6 +57,28 @@ fn validateSeedAccountRef(comptime Accounts: type, seed: SeedSpec) void {
5657
}
5758
}
5859

60+
/// Validate seedField references against account DataType
61+
fn validateSeedFieldRef(comptime DataType: type, comptime field_name: []const u8) void {
62+
if (!@hasField(DataType, field_name)) {
63+
@compileError("seedField references unknown data field: " ++ field_name ++ " in " ++ @typeName(DataType));
64+
}
65+
66+
// Verify the field type is valid for seeds (PublicKey or byte array)
67+
const field_type = @TypeOf(@field(@as(DataType, undefined), field_name));
68+
const is_valid_type = comptime blk: {
69+
if (field_type == PublicKey) break :blk true;
70+
const info = @typeInfo(field_type);
71+
if (info == .array) {
72+
if (info.array.child == u8) break :blk true;
73+
}
74+
break :blk false;
75+
};
76+
77+
if (!is_valid_type) {
78+
@compileError("seedField '" ++ field_name ++ "' must be PublicKey or [N]u8, found " ++ @typeName(field_type));
79+
}
80+
}
81+
5982
fn validateAccountRefs(comptime Accounts: type) void {
6083
const fields = @typeInfo(Accounts).@"struct".fields;
6184

@@ -93,6 +116,15 @@ fn validateAccountRefs(comptime Accounts: type) void {
93116
if (FieldType.SEEDS) |seeds| {
94117
inline for (seeds) |seed| {
95118
validateSeedAccountRef(Accounts, seed);
119+
// Validate seedField references against account DataType
120+
switch (seed) {
121+
.field => |field_name| {
122+
if (@hasDecl(FieldType, "DataType")) {
123+
validateSeedFieldRef(FieldType.DataType, field_name);
124+
}
125+
},
126+
else => {},
127+
}
96128
}
97129
}
98130
if (FieldType.SEEDS_PROGRAM) |seed| {
@@ -471,6 +503,176 @@ pub fn loadAccountsWithPda(
471503
return .{ .accounts = accounts, .bumps = bumps };
472504
}
473505

506+
/// Load accounts with automatic PDA seed resolution for all seed types
507+
///
508+
/// Unlike `loadAccountsWithPda` which only handles literal-only seeds,
509+
/// this function automatically resolves seedAccount and seedField references
510+
/// by first loading all accounts, then resolving seeds and validating PDAs.
511+
///
512+
/// Seed resolution order:
513+
/// 1. Load all accounts (without PDA validation)
514+
/// 2. For each account with seeds:
515+
/// - Resolve literal seeds directly
516+
/// - Resolve seedAccount by getting the referenced account's public key
517+
/// - Resolve seedField by reading the field from the account's data
518+
/// - Resolve seedBump from previously validated PDAs
519+
/// 3. Validate PDA addresses and store bumps
520+
/// 4. Run Phase 3 constraint validation
521+
///
522+
/// Note: seedField resolution requires the referenced data to already be deserialized,
523+
/// which means the account must be loaded before the field can be accessed.
524+
///
525+
/// Example:
526+
/// ```zig
527+
/// const result = try loadAccountsWithDependencies(MyAccounts, &program_id, account_infos);
528+
/// const accounts = result.accounts;
529+
/// const bumps = result.bumps;
530+
/// ```
531+
pub fn loadAccountsWithDependencies(
532+
comptime Accounts: type,
533+
program_id: *const PublicKey,
534+
infos: []const AccountInfo,
535+
) !struct { accounts: Accounts, bumps: Bumps } {
536+
const fields = @typeInfo(Accounts).@"struct".fields;
537+
538+
if (infos.len < fields.len) {
539+
return error.AccountNotEnoughAccountKeys;
540+
}
541+
542+
// Phase 1: Load all accounts without PDA validation
543+
var accounts: Accounts = undefined;
544+
545+
inline for (fields, 0..) |field, i| {
546+
const FieldType = field.type;
547+
const info = &infos[i];
548+
549+
if (@hasDecl(FieldType, "load")) {
550+
@field(accounts, field.name) = try FieldType.load(info);
551+
} else {
552+
@field(accounts, field.name) = info;
553+
}
554+
}
555+
556+
// Phase 2: Resolve seeds and validate PDAs
557+
var bumps = Bumps{};
558+
559+
inline for (fields, 0..) |field, i| {
560+
const FieldType = field.type;
561+
const info = &infos[i];
562+
563+
// Skip if no PDA seeds
564+
if (!@hasDecl(FieldType, "HAS_SEEDS") or !FieldType.HAS_SEEDS) continue;
565+
if (!@hasDecl(FieldType, "SEEDS")) continue;
566+
567+
const seed_specs = FieldType.SEEDS orelse continue;
568+
569+
// If all seeds are literals, we can resolve at comptime
570+
if (seeds_mod.areAllLiteralSeeds(seed_specs)) {
571+
const resolved_seeds = seeds_mod.resolveComptimeSeeds(seed_specs);
572+
573+
if (@hasDecl(FieldType, "loadWithPda")) {
574+
const result = try FieldType.loadWithPda(info, resolved_seeds, program_id);
575+
@field(accounts, field.name) = result.account;
576+
bumps.set(field.name, result.bump);
577+
}
578+
} else {
579+
// Runtime seed resolution required
580+
var seed_buffer = seeds_mod.SeedBuffer{};
581+
582+
inline for (seed_specs) |spec| {
583+
switch (spec) {
584+
.literal => |lit| {
585+
try seeds_mod.appendSeed(&seed_buffer, lit);
586+
},
587+
.account => |account_name| {
588+
// Get public key from referenced account
589+
const ref_account = @field(accounts, account_name);
590+
const RefType = @TypeOf(ref_account);
591+
592+
// Check if it's an Account wrapper with key() method
593+
if (@hasDecl(RefType, "key")) {
594+
const key_ptr = ref_account.key();
595+
try seeds_mod.appendSeed(&seed_buffer, &key_ptr.*.bytes);
596+
} else {
597+
// Handle AccountInfo or *AccountInfo
598+
const ActualType = if (@typeInfo(RefType) == .pointer)
599+
@typeInfo(RefType).pointer.child
600+
else
601+
RefType;
602+
603+
if (@hasField(ActualType, "id")) {
604+
// AccountInfo has id field (which is *PublicKey)
605+
const info_ptr = if (@typeInfo(RefType) == .pointer)
606+
ref_account
607+
else
608+
&ref_account;
609+
try seeds_mod.appendSeed(&seed_buffer, &info_ptr.id.*.bytes);
610+
} else {
611+
return error.AccountNotFound;
612+
}
613+
}
614+
},
615+
.field => |field_name| {
616+
// Get field value from account data
617+
const account = @field(accounts, field.name);
618+
if (@hasDecl(@TypeOf(account), "data")) {
619+
const data = account.data;
620+
const DataType = @TypeOf(data);
621+
const ActualDataType = if (@typeInfo(DataType) == .pointer)
622+
@typeInfo(DataType).pointer.child
623+
else
624+
DataType;
625+
626+
if (@hasField(ActualDataType, field_name)) {
627+
const data_ptr = if (@typeInfo(DataType) == .pointer) &data.* else &data;
628+
const field_ptr = &@field(data_ptr.*, field_name);
629+
const SeedFieldType = @TypeOf(field_ptr.*);
630+
631+
if (SeedFieldType == PublicKey) {
632+
try seeds_mod.appendSeed(&seed_buffer, &field_ptr.*.bytes);
633+
} else {
634+
const seed_field_info = @typeInfo(SeedFieldType);
635+
if (seed_field_info == .array and seed_field_info.array.child == u8) {
636+
try seeds_mod.appendSeed(&seed_buffer, field_ptr.*);
637+
} else {
638+
return error.FieldNotFound;
639+
}
640+
}
641+
} else {
642+
return error.FieldNotFound;
643+
}
644+
} else {
645+
return error.FieldNotFound;
646+
}
647+
},
648+
.bump => |bump_name| {
649+
// Get bump from previously validated PDA
650+
const bump_value = bumps.get(bump_name) orelse return error.BumpNotFound;
651+
try seeds_mod.appendBumpSeed(&seed_buffer, bump_value);
652+
},
653+
}
654+
}
655+
656+
// Validate PDA with resolved seeds using pda module
657+
const bump_value = pda_mod.validatePdaRuntime(
658+
info.id,
659+
seed_buffer.asSlice(),
660+
program_id,
661+
) catch {
662+
return error.ConstraintSeeds;
663+
};
664+
bumps.set(field.name, bump_value);
665+
}
666+
}
667+
668+
try validateDuplicateMutableAccounts(Accounts, &accounts);
669+
670+
// Phase 3: Validate constraints after all accounts are loaded
671+
try validatePhase3Constraints(Accounts, &accounts);
672+
673+
return .{ .accounts = accounts, .bumps = bumps };
674+
}
675+
474676
fn validateDuplicateMutableAccounts(comptime Accounts: type, accounts: *const Accounts) !void {
475677
const fields = @typeInfo(Accounts).@"struct".fields;
476678

anchor/src/pda.zig

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,41 @@ pub fn derivePda(
126126
};
127127
}
128128

129+
/// Validate PDA using runtime-resolved seeds (slice-based)
130+
///
131+
/// Use this when seeds are resolved at runtime (e.g., seedAccount, seedField).
132+
/// Unlike `validatePda`, this function accepts a slice of seed byte slices.
133+
///
134+
/// Example:
135+
/// ```zig
136+
/// var seed_buffer = SeedBuffer{};
137+
/// try seeds_mod.appendSeed(&seed_buffer, "counter");
138+
/// try seeds_mod.appendSeed(&seed_buffer, &authority_key.bytes);
139+
///
140+
/// const bump = try validatePdaRuntime(
141+
/// counter_account.key(),
142+
/// seed_buffer.asSlice(),
143+
/// program_id,
144+
/// );
145+
/// ```
146+
pub fn validatePdaRuntime(
147+
account_key: *const PublicKey,
148+
seeds: []const []const u8,
149+
program_id: *const PublicKey,
150+
) PdaError!u8 {
151+
// Use SDK's slice-based findProgramAddress
152+
const pda = PublicKey.findProgramAddressSlice(seeds, program_id.*) catch {
153+
return PdaError.DerivationFailed;
154+
};
155+
156+
// Compare addresses
157+
if (!account_key.equals(pda.address)) {
158+
return PdaError.InvalidPda;
159+
}
160+
161+
return pda.bump_seed[0];
162+
}
163+
129164
/// Create a PDA address with known bump (no search)
130165
///
131166
/// Use this when you already know the bump seed to avoid
@@ -252,3 +287,60 @@ test "validatePda with multiple seeds" {
252287

253288
try std.testing.expectEqual(pda.bump_seed[0], bump);
254289
}
290+
291+
test "validatePdaRuntime succeeds for valid PDA" {
292+
const program_id = comptime PublicKey.comptimeFromBase58("BPFLoaderUpgradeab1e11111111111111111111111");
293+
294+
// Derive PDA using comptime seeds
295+
const pda = try derivePda(.{"runtime_test"}, &program_id);
296+
297+
// Validate using runtime slice
298+
const runtime_seeds: []const []const u8 = &.{"runtime_test"};
299+
const bump = try validatePdaRuntime(&pda.address, runtime_seeds, &program_id);
300+
301+
try std.testing.expectEqual(pda.bump_seed[0], bump);
302+
}
303+
304+
test "validatePdaRuntime with multiple seeds" {
305+
const program_id = comptime PublicKey.comptimeFromBase58("BPFLoaderUpgradeab1e11111111111111111111111");
306+
const authority = comptime PublicKey.comptimeFromBase58("TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA");
307+
308+
// Derive PDA
309+
const pda = try derivePda(.{ "counter", &authority.bytes }, &program_id);
310+
311+
// Validate using runtime slice (simulating seedAccount resolution)
312+
const runtime_seeds: []const []const u8 = &.{ "counter", &authority.bytes };
313+
const bump = try validatePdaRuntime(&pda.address, runtime_seeds, &program_id);
314+
315+
try std.testing.expectEqual(pda.bump_seed[0], bump);
316+
}
317+
318+
test "validatePdaRuntime fails for wrong address" {
319+
const program_id = comptime PublicKey.comptimeFromBase58("BPFLoaderUpgradeab1e11111111111111111111111");
320+
321+
// Use a random address that's not the PDA
322+
var wrong_address = PublicKey.default();
323+
wrong_address.bytes[0] = 0xFF;
324+
325+
const runtime_seeds: []const []const u8 = &.{"test_seed"};
326+
const result = validatePdaRuntime(&wrong_address, runtime_seeds, &program_id);
327+
try std.testing.expectError(PdaError.InvalidPda, result);
328+
}
329+
330+
test "validatePdaRuntime with SeedBuffer" {
331+
const program_id = comptime PublicKey.comptimeFromBase58("BPFLoaderUpgradeab1e11111111111111111111111");
332+
const user = comptime PublicKey.comptimeFromBase58("TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA");
333+
334+
// Derive PDA using comptime
335+
const pda = try derivePda(.{ "user_data", &user.bytes }, &program_id);
336+
337+
// Build seeds using SeedBuffer (simulating runtime resolution)
338+
var buffer = SeedBuffer{};
339+
try seeds_mod.appendSeed(&buffer, "user_data");
340+
try seeds_mod.appendSeed(&buffer, &user.bytes);
341+
342+
// Validate using buffer's slice
343+
const bump = try validatePdaRuntime(&pda.address, buffer.asSlice(), &program_id);
344+
345+
try std.testing.expectEqual(pda.bump_seed[0], bump);
346+
}

0 commit comments

Comments
 (0)