Skip to content

Commit ad1b2c4

Browse files
committed
Fix tests
1 parent 7026afd commit ad1b2c4

File tree

5 files changed

+40
-20
lines changed

5 files changed

+40
-20
lines changed

NBitcoin/BIP174/Maps.cs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#nullable enable
2+
using NBitcoin.DataEncoders;
23
using NBitcoin.Protocol;
34
using System;
45
using System.Collections.Generic;
@@ -52,6 +53,12 @@ public Maps()
5253

5354
}
5455

56+
public void ThrowIfInvalidKeysLeft()
57+
{
58+
foreach (var m in this)
59+
m.ThrowIfInvalidKeysLeft();
60+
}
61+
5562
public Map NewMap()
5663
{
5764
Map map = new Map();
@@ -128,8 +135,23 @@ private static bool StartWith(byte[] prefix, byte[] data)
128135
}
129136
return true;
130137
}
131-
public bool TryRemove<T>(byte key, [MaybeNullWhen(false)] out T value) => TryRemove<T>([key], out value);
132-
public bool TryRemove<T>(byte[] key, [MaybeNullWhen(false)] out T value)
138+
139+
public void ThrowIfInvalidKeysLeft()
140+
{
141+
var readen = new HashSet<byte>(singleByteKeys);
142+
foreach (var kv in this)
143+
{
144+
if (readen.Contains(kv.Key[0]))
145+
throw new FormatException("Invalid PSBT, unexpected key " + Encoders.Hex.EncodeData(kv.Key));
146+
}
147+
}
148+
List<byte> singleByteKeys = new();
149+
public bool TryRemove<T>(byte key, [MaybeNullWhen(false)] out T value)
150+
{
151+
singleByteKeys.Add(key);
152+
return TryRemove<T>([key], out value);
153+
}
154+
bool TryRemove<T>(byte[] key, [MaybeNullWhen(false)] out T value)
133155
{
134156
value = default;
135157
object? val = null;

NBitcoin/BIP174/PSBT0.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ protected override void WriteCore(JsonTextWriter jsonWriter)
7979

8080
internal PSBT0(Maps maps, Network network) : base(maps, network, PSBTVersion.PSBTv0)
8181
{
82-
if (!maps.Global.TryRemove<byte[]>([PSBTConstants.PSBT_GLOBAL_UNSIGNED_TX], out var txBytes))
82+
if (!maps.Global.TryRemove<byte[]>(PSBTConstants.PSBT_GLOBAL_UNSIGNED_TX, out var txBytes))
8383
throw new FormatException("Invalid PSBT. No global TX");
8484
tx = Transaction.Load(txBytes, Network);
8585
tx.PrecomputeHash(true, true);
@@ -100,10 +100,9 @@ internal PSBT0(Maps maps, Network network) : base(maps, network, PSBTVersion.PSB
100100
{
101101
var index = (int)(1 + Inputs.Count + indexedOutput.N);
102102
var map = maps[index];
103-
if (map.Keys.Any(bytes => bytes.Length == 1 && PSBT2Constants.PSBT_V0_OUTPUT_EXCLUSIONSET.Contains(bytes[0])))
104-
throw new FormatException("Invalid PSBT v0. Contains v2 fields");
105103
Outputs.Add(new PSBT0Output(map, this, indexedOutput.N, indexedOutput.TxOut));
106104
}
105+
maps.ThrowIfInvalidKeysLeft();
107106
}
108107

109108
internal override void FillMap(Map map)
@@ -129,6 +128,8 @@ public PSBT0Output(Map map, PSBT parent, uint index, TxOut txOut) : base(map, pa
129128
{
130129
if (txOut is null)
131130
throw new ArgumentNullException(nameof(txOut));
131+
if (map.Keys.Any(bytes => bytes.Length == 1 && PSBT2Constants.PSBT_V0_OUTPUT_EXCLUSIONSET.Contains(bytes[0])))
132+
throw new FormatException("Invalid PSBT v0. Contains v2 fields");
132133
TxOut = txOut;
133134
}
134135

@@ -168,10 +169,6 @@ internal PSBT0Input(Map map, PSBT0 parent, uint index) : base(map, parent, index
168169
if (this.Unknown.Keys.Any(bytes => bytes.Length == 1 && PSBT2Constants.PSBT_V0_INPUT_EXCLUSIONSET.Contains(bytes[0])))
169170
throw new FormatException("Invalid PSBT v0. Contains v2 fields");
170171
txIn = parent.tx.Inputs[index];
171-
if (!Script.IsNullOrEmpty(txIn.ScriptSig))
172-
this.FinalScriptSig = txIn.ScriptSig;
173-
if (!WitScript.IsNullOrEmpty(txIn.WitScript))
174-
this.WitnessScript = txIn.WitScript;
175172
}
176173
TxIn txIn;
177174
public override OutPoint PrevOut => txIn.PrevOut;

NBitcoin/BIP174/PartiallySignedTransaction.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ public static PSBT Load(byte[] rawBytes, Network network)
204204

205205

206206
var maps = Maps.Load(stream);
207-
if (maps.Global.TryRemove<int>([PSBTConstants.PSBT_GLOBAL_VERSION], out var psbtVersion) && psbtVersion == 0)
207+
if (maps.Global.TryRemove<int>(PSBTConstants.PSBT_GLOBAL_VERSION, out var psbtVersion) && psbtVersion == 0)
208208
throw new FormatException("PSBTv0 should not include PSBT_GLOBAL_VERSION");
209209
return psbtVersion switch
210210
{

NBitcoin/BIP370/PSBT2.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,20 @@ internal PSBT2(Maps maps, Network network) : base(maps, network, PSBTVersion.PSB
4242
throw new FormatException("PSBT v2 must not contain PSBT_GLOBAL_UNSIGNED_TX");
4343
}
4444

45-
if (globalMap.TryRemove<uint>([PSBT2Constants.PSBT_GLOBAL_FALLBACK_LOCKTIME], out var v))
45+
if (globalMap.TryRemove<uint>(PSBT2Constants.PSBT_GLOBAL_FALLBACK_LOCKTIME, out var v))
4646
FallbackLockTime = new LockTime(v);
4747

48-
if (globalMap.TryRemove<uint>([PSBT2Constants.PSBT_GLOBAL_TX_VERSION], out var txVersion))
48+
if (globalMap.TryRemove<uint>(PSBT2Constants.PSBT_GLOBAL_TX_VERSION, out var txVersion))
4949
TransactionVersion = txVersion;
5050
else
5151
throw new FormatException("PSBT v2 must contain PSBT_GLOBAL_TX_VERSION");
5252

53-
if (globalMap.TryRemove<byte>([PSBT2Constants.PSBT_GLOBAL_TX_MODIFIABLE], out var modifiableFlagsByte))
53+
if (globalMap.TryRemove<byte>(PSBT2Constants.PSBT_GLOBAL_TX_MODIFIABLE, out var modifiableFlagsByte))
5454
ModifiableFlags = (PSBTModifiable)modifiableFlagsByte;
5555

56-
if (!globalMap.TryRemove<VarInt>([PSBT2Constants.PSBT_GLOBAL_INPUT_COUNT], out var inputCount))
56+
if (!globalMap.TryRemove<VarInt>(PSBT2Constants.PSBT_GLOBAL_INPUT_COUNT, out var inputCount))
5757
throw new FormatException("PSBT v2 must contain PSBT_GLOBAL_INPUT_COUNT");
58-
if (!globalMap.TryRemove<VarInt>([PSBT2Constants.PSBT_GLOBAL_OUTPUT_COUNT], out var outputCount))
58+
if (!globalMap.TryRemove<VarInt>(PSBT2Constants.PSBT_GLOBAL_OUTPUT_COUNT, out var outputCount))
5959
throw new FormatException("PSBT v2 must contain PSBT_GLOBAL_OUTPUT_COUNT");
6060

6161
Unknown = globalMap;
@@ -74,6 +74,7 @@ internal PSBT2(Maps maps, Network network) : base(maps, network, PSBTVersion.PSB
7474
var map = maps[(int)mapIndex];
7575
Outputs.Add(new PSBT2Output(map, this, (uint)outputIndex));
7676
}
77+
maps.ThrowIfInvalidKeysLeft();
7778
}
7879

7980
internal override Transaction GetGlobalTransaction(bool @unsafe)

NBitcoin/BIP370/PSBT2Input.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,26 +96,26 @@ internal PSBT2Input(OutPoint prevOut, PSBT parent, uint inputIndex) : base(new M
9696
}
9797
internal PSBT2Input(Map map, PSBT parent, uint inputIndex) : base(map, parent, inputIndex)
9898
{
99-
if (!map.TryRemove<byte[]>([PSBT2Constants.PSBT_IN_PREVIOUS_TXID], out var txidBytes) || txidBytes.Length != 32)
99+
if (!map.TryRemove<byte[]>(PSBT2Constants.PSBT_IN_PREVIOUS_TXID, out var txidBytes) || txidBytes.Length != 32)
100100
throw new FormatException("PSBT v2 must contain PSBT_IN_PREVIOUS_TXID");
101101

102-
if (!map.TryRemove<uint>([PSBT2Constants.PSBT_IN_OUTPUT_INDEX], out var index))
102+
if (!map.TryRemove<uint>(PSBT2Constants.PSBT_IN_OUTPUT_INDEX, out var index))
103103
throw new FormatException("PSBT v2 must contain PSBT_IN_OUTPUT_INDEX");
104104

105105
this.PrevOut = new OutPoint(new uint256(txidBytes), index);
106106

107-
if (map.TryRemove<uint>([PSBT2Constants.PSBT_IN_SEQUENCE], out var s))
107+
if (map.TryRemove<uint>(PSBT2Constants.PSBT_IN_SEQUENCE, out var s))
108108
Sequence = new Sequence(s);
109109

110-
if (map.TryRemove<uint>([PSBT2Constants.PSBT_IN_REQUIRED_TIME_LOCKTIME], out var timeLockTimeV))
110+
if (map.TryRemove<uint>(PSBT2Constants.PSBT_IN_REQUIRED_TIME_LOCKTIME, out var timeLockTimeV))
111111
{
112112
var locktime = new LockTime(timeLockTimeV);
113113
if (!locktime.IsTimeLock)
114114
throw new FormatException("PSBT v2 input locktime must be a time lock");
115115
LockTime = locktime.Date;
116116
}
117117

118-
if (map.TryRemove<uint>([PSBT2Constants.PSBT_IN_REQUIRED_HEIGHT_LOCKTIME], out var locktimeV))
118+
if (map.TryRemove<uint>(PSBT2Constants.PSBT_IN_REQUIRED_HEIGHT_LOCKTIME, out var locktimeV))
119119
{
120120
var locktime = new LockTime(locktimeV);
121121
if (!locktime.IsHeightLock)

0 commit comments

Comments
 (0)