|
| 1 | +//! Bulk data loading and seeding tools (bulk_load, seed) |
| 2 | +
|
| 3 | +use std::time::Instant; |
| 4 | + |
| 5 | +use tower_mcp::{CallToolResult, ResultExt}; |
| 6 | + |
| 7 | +use crate::serde_helpers; |
| 8 | +use crate::tools::macros::{database_tool, mcp_module}; |
| 9 | + |
| 10 | +/// A single Redis command represented as a list of arguments. |
| 11 | +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] |
| 12 | +pub struct Command { |
| 13 | + /// Redis command arguments (e.g. ["SET", "key", "value"] or ["ZADD", "myset", "1.0", "member"]) |
| 14 | + pub args: Vec<String>, |
| 15 | +} |
| 16 | + |
| 17 | +/// A field-value pair for hash seeding. |
| 18 | +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] |
| 19 | +pub struct FieldValue { |
| 20 | + /// Field name (supports {i} substitution) |
| 21 | + pub name: String, |
| 22 | + /// Value pattern (supports {i} substitution) |
| 23 | + pub value: String, |
| 24 | +} |
| 25 | + |
| 26 | +fn default_batch_size() -> usize { |
| 27 | + 1000 |
| 28 | +} |
| 29 | + |
| 30 | +/// Substitute `{i}` and `{i:0N}` patterns in a string with the given index. |
| 31 | +fn substitute_pattern(pattern: &str, i: u64) -> String { |
| 32 | + let mut result = pattern.to_string(); |
| 33 | + // Handle {i:0N} zero-padded patterns |
| 34 | + while let Some(start) = result.find("{i:0") { |
| 35 | + if let Some(end) = result[start..].find('}') { |
| 36 | + let width_str = &result[start + 4..start + end]; |
| 37 | + if let Ok(width) = width_str.parse::<usize>() { |
| 38 | + let replacement = format!("{:0>width$}", i, width = width); |
| 39 | + result = format!( |
| 40 | + "{}{}{}", |
| 41 | + &result[..start], |
| 42 | + replacement, |
| 43 | + &result[start + end + 1..] |
| 44 | + ); |
| 45 | + continue; |
| 46 | + } |
| 47 | + } |
| 48 | + break; |
| 49 | + } |
| 50 | + result.replace("{i}", &i.to_string()) |
| 51 | +} |
| 52 | + |
| 53 | +mcp_module! { |
| 54 | + bulk_load => "redis_bulk_load", |
| 55 | + seed => "redis_seed", |
| 56 | +} |
| 57 | + |
| 58 | +database_tool!(write, bulk_load, "redis_bulk_load", |
| 59 | + "Pipelined command execution. Accept a batch of Redis commands and execute them \ |
| 60 | + using Redis pipelining for high throughput. Returns count of commands executed, \ |
| 61 | + elapsed time, and throughput.", |
| 62 | + { |
| 63 | + /// List of commands to execute. Each command is an args array (e.g. [\"SET\", \"k\", \"v\"]). |
| 64 | + pub commands: Vec<Command>, |
| 65 | + /// Pipeline batch size (default: 1000). Commands are sent in batches of this size. |
| 66 | + #[serde(default = "default_batch_size", deserialize_with = "serde_helpers::string_or_usize::deserialize")] |
| 67 | + pub batch_size: usize, |
| 68 | + } => |conn, input| { |
| 69 | + if input.commands.is_empty() { |
| 70 | + return Ok(CallToolResult::text("No commands to execute")); |
| 71 | + } |
| 72 | + |
| 73 | + let batch_size = input.batch_size.max(1); |
| 74 | + let start = Instant::now(); |
| 75 | + let mut total_ok = 0usize; |
| 76 | + |
| 77 | + for (batch_idx, chunk) in input.commands.chunks(batch_size).enumerate() { |
| 78 | + let mut pipe = redis::pipe(); |
| 79 | + for cmd_input in chunk { |
| 80 | + if cmd_input.args.is_empty() { |
| 81 | + continue; |
| 82 | + } |
| 83 | + let mut cmd = redis::cmd(&cmd_input.args[0]); |
| 84 | + for arg in &cmd_input.args[1..] { |
| 85 | + cmd.arg(arg); |
| 86 | + } |
| 87 | + pipe.add_command(cmd).ignore(); |
| 88 | + } |
| 89 | + pipe.query_async::<()>(&mut conn) |
| 90 | + .await |
| 91 | + .tool_context(format!("Pipeline batch {} failed", batch_idx))?; |
| 92 | + total_ok += chunk.len(); |
| 93 | + } |
| 94 | + |
| 95 | + let elapsed = start.elapsed(); |
| 96 | + let rate = if elapsed.as_secs_f64() > 0.0 { |
| 97 | + total_ok as f64 / elapsed.as_secs_f64() |
| 98 | + } else { |
| 99 | + total_ok as f64 |
| 100 | + }; |
| 101 | + |
| 102 | + Ok(CallToolResult::text(format!( |
| 103 | + "Bulk load complete: {} commands executed in {:.2}s ({:.0} cmd/s)", |
| 104 | + total_ok, |
| 105 | + elapsed.as_secs_f64(), |
| 106 | + rate |
| 107 | + ))) |
| 108 | + } |
| 109 | +); |
| 110 | + |
| 111 | +database_tool!(write, seed, "redis_seed", |
| 112 | + "Declarative data generation for test/prototype data. Generates keys matching a pattern \ |
| 113 | + using Redis pipelining for high throughput.\n\n\ |
| 114 | + Supported data_type values: \"string\", \"hash\", \"sorted_set\", \"set\", \"list\", \"json\".\n\n\ |
| 115 | + Pattern substitution: use {i} for the index, {i:0N} for zero-padded (e.g. {i:06} for 6 digits).\n\n\ |
| 116 | + Examples:\n\ |
| 117 | + - String: key_pattern=\"user:{i}\", value_pattern=\"value-{i}\", count=1000\n\ |
| 118 | + - Hash: key_pattern=\"user:{i}\", field_values=[{name:\"name\",value:\"user-{i}\"},{name:\"score\",value:\"{i}\"}], count=1000\n\ |
| 119 | + - Sorted set: key_pattern=\"leaderboard\", member_pattern=\"player-{i:06}\", count=10000, score_min=0, score_max=10000\n\ |
| 120 | + - JSON: key_pattern=\"doc:{i}\", value_pattern='{\"id\":{i},\"name\":\"item-{i}\"}', count=1000", |
| 121 | + { |
| 122 | + /// Data type to generate: "string", "hash", "sorted_set", "set", "list", "json" |
| 123 | + pub data_type: String, |
| 124 | + /// Key pattern with {i} placeholder for index (e.g. "user:{i}", "shard:{i:04}:data") |
| 125 | + pub key_pattern: String, |
| 126 | + /// Number of items to generate |
| 127 | + #[serde(deserialize_with = "serde_helpers::string_or_u64::deserialize")] |
| 128 | + pub count: u64, |
| 129 | + /// For hash type: field-value pairs to set on each key. Supports {i} in both name and value. |
| 130 | + #[serde(default)] |
| 131 | + pub field_values: Option<Vec<FieldValue>>, |
| 132 | + /// For sorted_set/set/list: member pattern with {i} (e.g. "member-{i:08}") |
| 133 | + #[serde(default)] |
| 134 | + pub member_pattern: Option<String>, |
| 135 | + /// For string/json: value pattern with {i} |
| 136 | + #[serde(default)] |
| 137 | + pub value_pattern: Option<String>, |
| 138 | + /// For sorted_set: minimum score (default: 0.0) |
| 139 | + #[serde(default)] |
| 140 | + pub score_min: Option<f64>, |
| 141 | + /// For sorted_set: maximum score (default: count) |
| 142 | + #[serde(default)] |
| 143 | + pub score_max: Option<f64>, |
| 144 | + /// Optional TTL in seconds (applied to string, hash, and json types) |
| 145 | + #[serde(default, deserialize_with = "serde_helpers::string_or_opt_u64::deserialize")] |
| 146 | + pub ttl: Option<u64>, |
| 147 | + /// Pipeline batch size (default: 1000) |
| 148 | + #[serde(default = "default_batch_size", deserialize_with = "serde_helpers::string_or_usize::deserialize")] |
| 149 | + pub batch_size: usize, |
| 150 | + } => |conn, input| { |
| 151 | + let batch_size = input.batch_size.max(1); |
| 152 | + let count = input.count; |
| 153 | + let data_type = input.data_type.to_lowercase(); |
| 154 | + let start = Instant::now(); |
| 155 | + let mut total_commands = 0usize; |
| 156 | + |
| 157 | + // Validate data type |
| 158 | + match data_type.as_str() { |
| 159 | + "string" | "hash" | "sorted_set" | "set" | "list" | "json" => {} |
| 160 | + _ => { |
| 161 | + return Err(tower_mcp::Error::tool(format!( |
| 162 | + "Invalid data_type '{}'. Valid types: string, hash, sorted_set, set, list, json", |
| 163 | + input.data_type |
| 164 | + ))); |
| 165 | + } |
| 166 | + } |
| 167 | + |
| 168 | + // Validate required fields per type |
| 169 | + match data_type.as_str() { |
| 170 | + "string" => { |
| 171 | + if input.value_pattern.is_none() { |
| 172 | + return Err(tower_mcp::Error::tool( |
| 173 | + "value_pattern is required for string type" |
| 174 | + )); |
| 175 | + } |
| 176 | + } |
| 177 | + "hash" => { |
| 178 | + if input.field_values.as_ref().is_none_or(|f| f.is_empty()) { |
| 179 | + return Err(tower_mcp::Error::tool( |
| 180 | + "field_values with at least one entry is required for hash type" |
| 181 | + )); |
| 182 | + } |
| 183 | + } |
| 184 | + "sorted_set" | "set" | "list" => { |
| 185 | + if input.member_pattern.is_none() { |
| 186 | + return Err(tower_mcp::Error::tool(format!( |
| 187 | + "member_pattern is required for {} type", |
| 188 | + data_type |
| 189 | + ))); |
| 190 | + } |
| 191 | + } |
| 192 | + "json" => { |
| 193 | + if input.value_pattern.is_none() { |
| 194 | + return Err(tower_mcp::Error::tool( |
| 195 | + "value_pattern is required for json type" |
| 196 | + )); |
| 197 | + } |
| 198 | + } |
| 199 | + _ => unreachable!(), |
| 200 | + } |
| 201 | + |
| 202 | + let score_min = input.score_min.unwrap_or(0.0); |
| 203 | + let score_max = input.score_max.unwrap_or(count as f64); |
| 204 | + |
| 205 | + // Generate commands in batches |
| 206 | + let indices: Vec<u64> = (0..count).collect(); |
| 207 | + for chunk in indices.chunks(batch_size) { |
| 208 | + let mut pipe = redis::pipe(); |
| 209 | + |
| 210 | + for &i in chunk { |
| 211 | + let key = substitute_pattern(&input.key_pattern, i); |
| 212 | + |
| 213 | + match data_type.as_str() { |
| 214 | + "string" => { |
| 215 | + let value = substitute_pattern(input.value_pattern.as_ref().unwrap(), i); |
| 216 | + let mut cmd = redis::cmd("SET"); |
| 217 | + cmd.arg(&key).arg(&value); |
| 218 | + pipe.add_command(cmd).ignore(); |
| 219 | + total_commands += 1; |
| 220 | + |
| 221 | + if let Some(ttl) = input.ttl { |
| 222 | + let mut cmd = redis::cmd("EXPIRE"); |
| 223 | + cmd.arg(&key).arg(ttl); |
| 224 | + pipe.add_command(cmd).ignore(); |
| 225 | + total_commands += 1; |
| 226 | + } |
| 227 | + } |
| 228 | + "hash" => { |
| 229 | + let fields = input.field_values.as_ref().unwrap(); |
| 230 | + let mut cmd = redis::cmd("HSET"); |
| 231 | + cmd.arg(&key); |
| 232 | + for fv in fields { |
| 233 | + let name = substitute_pattern(&fv.name, i); |
| 234 | + let value = substitute_pattern(&fv.value, i); |
| 235 | + cmd.arg(&name).arg(&value); |
| 236 | + } |
| 237 | + pipe.add_command(cmd).ignore(); |
| 238 | + total_commands += 1; |
| 239 | + |
| 240 | + if let Some(ttl) = input.ttl { |
| 241 | + let mut cmd = redis::cmd("EXPIRE"); |
| 242 | + cmd.arg(&key).arg(ttl); |
| 243 | + pipe.add_command(cmd).ignore(); |
| 244 | + total_commands += 1; |
| 245 | + } |
| 246 | + } |
| 247 | + "sorted_set" => { |
| 248 | + let member = substitute_pattern(input.member_pattern.as_ref().unwrap(), i); |
| 249 | + let score = if count > 1 { |
| 250 | + score_min + (score_max - score_min) * (i as f64 / (count - 1) as f64) |
| 251 | + } else { |
| 252 | + score_min |
| 253 | + }; |
| 254 | + let mut cmd = redis::cmd("ZADD"); |
| 255 | + cmd.arg(&key).arg(score).arg(&member); |
| 256 | + pipe.add_command(cmd).ignore(); |
| 257 | + total_commands += 1; |
| 258 | + } |
| 259 | + "set" => { |
| 260 | + let member = substitute_pattern(input.member_pattern.as_ref().unwrap(), i); |
| 261 | + let mut cmd = redis::cmd("SADD"); |
| 262 | + cmd.arg(&key).arg(&member); |
| 263 | + pipe.add_command(cmd).ignore(); |
| 264 | + total_commands += 1; |
| 265 | + } |
| 266 | + "list" => { |
| 267 | + let member = substitute_pattern(input.member_pattern.as_ref().unwrap(), i); |
| 268 | + let mut cmd = redis::cmd("RPUSH"); |
| 269 | + cmd.arg(&key).arg(&member); |
| 270 | + pipe.add_command(cmd).ignore(); |
| 271 | + total_commands += 1; |
| 272 | + } |
| 273 | + "json" => { |
| 274 | + let value = substitute_pattern(input.value_pattern.as_ref().unwrap(), i); |
| 275 | + let mut cmd = redis::cmd("JSON.SET"); |
| 276 | + cmd.arg(&key).arg("$").arg(&value); |
| 277 | + pipe.add_command(cmd).ignore(); |
| 278 | + total_commands += 1; |
| 279 | + |
| 280 | + if let Some(ttl) = input.ttl { |
| 281 | + let mut cmd = redis::cmd("EXPIRE"); |
| 282 | + cmd.arg(&key).arg(ttl); |
| 283 | + pipe.add_command(cmd).ignore(); |
| 284 | + total_commands += 1; |
| 285 | + } |
| 286 | + } |
| 287 | + _ => unreachable!(), |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + pipe.query_async::<()>(&mut conn) |
| 292 | + .await |
| 293 | + .tool_context("Seed pipeline failed")?; |
| 294 | + } |
| 295 | + |
| 296 | + let elapsed = start.elapsed(); |
| 297 | + let rate = if elapsed.as_secs_f64() > 0.0 { |
| 298 | + total_commands as f64 / elapsed.as_secs_f64() |
| 299 | + } else { |
| 300 | + total_commands as f64 |
| 301 | + }; |
| 302 | + |
| 303 | + Ok(CallToolResult::text(format!( |
| 304 | + "Seed complete: {} {} items seeded ({} commands) in {:.2}s ({:.0} cmd/s)\n\n\ |
| 305 | + Tip: use redis_info with section=\"memory\" to check memory impact, \ |
| 306 | + or redis_dbsize to verify key count.", |
| 307 | + count, |
| 308 | + data_type, |
| 309 | + total_commands, |
| 310 | + elapsed.as_secs_f64(), |
| 311 | + rate |
| 312 | + ))) |
| 313 | + } |
| 314 | +); |
| 315 | + |
| 316 | +#[cfg(test)] |
| 317 | +mod tests { |
| 318 | + use super::*; |
| 319 | + |
| 320 | + #[test] |
| 321 | + fn test_substitute_simple() { |
| 322 | + assert_eq!(substitute_pattern("user:{i}", 42), "user:42"); |
| 323 | + assert_eq!(substitute_pattern("no-placeholder", 5), "no-placeholder"); |
| 324 | + } |
| 325 | + |
| 326 | + #[test] |
| 327 | + fn test_substitute_padded() { |
| 328 | + assert_eq!(substitute_pattern("user-{i:06}", 42), "user-000042"); |
| 329 | + assert_eq!(substitute_pattern("key-{i:08}", 1), "key-00000001"); |
| 330 | + assert_eq!( |
| 331 | + substitute_pattern("shard-{i:02}:member-{i}", 7), |
| 332 | + "shard-07:member-7" |
| 333 | + ); |
| 334 | + } |
| 335 | + |
| 336 | + #[test] |
| 337 | + fn test_substitute_multiple() { |
| 338 | + assert_eq!( |
| 339 | + substitute_pattern("{i}-{i}-{i}", 3), |
| 340 | + "3-3-3" |
| 341 | + ); |
| 342 | + } |
| 343 | +} |
0 commit comments