diff --git a/src/tint/fuzzers/tint_structure_fuzzer/CMakeLists.txt b/src/tint/fuzzers/tint_structure_fuzzer/CMakeLists.txt new file mode 100644 index 00000000000..23e994d54b9 --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/CMakeLists.txt @@ -0,0 +1,33 @@ + +add_executable(tint_structure_fuzzer + tint_structure_fuzzer.cc + syntax.cc + syntax.h + probabilities.h + ../mersenne_twister_engine.cc + ../mersenne_twister_engine.h + ../random_generator.cc + ../random_generator.h + ../tint_common_fuzzer.cc + ../tint_common_fuzzer.h + ../random_generator_engine.cc + ../random_generator_engine.h +) + +if (NOT WIN32) + set_source_files_properties(syntax.cc PROPERTIES COMPILE_FLAGS -O1) +endif() + +tint_fuzzer_compile_options(tint_structure_fuzzer) +tint_spvtools_compile_options(tint_structure_fuzzer) +target_include_directories(tint_structure_fuzzer PRIVATE ${CMAKE_BINARY_DIR}) +target_link_libraries(tint_structure_fuzzer PRIVATE tint_lang_hlsl_writer_helpers) +target_link_libraries(tint_structure_fuzzer PRIVATE tint_lang_msl_writer_helpers) +target_link_libraries(tint_structure_fuzzer PRIVATE tint_lang_spirv_writer_helpers) + +tint_core_compile_options(tint_structure_fuzzer) + +if (TINT_STRUCTURE_FUZZER_SANITIZERS) + target_compile_options(tint_structure_fuzzer PRIVATE -fsanitize=address -fsanitize=undefined) + target_link_options(tint_structure_fuzzer PRIVATE -fsanitize=address -fsanitize=undefined) +endif () diff --git a/src/tint/fuzzers/tint_structure_fuzzer/blocklist.txt b/src/tint/fuzzers/tint_structure_fuzzer/blocklist.txt new file mode 100644 index 00000000000..2989055bedb --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/blocklist.txt @@ -0,0 +1,4 @@ +src:tint_structure_fuzzer/syntax.cc +src:tint_structure_fuzzer/syntax.h +src:tint_structure_fuzzer/tint_structure_fuzzer.cc +src:tint_structure_fuzzer/probabilities.h diff --git a/src/tint/fuzzers/tint_structure_fuzzer/optimize.py b/src/tint/fuzzers/tint_structure_fuzzer/optimize.py new file mode 100644 index 00000000000..d2914cd5da8 --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/optimize.py @@ -0,0 +1,34 @@ +import subprocess +import re +import random +import numpy as np +from scipy.optimize import differential_evolution + +np.set_printoptions(precision=0, suppress=True, formatter={'float_kind': lambda x: str(int(x))}) + +def print_stats(xk, convergence): + print(f"Current best solution: {xk} convergence measure: {convergence:.6f}") + +def fn(args): + args = [ + "./tint_structure_fuzzer", + "--prob="+",".join(str(int(x)) for x in args), + "-cross_over=0", + "-mutate_depth=1", + "-max_total_time=30", + "-print_funcs=0", + "-fsanitize-coverage-ignorelist=./blocklist.txt" + ] + print(f"Calling {args}...") + result = subprocess.run(args, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + out = str(result.stderr) + matches = re.findall(r"cov: (\d+)", out) + if not matches: + raise ValueError("Coverage result not found") + result = int(matches[-1]) + print(f"Coverage = {result}") + return -result + +print("tint_structure_fuzzer optimizer") +val = differential_evolution(fn, bounds=[(0, 1000)] * 8, maxiter=100, popsize=10, callback=print_stats) +print(f"Best solution is {val.x} -> {-val.fun}") diff --git a/src/tint/fuzzers/tint_structure_fuzzer/probabilities.h b/src/tint/fuzzers/tint_structure_fuzzer/probabilities.h new file mode 100644 index 00000000000..bc97037de4f --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/probabilities.h @@ -0,0 +1,43 @@ +#ifndef SRC_TINT_FUZZERS_TINT_STRUCTURE_FUZZER_PROBABILITIES_H_ +#define SRC_TINT_FUZZERS_TINT_STRUCTURE_FUZZER_PROBABILITIES_H_ + +#include +#include +#include "src/tint/fuzzers/random_generator.h" + +namespace tint::fuzzers::structure_fuzzer { + +struct Probabilities { + Probabilities(std::vector values_) : values(std::move(values_)) { + unsigned sum = 0; + for (unsigned& v : values) { + unsigned s = v; + v = sum; + assert(sum + s > sum); + sum += s; + } + values.push_back(sum); + } + + size_t size() const { return values.size() - 1; } + + unsigned sum() const { return values.back(); } + + template + T sample(RandomGenerator& gen) const { + return static_cast(sample(gen)); + } + + unsigned sample(RandomGenerator& gen) const { + unsigned v = gen.GetUInt32(sum()); + auto it = std::upper_bound(values.begin(), values.end(), v); + assert(it != values.begin()); + --it; + return static_cast(std::distance(values.begin(), it)); + } + std::vector values; +}; + +} // namespace tint::fuzzers::structure_fuzzer + +#endif diff --git a/src/tint/fuzzers/tint_structure_fuzzer/pso.py b/src/tint/fuzzers/tint_structure_fuzzer/pso.py new file mode 100644 index 00000000000..38af5eb452c --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/pso.py @@ -0,0 +1,129 @@ +import subprocess +import re +import random +import numpy as np +import argparse + +class Particle: + def __init__(self, bounds): + self.position = np.array([random.uniform(low, high) for low, high in bounds]) + self.position[self.position == 0] = random.uniform(0.01, 1) + self.velocity = np.array([random.uniform(-1, 1) for _ in bounds]) + self.best_position = self.position.copy() + self.best_score = float('inf') + +class PSO: + def __init__(self, objective_func, bounds, num_particles=10, max_iter=100): + self.objective_func = objective_func + self.bounds = bounds + self.num_particles = num_particles + self.max_iter = max_iter + + # PSO parameters + self.w = 0.7 + self.c1 = 2.0 + self.c2 = 2.0 + self.particles = [Particle(bounds) for _ in range(num_particles)] + self.global_best_position = None + self.global_best_score = float('inf') + + def optimize(self, callback=None): + best_coverage_per_iteration = [] + + for iteration in range(self.max_iter): + iteration_best_score = float('inf') + + for particle in self.particles: + score = self.objective_func(particle.position) + + # Update personal best + if score < particle.best_score: + particle.best_score = score + particle.best_position = particle.position.copy() + + if score < self.global_best_score: + self.global_best_score = score + self.global_best_position = particle.position.copy() + + iteration_best_score = -self.global_best_score # Since we're maximizing the coverage + best_coverage_per_iteration.append(iteration_best_score) + + for particle in self.particles: + r1, r2 = random.random(), random.random() + + cognitive = self.c1 * r1 * (particle.best_position - particle.position) + social = self.c2 * r2 * (self.global_best_position - particle.position) + particle.velocity = (self.w * particle.velocity + cognitive + social) + + particle.position = particle.position + particle.velocity + + particle.position = np.clip(particle.position, + [b[0] for b in self.bounds], + [b[1] for b in self.bounds]) + + particle.position[particle.position == 0] = random.uniform(0.01, 1) + + if callback: + callback(self.global_best_position, iteration) + + print(f"Iteration {iteration}: Best coverage = {iteration_best_score}") + + return self.global_best_position, self.global_best_score + +def print_stats(xk, iteration): + print(f"Iteration {iteration}: Current best solution: {xk}") + +def fn(args, max_time): + args = np.round(args).astype(int) + args[args == 0] = random.randint(1, 1000) # Replace zeros with a random value + + cmd_args = [ + "./tint_structure_fuzzer", + "--prob=" + ",".join(str(x) for x in args), + "-cross_over=0", + "-mutate_depth=1", + f"-max_total_time={max_time}", + "-print_funcs=0", + "-fsanitize-coverage-ignorelist=./blocklist.txt", + "-jobs=1" + ] + print(f"Calling {cmd_args}...") + + result = subprocess.run(cmd_args, stderr=subprocess.PIPE, stdout=subprocess.PIPE) + out = str(result.stderr) + matches = re.findall(r"cov: (\d+)", out) + + if not matches: + raise ValueError("Coverage result not found") + + result = int(matches[-1]) + print(f"Coverage = {result}") + return -result + +def main(): + parser = argparse.ArgumentParser(description='Tint Structure Fuzzer Optimizer using PSO') + parser.add_argument('--max_time', type=int, default=120, + help='Maximum total time for each fuzzing run (default: 120)') + args = parser.parse_args() + + print("tint_structure_fuzzer optimizer (PSO)") + + # Define bounds for 8 parameters + bounds = [(0, 1000)] * 8 + + # Create PSO optimizer + pso = PSO( + objective_func=lambda x: fn(x, args.max_time), + bounds=bounds, + num_particles=10, + max_iter=100 + ) + + best_position, best_score = pso.optimize(callback=print_stats) + + print(f"Best solution is {best_position} -> {-best_score}") + +if __name__ == "__main__": + main() + + diff --git a/src/tint/fuzzers/tint_structure_fuzzer/syntax.cc b/src/tint/fuzzers/tint_structure_fuzzer/syntax.cc new file mode 100644 index 00000000000..aefe05db092 --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/syntax.cc @@ -0,0 +1,1262 @@ +#include "syntax.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "src/tint/fuzzers/random_generator.h" +#include "src/tint/utils/math/crc32.h" + +#define INLINE __attribute__((always_inline)) + +namespace tint::fuzzers::structure_fuzzer { + +namespace { + +struct Context { + std::unordered_map vars; + RandomGenerator& gen; + int nextVarId = 0; + + explicit Context(RandomGenerator& generator) : gen(generator) {} + + std::string createVariable(const std::string& type) { + std::string name = "v" + std::to_string(nextVarId++); + vars[name] = type; + return name; + } + + std::string getRandomVariable() { + if (vars.empty()) return ""; + auto it = vars.begin(); + std::advance(it, gen.GetUInt32(vars.size())); + return it->first; + } + + bool shouldUseVariable() { + return !vars.empty() && gen.GetUInt32(2) == 0; + } +}; + +struct VariableContext { + std::unordered_map vars; + int nextVarId = 0; + + std::string createVariable(const std::string& type) { + std::string name = "v" + std::to_string(nextVarId++); + vars[name] = type; + return name; + } + + std::string getRandomVariable(RandomGenerator& gen) { + if (vars.empty()) return ""; + auto it = vars.begin(); + std::advance(it, gen.GetUInt32(vars.size())); + return it->first; + } +}; + +struct ByteInputRnd { + RandomGenerator& gen; + Context& ctx; + + ByteInputRnd(RandomGenerator& g, Context& c) : gen(g), ctx(c) {} + + uint8_t byteTerm() { return gen.GetUInt32(256); } + + uint8_t byte() { return 0; } + + uint32_t range(uint32_t limit, bool repeat) INLINE { + if (repeat) { + return 0; + } + if (limit == 1) { + return 0; + } + uint32_t x = gen.GetUInt32(INT32_MAX); + float f = static_cast(x) / static_cast(INT32_MAX); + f = std::pow(f, 2.2f); + return std::clamp(static_cast(f * limit), 0u, limit - 1); + } +}; + +static void printHex(const uint8_t* data, size_t size) { + for (size_t i = 0; i < size; ++i) { + printf("%02X", data[i]); + } + printf("\n"); +} + +struct ByteInput { + const uint8_t* data; + size_t size; + size_t used; + bool acceptTruncated; + Context& ctx; + + ByteInput(const uint8_t* d, size_t s, size_t u, bool a, Context& c) + : data(d), size(s), used(u), acceptTruncated(a), ctx(c) {} + + void reset() { used = 0; } + + uint8_t byte() { + uint8_t result = 0; + if (used < size) { + result = data[used]; + ++used; + } else { + ++used; + } + return result; + } + + uint8_t byteTerm() { + if (used < size) { + return byte(); + } + return ctx.gen.GetUInt32(256); + } + + uint32_t range(uint32_t limit, bool repeat) INLINE { + if (limit == 1) { + return 0; + } + assert(limit <= 256); + return std::min(static_cast(byte()), limit - 1u); + } +}; + + +struct ByteOutput { + std::vector out; + void push(uint8_t val) { out.push_back(val); } +}; + +struct ByteOutputNull { + void push(uint8_t val) {} +}; + +struct TextOutput { + std::stringstream buffer; + char last = 0; + Context& ctx; + + explicit TextOutput(Context& c) : ctx(c) {} + + void raw(std::string_view s) INLINE { + if (s.empty()) { + return; + } + if (bool(std::isalnum(last)) == bool(std::isalnum(s.front()))) { + buffer << " "; + } + buffer << s; + last = s.back(); + } + + void ident(int n, std::string_view prefix = "x") INLINE { + raw(std::string(prefix) + std::to_string(n)); + } +}; + +enum class NodeId { + translation_unit = 0, // Implicit root node + additive_operator, + expression_list, + argument_expression_list, + assignment_statement, + attribute, + bitwise_expression_post_unary_expression_1, + bitwise_expression_post_unary_expression_2, + bitwise_expression_post_unary_expression_3, + bitwise_expression_post_unary_expression, + bool_literal, + case_selector, + component_or_swizzle_specifier, + compound_assignment_operator, + compound_statement, + core_lhs_expression, + decimal_float_literal, + decimal_int_literal, + diagnostic_control, + expression_1, + expression_2, + expression, + float_literal, + for_init, + for_update, + assign_expression, + comma_param, + global_decl_1, + return_type, + comma_struct_field, + global_decl, + comma_ident_pattern_token_1, + comma_ident_pattern_token_2, + global_directive, + global_value_decl, + hex_float_literal, + int_literal, + lhs_expression, + literal, + member_ident, + multiplicative_operator, + optionally_typed_ident_1, + optionally_typed_ident, + param, + primary_expression, + relational_expression_post_unary_expression, + multiplicative_operator_unary_expression, + shift_expression_post_unary_expression_1, + shift_expression_post_unary_expression, + elseif_statement, + else_statement, + breakif_statement, + continuing_statement, + statement, + switch_clause_1, + switch_clause, + swizzle_name, + template_arg_expression, + comma_expression, + unary_expression, + expression_list_angle, + variable_decl, + variable_or_value_statement, + variable_updating_statement, + + last = variable_updating_statement, +}; + +using GenerateFn0 = std::function; +using GenerateFn = std::function; + +GenerateFn0 emit(std::string_view sv); + +enum class IdentType { + Type, + UserType, + Function, + UserFunction, + Variable, + Other, +}; + +GenerateFn ident(IdentType type); + +enum class KeywordList { + diagnostic_severity, + requires_extensions, + address_space, +}; + +GenerateFn keywords(KeywordList list); + +void floatLiteral(TextOutput&); +void floatHexLiteral(TextOutput&); +void decimalLiteral(TextOutput&); +void hexLiteral(TextOutput&); + +struct Subnode { + Subnode(void (*fn)(uint8_t, TextOutput&), char mod = 0) : Subnode(GenerateFn(fn), mod) {} + Subnode(void (*fn)(TextOutput&), char mod = 0) : Subnode(GenerateFn0(fn), mod) {} + Subnode(GenerateFn fn, char mod = 0) : content(fn), mod(mod) {} + Subnode(GenerateFn0 fn, char mod = 0) : content(fn), mod(mod) {} + Subnode(NodeId node, char mod = 0) : content(node), mod(mod) {} + std::variant content; + char mod = 0; // '*', '?' +}; + +struct Modifier { + char v; +}; + +constexpr inline int maxDepth = 16; +constexpr inline int maxRepeats = 5; + +constexpr inline Modifier many{'*'}; +constexpr inline Modifier optional{'?'}; + +inline Subnode operator|(Subnode subnode, Modifier mod) { + subnode.mod = mod.v; + return subnode; +} + +struct Node : public std::vector> { + Node(std::initializer_list> list) INLINE + : std::vector>(list) { + assert(this->size() >= 1); + } +}; + +constexpr inline int limit = 9; + +const auto& nodes() { + static std::array instance{ + // translation_unit: + Node{ + {NodeId::global_decl | many}, + }, + // additive_operator: + Node{ + {emit("+")}, + {emit("-")}, + }, + // expression_list: + Node{ + {NodeId::expression, NodeId::comma_expression | many, emit(",") | optional}, + }, + // argument_expression_list: + Node{ + {emit("("), NodeId::expression_list | optional, emit(")")}, + }, + // assignment_statement: + Node{ + {NodeId::compound_assignment_operator}, + {emit("=")}, + }, + // attribute: + Node{ + {emit("@"), emit("compute")}, + {emit("@"), emit("const")}, + {emit("@"), emit("fragment")}, + {emit("@"), emit("interpolate"), emit("("), ident(IdentType::Other), + emit(",") | optional, emit(")")}, + {emit("@"), emit("interpolate"), emit("("), ident(IdentType::Other), emit(","), + ident(IdentType::Other), emit(",") | optional, emit(")")}, + {emit("@"), emit("invariant")}, + {emit("@"), emit("must_use")}, + {emit("@"), emit("vertex")}, + {emit("@"), emit("workgroup_size"), emit("("), NodeId::expression, emit(",") | optional, + emit(")")}, + {emit("@"), emit("workgroup_size"), emit("("), NodeId::expression, emit(","), + NodeId::expression, emit(",") | optional, emit(")")}, + {emit("@"), emit("workgroup_size"), emit("("), NodeId::expression, emit(","), + NodeId::expression, emit(","), NodeId::expression, emit(",") | optional, emit(")")}, + {emit("@"), emit("align"), emit("("), NodeId::expression, emit(",") | optional, + emit(")")}, + {emit("@"), emit("binding"), emit("("), NodeId::expression, emit(",") | optional, + emit(")")}, + {emit("@"), emit("blend_src"), emit("("), NodeId::expression, emit(",") | optional, + emit(")")}, + {emit("@"), emit("builtin"), emit("("), ident(IdentType::Other), emit(",") | optional, + emit(")")}, + {emit("@"), emit("diagnostic"), NodeId::diagnostic_control}, + {emit("@"), emit("group"), emit("("), NodeId::expression, emit(",") | optional, + emit(")")}, + {emit("@"), emit("id"), emit("("), NodeId::expression, emit(",") | optional, emit(")")}, + {emit("@"), emit("location"), emit("("), NodeId::expression, emit(",") | optional, + emit(")")}, + {emit("@"), emit("size"), emit("("), NodeId::expression, emit(",") | optional, + emit(")")}, + }, + // bitwise_expression_post_unary_expression_1: + Node{ + {emit("&"), NodeId::unary_expression}, + }, + // bitwise_expression_post_unary_expression_2: + Node{ + {emit("^"), NodeId::unary_expression}, + }, + // bitwise_expression_post_unary_expression_3: + Node{ + {emit("|"), NodeId::unary_expression}, + }, + // bitwise_expression_post_unary_expression: + Node{ + {emit("&"), NodeId::unary_expression, + NodeId::bitwise_expression_post_unary_expression_1 | many}, + {emit("^"), NodeId::unary_expression, + NodeId::bitwise_expression_post_unary_expression_2 | many}, + {emit("|"), NodeId::unary_expression, + NodeId::bitwise_expression_post_unary_expression_3 | many}, + }, + // bool_literal: + Node{ + {emit("false")}, + {emit("true")}, + }, + // case_selector: + Node{ + {NodeId::expression}, + {emit("default")}, + }, + // component_or_swizzle_specifier: + Node{ + {emit("."), NodeId::member_ident, NodeId::component_or_swizzle_specifier | optional}, + {emit("."), NodeId::swizzle_name, NodeId::component_or_swizzle_specifier | optional}, + {emit("["), NodeId::expression, emit("]"), + NodeId::component_or_swizzle_specifier | optional}, + }, + // compound_assignment_operator: + Node{ + {emit("<<=")}, + {emit(">>=")}, + {emit("%=")}, + {emit("&=")}, + {emit("*=")}, + {emit("+=")}, + {emit("-=")}, + {emit("/=")}, + {emit("^=")}, + {emit("|=")}, + }, + // compound_statement: + Node{ + {NodeId::attribute | many, emit("{"), NodeId::statement | many, emit("}")}, + }, + // core_lhs_expression: + Node{ + {ident(IdentType::Other)}, + {emit("("), NodeId::lhs_expression, emit(")")}, + }, + // decimal_float_literal: + Node{ + {&floatLiteral}, + }, + // decimal_int_literal: + Node{ + {&decimalLiteral}, + }, + // diagnostic_control: + Node{ + {emit("("), keywords(KeywordList::diagnostic_severity), emit(","), + emit("derivative_uniformity"), emit(",") | optional, emit(")")}, + }, + // expression_1: + Node{ + {emit("&&"), NodeId::unary_expression, + NodeId::relational_expression_post_unary_expression}, + }, + // expression_2: + Node{ + {emit("||"), NodeId::unary_expression, + NodeId::relational_expression_post_unary_expression}, + }, + // expression: + Node{ + {NodeId::unary_expression, NodeId::bitwise_expression_post_unary_expression}, + {NodeId::unary_expression, NodeId::relational_expression_post_unary_expression}, + {NodeId::unary_expression, NodeId::relational_expression_post_unary_expression, + emit("&&"), NodeId::unary_expression, + NodeId::relational_expression_post_unary_expression, NodeId::expression_1 | many}, + {NodeId::unary_expression, NodeId::relational_expression_post_unary_expression, + emit("||"), NodeId::unary_expression, + NodeId::relational_expression_post_unary_expression, NodeId::expression_2 | many}, + }, + // float_literal: + Node{ + {floatLiteral}, + {floatHexLiteral}, + }, + // for_init: + Node{ + {ident(IdentType::Variable), NodeId::argument_expression_list}, + {NodeId::variable_or_value_statement}, + {NodeId::variable_updating_statement}, + }, + // for_update: + Node{ + {ident(IdentType::Variable), NodeId::argument_expression_list}, + {NodeId::variable_updating_statement}, + }, + // assign_expression: + Node{ + {emit("="), NodeId::expression}, + }, + // comma_param: + Node{ + {emit(","), NodeId::param}, + }, + // global_decl_1: + Node{ + {NodeId::attribute | many, ident(IdentType::Variable), emit(":"), + ident(IdentType::Type), NodeId::comma_param | many, emit(",") | optional}, + }, + // return_type: + Node{ + {emit("->"), NodeId::attribute | many, ident(IdentType::Type)}, + }, + // comma_struct_field: + Node{ + {emit(","), NodeId::attribute | many, NodeId::member_ident, emit(":"), + ident(IdentType::Type)}, + }, + // global_decl: + Node{ + {NodeId::attribute | many, emit("fn"), ident(IdentType::UserFunction), emit("("), + NodeId::global_decl_1 | optional, emit(")"), NodeId::return_type | optional, + NodeId::attribute | many, emit("{"), NodeId::statement | many, emit("}")}, + {NodeId::attribute | many, emit("var"), NodeId::expression_list_angle | optional, + NodeId::optionally_typed_ident, NodeId::assign_expression | optional, emit(";")}, + {NodeId::global_value_decl, emit(";")}, + {emit(";")}, + {emit("struct"), ident(IdentType::UserType), emit("{"), NodeId::attribute | many, + NodeId::member_ident, emit(":"), ident(IdentType::Type), + NodeId::comma_struct_field | many, emit(",") | optional, emit("}")}, + {emit("const_assert"), NodeId::expression, emit(";")}, + {emit("alias"), ident(IdentType::UserType), emit("="), ident(IdentType::Type), + emit(";")}, + }, + // comma_ident_pattern_token_1: + Node{ + {emit(","), emit("f16")}, + }, + // comma_ident_pattern_token_2: + Node{ + {emit(","), keywords(KeywordList::requires_extensions)}, + }, + // global_directive: + Node{ + {emit("diagnostic"), emit("("), keywords(KeywordList::diagnostic_severity), emit(","), + emit("derivative_uniformity"), emit(",") | optional, emit(")"), emit(";")}, + {emit("enable"), emit("f16"), NodeId::comma_ident_pattern_token_1 | many, + emit(",") | optional, emit(";")}, + {emit("requires"), keywords(KeywordList::requires_extensions), + NodeId::comma_ident_pattern_token_2 | many, emit(",") | optional, emit(";")}, + }, + // global_value_decl: + Node{ + {NodeId::attribute | many, emit("override"), NodeId::optionally_typed_ident, + NodeId::assign_expression | optional}, + {emit("const"), NodeId::optionally_typed_ident, NodeId::assign_expression}, + }, + // hex_float_literal: + Node{ + {&floatHexLiteral}, + }, + // int_literal: + Node{ + {&decimalLiteral}, + {hexLiteral}, + }, + // lhs_expression: + Node{ + {NodeId::core_lhs_expression, NodeId::component_or_swizzle_specifier | optional}, + {emit("&"), NodeId::lhs_expression}, + {emit("*"), NodeId::lhs_expression}, + }, + // literal: + Node{ + {NodeId::int_literal}, + {NodeId::float_literal}, + {NodeId::bool_literal}, + }, + // member_ident: + Node{ + {ident(IdentType::Variable)}, + }, + // multiplicative_operator: + Node{ + {emit("*")}, + {emit("/")}, + {emit("%")}, + }, + // optionally_typed_ident_1: + Node{ + {emit(":"), ident(IdentType::Type)}, + }, + // optionally_typed_ident: + Node{ + {ident(IdentType::Variable), NodeId::optionally_typed_ident_1 | optional}, + }, + // param: + Node{ + {NodeId::attribute | many, ident(IdentType::Variable), emit(":"), + ident(IdentType::Type)}, + }, + // primary_expression: + Node{ + {NodeId::literal}, + {ident(IdentType::Variable)}, + {ident(IdentType::Function), NodeId::argument_expression_list}, + {emit("("), NodeId::expression, emit(")")}, + {ident(IdentType::Type), NodeId::argument_expression_list}, + }, + // relational_expression_post_unary_expression: + Node{ + {NodeId::shift_expression_post_unary_expression, emit("=="), NodeId::unary_expression, + NodeId::shift_expression_post_unary_expression}, + {NodeId::shift_expression_post_unary_expression, emit("!="), NodeId::unary_expression, + NodeId::shift_expression_post_unary_expression}, + {NodeId::shift_expression_post_unary_expression}, + {NodeId::shift_expression_post_unary_expression, emit(">"), NodeId::unary_expression, + NodeId::shift_expression_post_unary_expression}, + {NodeId::shift_expression_post_unary_expression, emit(">="), NodeId::unary_expression, + NodeId::shift_expression_post_unary_expression}, + {NodeId::shift_expression_post_unary_expression, emit("<"), NodeId::unary_expression, + NodeId::shift_expression_post_unary_expression}, + {NodeId::shift_expression_post_unary_expression, emit("<="), NodeId::unary_expression, + NodeId::shift_expression_post_unary_expression}, + }, + // multiplicative_operator_unary_expression: + Node{ + {NodeId::multiplicative_operator, NodeId::unary_expression}, + }, + // shift_expression_post_unary_expression_1: + Node{ + {NodeId::additive_operator, NodeId::unary_expression, + NodeId::multiplicative_operator_unary_expression | many}, + }, + // shift_expression_post_unary_expression: + Node{ + {NodeId::multiplicative_operator_unary_expression | many, + NodeId::shift_expression_post_unary_expression_1 | many}, + {emit("<<"), NodeId::unary_expression}, + {emit(">>"), NodeId::unary_expression}, + }, + // elseif_statement: + Node{ + {emit("else"), emit("if"), NodeId::expression, NodeId::compound_statement}, + }, + // else_statement: + Node{ + {emit("else"), NodeId::compound_statement}, + }, + // breakif_statement: + Node{ + {emit("break"), emit("if"), NodeId::expression, emit(";")}, + }, + // continuing_statement: + Node{ + {emit("continuing"), NodeId::attribute | many, emit("{"), NodeId::statement | many, + NodeId::breakif_statement | optional, emit("}")}, + }, + // statement: + Node{ + {emit("return"), NodeId::expression, emit(";")}, + {NodeId::variable_or_value_statement, emit(";")}, + {NodeId::variable_updating_statement, emit(";")}, + {NodeId::attribute | many, emit("if"), NodeId::expression, NodeId::compound_statement, + NodeId::elseif_statement | many, NodeId::else_statement | optional}, + {NodeId::attribute | many, emit("for"), emit("("), NodeId::for_init | optional, + emit(";"), NodeId::expression | optional, emit(";"), NodeId::for_update | optional, + emit(")"), NodeId::compound_statement}, + {emit("return"), emit(";")}, + {NodeId::attribute | many, emit("loop"), NodeId::attribute | many, emit("{"), + NodeId::statement | many, NodeId::continuing_statement | optional, emit("}")}, + {NodeId::attribute | many, emit("switch"), NodeId::expression, NodeId::attribute | many, + emit("{"), NodeId::switch_clause | many, emit("}")}, + {NodeId::attribute | many, emit("while"), NodeId::expression, + NodeId::compound_statement}, + {NodeId::compound_statement}, + {ident(IdentType::Type), NodeId::argument_expression_list, emit(";")}, + {emit("break"), emit(";")}, + {emit("continue"), emit(";")}, + {emit("const_assert"), NodeId::expression, emit(";")}, + {emit("discard"), emit(";")}, + {emit(";")}, + }, + // switch_clause_1: + Node{ + {emit(","), NodeId::case_selector}, + }, + // switch_clause: + Node{ + {emit("case"), NodeId::case_selector, NodeId::switch_clause_1 | many, + emit(",") | optional, emit(":") | optional, NodeId::compound_statement}, + {emit("default"), emit(":") | optional, NodeId::compound_statement}, + }, + // swizzle_name: + Node{ + {emit("x")}, + {emit("xy")}, + {emit("xyz")}, + {emit("xyzw")}, + {emit("r")}, + {emit("rg")}, + {emit("rgb")}, + {emit("rgba")}, + {emit("x")}, + {emit("xx")}, + {emit("xxx")}, + {emit("xxxx")}, + }, + // template_arg_expression: + Node{ + {NodeId::expression}, + }, + // comma_expression: + Node{ + {emit(","), NodeId::expression}, + }, + // unary_expression: + Node{ + {NodeId::primary_expression, NodeId::component_or_swizzle_specifier | optional}, + {emit("!"), NodeId::unary_expression}, + {emit("&"), NodeId::unary_expression}, + {emit("*"), NodeId::unary_expression}, + {emit("-"), NodeId::unary_expression}, + {emit("~"), NodeId::unary_expression}, + }, + // expression_list_angle: + Node{ + {emit("<"), keywords(KeywordList::address_space), emit(">")}, + }, + // variable_decl: + Node{ + {emit("var"), NodeId::expression_list_angle | optional, NodeId::optionally_typed_ident}, + }, + // variable_or_value_statement: + Node{ + // {NodeId::variable_decl}, + {NodeId::variable_decl, NodeId::assign_expression}, + {emit("const"), NodeId::optionally_typed_ident, NodeId::assign_expression}, + {emit("let"), NodeId::optionally_typed_ident, NodeId::assign_expression}, + }, + // variable_updating_statement: + Node{ + {NodeId::lhs_expression, NodeId::assign_expression}, + {NodeId::lhs_expression, NodeId::compound_assignment_operator, NodeId::expression}, + {NodeId::lhs_expression, emit("++")}, + {NodeId::lhs_expression, emit("--")}, + {emit("_"), NodeId::assign_expression}, + }, + }; + static_assert(instance.size() == static_cast(NodeId::last) + 1); + return instance; +} + +GenerateFn0 emit(std::string_view string) { + return [string](TextOutput& out) INLINE { out.raw(string); }; +} + +GenerateFn keywords(KeywordList list) { + return [list](uint8_t value, TextOutput& out) { + switch (list) { + case KeywordList::diagnostic_severity: { + out.raw(std::array{"off", "error", "warning", "info"}[value % 4]); + break; + } + case KeywordList::requires_extensions: { + out.raw(std::array{"packed_4x8_integer_dot_product", + "pointer_composite_access"}[value % 2]); + break; + } + case KeywordList::address_space: { + out.raw(std::array{"function", "private", "workgroup", "uniform", + "storage"}[value % 5]); + break; + } + } + }; +} + +GenerateFn ident(IdentType type) { + return [type](uint8_t value, TextOutput& out) INLINE { + constexpr std::string_view funcs[] = { + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atanh", + "atan2", + "ceil", + "clamp", + "cos", + "cosh", + "countLeadingZeros", + "countOneBits", + "countTrailingZeros", + "cross", + "degrees", + "determinant", + "distance", + "dot", + "dot4U8Packed", + "dot4I8Packed", + "exp", + "exp2", + "extractBits", + "faceForward", + "firstLeadingBit", + "firstTrailingBit", + "floor", + "fma", + "fract", + "frexp", + "insertBits", + "inverseSqrt", + "ldexp", + "length", + "log", + "log2", + "max", + "min", + "mix", + "modf", + "normalize", + "pow", + "radians", + "reflect", + "refract", + "reverseBits", + "round", + "saturate", + "sign", + "sin", + "sinh", + "smoothstep", + "sqrt", + "step", + "tan", + "tanh", + "transpose", + "trunc", + "dpdx", + "dpdxCoarse", + "dpdxFine", + "dpdy", + "dpdyCoarse", + "dpdyFine", + "fwidth", + "fwidthCoarse", + "fwidthFine", + "textureDimensions", + "textureGather", + "textureGatherCompare", + "textureLoad", + "textureNumLayers", + "textureNumLevels", + "textureNumSamples", + "textureSample", + "textureSampleBias", + "textureSampleCompare", + "textureSampleCompareLevel", + "textureSampleGrad", + "textureSampleLevel", + "textureSampleBaseClampToEdge", + "textureStore", + "atomicLoad", + "atomicStore", + "atomicAdd", + "atomicSub", + "atomicMax", + "atomicMin", + "atomicAnd", + "atomicOr", + "atomicXor", + "atomicExchange", + "atomicCompareExchangeWeak", + "pack4x8snorm", + "pack4x8unorm", + "pack4xI8", + "pack4xU8", + "pack4xI8Clamp", + "pack4xU8Clamp", + "pack2x16snorm", + "pack2x16unorm", + "pack2x16float", + "unpack4x8snorm", + "unpack4x8unorm", + "unpack4xI8", + "unpack4xU8", + "unpack2x16snorm", + "unpack2x16unorm", + "unpack2x16float", + "storageBarrier", + "textureBarrier", + "workgroupBarrier", + "workgroupUniformLoad", + }; + constexpr std::string_view types[] = { + "bool", + "vec2", + "vec3", + "vec4", + "u32", + "vec2", + "vec3", + "vec4", + "i32", + "vec2", + "vec3", + "vec4", + "f32", + "vec2", + "vec3", + "vec4", + "mat2x2", + "mat2x3", + "mat2x4", + "mat3x2", + "mat3x3", + "mat3x4", + "mat4x2", + "mat4x3", + "mat4x4", + "array", + "array", + "array", + "array", + "array", + "array", + "array", + "array", + }; + if (type == IdentType::Type || type == IdentType::UserType) { + if (type == IdentType::UserType || value < 12) { + out.ident(value, "t"); + } else { + out.raw(types[value % std::size(types)]); + } + } else if (type == IdentType::Function || type == IdentType::UserFunction) { + if (type == IdentType::UserFunction || value < 12) { + out.ident(value, "f"); + } else { + out.raw(funcs[value % std::size(funcs)]); + } + } else { + out.ident(value, "x"); + } + }; +} + +void floatLiteral(TextOutput& out) { + if (out.ctx.shouldUseVariable()) { + out.raw(out.ctx.getRandomVariable()); + } else { + out.raw("3.1416"); + std::string var = out.ctx.createVariable("f32"); + out.raw(" /* stored in " + var + " */"); + } +} + +void floatHexLiteral(TextOutput& out) { + if (out.ctx.shouldUseVariable()) { + out.raw(out.ctx.getRandomVariable()); + } else { + out.raw("0x1.Fp4"); + std::string var = out.ctx.createVariable("f32"); + out.raw(" /* stored in " + var + " */"); + } +} + +void decimalLiteral(TextOutput& out) { + if (out.ctx.shouldUseVariable()) { + out.raw(out.ctx.getRandomVariable()); + } else { + out.raw(std::to_string(out.ctx.gen.GetUInt32(1000))); + std::string var = out.ctx.createVariable("i32"); + out.raw(" /* stored in " + var + " */"); + } +} + +void hexLiteral(TextOutput& out) { + if (out.ctx.shouldUseVariable()) { + out.raw(out.ctx.getRandomVariable()); + } else { + out.raw("0x" + std::to_string(out.ctx.gen.GetUInt32(0xFFFF))); + std::string var = out.ctx.createVariable("i32"); + out.raw(" /* stored in " + var + " */"); + } +} + +struct MutationStat { + int alternatives = 0; + int repeats[3]{0, 0, 0}; + int optionals[2]{0, 0}; + int terminals = 0; +}; + +template +void mutate(ByteIn& in, + ByteOut& out, + Mutation mutation, + int& index, + NodeId id, + RandomGenerator& gen, + Context& ctx, + int depth = 0); + +template +void mutateOne(ByteIn& in, + ByteOut& out, + Mutation mutation, + int& index, + const Subnode& subnode, + RandomGenerator& gen, + Context& ctx, + int depth = 0); + +template +void mutateAlt(ByteIn& in, + ByteOut& out, + Mutation mutation, + int& index, + const std::vector& subnodes, + RandomGenerator& gen, + Context& ctx, + int depth = 0); + +template +void mutateOne(ByteIn& in, + ByteOut& out, + Mutation mutation, + int& index, + const Subnode& subnode, + RandomGenerator& gen, + Context& ctx, + int depth) { + if (const GenerateFn0* fn = std::get_if(&subnode.content)) { + } else if (const GenerateFn* fn = std::get_if(&subnode.content)) { + uint8_t val = in.byteTerm(); + if (mutation == Mutation::RandomTerminal) { + if (index-- == 0) { + val = gen.GetUInt32(256); + } + } + out.push(val); + } else if (const NodeId* nodeId = std::get_if(&subnode.content)) { + mutate(in, out, mutation, index, *nodeId, gen, ctx, depth + 1); + } +} + +template +void mutateAlt(ByteIn& in, + ByteOut& out, + Mutation mutation, + int& index, + const std::vector& subnodes, + RandomGenerator& gen, + Context& ctx, + int depth) { + for (const Subnode& subnode : subnodes) { + int repetitions = 1; + int newRepetitions = 1; + if (subnode.mod == '*') { + newRepetitions = repetitions = in.range(maxRepeats + 1, true); + if (mutation == Mutation::IncRepeat && repetitions < maxRepeats) { + if (index-- == 0) { + ++newRepetitions; + } + } + if (mutation == Mutation::DecRepeat && repetitions > 0) { + if (index-- == 0) { + --newRepetitions; + } + } + out.push(newRepetitions); + } else if (subnode.mod == '?') { + newRepetitions = repetitions = in.range(2, true); + if (repetitions == 0 && mutation >= Mutation::AddOptional) { + if (index-- == 0) { + newRepetitions = 1; + } + } + if (repetitions == 1 && mutation >= Mutation::RemoveOptional) { + if (index-- == 0) { + newRepetitions = 0; + } + } + out.push(newRepetitions); + } + + for (int i = 0; i < std::min(repetitions, newRepetitions); ++i) { + mutateOne(in, out, mutation, index, subnode, gen, ctx, depth); + } + + if (newRepetitions > repetitions) { + ByteInputRnd rndIn{gen, ctx}; + for (int i = 0; i < newRepetitions - repetitions; ++i) { + mutateOne(rndIn, out, mutation, index, subnode, gen, ctx, depth); + } + } else if (newRepetitions < repetitions) { + ByteOutputNull nullOut; + for (int i = 0; i < repetitions - newRepetitions; ++i) { + mutateOne(in, nullOut, mutation, index, subnode, gen, ctx, depth); + } + } + } +} + +template +void mutate(ByteIn& in, + ByteOut& out, + Mutation mutation, + int& index, + NodeId id, + RandomGenerator& gen, + Context& ctx, + int depth) { + if (depth > maxDepth) { + return; + } + const Node& node = nodes()[static_cast(id)]; + uint8_t alternative = 0; + if (node.size() > 1) { + alternative = in.range(node.size(), false); + if (mutation >= Mutation::NextAlternative && mutation <= Mutation::RandomAlternative) { + if (index-- == 0) { + ByteOutputNull nullOut; + mutateAlt(in, nullOut, mutation, index, node[alternative], gen, ctx, depth); + switch (mutation) { + case Mutation::NextAlternative: + alternative = (alternative + 1) % node.size(); + break; + case Mutation::PrevAlternative: + alternative = (alternative + node.size() - 1) % node.size(); + break; + case Mutation::RandomAlternative: { + uint8_t newAlternative = gen.GetUInt32(node.size() - 1); + alternative = newAlternative >= alternative ? newAlternative + 1 : newAlternative; + break; + } + default: + break; + } + assert(alternative < node.size()); + out.push(alternative); + ByteInputRnd rndIn{gen, ctx}; + mutateAlt(rndIn, out, mutation, index, node[alternative], gen, ctx, depth); + return; + } + } + out.push(alternative); + } + assert(alternative < node.size()); + mutateAlt(in, out, mutation, index, node[alternative], gen, ctx, depth); +} + +void count(MutationStat& stat, ByteInput& in, NodeId id, int depth = 0) { + if (depth > maxDepth) { + return; + } + const Node& node = nodes()[static_cast(id)]; + if (node.size() > 1) { + ++stat.alternatives; + } + uint8_t alternative = in.range(node.size(), false); + auto& alt = node[alternative]; + for (const Subnode& subnode : alt) { + int repetitions = 1; + if (subnode.mod == '*') { + repetitions = in.range(maxRepeats + 1, true); + if (repetitions == 0) { + ++stat.repeats[0]; + } else if (repetitions == maxRepeats) { + ++stat.repeats[2]; + } else { + ++stat.repeats[1]; + } + } else if (subnode.mod == '?') { + repetitions = in.range(2, true); + ++stat.optionals[repetitions]; + } + for (int i = 0; i < repetitions; ++i) { + if (std::get_if(&subnode.content)) { + // Do nothing + } else if (std::get_if(&subnode.content)) { + in.byteTerm(); + ++stat.terminals; + } else if (const NodeId* nodeId = std::get_if(&subnode.content)) { + count(stat, in, *nodeId, depth + 1); + } + } + } +} + +void generate(ByteInput& in, TextOutput& out, NodeId id, int depth = 0) { + if (depth > maxDepth) { + return; + } + const Node& node = nodes()[static_cast(id)]; + uint8_t alternative = in.range(node.size(), false); // alternative + auto& alt = node[alternative]; + for (const Subnode& subnode : alt) { + int repetitions = 1; + if (subnode.mod == '*') { + repetitions = in.range(maxRepeats + 1, true); // repeat + } else if (subnode.mod == '?') { + repetitions = in.range(2, true); // optional + } + for (int i = 0; i < repetitions; ++i) { + if (const GenerateFn0* fn = std::get_if(&subnode.content)) { + (*fn)(out); + } else if (const GenerateFn* fn = std::get_if(&subnode.content)) { + (*fn)(in.byteTerm(), out); + } else if (const NodeId* nodeId = std::get_if(&subnode.content)) { + generate(in, out, *nodeId, depth + 1); + } + } + } +} + +} + +std::vector WGSLMutate(Mutation mutation, + const uint8_t* data, + size_t size, + RandomGenerator& gen) { + Context ctx(gen); + ByteInput in{data, size, 0, true, ctx}; + ByteOutput out{}; + MutationStat stat{}; + int index = -1; + count(stat, in, NodeId::translation_unit); + + while (index < 0) { + index = -1; + switch (mutation) { + case Mutation::AddOptional: + if (stat.optionals[0] > 0) { + index = gen.GetUInt32(stat.optionals[0]); + } + break; + case Mutation::RemoveOptional: + if (stat.optionals[1] > 0) { + index = gen.GetUInt32(stat.optionals[1]); + } + break; + case Mutation::IncRepeat: + if (stat.repeats[0] + stat.repeats[1] > 0) { + index = gen.GetUInt32(stat.repeats[0] + stat.repeats[1]); + } + break; + case Mutation::DecRepeat: + if (stat.repeats[1] + stat.repeats[2] > 0) { + index = gen.GetUInt32(stat.repeats[1] + stat.repeats[2]); + } + break; + case Mutation::NextAlternative: + case Mutation::PrevAlternative: + case Mutation::RandomAlternative: + if (stat.alternatives > 0) { + index = gen.GetUInt32(stat.alternatives); + } + break; + case Mutation::RandomTerminal: + if (stat.terminals > 0) { + index = gen.GetUInt32(stat.terminals); + } + break; + } + if (index == -1) { + mutation = static_cast((static_cast(mutation) + 1) % + (static_cast(Mutation::Last) + 1)); + } + } + in.reset(); + mutate(in, out, mutation, index, NodeId::translation_unit, gen, ctx, 0); + return std::move(out.out); +} + +std::string WGSLSource(const uint8_t* data, size_t size) { + RandomGenerator gen{CRC32(data, size)}; + Context ctx(gen); + ByteInput in{data, size, 0, false, ctx}; + TextOutput out(ctx); + generate(in, out, NodeId::translation_unit); + return std::move(out.buffer).str(); +} + +} // namespace tint::fuzzers::structure_fuzzer + diff --git a/src/tint/fuzzers/tint_structure_fuzzer/syntax.h b/src/tint/fuzzers/tint_structure_fuzzer/syntax.h new file mode 100644 index 00000000000..2eeda2b3628 --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/syntax.h @@ -0,0 +1,35 @@ + +#ifndef SRC_TINT_FUZZERS_TINT_STRUCTURE_FUZZER_SYNTAX_H_ +#define SRC_TINT_FUZZERS_TINT_STRUCTURE_FUZZER_SYNTAX_H_ + +#include +#include +#include "src/tint/fuzzers/random_generator.h" + +namespace tint::fuzzers::structure_fuzzer { + +enum class Mutation { + AddOptional, // ? + RemoveOptional, // ? + IncRepeat, // * + DecRepeat, // * + NextAlternative, // | + PrevAlternative, // | + RandomAlternative, // | + RandomTerminal, // + + Last = RandomTerminal, +}; + +void WGSLInit(); + +std::vector WGSLMutate(Mutation mutation, + const uint8_t* data, + size_t size, + RandomGenerator& gen); + +std::string WGSLSource(const uint8_t* data, size_t size); + +} // namespace tint::fuzzers::structure_fuzzer + +#endif diff --git a/src/tint/fuzzers/tint_structure_fuzzer/tint_structure_fuzzer.cc b/src/tint/fuzzers/tint_structure_fuzzer/tint_structure_fuzzer.cc new file mode 100644 index 00000000000..20c58a6396d --- /dev/null +++ b/src/tint/fuzzers/tint_structure_fuzzer/tint_structure_fuzzer.cc @@ -0,0 +1,187 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "src/tint/cmd/fuzz/wgsl/fuzz.h" +#include "src/tint/fuzzers/random_generator.h" +#include "src/tint/fuzzers/tint_common_fuzzer.h" +#include "src/tint/fuzzers/transform_builder.h" +#include "src/tint/lang/wgsl/ast/transform/add_empty_entry_point.h" +#include "src/tint/lang/wgsl/ast/transform/builtin_polyfill.h" +#include "src/tint/lang/wgsl/ast/transform/fold_constants.h" +#include "src/tint/lang/wgsl/ast/transform/remove_unreachable_statements.h" +#include "src/tint/lang/wgsl/ast/transform/renamer.h" +#include "src/tint/lang/wgsl/ast/transform/unshadow.h" +#include "src/tint/lang/wgsl/reader/reader.h" +#include "src/tint/lang/wgsl/writer/writer.h" +#include "src/tint/utils/cli/cli.h" +#include "testing/libfuzzer/libfuzzer_exports.h" + +#include "probabilities.h" +#include "syntax.h" + +namespace { + +tint::fuzz::wgsl::Options options; + +} // namespace + +using namespace std::string_view_literals; + +namespace tint::fuzzers::structure_fuzzer { +namespace { + +static std::optional mutations; + +static void printHex(FILE* f, const uint8_t* data, size_t size) { + for (size_t i = 0; i < size; ++i) { + fprintf(f, "%02X", data[i]); + } + fprintf(f, "\n"); +} + +extern "C" int LLVMFuzzerInitialize(int* argc, char*** argv) { + tint::cli::OptionSet opts; + + constexpr size_t numMutations = static_cast(Mutation::Last) + 1; + + std::vector list; + for (int i = 1; i < *argc; ++i) { + std::string_view arg((*argv)[i]); + if (arg.find("--prob=") == 0) { + arg.remove_prefix(7); + size_t start = 0; + size_t stop = arg.find_first_of(", ", start); + while (start != std::string_view::npos) { + auto result = strconv::ParseNumber(arg.substr(start, stop - start)); + assert(result == tint::Success); + list.push_back(result.Get()); + if (stop == std::string_view::npos) + break; + start = stop + 1; + stop = arg.find_first_of(", ", start); + } + } + } + fprintf(stdout, "Mutation parameters:\n"); + for (int i = 0; i < list.size(); ++i) { + fprintf(stdout, "[%d] %u\n", i, list[i]); + } + + if (list.size() != numMutations) { + fprintf(stderr, "Incorrect number of arguments. Expected %zu\n", numMutations); + list.resize(numMutations, 10); + } + mutations.emplace(std::move(list)); + +#if 0 + { + RandomGenerator gen{1}; + std::string wgsl_str; + + std::vector input{}; + for (int i = 0; i < 1000; ++i) { + Mutation mutation = mutations->sample(gen); + wgsl_str = WGSLSource(input.data(), input.size()); + printHex(stderr, input.data(), input.size()); + fprintf(stderr, "%s\n\n", wgsl_str.c_str()); + input = WGSLMutate(mutation, input.data(), input.size(), gen); + } + std::exit(1); + } +#endif + return 0; +} + +extern "C" size_t LLVMFuzzerCustomMutator(uint8_t* data, + size_t size, + size_t maxSize, + unsigned int seed) { + if (!mutations.has_value()) { + return 0; + } + RandomGenerator gen(seed); + Mutation mutation = mutations->sample(gen); + std::vector output = WGSLMutate(mutation, data, size, gen); + if (output.size() > maxSize) { + fprintf(stderr, "Mutate output truncated %zu -> %zu\n", output.size(), maxSize); + fflush(stderr); + output.resize(maxSize); + } + memcpy(data, output.data(), output.size()); +#if 0 + static FILE* fdbg = nullptr; + if (!fdbg) { + fdbg = fopen("fuzzer-mutator.txt", "w"); + } + fprintf(fdbg, "@@@ Mutate m=%d seed=%u\n", (int)mutation, seed); + fprintf(fdbg, "IN: "); + printHex(fdbg, data, size); + std::string wgsl_str = WGSLSource(data, size); + fprintf(fdbg, "IN: %s\n", wgsl_str.c_str()); + fprintf(fdbg, "OUT: "); + printHex(fdbg, output.data(), output.size()); + wgsl_str = WGSLSource(output.data(), output.size()); + fprintf(fdbg, "OUT: %s\n", wgsl_str.c_str()); + fflush(fdbg); +#endif + return output.size(); +} + +static FILE* dbg[2]{nullptr, nullptr}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + if (size <= 1) { + return 0; + } + std::string wgsl_str = WGSLSource(data, size); + + static const bool debug = std::getenv("TINT_STRUCTURE_FUZZER_DEBUG"); + + if (debug && !dbg[0]) { + dbg[0] = fopen("fuzzer-failure.txt", "w"); + dbg[1] = fopen("fuzzer-success.txt", "w"); + } + + bool successfull = false; + + for (OutputFormat fmt : + {OutputFormat::kWGSL, OutputFormat::kSpv, OutputFormat::kHLSL, OutputFormat::kMSL}) { + TransformBuilder tb(data, size); + // tb.AddPlatformIndependentPasses(); + tb.AddTransform(); + tb.AddTransform(); + tb.manager()->Add(); + tb.manager()->Add(); + tb.manager()->Add(); + tb.manager()->Add(); + + CommonFuzzer fuzzer(InputFormat::kWGSL, fmt); + fuzzer.SetTransformManager(tb.manager(), tb.data_map()); + + fuzzer.Run(reinterpret_cast(wgsl_str.data()), wgsl_str.size()); + if (debug && fmt == OutputFormat::kSpv) { + if (!fuzzer.HasErrors()) { + successfull = true; + fprintf(dbg[1], "|IN(%zu): ", size); + printHex(dbg[1], data, size); + fprintf(dbg[1], " OUT: %s\n", wgsl_str.c_str()); + // fflush(dbg[1]); + } else { + fprintf(dbg[0], "|IN(%zu): ", size); + printHex(dbg[0], data, size); + fprintf(dbg[0], " OUT: %s\n", wgsl_str.c_str()); + fprintf(dbg[0], " ERROR: %s", fuzzer.Diagnostics().Str().c_str()); + // fflush(dbg[0]); + } + } + } + return 0; // successfull || size < 12 ? 0 : -1; // rand() % 2 == 0 ? -1 : 0; +} +} // namespace +} // namespace tint::fuzzers::structure_fuzzer