diff --git a/.github/workflows/cmake-multi-platform.yml b/.github/workflows/cmake-multi-platform.yml index 0c6098ea5..ac4127f4e 100644 --- a/.github/workflows/cmake-multi-platform.yml +++ b/.github/workflows/cmake-multi-platform.yml @@ -56,16 +56,15 @@ jobs: run: | echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT" - - name: Install SPIRV-Headers - run: | - git clone https://github.com/KhronosGroup/SPIRV-Headers - mkdir SPIRV-Headers/build - cmake -B SPIRV-Headers/build -S SPIRV-Headers - mkdir SPIRV-Headers/install - cmake --install SPIRV-Headers/build --prefix SPIRV-Headers/install + #- name: Install SPIRV-Headers + # run: | + # git clone https://github.com/KhronosGroup/SPIRV-Headers + # mkdir SPIRV-Headers/build + # cmake -B SPIRV-Headers/build -S SPIRV-Headers + # mkdir SPIRV-Headers/install + # cmake --install SPIRV-Headers/build --prefix SPIRV-Headers/install - name: Install Clang & LLVM (setup-clang) - if: (!startsWith(matrix.os,'windows')) uses: egor-tensin/setup-clang@v1.4 #- name: Install LLVM (winlibs) @@ -89,12 +88,6 @@ jobs: run: | echo "CMAKE_PLATFORM_SPECIFIC_ARGS=${{ env.CMAKE_PLATFORM_SPECIFIC_ARGS }} -DLLVM_DIR=/usr/lib/llvm-14/cmake" >> $GITHUB_ENV - - name: Install json-c (vcpkg) - if: startsWith(matrix.os,'windows') - run: | - vcpkg install json-c - echo "CMAKE_PLATFORM_SPECIFIC_ARGS=-DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake" >> "$env:GITHUB_ENV" - - name: Configure CMake # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..d16386367 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +build/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 075ae1c91..e24b50ab0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ -[submodule "murmur3"] - path = murmur3 - url = https://github.com/PeterScott/murmur3 +[submodule "SPIRV-Headers"] + path = SPIRV-Headers + url = https://github.com/shady-gang/SPIRV-Headers.git + branch = main diff --git a/CMakeLists.txt b/CMakeLists.txt index e4b98c444..8a38d3dac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,53 +1,120 @@ cmake_minimum_required(VERSION 3.13) project (shady C) +include(ExternalProject) +include(FetchContent) + +find_package(Git) + set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +option(BUILD_SHARED_LIBS "Build using shared libraries" ON) +option(SHADY_USE_FETCHCONTENT "Use FetchContent to grab json-c" ON) +option(SHADY_WIN32_FIX_PARTIAL_LLVM_INSTALL "If you install LLVM on windows, it doesn't come with header files. This fixes it" ON) if (MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) endif() -find_package(SPIRV-Headers REQUIRED) +if (WIN32) + set(BUILD_SHARED_LIBS OFF) +endif() + +add_subdirectory(SPIRV-Headers) + +if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24" AND ${SHADY_USE_FETCHCONTENT}) + FetchContent_Declare( + json-c + GIT_REPOSITORY https://github.com/json-c/json-c + GIT_TAG master + OVERRIDE_FIND_PACKAGE + ) + + FetchContent_MakeAvailable(json-c) + add_library(json-c::json-c ALIAS json-c) + configure_file(${json-c_SOURCE_DIR}/json.h.cmakein ${json-c_BINARY_DIR}/json-c/json.h @ONLY) + target_include_directories(json-c PUBLIC $) +else () + message("CMake 3.24 or later is required to use FetchContent") +endif () + +find_package(LLVM QUIET) +if (LLVM_FOUND) + message("LLVM ${LLVM_VERSION} found") +endif() +if(NOT ${LLVM_FOUND} AND WIN32 AND ${SHADY_WIN32_FIX_PARTIAL_LLVM_INSTALL}) + find_program(clang_exe "clang.exe") + if(${clang_exe} STREQUAL "clang_exe-NOTFOUND") + message(STATUS "Win32: Installed LLVM not found") + else() + execute_process(COMMAND ${clang_exe} --version OUTPUT_VARIABLE clang_status) + string(REGEX MATCH "InstalledDir: (.*)[\r\n]" match ${clang_status}) + file(TO_CMAKE_PATH "${CMAKE_MATCH_1}/../" LLVM_DIR) + cmake_path(ABSOLUTE_PATH LLVM_DIR NORMALIZE) + string(REGEX MATCH "clang version ([0-9]+).([0-9]+).([0-9]+)" match2 ${clang_status}) + set(LLVM_VERSION_MAJOR ${CMAKE_MATCH_1}) + set(LLVM_VERSION_MINOR ${CMAKE_MATCH_2}) + set(LLVM_VERSION_PATCH ${CMAKE_MATCH_3}) + set(LLVM_VERSION "${LLVM_VERSION_MAJOR}.${LLVM_VERSION_MINOR}.${LLVM_VERSION_PATCH}") + add_library(LLVM-C SHARED IMPORTED) + set_property(TARGET LLVM-C PROPERTY + IMPORTED_LOCATION "${LLVM_DIR}bin/LLVM-C.dll") + set_property(TARGET LLVM-C PROPERTY + IMPORTED_IMPLIB "${LLVM_DIR}lib/LLVM-C.lib") + + execute_process( + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${GIT_EXECUTABLE} clone -n --depth 1 --filter=tree:0 https://github.com/llvm/llvm-project/ --branch "llvmorg-${LLVM_VERSION_MAJOR}.${LLVM_VERSION_MINOR}.${LLVM_VERSION_PATCH}" + ) + execute_process( + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/llvm-project + COMMAND ${GIT_EXECUTABLE} sparse-checkout set --no-cone llvm/include/llvm-c + ) + execute_process( + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/llvm-project + COMMAND ${GIT_EXECUTABLE} checkout + ) + target_include_directories(LLVM-C INTERFACE ${CMAKE_CURRENT_BINARY_DIR}/llvm-project/llvm/include) + target_compile_definitions(LLVM-C INTERFACE -DLLVM_VERSION_MAJOR=${LLVM_VERSION_MAJOR} -DLLVM_VERSION_MINOR=${LLVM_VERSION_MINOR} -DLLVM_VERSION_PATCH=${LLVM_VERSION_PATCH}) + message(STATUS "Win32: Installed LLVM ${LLVM_VERSION} found at ${LLVM_DIR}") + set(LLVM_FOUND TRUE) + endif() +endif() + +include(GNUInstallDirs) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) -set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib") +set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}") +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) # required for MSVC set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS 1) cmake_policy(SET CMP0118 NEW) add_subdirectory(src) +add_subdirectory(vcc) + +add_subdirectory(zhady) include(CTest) if (BUILD_TESTING) add_subdirectory(test) endif() -set(BUILD_SAMPLES ON CACHE BOOL "Whether to build built-in demo applications") -if (BUILD_SAMPLES) - add_subdirectory(samples) -endif() +add_subdirectory(samples) include(CMakePackageConfigHelpers) -install(TARGETS api EXPORT shady_export_set) -install(TARGETS shady EXPORT shady_export_set ARCHIVE DESTINATION ${CMAKE_ARCHIVE_OUTPUT_DIRECTORY}) - if (TARGET vcc) add_subdirectory(vcc-std) endif () -if (TARGET runtime) - install(TARGETS runtime EXPORT shady_export_set ARCHIVE DESTINATION ${CMAKE_ARCHIVE_OUTPUT_DIRECTORY}) -endif() - -install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/shady DESTINATION include) install(EXPORT shady_export_set DESTINATION share/cmake/shady/ NAMESPACE shady:: FILE shady-targets.cmake) configure_file(cmake/shady-config.cmake.in shady-config.cmake @ONLY) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/shady-config.cmake" DESTINATION share/cmake/shady) -#install(FILES "${CMAKE_CURRENT_BINARY_DIR}/shady-config.cmake" DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/shady) diff --git a/SPIRV-Headers b/SPIRV-Headers new file mode 160000 index 000000000..1453552cf --- /dev/null +++ b/SPIRV-Headers @@ -0,0 +1 @@ +Subproject commit 1453552cfe66d4aaf58d083372e6bf81f715e4a8 diff --git a/doc/code_style.md b/doc/code_style.md new file mode 100644 index 000000000..6db664898 --- /dev/null +++ b/doc/code_style.md @@ -0,0 +1,40 @@ +# Code style Guidelines + +This guide is for our own reference as much as anyone else. +If you have a good reason for breaking any of those rules we're happy to consider your contributions without prejudice. + + * 4 spaces indentation + * `Type* name` + * `snake_case` by default + * `UpperCamelCase` for types and typedefs + * Use only typedefs or append `_` to struct/union/enums names + * `LOUD_CASE` for public macros + * `{ spaces, surrounding, initializer, lists }` + * Unless you're literally contributing using a 80-column display (for which I'll ask visual proof), don't format your code as if you do. + * Include order: + * If appropriate, (Private) `self.h` should always come first, with other local headers in the same group + * Then other `shady/` headers + * Then in-project utility headers + * Then external headers + * Finally, standard C library headers. + * Each category of includes spaced by an empty line + +## Symbol naming + +Due to C not having namespaces, we have to deal with this painfully automatable problem ourselves. + + * Avoid exposing any symbols that don't need to be exposed (use `static` wherever you can) + * Prefixes: + * `shd_` in front of API functions (in the root `include` folder) + * `slim_`, `l2s_`, `spv_` and `vcc_` are used in various sub-projects + * `subsystem_` is acceptable for internal use + * `shd_subsystem_` is preferable where a clear delineation can be made + * `shd_new_thing` and `shd_destroy_thing` for constructors and destructors + * `static inline` functions in headers are not required to be prefixed + * Types & Typedefs may be prefixed with `Shd` + * Alternatively, subsystem-relevant ones can use another prefix, much like for functions +* Do not expose global variables to external APIs at all (provide getter functions if necessary) + +## Cursing in comments + +Can be funny but keep it somewhat family friendly. \ No newline at end of file diff --git a/include/shady/analysis/literal.h b/include/shady/analysis/literal.h new file mode 100644 index 000000000..9125d4504 --- /dev/null +++ b/include/shady/analysis/literal.h @@ -0,0 +1,20 @@ +#ifndef SHADY_ANALYSIS_LITERAL_H +#define SHADY_ANALYSIS_LITERAL_H + +#include "shady/ir/grammar.h" + +const char* shd_get_string_literal(IrArena* arena, const Node* node); + +typedef struct { + bool enter_loads; + bool allow_incompatible_types; + bool assume_globals_immutability; +} NodeResolveConfig; + +NodeResolveConfig shd_default_node_resolve_config(void); +const Node* shd_chase_ptr_to_source(const Node* ptr, NodeResolveConfig config); +const Node* shd_resolve_ptr_to_value(const Node* ptr, NodeResolveConfig config); + +const Node* shd_resolve_node_to_definition(const Node* node, NodeResolveConfig config); + +#endif diff --git a/include/shady/be/c.h b/include/shady/be/c.h new file mode 100644 index 000000000..d16a7292f --- /dev/null +++ b/include/shady/be/c.h @@ -0,0 +1,27 @@ +#ifndef SHD_BE_C_H +#define SHD_BE_C_H + +#include "shady/ir/base.h" + +typedef enum { + CDialect_C11, + CDialect_GLSL, + CDialect_ISPC, + CDialect_CUDA, +} CDialect; + +typedef struct { + CDialect dialect; + bool explicitly_sized_types; + bool allow_compound_literals; + bool decay_unsized_arrays; + int glsl_version; +} CEmitterConfig; + +CEmitterConfig shd_default_c_emitter_config(void); + +typedef struct CompilerConfig_ CompilerConfig; +void shd_emit_c(const CompilerConfig* compiler_config, CEmitterConfig config, Module* mod, size_t* output_size, char** output, Module** new_mod); + +#endif + diff --git a/include/shady/be/dump.h b/include/shady/be/dump.h new file mode 100644 index 000000000..a2dc858aa --- /dev/null +++ b/include/shady/be/dump.h @@ -0,0 +1,16 @@ +#ifndef SHD_BE_DUMP_H +#define SHD_BE_DUMP_H + +#include "shady/ir/base.h" +#include "shady/ir/module.h" + +#include + +void shd_dump_module(Module* mod); +void shd_dump_node(const Node* node); + +void shd_dump_cfgs(FILE* output, Module* mod); +void shd_dump_loop_trees(FILE* output, Module* mod); + +#endif + diff --git a/include/shady/be/spirv.h b/include/shady/be/spirv.h new file mode 100644 index 000000000..a66b0dd31 --- /dev/null +++ b/include/shady/be/spirv.h @@ -0,0 +1,10 @@ +#ifndef SHD_BE_SPIRV_H +#define SHD_BE_SPIRV_H + +#include "shady/ir/base.h" + +typedef struct CompilerConfig_ CompilerConfig; +void shd_emit_spirv(const CompilerConfig* config, Module* mod, size_t* output_size, char** output, Module** new_mod); + +#endif + diff --git a/include/shady/builtins.h b/include/shady/builtins.h deleted file mode 100644 index 7cffc8bd1..000000000 --- a/include/shady/builtins.h +++ /dev/null @@ -1,55 +0,0 @@ -#ifndef SHADY_EMIT_BUILTINS -#define SHADY_EMIT_BUILTINS - -#include "shady/ir.h" - -#define u32vec3_type(arena) pack_type(arena, (PackType) { .width = 3, .element_type = uint32_type(arena) }) -#define i32vec3_type(arena) pack_type(arena, (PackType) { .width = 3, .element_type = int32_type(arena) }) -#define i32vec4_type(arena) pack_type(arena, (PackType) { .width = 4, .element_type = int32_type(arena) }) - -#define f32vec4_type(arena) pack_type(arena, (PackType) { .width = 4, .element_type = fp32_type(arena) }) - -#define SHADY_BUILTINS() \ -BUILTIN(BaseInstance, AsInput, uint32_type(arena) )\ -BUILTIN(BaseVertex, AsInput, uint32_type(arena) )\ -BUILTIN(DeviceIndex, AsInput, uint32_type(arena) )\ -BUILTIN(DrawIndex, AsInput, uint32_type(arena) )\ -BUILTIN(VertexIndex, AsInput, uint32_type(arena) )\ -BUILTIN(FragCoord, AsInput, f32vec4_type(arena) )\ -BUILTIN(FragDepth, AsOutput, fp32_type(arena) )\ -BUILTIN(InstanceId, AsInput, uint32_type(arena) )\ -BUILTIN(InvocationId, AsInput, uint32_type(arena) )\ -BUILTIN(InstanceIndex, AsInput, uint32_type(arena) )\ -BUILTIN(LocalInvocationId, AsInput, u32vec3_type(arena) )\ -BUILTIN(LocalInvocationIndex, AsInput, uint32_type(arena) )\ -BUILTIN(GlobalInvocationId, AsInput, u32vec3_type(arena) )\ -BUILTIN(WorkgroupId, AsUInput, u32vec3_type(arena) )\ -BUILTIN(WorkgroupSize, AsUInput, u32vec3_type(arena) )\ -BUILTIN(NumSubgroups, AsUInput, uint32_type(arena) )\ -BUILTIN(NumWorkgroups, AsUInput, u32vec3_type(arena) )\ -BUILTIN(Position, AsOutput, f32vec4_type(arena) )\ -BUILTIN(PrimitiveId, AsInput, uint32_type(arena) )\ -BUILTIN(SubgroupLocalInvocationId, AsInput, uint32_type(arena) )\ -BUILTIN(SubgroupId, AsUInput, uint32_type(arena) )\ -BUILTIN(SubgroupSize, AsInput, uint32_type(arena) )\ - -typedef enum { -#define BUILTIN(name, as, datatype) Builtin##name, -SHADY_BUILTINS() -#undef BUILTIN - BuiltinsCount -} Builtin; - -AddressSpace get_builtin_as(Builtin); -String get_builtin_name(Builtin); - -const Type* get_builtin_type(IrArena* arena, Builtin); -Builtin get_builtin_by_name(String); - -typedef enum SpvBuiltIn_ SpvBuiltIn; -Builtin get_builtin_by_spv_id(SpvBuiltIn id); - -bool is_decl_builtin(const Node*); -Builtin get_decl_builtin(const Node*); - -#endif diff --git a/include/shady/config.h b/include/shady/config.h new file mode 100644 index 000000000..eabf5169a --- /dev/null +++ b/include/shady/config.h @@ -0,0 +1,133 @@ +#ifndef SHD_CONFIG_H +#define SHD_CONFIG_H + +#include "shady/ir/base.h" +#include "shady/ir/int.h" +#include "shady/ir/grammar.h" +#include "shady/ir/execution_model.h" + +typedef struct { + IntSizes ptr_size; + /// The base type for emulated memory + IntSizes word_size; +} PointerModel; + +typedef struct { + PointerModel memory; +} TargetConfig; + +TargetConfig shd_default_target_config(void); + +typedef enum { + /// Uses the MaskType + SubgroupMaskAbstract, + /// Uses a 64-bit integer + SubgroupMaskInt64 +} SubgroupMaskRepresentation; + +typedef struct ArenaConfig_ ArenaConfig; +struct ArenaConfig_ { + bool name_bound; + bool check_op_classes; + bool check_types; + bool allow_fold; + bool validate_builtin_types; // do @Builtins variables need to match their type in builtins.h ? + bool is_simt; + + struct { + bool physical; + bool allowed; + } address_spaces[NumAddressSpaces]; + + struct { + /// Selects which type the subgroup intrinsic primops use to manipulate masks + SubgroupMaskRepresentation subgroup_mask_representation; + + uint32_t workgroup_size[3]; + } specializations; + + PointerModel memory; + + /// 'folding' optimisations - happen in the constructors directly + struct { + bool inline_single_use_bbs; + bool fold_static_control_flow; + bool delete_unreachable_structured_cases; + bool weaken_non_leaking_allocas; + } optimisations; +}; + +ArenaConfig shd_default_arena_config(const TargetConfig* target); +const ArenaConfig* shd_get_arena_config(const IrArena* a); + +typedef struct CompilerConfig_ CompilerConfig; +struct CompilerConfig_ { + bool dynamic_scheduling; + uint32_t per_thread_stack_size; + + struct { + uint8_t major; + uint8_t minor; + } target_spirv_version; + + struct { + bool restructure_with_heuristics; + bool add_scope_annotations; + bool has_scope_annotations; + } input_cf; + + struct { + bool emulate_generic_ptrs; + bool emulate_physical_memory; + + bool emulate_subgroup_ops; + bool emulate_subgroup_ops_extended_types; + bool int64; + bool decay_ptrs; + } lower; + + struct { + bool spv_shuffle_instead_of_broadcast_first; + bool force_join_point_lifting; + } hacks; + + struct { + struct { + bool after_every_pass; + bool delete_unused_instructions; + } cleanup; + bool inline_everything; + } optimisations; + + struct { + bool memory_accesses; + bool stack_accesses; + bool god_function; + bool stack_size; + bool subgroup_ops; + } printf_trace; + + struct { + int max_top_iterations; + } shader_diagnostics; + + struct { + bool print_generated, print_builtin, print_internal; + } logging; + + struct { + String entry_point; + ExecutionModel execution_model; + uint32_t subgroup_size; + } specialization; + + TargetConfig target; + + struct { + struct { void* uptr; void (*fn)(void*, String, Module*); } after_pass; + } hooks; +}; + +CompilerConfig shd_default_compiler_config(void); + +#endif diff --git a/include/shady/driver.h b/include/shady/driver.h index a0a5ace03..b485cd52f 100644 --- a/include/shady/driver.h +++ b/include/shady/driver.h @@ -1,7 +1,11 @@ #ifndef SHADY_CLI #define SHADY_CLI -#include "shady/ir.h" +#include "shady/ir/base.h" +#include "shady/config.h" + +#include "shady/be/c.h" +#include "shady/be/spirv.h" struct List; @@ -25,9 +29,9 @@ typedef enum { SrcLLVM, } SourceLanguage; -SourceLanguage guess_source_language(const char* filename); -ShadyErrorCodes driver_load_source_file(SourceLanguage lang, size_t, const char* file_contents, Module* mod); -ShadyErrorCodes driver_load_source_file_from_filename(const char* filename, Module* mod); +SourceLanguage shd_driver_guess_source_language(const char* filename); +ShadyErrorCodes shd_driver_load_source_file(const CompilerConfig* config, SourceLanguage lang, size_t len, const char* file_contents, String name, Module** mod); +ShadyErrorCodes shd_driver_load_source_file_from_filename(const CompilerConfig* config, const char* filename, String name, Module** mod); typedef enum { TgtAuto, @@ -37,16 +41,16 @@ typedef enum { TgtISPC, } CodegenTarget; -CodegenTarget guess_target(const char* filename); +CodegenTarget shd_guess_target(const char* filename); -void cli_pack_remaining_args(int* pargc, char** argv); +void shd_pack_remaining_args(int* pargc, char** argv); // parses 'common' arguments such as log level etc -void cli_parse_common_args(int* pargc, char** argv); +void shd_parse_common_args(int* pargc, char** argv); // parses compiler pipeline options -void cli_parse_compiler_config_args(CompilerConfig*, int* pargc, char** argv); +void shd_parse_compiler_config_args(CompilerConfig* config, int* pargc, char** argv); // parses the remaining arguments into a list of files -void cli_parse_input_files(struct List*, int* pargc, char** argv); +void shd_driver_parse_input_files(struct List* list, int* pargc, char** argv); typedef struct { CompilerConfig config; @@ -59,12 +63,18 @@ typedef struct { const char* loop_tree_output_filename; } DriverConfig; -DriverConfig default_driver_config(); -void destroy_driver_config(DriverConfig*); +DriverConfig shd_default_driver_config(void); +void shd_destroy_driver_config(DriverConfig* config); + +void shd_parse_driver_args(DriverConfig* args, int* pargc, char** argv); + +ShadyErrorCodes shd_driver_load_source_files(DriverConfig* args, Module* mod); +ShadyErrorCodes shd_driver_compile(DriverConfig* args, Module* mod); -void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv); +typedef enum CompilationResult_ { + CompilationNoError +} CompilationResult; -ShadyErrorCodes driver_load_source_files(DriverConfig* args, Module* mod); -ShadyErrorCodes driver_compile(DriverConfig* args, Module* mod); +CompilationResult shd_run_compiler_passes(CompilerConfig* config, Module** pmod); #endif diff --git a/include/shady/fe/slim.h b/include/shady/fe/slim.h new file mode 100644 index 000000000..f5dd20d37 --- /dev/null +++ b/include/shady/fe/slim.h @@ -0,0 +1,13 @@ +#ifndef SHADY_SLIM_H +#define SHADY_SLIM_H + +#include "shady/ir/module.h" +#include "shady/config.h" + +typedef struct { + bool front_end; +} SlimParserConfig; + +Module* shd_parse_slim_module(const CompilerConfig* config, const SlimParserConfig* pconfig, const char* contents, String name); + +#endif diff --git a/include/shady/grammar.h b/include/shady/grammar.h deleted file mode 100644 index a49d64701..000000000 --- a/include/shady/grammar.h +++ /dev/null @@ -1,91 +0,0 @@ -#ifndef SHADY_IR_H -#error "do not include this file by itself, include shady/ir.h instead" -#endif - -typedef enum DivergenceQualifier_ { - Unknown, - Uniform, - Varying -} DivergenceQualifier; - -typedef enum { - NotSpecial, - /// for instructions with multiple yield values. Must be deconstructed by a let, cannot appear anywhere else - MultipleReturn, - /// Gets the 'Block' SPIR-V annotation, needed for UBO/SSBO variables - DecorateBlock -} RecordSpecialFlag; - -typedef enum { - IntTy8, - IntTy16, - IntTy32, - IntTy64, -} IntSizes; - -enum { - IntSizeMin = IntTy8, - IntSizeMax = IntTy64, -}; - -static inline int int_size_in_bytes(IntSizes s) { - switch (s) { - case IntTy8: return 1; - case IntTy16: return 2; - case IntTy32: return 4; - case IntTy64: return 8; - } -} - -typedef enum { - FloatTy16, - FloatTy32, - FloatTy64 -} FloatSizes; - -static inline int float_size_in_bytes(FloatSizes s) { - switch (s) { - case FloatTy16: return 2; - case FloatTy32: return 4; - case FloatTy64: return 8; - } -} - -#define EXECUTION_MODELS(EM) \ -EM(Compute, 1) \ -EM(Fragment, 0) \ -EM(Vertex, 0) \ - -typedef enum { - EmNone, -#define EM(name, _) Em##name, -EXECUTION_MODELS(EM) -#undef EM -} ExecutionModel; - -ExecutionModel execution_model_from_string(const char*); - -//////////////////////////////// Generated definitions //////////////////////////////// - -// see primops.json -#include "primops_generated.h" - -String get_primop_name(Op op); -bool has_primop_got_side_effects(Op op); - -// see grammar.json -#include "grammar_generated.h" - -extern const char* node_tags[]; -extern const bool node_type_has_payload[]; - -//////////////////////////////// Node categories //////////////////////////////// - -inline static bool is_nominal(const Node* node) { - NodeTag tag = node->tag; - if (node->tag == PrimOp_TAG && has_primop_got_side_effects(node->payload.prim_op.op)) - return true; - return tag == Function_TAG || tag == BasicBlock_TAG || tag == Constant_TAG || tag == Variable_TAG || tag == GlobalVariable_TAG || tag == NominalType_TAG || tag == Case_TAG; -} - -inline static bool is_function(const Node* node) { return node->tag == Function_TAG; } diff --git a/include/shady/grammar.json b/include/shady/grammar.json index ff131a9e5..d6a278783 100644 --- a/include/shady/grammar.json +++ b/include/shady/grammar.json @@ -2,53 +2,32 @@ "address-spaces": [ { "name": "Generic", - "llvm-id": 0, - "physical": true + "llvm-id": 0 }, { - "name": "GlobalPhysical", + "name": "Global", "description": "Global memory, all threads see the same data (not necessarily consistent!)", - "llvm-id": 1, - "physical": true + "llvm-id": 1 }, { - "name": "SharedPhysical", + "name": "Shared", "description": "Points into workgroup-private memory (aka shared memory)", - "llvm-id": 3, - "physical": true + "llvm-id": 3 }, { - "name": "SubgroupPhysical", + "name": "Subgroup", "description": [ "Points into subgroup-private memory", "All threads in a subgroup see the same contents for the same address, but threads in different subgroups see different data.", "Needs to be lowered to something else since targets do not understand this" ], - "llvm-id": 9, - "physical": true + "llvm-id": 9 }, { - "name": "PrivatePhysical", + "name": "Private", "description": [ "Points into thread-private memory (all threads see different contents for the same address)" ], - "llvm-id": 5, - "physical": true - }, - { - "name": "GlobalLogical", - "llvm-id": 388 - }, - { - "name": "SharedLogical", - "llvm-id": 387 - }, - { - "name": "SubgroupLogical", - "llvm-id": 386 - }, - { - "name": "PrivateLogical", - "llvm-id": 385 + "llvm-id": 5 }, { "name": "Input", @@ -74,8 +53,8 @@ "llvm-id": 392 }, { - "name": "FunctionLogical", - "description": "Weird SPIR-V nonsense: this is like PrivateLogical, but with non-static lifetimes (ie function lifetime)", + "name": "Function", + "description": "Weird SPIR-V nonsense: this is like Private, but with non-static lifetimes (ie function lifetime)", "llvm-id": 393 }, { @@ -106,30 +85,53 @@ "name": "value" }, { - "name": "variable", + "name": "param", "generate-enum": false }, { - "name": "instruction" + "name": "abstraction", + "ops": [ + { "name": "params", "class": "param", "list": true }, + { "name": "body", "class": "terminator", "nullable": true } + ] }, { - "name": "terminator" + "name": "instruction" }, { - "name": "declaration" + "name": "terminator", + "ops": [ + { "name": "mem", "class": "mem" } + ] }, { - "name": "case" + "name": "declaration", + "ops": [ + { "name": "annotations", "class": "annotation", "list": true }, + { "name": "name", "class": "string" } + ] }, { "name": "basic_block" }, { - "name": "annotation" + "name": "annotation", + "ops": [ + { "name": "name", "class": "string" } + ] }, { "name": "jump", "generate-enum": false + }, + { + "name": "structured_construct", + "ops": [ + { "name": "tail", "class": "basic_block" } + ] + }, + { + "name": "mem" } ], "nodes": [ @@ -148,6 +150,10 @@ { "name": "NoRet", "snake_name": "noret_type", + "description": [ + "Empty type: there are no values of this type.", + "Useful for the codomain of things that don't return at all" + ], "class": "type", "type": false }, @@ -218,7 +224,8 @@ "class": "type", "ops": [ { "name": "address_space", "type": "AddressSpace" }, - { "name": "pointed_type", "class": "type" } + { "name": "pointed_type", "class": "type" }, + { "name": "is_reference", "type": "bool" } ] }, { @@ -234,7 +241,7 @@ "class": "type", "ops": [ { "name": "element_type", "class": "type" }, - { "name": "size", "class": "value" } + { "name": "size", "class": "value", "nullable": true } ] }, { @@ -242,7 +249,7 @@ "class": "type", "ops": [ { "name": "element_type", "class": "type" }, - { "name": "width", "type": "int" } + { "name": "width", "type": "uint32_t" } ] }, { @@ -254,39 +261,12 @@ ] }, { - "name": "ImageType", - "class": "type", - "type": false, - "ops": [ - { "name": "sampled_type", "class": "type" }, - { "name": "dim", "type": "uint32_t" }, - { "name": "depth", "type": "uint32_t" }, - { "name": "onion", "type": "bool" }, - { "name": "multisample", "type": "bool" }, - { "name": "sampled", "type": "uint32_t" } - ] - }, - { - "name": "SamplerType", - "class": "type", - "type": false - }, - { - "name": "CombinedImageSamplerType", - "class": "type", - "type": false, - "ops": [ - { "name": "image_type", "class": "type" } - ] - }, - { - "name": "Variable", - "snake_name": "var", - "class": ["value", "variable"], + "name": "Param", + "class": ["value", "param"], "constructor": "custom", + "nominal": true, "ops": [ { "name": "type", "class": "type" }, - { "name": "id", "type": "VarId" }, { "name": "name", "class": "string" }, { "name": "abs", "type": "const Node*", "ignore": true }, { "name": "pindex", "type": "unsigned", "ignore": true } @@ -360,7 +340,7 @@ "Re-ordering values does not count as a computation here !" ], "ops": [ - { "name": "type", "class": "type" }, + { "name": "type", "class": "type", "nullable": true }, { "name": "contents", "class": "value", "list": true } ] }, @@ -400,102 +380,86 @@ }, { "name": "Call", - "class": "instruction", + "class": ["instruction", "value", "mem"], + "nominal": true, "ops": [ + { "name": "mem", "class": "mem" }, { "name": "callee", "class": "value" }, { "name": "args", "class": "value", "list": true } ] }, { - "name": "PrimOp", - "class": "instruction", + "name": "MemAndValue", + "class": ["value", "mem"], + "description": "Associate a value with a mem object, this allows adding side effects to a value", "ops": [ - { "name": "op", "type": "Op" }, - { "name": "type_arguments", "class": "type", "list": true }, - { "name": "operands", "class": "value", "list": true } + { "name": "mem", "class": "mem" }, + { "name": "value", "class": "value" } ] }, { "name": "If", "snake_name": "if_instr", - "class": "instruction", + "class": ["terminator", "structured_construct"], "ops": [ + { "name": "mem", "class": "mem" }, { "name": "yield_types", "class": "type", "list": true }, { "name": "condition", "class": "value" }, - { "name": "if_true", "class": "case" }, - { "name": "if_false", "class": "case" } + { "name": "if_true", "class": "basic_block" }, + { "name": "if_false", "class": "basic_block", "nullable": true }, + { "name": "tail", "class": "basic_block" } ] }, { "name": "Match", "snake_name": "match_instr", - "class": "instruction", + "class": ["terminator", "structured_construct"], "ops": [ + { "name": "mem", "class": "mem" }, { "name": "yield_types", "class": "type", "list": true }, { "name": "inspect", "class": "value" }, { "name": "literals", "class": "value", "list": true }, - { "name": "cases", "class": "case", "list": true }, - { "name": "default_case", "class": "case" } + { "name": "cases", "class": "basic_block", "list": true }, + { "name": "default_case", "class": "basic_block" }, + { "name": "tail", "class": "basic_block" } ] }, { "name": "Loop", "snake_name": "loop_instr", - "class": "instruction", + "class": ["terminator", "structured_construct"], "ops": [ + { "name": "mem", "class": "mem" }, { "name": "yield_types", "class": "type", "list": true }, - { "name": "body", "class": "case" }, - { "name": "initial_args", "class": "value", "list": true } + { "name": "body", "class": "basic_block" }, + { "name": "initial_args", "class": "value", "list": true }, + { "name": "tail", "class": "basic_block" } ] }, { "name": "Control", - "class": "instruction", + "class": ["terminator", "structured_construct"], "ops": [ + { "name": "mem", "class": "mem" }, { "name": "yield_types", "class": "type", "list": true }, - { "name": "inside", "class": "case" } - ] - }, - { - "name": "Block", - "class": "instruction", - "description": "Used as a helper to insert multiple instructions in place of one", - "ops": [ - { "name": "yield_types", "class": "type", "list": true }, - { "name": "inside", "class": "case" } + { "name": "inside", "class": "basic_block" }, + { "name": "tail", "class": "basic_block" } ] }, { "name": "Comment", - "class": "instruction", + "class": ["instruction", "mem"], "ops": [ + { "name": "mem", "class": "mem" }, { "name": "string", "class": "string" } ] }, - { - "name": "Let", - "class": "terminator", - "constructor": "custom", - "ops": [ - { "name": "instruction", "class": "instruction" }, - { "name": "tail", "class": "case" } - ] - }, - { - "name": "LetMut", - "class": "terminator", - "constructor": "custom", - "front-end-only": true, - "ops": [ - { "name": "instruction", "class": "instruction" }, - { "name": "tail", "class": "case" } - ] - }, { "name": "TailCall", "class": "terminator", "ops": [ - { "name": "target", "class": "value" }, + { "name": "mem", "class": "mem" }, + { "name": "callee", "class": "value" }, { "name": "args", "class": "value", "list": true } ] }, @@ -503,6 +467,7 @@ "name": "Jump", "class": ["terminator", "jump"], "ops": [ + { "name": "mem", "class": "mem" }, { "name": "target", "class": "basic_block" }, { "name": "args", "class": "value", "list": true } ] @@ -515,7 +480,8 @@ "Branch alternatives are made out of Jump terminators" ], "ops": [ - { "name": "branch_condition", "class": "value" }, + { "name": "mem", "class": "mem" }, + { "name": "condition", "class": "value" }, { "name": "true_jump", "class": "jump" }, { "name": "false_jump", "class": "jump" } ] @@ -526,6 +492,7 @@ "class": "terminator", "description": "N-way variant of Branch. See Branch.", "ops": [ + { "name": "mem", "class": "mem" }, { "name": "switch_value", "class": "value" }, { "name": "case_values", "class": "value", "list": true }, { "name": "case_jumps", "class": "jump", "list": true }, @@ -541,6 +508,7 @@ "If @p is_indirect is set, the target must be a function pointer. Otherwise, the target must be a function directly." ], "ops": [ + { "name": "mem", "class": "mem" }, { "name": "join_point", "class": "value" }, { "name": "args", "class": "value", "list": true } ] @@ -548,21 +516,27 @@ { "name": "MergeContinue", "class": "terminator", + "nominal": true, "ops": [ + { "name": "mem", "class": "mem" }, { "name": "args", "class": "value", "list": true } ] }, { "name": "MergeBreak", "class": "terminator", + "nominal": true, "ops": [ + { "name": "mem", "class": "mem" }, { "name": "args", "class": "value", "list": true } ] }, { - "name": "Yield", + "name": "MergeSelection", "class": "terminator", + "nominal": true, "ops": [ + { "name": "mem", "class": "mem" }, { "name": "args", "class": "value", "list": true } ] }, @@ -570,26 +544,31 @@ "name": "Return", "snake_name": "fn_ret", "class": "terminator", + "nominal": true, "ops": [ - { "name": "fn", "class": "declaration" }, + { "name": "mem", "class": "mem" }, { "name": "args", "class": "value", "list": true } ] }, { "name": "Unreachable", - "class": "terminator" + "class": "terminator", + "ops": [ + { "name": "mem", "class": "mem" } + ] }, { "name": "Function", "snake_name": "fun", - "class": "declaration", + "class": ["abstraction", "declaration"], "constructor": "custom", + "nominal": true, "ops": [ { "name": "module", "type": "Module*", "ignore": true }, { "name": "name", "class": "string" }, { "name": "annotations", "class": "annotation", "list": true }, - { "name": "params", "class": "variable", "list": true }, - { "name": "body", "class": "terminator" }, + { "name": "params", "class": "param", "list": true }, + { "name": "body", "class": "terminator", "nullable": true }, { "name": "return_types", "class": "type", "list": true } ] }, @@ -598,25 +577,27 @@ "class": "declaration", "constructor": "custom", "description": "Constants are used to express possibly complicated compile-time expressions", + "nominal": true, "ops": [ { "name": "module", "type": "Module*", "ignore": true }, { "name": "name", "class": "string" }, { "name": "annotations", "class": "annotation", "list": true }, { "name": "type_hint", "class": "type" }, - { "name": "instruction", "class": "instruction" } + { "name": "value", "class": "value", "nullable": true } ] }, { "name": "GlobalVariable", "class": "declaration", "constructor": "custom", + "nominal": true, "ops": [ { "name": "module", "type": "Module*", "ignore": true }, { "name": "name", "class": "string" }, { "name": "annotations", "class": "annotation", "list": true }, { "name": "type", "class": "type" }, { "name": "address_space", "type": "AddressSpace"}, - { "name": "init", "class": "value", "ignore": true } + { "name": "init", "class": "value", "nullable": true } ] }, { @@ -624,12 +605,13 @@ "snake_name": "nom_type", "class": "declaration", "constructor": "custom", + "nominal": true, "type": false, "ops": [ { "name": "module", "type": "Module*", "ignore": true }, { "name": "name", "class": "string" }, { "name": "annotations", "class": "annotation", "list": true }, - { "name": "body", "class": "type" } + { "name": "body", "class": "type", "nullable": true } ] }, { @@ -671,46 +653,33 @@ "name": "BasicBlock", "constructor": "custom", "description": "A named abstraction that lives inside a function and can be jumped to", - "class": "basic_block", - "ops": [ - { "name": "params", "class": "variable", "list": true }, - { "name": "body", "class": "terminator" }, - { "name": "fn", "class": "declaration" }, - { "name": "name", "class": "string" } - ] - }, - { - "name": "Case", - "snake_name": "case_", - "constructor": "custom", - "class": "case", - "description": [ - "An unnamed abstraction that lives inside a function, and can be used as part of various control-flow constructs", - "Most notably, the tails of standard `let` nodes" - ], + "class": ["abstraction", "basic_block"], + "nominal": true, "ops": [ - { "name": "params", "class": "variable", "list": true }, - { "name": "body", "class": "terminator" }, - { "name": "structured_construct", "type": "const Node*", "ignored": true } + { "name": "params", "class": "param", "list": true }, + { "name": "body", "class": "terminator", "nullable": true }, + { "name": "name", "class": "string" }, + { "name": "insert", "type": "BodyBuilder*", "ignore": true } ] }, { - "name": "Unbound", - "description": "Unbound identifier, obtained by parsing a file", + "name": "AbsMem", + "class": "mem", "type": false, - "front-end-only": true, "ops": [ - { "name": "name", "type": "String" } + { "name": "abs", "class": "abstraction" } ] }, { - "name": "UnboundBBs", - "description": "A node together with unbound basic blocks it dominates, obtained by parsing a file", - "type": false, - "front-end-only": true, + "name": "ExtInstr", + "description": "Turns a pointer into an assignable L-value", + "class": ["mem", "value", "instruction"], "ops": [ - { "name": "body", "class": "terminator" }, - { "name": "children_blocks", "class": "basic_block", "list": true } + { "name": "mem", "class": "mem", "nullable": true }, + { "name": "result_t", "class": "type" }, + { "name": "set", "class": "string" }, + { "name": "opcode", "type": "uint32_t" }, + { "name": "operands", "class": "value", "list": true } ] } ] diff --git a/include/shady/ir.h b/include/shady/ir.h index d8c616627..c266ddea6 100644 --- a/include/shady/ir.h +++ b/include/shady/ir.h @@ -1,345 +1,25 @@ #ifndef SHADY_IR_H #define SHADY_IR_H -#include -#include -#include -#include - -typedef struct IrArena_ IrArena; -typedef struct Node_ Node; -typedef struct Node_ Type; -typedef unsigned int VarId; -typedef const char* String; - -//////////////////////////////// Lists & Strings //////////////////////////////// - -typedef struct Nodes_ { - size_t count; - const Node** nodes; -} Nodes; - -typedef struct Strings_ { - size_t count; - String* strings; -} Strings; - -Nodes nodes(IrArena*, size_t count, const Node*[]); -Strings strings(IrArena*, size_t count, const char*[]); - -Nodes empty(IrArena*); -Nodes singleton(const Node*); -#define mk_nodes(arena, ...) nodes(arena, sizeof((const Node*[]) { __VA_ARGS__ }) / sizeof(const Node*), (const Node*[]) { __VA_ARGS__ }) - -const Node* first(Nodes nodes); - -Nodes append_nodes(IrArena*, Nodes, const Node*); -Nodes prepend_nodes(IrArena*, Nodes, const Node*); -Nodes concat_nodes(IrArena*, Nodes, Nodes); -Nodes change_node_at_index(IrArena*, Nodes, size_t, const Node*); - -String string_sized(IrArena*, size_t size, const char* start); -String string(IrArena*, const char*); -// see also: format_string in util.h -String format_string_interned(IrArena*, const char* str, ...); -String unique_name(IrArena*, const char* base_name); -String name_type_safe(IrArena*, const Type*); - -//////////////////////////////// Modules //////////////////////////////// - -typedef struct Module_ Module; - -Module* new_module(IrArena*, String name); - -IrArena* get_module_arena(const Module*); -String get_module_name(const Module*); -Nodes get_module_declarations(const Module*); -const Node* get_declaration(const Module*, String); - -//////////////////////////////// Grammar //////////////////////////////// - -// The language grammar is big enough that it deserve its own files - -#include "grammar.h" - -//////////////////////////////// IR Arena //////////////////////////////// - -typedef struct { - bool name_bound; - bool check_op_classes; - bool check_types; - bool allow_fold; - bool untyped_ptrs; - bool validate_builtin_types; // do @Builtins variables need to match their type in builtins.h ? - bool is_simt; - - bool allow_subgroup_memory; - bool allow_shared_memory; - - struct { - /// Selects which type the subgroup intrinsic primops use to manipulate masks - enum { - /// Uses the MaskType - SubgroupMaskAbstract, - /// Uses a 64-bit integer - SubgroupMaskInt64 - } subgroup_mask_representation; - - uint32_t subgroup_size; - uint32_t workgroup_size[3]; - } specializations; - - struct { - IntSizes ptr_size; - /// The base type for emulated memory - IntSizes word_size; - } memory; - - /// 'folding' optimisations - happen in the constructors directly - struct { - bool delete_unreachable_structured_cases; - } optimisations; -} ArenaConfig; - -typedef struct CompilerConfig_ CompilerConfig; -ArenaConfig default_arena_config(); - -IrArena* new_ir_arena(ArenaConfig); -void destroy_ir_arena(IrArena*); -ArenaConfig get_arena_config(const IrArena*); - -//////////////////////////////// Getters //////////////////////////////// - -/// Get the name out of a global variable, function or constant -String get_decl_name(const Node*); -String get_value_name(const Node*); -String get_value_name_safe(const Node*); -void set_variable_name(Node* var, String); - -const Node* get_quoted_value(const Node* instruction); -const IntLiteral* resolve_to_int_literal(const Node* node); -int64_t get_int_literal_value(IntLiteral, bool sign_extend); -const FloatLiteral* resolve_to_float_literal(const Node* node); -double get_float_literal_value(FloatLiteral); -const char* get_string_literal(IrArena*, const Node*); - -String get_address_space_name(AddressSpace); -/// Returns false iff pointers in that address space can contain different data at the same address -/// (amongst threads in the same subgroup) -bool is_addr_space_uniform(IrArena*, AddressSpace); - -const Node* lookup_annotation(const Node* decl, const char* name); -const Node* lookup_annotation_list(Nodes, const char* name); -const Node* get_annotation_value(const Node* annotation); -Nodes get_annotation_values(const Node* annotation); -/// Gets the string literal attached to an annotation, if present. -const char* get_annotation_string_payload(const Node* annotation); -bool lookup_annotation_with_string_payload(const Node* decl, const char* annotation_name, const char* expected_payload); -String get_annotation_name(const Node* node); -Nodes filter_out_annotation(IrArena*, Nodes, const char* name); - -bool is_abstraction (const Node*); -String get_abstraction_name (const Node* abs); -const Node* get_abstraction_body (const Node* abs); -Nodes get_abstraction_params(const Node* abs); - -void set_abstraction_body (Node* abs, const Node* body); - -const Node* get_let_instruction(const Node* let); -const Node* get_let_tail(const Node* let); - -typedef struct { - bool enter_loads; - bool allow_incompatible_types; - bool assume_globals_immutability; -} NodeResolveConfig; -NodeResolveConfig default_node_resolve_config(); -const Node* resolve_ptr_to_value(const Node* node, NodeResolveConfig config); - -const Node* resolve_node_to_definition(const Node* node, NodeResolveConfig config); - -//////////////////////////////// Constructors //////////////////////////////// - -/// For typing things that don't return at all -const Type* noret_type(IrArena*); -/// For making pointers to nothing in particular (equivalent to C's void*) -const Node* unit_type(IrArena*); -/// For typing instructions that return nothing (equivalent to C's void f()) -const Node* empty_multiple_return_type(IrArena*); - -const Type* int_type_helper(IrArena*, bool, IntSizes); - -const Type* int8_type(IrArena*); -const Type* int16_type(IrArena*); -const Type* int32_type(IrArena*); -const Type* int64_type(IrArena*); - -const Type* uint8_type(IrArena*); -const Type* uint16_type(IrArena*); -const Type* uint32_type(IrArena*); -const Type* uint64_type(IrArena*); - -const Type* int8_literal(IrArena*, int8_t i); -const Type* int16_literal(IrArena*, int16_t i); -const Type* int32_literal(IrArena*, int32_t i); -const Type* int64_literal(IrArena*, int64_t i); - -const Type* uint8_literal(IrArena*, uint8_t i); -const Type* uint16_literal(IrArena*, uint16_t i); -const Type* uint32_literal(IrArena*, uint32_t i); -const Type* uint64_literal(IrArena*, uint64_t i); - -const Type* fp16_type(IrArena*); -const Type* fp32_type(IrArena*); -const Type* fp64_type(IrArena*); - -const Node* fp_literal_helper(IrArena*, FloatSizes, double); - -const Node* type_decl_ref_helper(IrArena*, const Node* decl); - -// values -Node* var(IrArena*, const Type* type, const char* name); - -const Node* tuple_helper(IrArena*, Nodes contents); -const Node* composite_helper(IrArena*, const Type*, Nodes contents); -const Node* fn_addr_helper(IrArena*, const Node* fn); -const Node* ref_decl_helper(IrArena*, const Node* decl); -const Node* string_lit_helper(IrArena* a, String s); -const Node* annotation_value_helper(IrArena* a, String n, const Node* v); - -// instructions -/// Turns a value into an 'instruction' (the enclosing let will be folded away later) -/// Useful for local rewrites -const Node* quote_helper(IrArena*, Nodes values); -const Node* prim_op_helper(IrArena*, Op, Nodes, Nodes); - -// terminators -const Node* let(IrArena*, const Node* instruction, const Node* tail); -const Node* let_mut(IrArena*, const Node* instruction, const Node* tail); -const Node* jump_helper(IrArena* a, const Node* dst, Nodes args); - -// decl ctors -Node* function (Module*, Nodes params, const char* name, Nodes annotations, Nodes return_types); -Node* constant (Module*, Nodes annotations, const Type*, const char* name); -Node* global_var (Module*, Nodes annotations, const Type*, String, AddressSpace); -Type* nominal_type(Module*, Nodes annotations, String name); - -// basic blocks, lambdas and their helpers -Node* basic_block(IrArena*, Node* function, Nodes params, const char* name); -const Node* case_(IrArena* a, Nodes params, const Node* body); - -/// Used to build a chain of let -typedef struct BodyBuilder_ BodyBuilder; -BodyBuilder* begin_body(IrArena*); - -/// Appends an instruction to the builder, may apply optimisations. -/// If the arena is typed, returns a list of variables bound to the values yielded by that instruction -Nodes bind_instruction(BodyBuilder*, const Node* instruction); -Nodes bind_instruction_named(BodyBuilder*, const Node* instruction, String const output_names[]); - -/// Like append instruction, but you explicitly give it information about any yielded values -/// ! In untyped arenas, you need to call this because we can't guess how many things are returned without typing info ! -Nodes bind_instruction_explicit_result_types(BodyBuilder*, const Node* initial_value, Nodes provided_types, String const output_names[], bool mut); -Nodes bind_instruction_outputs_count(BodyBuilder*, const Node* initial_value, size_t outputs_count, String const output_names[], bool mut); - -void bind_variables(BodyBuilder*, Nodes vars, Nodes values); - -const Node* finish_body(BodyBuilder*, const Node* terminator); -void cancel_body(BodyBuilder*); -const Node* yield_values_and_wrap_in_block_explicit_return_types(BodyBuilder*, Nodes, const Nodes*); -const Node* yield_values_and_wrap_in_block(BodyBuilder*, Nodes); -const Node* bind_last_instruction_and_wrap_in_block_explicit_return_types(BodyBuilder*, const Node*, const Nodes*); -const Node* bind_last_instruction_and_wrap_in_block(BodyBuilder*, const Node*); - -//////////////////////////////// Compilation //////////////////////////////// - -struct CompilerConfig_ { - bool dynamic_scheduling; - uint32_t per_thread_stack_size; - - struct { - uint8_t major; - uint8_t minor; - } target_spirv_version; - - struct { - bool emulate_subgroup_ops; - bool emulate_subgroup_ops_extended_types; - bool simt_to_explicit_simd; - bool int64; - bool decay_ptrs; - } lower; - - struct { - bool spv_shuffle_instead_of_broadcast_first; - bool force_join_point_lifting; - bool no_physical_global_ptrs; - } hacks; - - struct { - struct { - bool after_every_pass; - bool delete_unused_instructions; - } cleanup; - } optimisations; - - struct { - bool memory_accesses; - bool stack_accesses; - bool god_function; - bool stack_size; - bool subgroup_ops; - } printf_trace; - - struct { - int max_top_iterations; - } shader_diagnostics; - - struct { - bool skip_generated, skip_builtin, skip_internal; - } logging; - - struct { - String entry_point; - ExecutionModel execution_model; - uint32_t subgroup_size; - } specialization; - - struct { - struct { void* uptr; void (*fn)(void*, String, Module*); } after_pass; - } hooks; -}; - -CompilerConfig default_compiler_config(); - -typedef enum CompilationResult_ { - CompilationNoError -} CompilationResult; - -CompilationResult run_compiler_passes(CompilerConfig* config, Module** mod); - -//////////////////////////////// Emission //////////////////////////////// - -void emit_spirv(CompilerConfig* config, Module*, size_t* output_size, char** output, Module** new_mod); - -typedef enum { - C, - GLSL, - ISPC -} CDialect; - -typedef struct { - CDialect dialect; - bool explicitly_sized_types; - bool allow_compound_literals; -} CEmitterConfig; - -void emit_c(CompilerConfig compiler_config, CEmitterConfig emitter_config, Module*, size_t* output_size, char** output, Module** new_mod); - -void dump_cfg(FILE* file, Module*); -void dump_loop_trees(FILE* output, Module* mod); -void dump_module(Module*); -void print_module_into_str(Module*, char** str_ptr, size_t*); -void dump_node(const Node* node); +#include "shady/ir/base.h" + +#include "shady/ir/arena.h" +#include "shady/ir/module.h" +#include "shady/ir/grammar.h" + +#include "shady/ir/int.h" +#include "shady/ir/float.h" +#include "shady/ir/composite.h" +#include "shady/ir/execution_model.h" +#include "shady/ir/primop.h" +#include "shady/ir/debug.h" +#include "shady/ir/annotation.h" +#include "shady/ir/mem.h" +#include "shady/ir/type.h" +#include "shady/ir/function.h" +#include "shady/ir/decl.h" + +#include "shady/ir/builder.h" +#include "shady/analysis/literal.h" #endif diff --git a/include/shady/ir/annotation.h b/include/shady/ir/annotation.h new file mode 100644 index 000000000..3c481fd9c --- /dev/null +++ b/include/shady/ir/annotation.h @@ -0,0 +1,15 @@ +#ifndef SHADY_IR_ANNOTATION_H +#define SHADY_IR_ANNOTATION_H + +#include "shady/ir/base.h" + +const Node* shd_lookup_annotation(const Node* decl, const char* name); +const Node* shd_lookup_annotation_list(Nodes annotations, const char* name); +const Node* shd_get_annotation_value(const Node* annotation); +Nodes shd_get_annotation_values(const Node* annotation); +/// Gets the string literal attached to an annotation, if present. +const char* shd_get_annotation_string_payload(const Node* annotation); +bool shd_lookup_annotation_with_string_payload(const Node* decl, const char* annotation_name, const char* expected_payload); +Nodes shd_filter_out_annotation(IrArena* arena, Nodes annotations, const char* name); + +#endif diff --git a/include/shady/ir/arena.h b/include/shady/ir/arena.h new file mode 100644 index 000000000..e700e6cf2 --- /dev/null +++ b/include/shady/ir/arena.h @@ -0,0 +1,11 @@ +#ifndef SHADY_IR_ARENA_H +#define SHADY_IR_ARENA_H + +#include "shady/ir/base.h" +#include "shady/config.h" + +IrArena* shd_new_ir_arena(const ArenaConfig* config); +void shd_destroy_ir_arena(IrArena* arena); +const Node* shd_get_node_by_id(const IrArena* a, NodeId id); + +#endif diff --git a/include/shady/ir/base.h b/include/shady/ir/base.h new file mode 100644 index 000000000..aeaf97364 --- /dev/null +++ b/include/shady/ir/base.h @@ -0,0 +1,54 @@ +#ifndef SHADY_IR_BASE_H +#define SHADY_IR_BASE_H + +#include +#include +#include + +#ifdef __GNUC__ +#define SHADY_DESIGNATED_INIT __attribute__((designated_init)) +#else +#define SHADY_DESIGNATED_INIT +#endif + +typedef struct IrArena_ IrArena; +typedef struct Module_ Module; +typedef struct Node_ Node; +typedef struct Node_ Type; +typedef uint32_t NodeId; +typedef const char* String; + +typedef struct Nodes_ { + size_t count; + const Node** nodes; +} Nodes; + +typedef struct Strings_ { + size_t count; + String* strings; +} Strings; + +Nodes shd_nodes(IrArena*, size_t count, const Node*[]); +Strings shd_strings(IrArena* arena, size_t count, const char** in_strs); + +Nodes shd_empty(IrArena* a); +Nodes shd_singleton(const Node* n); + +#define mk_nodes(arena, ...) shd_nodes(arena, sizeof((const Node*[]) { __VA_ARGS__ }) / sizeof(const Node*), (const Node*[]) { __VA_ARGS__ }) + +const Node* shd_first(Nodes nodes); + +Nodes shd_nodes_append(IrArena*, Nodes, const Node*); +Nodes shd_nodes_prepend(IrArena*, Nodes, const Node*); +Nodes shd_concat_nodes(IrArena* arena, Nodes a, Nodes b); +Nodes shd_change_node_at_index(IrArena* arena, Nodes old, size_t i, const Node* n); +bool shd_find_in_nodes(Nodes nodes, const Node* n); + +String shd_string_sized(IrArena*, size_t size, const char* start); +String shd_string(IrArena*, const char*); + +// see also: format_string in util.h +String shd_fmt_string_irarena(IrArena* arena, const char* str, ...); +String shd_make_unique_name(IrArena* arena, const char* str); + +#endif diff --git a/include/shady/ir/builder.h b/include/shady/ir/builder.h new file mode 100644 index 000000000..3fce3e89a --- /dev/null +++ b/include/shady/ir/builder.h @@ -0,0 +1,65 @@ +#ifndef SHADY_BUILDER_H +#define SHADY_BUILDER_H + +#include "shady/ir/base.h" + +typedef struct BodyBuilder_ BodyBuilder; + +/// Used to build a chain of let +BodyBuilder* shd_bld_begin(IrArena* a, const Node* mem); +BodyBuilder* shd_bld_begin_pure(IrArena* a); +BodyBuilder* shd_bld_begin_pseudo_instr(IrArena* a, const Node* mem); + +IrArena* shd_get_bb_arena(BodyBuilder* bb); + +/// Appends an instruction to the builder, may apply optimisations. +/// If the arena is typed, returns a list of variables bound to the values yielded by that instruction +Nodes shd_bld_add_instruction_extract(BodyBuilder* bb, const Node* instruction); +const Node* shd_bld_add_instruction(BodyBuilder* bb, const Node* instr); + +/// Like append shd_bld_add_instruction_extract, but you explicitly give it information about any yielded values +/// ! In untyped arenas, you need to call this because we can't guess how many things are returned without typing info ! +Nodes shd_bld_add_instruction_extract_count(BodyBuilder* bb, const Node* instruction, size_t outputs_count); + +Nodes shd_bld_if(BodyBuilder* bb, Nodes yield_types, const Node* condition, const Node* true_case, Node* false_case); +Nodes shd_bld_match(BodyBuilder* bb, Nodes yield_types, const Node* inspectee, Nodes literals, Nodes cases, Node* default_case); +Nodes shd_bld_loop(BodyBuilder* bb, Nodes yield_types, Nodes initial_args, Node* body); + +typedef struct { + Nodes results; + Node* case_; + const Node* jp; +} begin_control_t; +begin_control_t shd_bld_begin_control(BodyBuilder* bb, Nodes yield_types); + +typedef struct { + Nodes results; + Node* loop_body; + Nodes params; + const Node* continue_jp; + const Node* break_jp; +} begin_loop_helper_t; +begin_loop_helper_t shd_bld_begin_loop_helper(BodyBuilder* bb, Nodes yield_types, Nodes arg_types, Nodes initial_values); + +Nodes shd_bld_control(BodyBuilder* bb, Nodes yield_types, Node* body); + +const Node* shd_bb_mem(BodyBuilder* bb); + +const Node* shd_bld_finish(BodyBuilder* bb, const Node* terminator); +const Node* shd_bld_return(BodyBuilder* bb, Nodes args); +const Node* shd_bld_unreachable(BodyBuilder* bb); +const Node* shd_bld_selection_merge(BodyBuilder* bb, Nodes args); +const Node* shd_bld_loop_continue(BodyBuilder* bb, Nodes args); +const Node* shd_bld_loop_break(BodyBuilder* bb, Nodes args); +const Node* shd_bld_join(BodyBuilder* bb, const Node* jp, Nodes args); +const Node* shd_bld_jump(BodyBuilder* bb, const Node* target, Nodes args); + +void shd_bld_cancel(BodyBuilder* bb); + +const Node* shd_bld_to_instr_yield_value(BodyBuilder* bb, const Node* value); +const Node* shd_bld_to_instr_yield_values(BodyBuilder* bb, Nodes values); +const Node* shd_bld_to_instr_with_last_instr(BodyBuilder* bb, const Node* instruction); + +const Node* shd_bld_to_instr_pure_with_values(BodyBuilder* bb, Nodes values); + +#endif diff --git a/include/shady/ir/builtin.h b/include/shady/ir/builtin.h new file mode 100644 index 000000000..67bcd80d4 --- /dev/null +++ b/include/shady/ir/builtin.h @@ -0,0 +1,65 @@ +#ifndef SHADY_IR_BUILTIN_H +#define SHADY_IR_BUILTIN_H + +#include "shady/ir/grammar.h" +#include "shady/ir/int.h" +#include "shady/ir/float.h" + +#define shd_u32vec3_type(arena) pack_type(arena, (PackType) { .width = 3, .element_type = shd_uint32_type(arena) }) +#define shd_i32vec3_type(arena) pack_type(arena, (PackType) { .width = 3, .element_type = shd_int32_type(arena) }) +#define shd_i32vec4_type(arena) pack_type(arena, (PackType) { .width = 4, .element_type = shd_int32_type(arena) }) + +#define shd_f32vec4_type(arena) pack_type(arena, (PackType) { .width = 4, .element_type = shd_fp32_type(arena) }) + +#define SHADY_BUILTINS() \ +BUILTIN(BaseInstance, AsInput, shd_uint32_type(arena) )\ +BUILTIN(BaseVertex, AsInput, shd_uint32_type(arena) )\ +BUILTIN(DeviceIndex, AsInput, shd_uint32_type(arena) )\ +BUILTIN(DrawIndex, AsInput, shd_uint32_type(arena) )\ +BUILTIN(VertexIndex, AsInput, shd_uint32_type(arena) )\ +BUILTIN(FragCoord, AsInput, shd_f32vec4_type(arena) )\ +BUILTIN(FragDepth, AsOutput, shd_fp32_type(arena) )\ +BUILTIN(InstanceId, AsInput, shd_uint32_type(arena) )\ +BUILTIN(InvocationId, AsInput, shd_uint32_type(arena) )\ +BUILTIN(InstanceIndex, AsInput, shd_uint32_type(arena) )\ +BUILTIN(LocalInvocationId, AsInput, shd_u32vec3_type(arena) )\ +BUILTIN(LocalInvocationIndex, AsInput, shd_uint32_type(arena) )\ +BUILTIN(GlobalInvocationId, AsInput, shd_u32vec3_type(arena) )\ +BUILTIN(WorkgroupId, AsUInput, shd_u32vec3_type(arena) )\ +BUILTIN(WorkgroupSize, AsUInput, shd_u32vec3_type(arena) )\ +BUILTIN(NumSubgroups, AsUInput, shd_uint32_type(arena) )\ +BUILTIN(NumWorkgroups, AsUInput, shd_u32vec3_type(arena) )\ +BUILTIN(Position, AsOutput, shd_f32vec4_type(arena) )\ +BUILTIN(PrimitiveId, AsInput, shd_uint32_type(arena) )\ +BUILTIN(SubgroupLocalInvocationId, AsInput, shd_uint32_type(arena) )\ +BUILTIN(SubgroupId, AsUInput, shd_uint32_type(arena) )\ +BUILTIN(SubgroupSize, AsInput, shd_uint32_type(arena) )\ + +typedef enum { +#define BUILTIN(name, as, datatype) Builtin##name, +SHADY_BUILTINS() +#undef BUILTIN + BuiltinsCount +} Builtin; + +AddressSpace shd_get_builtin_address_space(Builtin builtin); +String shd_get_builtin_name(Builtin builtin); + +const Type* shd_get_builtin_type(IrArena* arena, Builtin builtin); +Builtin shd_get_builtin_by_name(String s); + +typedef enum SpvBuiltIn_ SpvBuiltIn; +Builtin shd_get_builtin_by_spv_id(SpvBuiltIn id); + +bool shd_is_decl_builtin(const Node* decl); +Builtin shd_get_decl_builtin(const Node* decl); + +int32_t shd_get_builtin_spv_id(Builtin builtin); + +bool shd_is_builtin_load_op(const Node* n, Builtin* out); + +const Node* shd_get_builtin(Module* m, Builtin b); +const Node* shd_get_or_create_builtin(Module* m, Builtin b, String n); +const Node* shd_bld_builtin_load(Module* m, BodyBuilder* bb, Builtin b); + +#endif diff --git a/include/shady/ir/cast.h b/include/shady/ir/cast.h new file mode 100644 index 000000000..6d660ed15 --- /dev/null +++ b/include/shady/ir/cast.h @@ -0,0 +1,13 @@ +#ifndef SHADY_IR_CAST_H +#define SHADY_IR_CAST_H + +#include "shady/ir/base.h" +#include "shady/ir/builder.h" + +const Node* shd_bld_reinterpret_cast(BodyBuilder* bb, const Type* dst, const Node* src); +const Node* shd_bld_conversion(BodyBuilder* bb, const Type* dst, const Node* src); + +bool shd_is_reinterpret_cast_legal(const Type* src_type, const Type* dst_type); +bool shd_is_conversion_legal(const Type* src_type, const Type* dst_type); + +#endif diff --git a/include/shady/ir/composite.h b/include/shady/ir/composite.h new file mode 100644 index 000000000..e34eab39c --- /dev/null +++ b/include/shady/ir/composite.h @@ -0,0 +1,17 @@ +#ifndef SHADY_IR_COMPOSITE_H +#define SHADY_IR_COMPOSITE_H + +#include "shady/ir/grammar.h" + +const Node* shd_maybe_tuple_helper(IrArena* a, Nodes values); +const Node* shd_tuple_helper(IrArena*, Nodes contents); + +const Node* shd_extract_helper(IrArena* a, const Node* base, Nodes selectors); +const Node* shd_extract_single_helper(IrArena* a, const Node* composite, const Node* index); + +void shd_enter_composite_type(const Type** datatype, bool* uniform, const Node* selector, bool allow_entering_pack); +void shd_enter_composite_type_indices(const Type** datatype, bool* uniform, Nodes indices, bool allow_entering_pack); + +Nodes shd_deconstruct_composite(IrArena* a, const Node* value, size_t outputs_count); + +#endif diff --git a/include/shady/ir/debug.h b/include/shady/ir/debug.h new file mode 100644 index 000000000..694ba0178 --- /dev/null +++ b/include/shady/ir/debug.h @@ -0,0 +1,15 @@ +#ifndef SHADY_IR_DEBUG_H +#define SHADY_IR_DEBUG_H + +#include "shady/ir/base.h" +#include "shady/ir/builder.h" + +/// Get the name out of a global variable, function or constant +String shd_get_value_name_safe(const Node* v); +String shd_get_value_name_unsafe(const Node* v); +void shd_set_value_name(const Node* var, String name); + +void shd_bld_comment(BodyBuilder* bb, String str); +void shd_bld_debug_printf(BodyBuilder* bb, String pattern, Nodes args); + +#endif diff --git a/include/shady/ir/decl.h b/include/shady/ir/decl.h new file mode 100644 index 000000000..6a2755da0 --- /dev/null +++ b/include/shady/ir/decl.h @@ -0,0 +1,16 @@ +#ifndef SHADY_IR_DECL_H +#define SHADY_IR_DECL_H + +#include "shady/ir/base.h" +#include "shady/ir/grammar.h" + +Node* _shd_constant(Module*, Nodes annotations, const Type*, const char* name); +Node* _shd_global_var(Module*, Nodes annotations, const Type*, String, AddressSpace); + +static inline Node* constant(Module* m, Nodes annotations, const Type* t, const char* name) { return _shd_constant(m, annotations, t, name); } +static inline Node* global_var(Module* m, Nodes annotations, const Type* t, String name, AddressSpace as) { return _shd_global_var(m, annotations, t, name, as); } + +typedef struct Rewriter_ Rewriter; +const Node* shd_find_or_process_decl(Rewriter* rewriter, const char* name); + +#endif diff --git a/include/shady/ir/execution_model.h b/include/shady/ir/execution_model.h new file mode 100644 index 000000000..3586d730c --- /dev/null +++ b/include/shady/ir/execution_model.h @@ -0,0 +1,18 @@ +#ifndef SHADY_IR_EXECUTION_MODEL_H +#define SHADY_IR_EXECUTION_MODEL_H + +#define EXECUTION_MODELS(EM) \ +EM(Compute, 1) \ +EM(Fragment, 0) \ +EM(Vertex, 0) \ + +typedef enum { + EmNone, +#define EM(name, _) Em##name, + EXECUTION_MODELS(EM) +#undef EM +} ExecutionModel; + +ExecutionModel shd_execution_model_from_string(const char*); + +#endif diff --git a/include/shady/ir/ext.h b/include/shady/ir/ext.h new file mode 100644 index 000000000..75faf2166 --- /dev/null +++ b/include/shady/ir/ext.h @@ -0,0 +1,9 @@ +#ifndef SHADY_IR_EXT_H +#define SHADY_IR_EXT_H + +#include "shady/ir/base.h" +#include "shady/ir/builder.h" + +const Node* shd_bld_ext_instruction(BodyBuilder* bb, String set, int opcode, const Type* return_t, Nodes operands); + +#endif diff --git a/include/shady/ir/float.h b/include/shady/ir/float.h new file mode 100644 index 000000000..4d32d36d8 --- /dev/null +++ b/include/shady/ir/float.h @@ -0,0 +1,23 @@ +#ifndef SHADY_IR_FLOAT_H +#define SHADY_IR_FLOAT_H + +#include "shady/ir/grammar.h" + +static inline int float_size_in_bytes(FloatSizes s) { + switch (s) { + case FloatTy16: return 2; + case FloatTy32: return 4; + case FloatTy64: return 8; + } +} + +const Type* shd_fp16_type(IrArena* arena); +const Type* shd_fp32_type(IrArena* arena); +const Type* shd_fp64_type(IrArena* arena); + +const Node* shd_fp_literal_helper(IrArena* a, FloatSizes size, double value); + +const FloatLiteral* shd_resolve_to_float_literal(const Node* node); +double shd_get_float_literal_value(FloatLiteral literal); + +#endif diff --git a/include/shady/ir/function.h b/include/shady/ir/function.h new file mode 100644 index 000000000..126eca4a6 --- /dev/null +++ b/include/shady/ir/function.h @@ -0,0 +1,32 @@ +#ifndef SHADY_IR_FUNCTION_H +#define SHADY_IR_FUNCTION_H + +#include "shady/ir/grammar.h" +#include "shady/ir/type.h" + +Node* _shd_param(IrArena*, const Type* type, const char* name); +Node* _shd_function(Module*, Nodes params, const char* name, Nodes annotations, Nodes return_types); +Node* _shd_basic_block(IrArena*, Nodes params, const char* name); + +static inline Node* param(IrArena* a, const Type* type, const char* name) { return _shd_param(a, type, name); } +static inline Node* function(Module* m, Nodes params, const char* name, Nodes annotations, Nodes return_types) { return _shd_function(m, params, name, annotations, return_types); } +static inline Node* basic_block(IrArena* a, Nodes params, const char* name) { return _shd_basic_block(a, params, name); } +static inline Node* case_(IrArena* a, Nodes params) { return basic_block(a, params, NULL); } + +/// For typing instructions that return nothing (equivalent to C's void f()) +static inline const Type* empty_multiple_return_type(IrArena* arena) { + return shd_as_qualified_type(unit_type(arena), true); +} + +inline static bool is_function(const Node* node) { return node->tag == Function_TAG; } + +const Node* shd_get_abstraction_mem(const Node* abs); +String shd_get_abstraction_name(const Node* abs); +String shd_get_abstraction_name_unsafe(const Node* abs); +String shd_get_abstraction_name_safe(const Node* abs); + +void shd_set_abstraction_body(Node* abs, const Node* body); + +Nodes shd_bld_call(BodyBuilder* bb, const Node* callee, Nodes args); + +#endif diff --git a/include/shady/ir/grammar.h b/include/shady/ir/grammar.h new file mode 100644 index 000000000..32c1e1be3 --- /dev/null +++ b/include/shady/ir/grammar.h @@ -0,0 +1,46 @@ +#ifndef SHADY_IR_GRAMMAR_H +#define SHADY_IR_GRAMMAR_H + +#include "shady/ir/base.h" + +// These enums and structs are used in the node payloads so they must live here +// instead of in the relevant header + +typedef enum { + IntTy8, + IntTy16, + IntTy32, + IntTy64, +} IntSizes; + +enum { + IntSizeMin = IntTy8, + IntSizeMax = IntTy64, +}; + +typedef enum { + FloatTy16, + FloatTy32, + FloatTy64 +} FloatSizes; + +typedef enum { + NotSpecial, + /// for instructions with multiple yield values. Must be deconstructed by a let, cannot appear anywhere else + MultipleReturn, + /// Gets the 'Block' SPIR-V annotation, needed for UBO/SSBO variables + DecorateBlock +} RecordSpecialFlag; + +typedef struct BodyBuilder_ BodyBuilder; + +// see primops.json +#include "primops_generated.h" + +// see grammar.json +#include "grammar_generated.h" + +bool shd_is_node_nominal(const Node* node); +const char* shd_get_node_tag_string(NodeTag tag); + +#endif diff --git a/include/shady/ir/int.h b/include/shady/ir/int.h new file mode 100644 index 000000000..8248fae7a --- /dev/null +++ b/include/shady/ir/int.h @@ -0,0 +1,48 @@ +#ifndef SHADY_IR_INT_H +#define SHADY_IR_INT_H + +#include "shady/ir/grammar.h" +#include "shady/ir/builder.h" + +static inline int int_size_in_bytes(IntSizes s) { + switch (s) { + case IntTy8: return 1; + case IntTy16: return 2; + case IntTy32: return 4; + case IntTy64: return 8; + } +} + +const Type* shd_int_type_helper(IrArena* a, bool s, IntSizes w); + +const Type* shd_int8_type(IrArena* arena); +const Type* shd_int16_type(IrArena* arena); +const Type* shd_int32_type(IrArena* arena); +const Type* shd_int64_type(IrArena* arena); + +const Type* shd_uint8_type(IrArena* arena); +const Type* shd_uint16_type(IrArena* arena); +const Type* shd_uint32_type(IrArena* arena); +const Type* shd_uint64_type(IrArena* arena); + +const Node* shd_int8_literal(IrArena* arena, int8_t i); +const Node* shd_int16_literal(IrArena* arena, int16_t i); +const Node* shd_int32_literal(IrArena* arena, int32_t i); +const Node* shd_int64_literal(IrArena* arena, int64_t i); + +const Node* shd_uint8_literal(IrArena* arena, uint8_t u); +const Node* shd_uint16_literal(IrArena* arena, uint16_t u); +const Node* shd_uint32_literal(IrArena* arena, uint32_t u); +const Node* shd_uint64_literal(IrArena* arena, uint64_t u); + +const IntLiteral* shd_resolve_to_int_literal(const Node* node); +int64_t shd_get_int_literal_value(IntLiteral literal, bool sign_extend); + +int64_t shd_get_int_value(const Node* node, bool sign_extend); + +const Node* shd_bld_convert_int_extend_according_to_src_t(BodyBuilder* bb, const Type* dst_type, const Node* src); +const Node* shd_bld_convert_int_extend_according_to_dst_t(BodyBuilder* bb, const Type* dst_type, const Node* src); +const Node* shd_bld_convert_int_zero_extend(BodyBuilder* bb, const Type* dst_type, const Node* src); +const Node* shd_bld_convert_int_sign_extend(BodyBuilder* bb, const Type* dst_type, const Node* src); + +#endif diff --git a/include/shady/ir/mem.h b/include/shady/ir/mem.h new file mode 100644 index 000000000..452fc5aed --- /dev/null +++ b/include/shady/ir/mem.h @@ -0,0 +1,15 @@ +#ifndef SHADY_IR_MEM_H +#define SHADY_IR_MEM_H + +#include "shady/ir/base.h" + +const Node* shd_get_parent_mem(const Node* mem); +const Node* shd_get_original_mem(const Node* mem); + +const Node* shd_bld_stack_alloc(BodyBuilder* bb, const Type* type); +const Node* shd_bld_local_alloc(BodyBuilder* bb, const Type* type); + +const Node* shd_bld_load(BodyBuilder* bb, const Node* ptr); +void shd_bld_store(BodyBuilder* bb, const Node* ptr, const Node* value); + +#endif diff --git a/include/shady/ir/memory_layout.h b/include/shady/ir/memory_layout.h new file mode 100644 index 000000000..543896f8b --- /dev/null +++ b/include/shady/ir/memory_layout.h @@ -0,0 +1,41 @@ +#ifndef SHADY_MEMORY_LAYOUT_H +#define SHADY_MEMORY_LAYOUT_H + +#include "shady/ir/base.h" +#include "shady/ir/grammar.h" +#include "shady/config.h" + +typedef struct { + const Type* type; + size_t size_in_bytes; + size_t alignment_in_bytes; +} TypeMemLayout; + +typedef struct { + TypeMemLayout mem_layout; + size_t offset_in_bytes; +} FieldLayout; + +TypeMemLayout shd_get_mem_layout(IrArena* a, const Type* type); + +TypeMemLayout shd_get_record_layout(IrArena* a, const Node* record_type, FieldLayout* fields); +size_t shd_get_record_field_offset_in_bytes(IrArena* a, const Type* t, size_t i); + +static inline const Node* size_t_type(IrArena* a) { + return int_type(a, (Int) { .width = shd_get_arena_config(a)->memory.ptr_size, .is_signed = false }); +} + +static inline const Node* size_t_literal(IrArena* a, uint64_t value) { + return int_literal(a, (IntLiteral) { .width = shd_get_arena_config(a)->memory.ptr_size, .is_signed = false, .value = value }); +} + +const Node* shd_bytes_to_words(BodyBuilder* bb, const Node* bytes); +uint64_t shd_bytes_to_words_static(const IrArena* a, uint64_t bytes); +IntSizes shd_float_to_int_width(FloatSizes width); + +size_t shd_get_type_bitwidth(const Type* t); + +const Node* _shd_lea_helper(IrArena* a, const Node* ptr, const Node* offset, Nodes indices); +static inline const Node* lea_helper(IrArena* a, const Node* ptr, const Node* offset, Nodes indices) { return _shd_lea_helper(a, ptr, offset, indices); } + +#endif diff --git a/include/shady/ir/module.h b/include/shady/ir/module.h new file mode 100644 index 000000000..0dac817d7 --- /dev/null +++ b/include/shady/ir/module.h @@ -0,0 +1,15 @@ +#ifndef SHADY_IR_MODULE_H +#define SHADY_IR_MODULE_H + +#include "shady/ir/base.h" + +Module* shd_new_module(IrArena* arena, String name); + +IrArena* shd_module_get_arena(const Module* m); +String shd_module_get_name(const Module* m); +Nodes shd_module_get_declarations(const Module* m); +Node* shd_module_get_declaration(const Module* m, String name); + +void shd_module_link(Module* dst, Module* src); + +#endif diff --git a/include/shady/ir/primop.h b/include/shady/ir/primop.h new file mode 100644 index 000000000..643ff0870 --- /dev/null +++ b/include/shady/ir/primop.h @@ -0,0 +1,11 @@ +#ifndef SHADY_IR_PRIMOP_H +#define SHADY_IR_PRIMOP_H + +#include "shady/ir/grammar.h" + +OpClass shd_get_primop_class(Op op); + +String shd_get_primop_name(Op op); +bool shd_has_primop_got_side_effects(Op op); + +#endif diff --git a/include/shady/ir/stack.h b/include/shady/ir/stack.h new file mode 100644 index 000000000..690e4bd8f --- /dev/null +++ b/include/shady/ir/stack.h @@ -0,0 +1,14 @@ +#ifndef SHADY_IR_STACK_H +#define SHADY_IR_STACK_H + +#include "shady/ir/base.h" +#include "shady/ir/builder.h" + +void shd_bld_stack_push_value(BodyBuilder* bb, const Node* value); +void shd_bld_stack_push_values(BodyBuilder* bb, Nodes values); +const Node* shd_bld_stack_pop_value(BodyBuilder* bb, const Type* type); +const Node* shd_bld_get_stack_base_addr(BodyBuilder* bb); +const Node* shd_bld_get_stack_size(BodyBuilder* bb); +void shd_bld_set_stack_size(BodyBuilder* bb, const Node* new_size); + +#endif diff --git a/include/shady/ir/type.h b/include/shady/ir/type.h new file mode 100644 index 000000000..9c4f9471e --- /dev/null +++ b/include/shady/ir/type.h @@ -0,0 +1,97 @@ +#ifndef SHADY_IR_TYPE_H +#define SHADY_IR_TYPE_H + +#include "shady/ir/grammar.h" + +/// Unit type, carries no information (equivalent to C's void) +/// There is exactly one possible value of this type: () +static inline const Node* unit_type(IrArena* arena) { + return record_type(arena, (RecordType) { + .members = shd_empty(arena), + }); +} + +Type* _shd_nominal_type(Module*, Nodes annotations, String name); +static inline Type* nominal_type(Module* m, Nodes annotations, String name) { return _shd_nominal_type(m, annotations, name); } + +const Type* shd_get_actual_mask_type(IrArena* arena); +const Node* shd_get_default_value(IrArena* a, const Type* t); + +bool shd_is_subtype(const Type* supertype, const Type* type); +void shd_check_subtype(const Type* supertype, const Type* type); + +/// Is this a type that a value in the language can have ? +bool shd_is_value_type(const Type*); + +/// Is this a valid data type (for usage in other types and as type arguments) ? +bool shd_is_data_type(const Type*); + +bool shd_is_arithm_type(const Type*); +bool shd_is_shiftable_type(const Type*); +bool shd_has_boolean_ops(const Type*); +bool shd_is_comparable_type(const Type*); +bool shd_is_ordered_type(const Type*); +bool shd_is_physical_ptr_type(const Type* t); +bool shd_is_generic_ptr_type(const Type* t); + +/// Returns the (possibly qualified) pointee type from a (possibly qualified) ptr type +const Type* shd_get_pointee_type(IrArena*, const Type*); + +String shd_get_address_space_name(AddressSpace); +/// Returns false iff pointers in that address space can contain different data at the same address +/// (amongst threads in the same subgroup) +bool shd_is_addr_space_uniform(IrArena*, AddressSpace); + +String shd_get_type_name(IrArena* arena, const Type* t); + +const Type* shd_maybe_multiple_return(IrArena* arena, Nodes types); +Nodes shd_unwrap_multiple_yield_types(IrArena* arena, const Type* type); + +/// Collects the annotated types in the list of variables +/// NB: this is different from get_values_types, that function uses node.type, whereas this one uses node.payload.var.type +/// This means this function works in untyped modules where node.type is NULL. +Nodes shd_get_param_types(IrArena* arena, Nodes variables); + +Nodes shd_get_values_types(IrArena*, Nodes); + +// Qualified type helpers +/// Ensures an operand has divergence-annotated type and extracts it +const Type* shd_get_unqualified_type(const Type*); +bool shd_is_qualified_type_uniform(const Type*); +bool shd_deconstruct_qualified_type(const Type**); + +const Type* shd_as_qualified_type(const Type* type, bool uniform); + +Nodes shd_strip_qualifiers(IrArena*, Nodes); +Nodes shd_add_qualifiers(IrArena*, Nodes, bool); + +// Pack (vector) type helpers +const Type* shd_get_packed_type_element(const Type* type); +size_t shd_get_packed_type_width(const Type* type); +size_t shd_deconstruct_packed_type(const Type** type); + +/// Helper for creating pack types, wraps type in a pack_type if width > 1 +const Type* shd_maybe_packed_type_helper(const Type* type, size_t width); + +/// 'Maybe' variants that work with any types, and assume width=1 for non-packed types +/// Useful for writing generic type checking code ! +const Type* shd_get_maybe_packed_type_element(const Type* type); +size_t shd_get_maybe_packed_type_width(const Type* type); +size_t shd_deconstruct_maybe_packed_type(const Type** type); + +// Pointer type helpers +const Type* shd_get_pointer_type_element(const Type* type); +AddressSpace shd_deconstruct_pointer_type(const Type** type); + +// Nominal type helpers +const Node* shd_get_nominal_type_decl(const Type* type); +const Type* shd_get_nominal_type_body(const Type* type); +const Node* shd_get_maybe_nominal_type_decl(const Type* type); +const Type* shd_get_maybe_nominal_type_body(const Type* type); + +// Composite type helpers +Nodes shd_get_composite_type_element_types(const Type* type); +const Node* shd_get_fill_type_element_type(const Type* composite_t); +const Node* shd_get_fill_type_size(const Type* composite_t); + +#endif diff --git a/include/shady/pass.h b/include/shady/pass.h new file mode 100644 index 000000000..1db0a2515 --- /dev/null +++ b/include/shady/pass.h @@ -0,0 +1,19 @@ +#ifndef SHADY_PASS_H +#define SHADY_PASS_H + +#include "shady/ir/arena.h" +#include "shady/ir/module.h" +#include "shady/config.h" +#include "shady/rewrite.h" + +typedef Module* (RewritePass)(const CompilerConfig* config, Module* src); +typedef bool (OptPass)(const CompilerConfig* config, Module** m); + +void shd_run_pass_impl(const CompilerConfig* config, Module** pmod, IrArena* initial_arena, RewritePass pass, String pass_name); +#define RUN_PASS(pass_name) shd_run_pass_impl(config, pmod, initial_arena, pass_name, #pass_name); + +void shd_apply_opt_impl(const CompilerConfig* config, bool* todo, Module** m, OptPass pass, String pass_name); +#define APPLY_OPT(pass_name) shd_apply_opt_impl(config, &todo, &m, pass_name, #pass_name); + +#endif + diff --git a/include/shady/primops.json b/include/shady/primops.json index 1caf963a2..177b04418 100644 --- a/include/shady/primops.json +++ b/include/shady/primops.json @@ -15,24 +15,12 @@ { "name": "math" }, - { - "name": "stack" - }, - { - "name": "memory" - }, { "name": "memory_layout" }, { "name": "subgroup_intrinsic" }, - { - "name": "ast" - }, - { - "name": "join_point" - }, { "name": "mask" } @@ -66,6 +54,10 @@ "name": "div", "class": "arithmetic" }, + { + "name": "fma", + "class": "arithmetic" + }, { "name": "mod", "class": "arithmetic" @@ -182,68 +174,6 @@ "name": "cos", "class": "math" }, - { - "name": "push_stack", - "class": "stack", - "side-effects": true - }, - { - "name": "pop_stack", - "class": "stack", - "side-effects": true - }, - { - "name": "get_stack_pointer", - "class": "stack" - }, - { - "name": "get_stack_base", - "class": "stack" - }, - { - "name": "set_stack_pointer", - "class": "stack", - "side-effects": true - }, - { - "name": "alloca", - "class": "memory", - "side-effects": true - }, - { - "name": "alloca_logical", - "class": "memory", - "side-effects": true - }, - { - "name": "alloca_subgroup", - "class": "memory", - "side-effects": true - }, - { - "name": "load", - "class": "memory", - "side-effects": true - }, - { - "name": "store", - "class": "memory", - "side-effects": true - }, - { - "name": "lea", - "class": "memory" - }, - { - "name": "memcpy", - "class": "memory", - "side-effects": true - }, - { - "name": "memset", - "class": "memory", - "side-effects": true - }, { "name": "size_of", "class": "memory_layout" @@ -256,53 +186,10 @@ "name": "offset_of", "class": "memory_layout" }, - { - "name": "subgroup_elect_first", - "class": "subgroup_intrinsic" - }, - { - "name": "subgroup_broadcast_first", - "class": "subgroup_intrinsic" - }, { "name": "subgroup_assume_uniform", "class": "subgroup_intrinsic" }, - { - "name": "subgroup_reduce_sum", - "class": "subgroup_intrinsic" - }, - { - "name": "subgroup_active_mask", - "class": "subgroup_intrinsic" - }, - { - "name": "subgroup_ballot", - "class": "subgroup_intrinsic" - }, - { - "name": "assign", - "class": "ast", - "side-effects": true - }, - { - "name": "subscript", - "class": "ast", - "side-effects": true - }, - { - "name": "deref", - "class": "ast", - "side-effects": true - }, - { - "name": "addrof", - "class": "ast", - "side-effects": true - }, - { - "name": "quote" - }, { "name": "select" }, @@ -324,24 +211,10 @@ { "name": "shuffle" }, - { - "name": "debug_printf", - "side-effects": true - }, { "name": "sample_texture", "side-effects": true }, - { - "name": "create_joint_point", - "class": "join_point", - "side-effects": true - }, - { - "name": "default_join_point", - "class": "join_point", - "side-effects": true - }, { "name": "empty_mask", "class": "mask" @@ -350,5 +223,144 @@ "name": "mask_is_thread_active", "class": "mask" } + ], + "nodes": [ + { + "name": "PrimOp", + "class": ["instruction", "value"], + "ops": [ + { "name": "op", "type": "Op" }, + { "name": "type_arguments", "class": "type", "list": true }, + { "name": "operands", "class": "value", "list": true } + ] + }, + { + "name": "StackAlloc", + "class": ["instruction", "value", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "type", "class": "type" } + ] + }, + { + "name": "LocalAlloc", + "class": ["instruction", "value", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "type", "class": "type" } + ] + }, + { + "name": "Load", + "class": ["instruction", "value", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "ptr", "class": "value" } + ] + }, + { + "name": "Store", + "class": ["instruction", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "ptr", "class": "value" }, + { "name": "value", "class": "value" } + ] + }, + { + "name": "PtrCompositeElement", + "class": ["instruction", "value"], + "ops": [ + { "name": "ptr", "class": "value" }, + { "name": "index", "class": "value" } + ] + }, + { + "name": "PtrArrayElementOffset", + "class": ["instruction", "value"], + "ops": [ + { "name": "ptr", "class": "value" }, + { "name": "offset", "class": "value" } + ] + }, + { + "name": "CopyBytes", + "class": ["instruction", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "dst", "class": "value" }, + { "name": "src", "class": "value" }, + { "name": "count", "class": "value" } + ] + }, + { + "name": "FillBytes", + "class": ["instruction", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "dst", "class": "value" }, + { "name": "src", "class": "value" }, + { "name": "count", "class": "value" } + ] + }, + { + "name": "PushStack", + "class": ["instruction", "mem", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "value", "class": "value" } + ] + }, + { + "name": "PopStack", + "class": ["instruction", "value", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "type", "class": "type" } + ] + }, + { + "name": "GetStackSize", + "class": ["instruction", "value", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" } + ] + }, + { + "name": "SetStackSize", + "class": ["instruction", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "value", "class": "value" } + ] + }, + { + "name": "GetStackBaseAddr", + "class": ["instruction", "value"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" } + ] + }, + { + "name": "DebugPrintf", + "class": ["instruction", "mem"], + "nominal": true, + "ops": [ + { "name": "mem", "class": "mem" }, + { "name": "string", "class": "string" }, + { "name": "args", "class": "value", "list": true } + ] + } ] } \ No newline at end of file diff --git a/include/shady/print.h b/include/shady/print.h new file mode 100644 index 000000000..76fbd3f74 --- /dev/null +++ b/include/shady/print.h @@ -0,0 +1,25 @@ +#ifndef SHADY_PRINT +#define SHADY_PRINT + +#include "shady/ir/base.h" +#include "shady/be/dump.h" + +#include "printer.h" + +typedef struct { + bool print_builtin; + bool print_internal; + bool print_generated; + bool print_ptrs; + bool color; + bool reparseable; + bool in_cfg; +} NodePrintConfig; + +void shd_print_module_into_str(Module* mod, char** str_ptr, size_t* size); +void shd_print_node_into_str(const Node* node, char** str_ptr, size_t* size); + +void shd_print_module(Printer* printer, NodePrintConfig config, Module* mod); +void shd_print_node(Printer* printer, NodePrintConfig config, const Node* node); + +#endif \ No newline at end of file diff --git a/include/shady/rewrite.h b/include/shady/rewrite.h new file mode 100644 index 000000000..fbe41a43d --- /dev/null +++ b/include/shady/rewrite.h @@ -0,0 +1,74 @@ +#ifndef SHADY_REWRITE_H +#define SHADY_REWRITE_H + +#include "shady/ir/grammar.h" + +typedef struct Rewriter_ Rewriter; + +typedef const Node* (*RewriteNodeFn)(Rewriter*, const Node*); +typedef const Node* (*RewriteOpFn)(Rewriter*, NodeClass, String, const Node*); + +const Node* shd_rewrite_node(Rewriter* rewriter, const Node* node); +const Node* shd_rewrite_node_with_fn(Rewriter* rewriter, const Node* node, RewriteNodeFn fn); + +const Node* shd_rewrite_op(Rewriter* rewriter, NodeClass class, String op_name, const Node* node); +const Node* shd_rewrite_op_with_fn(Rewriter* rewriter, NodeClass class, String op_name, const Node* node, RewriteOpFn fn); + +/// Applies the rewriter to all nodes in the collection +Nodes shd_rewrite_nodes(Rewriter* rewriter, Nodes old_nodes); +Nodes shd_rewrite_nodes_with_fn(Rewriter* rewriter, Nodes values, RewriteNodeFn fn); + +Nodes shd_rewrite_ops(Rewriter* rewriter, NodeClass class, String op_name, Nodes old_nodes); +Nodes shd_rewrite_ops_with_fn(Rewriter* rewriter, NodeClass class, String op_name, Nodes values, RewriteOpFn fn); + +struct Rewriter_ { + RewriteNodeFn rewrite_fn; + RewriteOpFn rewrite_op_fn; + IrArena* src_arena; + IrArena* dst_arena; + Module* src_module; + Module* dst_module; + struct { + bool search_map; + bool write_map; + } config; + + Rewriter* parent; + + struct Dict* map; + bool own_decls; + struct Dict* decls_map; +}; + +Rewriter shd_create_rewriter_base(Module* src, Module* dst); +Rewriter shd_create_node_rewriter(Module* src, Module* dst, RewriteNodeFn fn); +Rewriter shd_create_op_rewriter(Module* src, Module* dst, RewriteOpFn fn); +Rewriter shd_create_importer(Module* src, Module* dst); + +Rewriter shd_create_children_rewriter(Rewriter* parent); +Rewriter shd_create_decl_rewriter(Rewriter* parent); +void shd_destroy_rewriter(Rewriter* r); + +void shd_rewrite_module(Rewriter* rewriter); + +/// Rewrites a node using the rewriter to provide the node and type operands +const Node* shd_recreate_node(Rewriter* rewriter, const Node* node); + +/// Rewrites a constant / function header +Node* shd_recreate_node_head(Rewriter* rewriter, const Node* old); +void shd_recreate_node_body(Rewriter* rewriter, const Node* old, Node* new); + +/// Rewrites a variable under a new identity +const Node* shd_recreate_param(Rewriter* rewriter, const Node* old); +Nodes shd_recreate_params(Rewriter* rewriter, Nodes oparams); + +/// Looks up if the node was already processed +const Node** shd_search_processed(const Rewriter* ctx, const Node* old); +/// Same as shd_search_processed but asserts if it fails to find a mapping +const Node* shd_find_processed(const Rewriter* ctx, const Node* old); +void shd_register_processed(Rewriter* ctx, const Node* old, const Node* new); +void shd_register_processed_list(Rewriter* rewriter, Nodes old, Nodes new); + +void shd_dump_rewriter_map(Rewriter* r); + +#endif diff --git a/include/shady/runtime.h b/include/shady/runtime.h index 0604dc39e..68aa75a80 100644 --- a/include/shady/runtime.h +++ b/include/shady/runtime.h @@ -11,39 +11,44 @@ typedef struct { bool allow_no_devices; } RuntimeConfig; +RuntimeConfig shd_rt_default_config(); +void shd_rt_cli_parse_runtime_config(RuntimeConfig* config, int* pargc, char** argv); + typedef struct Runtime_ Runtime; typedef struct Device_ Device; typedef struct Program_ Program; -typedef struct Command_ Command; +typedef struct Command_ Command; typedef struct Buffer_ Buffer; -Runtime* initialize_runtime(RuntimeConfig config); -void shutdown_runtime(Runtime*); +Runtime* shd_rt_initialize(RuntimeConfig config); +void shd_rt_shutdown(Runtime* runtime); -size_t device_count(Runtime*); -Device* get_device(Runtime*, size_t i); -Device* get_an_device(Runtime*); -const char* get_device_name(Device*); +size_t shd_rt_device_count(Runtime* r); +Device* shd_rt_get_device(Runtime* r, size_t i); +Device* shd_rt_get_an_device(Runtime* r); +const char* shd_rt_get_device_name(Device* d); typedef struct CompilerConfig_ CompilerConfig; typedef struct Module_ Module; -Program* new_program_from_module(Runtime*, const CompilerConfig*, Module*); -Program* load_program(Runtime*, const CompilerConfig*, const char* program_src); -Program* load_program_from_disk(Runtime*, const CompilerConfig*, const char* path); +Program* shd_rt_new_program_from_module(Runtime* runtime, const CompilerConfig* base_config, Module* mod); + +typedef struct { + uint64_t* profiled_gpu_time; +} ExtraKernelOptions; -Command* launch_kernel(Program*, Device*, const char* entry_point, int dimx, int dimy, int dimz, int args_count, void** args); -bool wait_completion(Command*); +Command* shd_rt_launch_kernel(Program* p, Device* d, const char* entry_point, int dimx, int dimy, int dimz, int args_count, void** args, ExtraKernelOptions* extra_options); +bool shd_rt_wait_completion(Command* cmd); -Buffer* allocate_buffer_device(Device*, size_t); -bool can_import_host_memory(Device*); -Buffer* import_buffer_host(Device*, void*, size_t); -void destroy_buffer(Buffer*); +Buffer* shd_rt_allocate_buffer_device(Device* device, size_t bytes); +bool shd_rt_can_import_host_memory(Device* device); +Buffer* shd_rt_import_buffer_host(Device* device, void* ptr, size_t bytes); +void shd_rt_destroy_buffer(Buffer* buf); -void* get_buffer_host_pointer(Buffer* buf); -uint64_t get_buffer_device_pointer(Buffer* buf); +void* shd_rt_get_buffer_host_pointer(Buffer* buf); +uint64_t shd_rt_get_buffer_device_pointer(Buffer* buf); -bool copy_to_buffer(Buffer* dst, size_t buffer_offset, void* src, size_t size); -bool copy_from_buffer(Buffer* src, size_t buffer_offset, void* dst, size_t size); +bool shd_rt_copy_to_buffer(Buffer* dst, size_t buffer_offset, void* src, size_t size); +bool shd_rt_copy_from_buffer(Buffer* src, size_t buffer_offset, void* dst, size_t size); #endif diff --git a/include/shady/spv_imports.json b/include/shady/spv_imports.json new file mode 100644 index 000000000..0c4a105ed --- /dev/null +++ b/include/shady/spv_imports.json @@ -0,0 +1,47 @@ +{ + "instruction-filters": [ + { + "operand-filters": [ + { + "import": "yes" + }, + { + "filter-kind": { "IdResult": {}}, + "import": "no" + }, + { + "filter-kind": { "IdRef": {} }, + "overlay": { + "class": "value" + } + }, + { + "filter-kind": { "Dim": {}, "LiteralInteger": {}, "ImageFormat": {} }, + "overlay": { + "type": "uint32_t" + } + } + ] + }, + { + "filter-name": { "OpTypeSampler": {}, "OpTypeImage": {}, "OpTypeSampledImage": {} }, + "import": "yes", + "operand-filters": [ + { + "filter-kind": { "AccessQualifier": {}}, + "import": "no" + }, + { + "filter-kind": { "IdRef": {}}, + "overlay": { + "class": "type" + } + } + ], + "overlay": { + "class": "type", + "type": false + } + } + ] +} \ No newline at end of file diff --git a/include/shady/visit.h b/include/shady/visit.h new file mode 100644 index 000000000..c8e9d1b03 --- /dev/null +++ b/include/shady/visit.h @@ -0,0 +1,28 @@ +#ifndef SHADY_VISIT_H +#define SHADY_VISIT_H + +#include "shady/ir/grammar.h" + +typedef struct Visitor_ Visitor; +typedef void (*VisitNodeFn)(Visitor*, const Node*); +typedef void (*VisitOpFn)(Visitor*, NodeClass, String, const Node*, size_t); + +struct Visitor_ { + VisitNodeFn visit_node_fn; + VisitOpFn visit_op_fn; +}; + +void shd_visit_node_operands(Visitor* visitor, NodeClass exclude, const Node* node); +void shd_visit_module(Visitor* visitor, Module* mod); + +void shd_visit_node(Visitor* visitor, const Node* node); +void shd_visit_nodes(Visitor* visitor, Nodes nodes); + +void shd_visit_op(Visitor* visitor, NodeClass op_class, String op_name, const Node* op, size_t i); +void shd_visit_ops(Visitor* visitor, NodeClass op_class, String op_name, Nodes ops); + +// visits the abstractions in the function, starting with the entry block (ie the function itself) +void shd_visit_function_rpo(Visitor* visitor, const Node* function); +void shd_visit_function_bodies_rpo(Visitor* visitor, const Node* function); + +#endif diff --git a/include/vcc/driver.h b/include/vcc/driver.h new file mode 100644 index 000000000..d41aa3aed --- /dev/null +++ b/include/vcc/driver.h @@ -0,0 +1,22 @@ +#ifndef VCC_DRIVER_H +#define VCC_DRIVER_H + +#include "shady/driver.h" + +typedef struct { + bool delete_tmp_file; + bool only_run_clang; + const char* tmp_filename; + const char* include_path; +} VccConfig; + +void vcc_check_clang(void); + +VccConfig vcc_init_config(CompilerConfig* compiler_config); +void cli_parse_vcc_args(VccConfig* options, int* pargc, char** argv); +void destroy_vcc_options(VccConfig vcc_options); + +void vcc_run_clang(VccConfig* vcc_options, size_t num_source_files, String input_filenames[]); +Module* vcc_parse_back_into_module(CompilerConfig* config, VccConfig* vcc_options, String module_name); + +#endif diff --git a/murmur3 b/murmur3 deleted file mode 160000 index dae94be0c..000000000 --- a/murmur3 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit dae94be0c0f54a399d23ea6cbe54bca5a4e93ce4 diff --git a/readme.md b/readme.md index f05c8a763..0a3f17eee 100644 --- a/readme.md +++ b/readme.md @@ -75,7 +75,7 @@ The grammar is defined in [grammar.json](include/shady/grammar.json), this file ## Language syntax -The textual syntax of the language is C-like in that return types come first. Variance annotations are supported. +The textual syntax of the language is C-like in that return types come shd_first. Variance annotations are supported. Overall the language is structurally close to SPIR-V and LLVM, very much on purpose. There is a 'front-end' (slim) variant of the IR that allows for mutable variables and using instructions as values. diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 2af9c4d93..47f50f483 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -1,4 +1,6 @@ +option(SHADY_ENABLE_SAMPLES "Demo applications and gpu programs" ON) +# TODO: this probably doesn't belong here # find math lib; based on https://stackoverflow.com/a/74867749 find_library(MATH_LIBRARY m) if (MATH_LIBRARY) @@ -8,5 +10,7 @@ else() add_library(m INTERFACE) endif() -add_subdirectory(checkerboard) -add_subdirectory(aobench) +if (SHADY_ENABLE_SAMPLES) + add_subdirectory(checkerboard) + add_subdirectory(aobench) +endif() diff --git a/samples/aobench/CMakeLists.txt b/samples/aobench/CMakeLists.txt index ef8e7503f..ce5f18ced 100644 --- a/samples/aobench/CMakeLists.txt +++ b/samples/aobench/CMakeLists.txt @@ -1,22 +1,21 @@ -find_program(CLANG_EXE "clang") -find_program(LLVM-SPIRV_EXE "llvm-spirv") - -if (CLANG_EXE AND LLVM-SPIRV_EXE) +if (NOT TARGET vcc) + message("Vcc unavailable. Skipping aobench sample.") +elseif(NOT TARGET runtime) + message("Runtime component unavailable. Skipping aobench sample.") +else() add_executable(aobench_host ao_host.c ao_main.c) + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ao.comp.c.ll COMMAND vcc ARGS ${CMAKE_CURRENT_SOURCE_DIR}/ao.comp.cpp --only-run-clang -O3 -fno-slp-vectorize -fno-vectorize -o ${CMAKE_CURRENT_BINARY_DIR}/ao.comp.c.ll COMMENT ao.comp.c.ll DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/ao.comp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ao.c) - add_custom_command(OUTPUT ao.cl.ll COMMAND clang ARGS ${CMAKE_CURRENT_SOURCE_DIR}/ao.cl -std=clc++2021 -emit-llvm -o ${CMAKE_CURRENT_BINARY_DIR}/ao.cl.ll -c -target spir64-unknown-unknown -O2 DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/ao.cl ${CMAKE_CURRENT_SOURCE_DIR}/ao.c) - add_custom_command(OUTPUT ao.cl.spv COMMAND llvm-spirv ARGS ${CMAKE_CURRENT_BINARY_DIR}/ao.cl.ll -o ${CMAKE_CURRENT_BINARY_DIR}/ao.cl.spv DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/ao.cl.ll) - - set_property(SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/ao_main.c APPEND PROPERTY OBJECT_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/ao.cl.spv) + set_property(SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/ao_main.c APPEND PROPERTY OBJECT_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/ao.comp.c.ll) add_custom_command(TARGET aobench_host POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - ${CMAKE_CURRENT_BINARY_DIR}/ao.cl.spv - ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ao.cl.spv) + ${CMAKE_CURRENT_BINARY_DIR}/ao.comp.c.ll + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ao.comp.c.ll) find_program(ISPC_EXE "ispc") if (ISPC_EXE) target_compile_definitions(aobench_host PUBLIC ENABLE_ISPC=1) - add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc COMMAND slim ARGS ao.cl.spv --output ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc --entry-point aobench_kernel COMMENT generating aobench.ispc DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/ao.cl.spv) + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc COMMAND slim ARGS ${CMAKE_CURRENT_BINARY_DIR}/ao.comp.c.ll -o ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc --entry-point aobench_kernel --restructure-everything COMMENT generating aobench.ispc DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/ao.comp.c.ll) add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc.o COMMAND ispc ARGS ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc -o ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc.o --pic -g -O2 -woff COMMENT generating aobench.ispc.o DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc) add_library(aobench_ispc OBJECT ${CMAKE_CURRENT_BINARY_DIR}/aobench.ispc.o) set_target_properties(aobench_ispc PROPERTIES LINKER_LANGUAGE C) @@ -24,6 +23,4 @@ if (CLANG_EXE AND LLVM-SPIRV_EXE) endif() target_link_libraries(aobench_host PRIVATE m shady runtime common) -else() - message("Clang and/or llvm-spirv not found. Skipping aobench.") endif() diff --git a/samples/aobench/ao.c b/samples/aobench/ao.c index 1f9010dd8..c5e5b4b46 100644 --- a/samples/aobench/ao.c +++ b/samples/aobench/ao.c @@ -2,7 +2,11 @@ #include "ao.h" -unsigned int FNVHash(char* str, unsigned int length) { +#ifndef FUNCTION +#define FUNCTION +#endif + +FUNCTION unsigned int FNVHash(char* str, unsigned int length) { const unsigned int fnv_prime = 0x811C9DC5; unsigned int hash = 0; unsigned int i = 0; @@ -16,31 +20,31 @@ unsigned int FNVHash(char* str, unsigned int length) { return hash; } -unsigned int nrand(unsigned int* rng) { +FUNCTION unsigned int nrand(unsigned int* rng) { unsigned int orand = *rng; *rng = FNVHash((char*) &orand, 4); return *rng; } -Scalar drand48(Ctx* ctx) { +FUNCTION Scalar drand48(Ctx* ctx) { Scalar n = (nrand(&ctx->rng) / 65536.0f); n = n - floorf(n); return n; } -static Scalar vdot(vec v0, vec v1) +FUNCTION Scalar vdot(vec v0, vec v1) { return v0.x * v1.x + v0.y * v1.y + v0.z * v1.z; } -static void vcross(vec *c, vec v0, vec v1) +FUNCTION void vcross(vec *c, vec v0, vec v1) { c->x = v0.y * v1.z - v0.z * v1.y; c->y = v0.z * v1.x - v0.x * v1.z; c->z = v0.x * v1.y - v0.y * v1.x; } -static void vnormalize(vec *c) +FUNCTION void vnormalize(vec *c) { Scalar length = sqrtf(vdot((*c), (*c))); @@ -51,10 +55,10 @@ static void vnormalize(vec *c) } } -static void +FUNCTION void ray_sphere_intersect(Isect *isect, const Ray *ray, const Sphere *sphere) { - vec rs; + vec rs = { 0 }; rs.x = ray->org.x - sphere->center.x; rs.y = ray->org.y - sphere->center.y; @@ -84,7 +88,7 @@ ray_sphere_intersect(Isect *isect, const Ray *ray, const Sphere *sphere) } } -static void +FUNCTION void ray_plane_intersect(Isect *isect, const Ray *ray, const Plane *plane) { Scalar d = -vdot(plane->p, plane->n); @@ -106,7 +110,7 @@ ray_plane_intersect(Isect *isect, const Ray *ray, const Plane *plane) } } -static void +FUNCTION void orthoBasis(vec *basis, vec n) { basis[2] = n; @@ -129,7 +133,7 @@ orthoBasis(vec *basis, vec n) vnormalize(&basis[1]); } -static void ambient_occlusion(Ctx* ctx, vec *col, const Isect *isect) +FUNCTION void ambient_occlusion(Ctx* ctx, vec *col, const Isect *isect) { int i, j; int ntheta = NAO_SAMPLES; @@ -189,7 +193,7 @@ static void ambient_occlusion(Ctx* ctx, vec *col, const Isect *isect) col->z = occlusion; } -unsigned char aobench_clamp(Scalar f) +FUNCTION unsigned char aobench_clamp(Scalar f) { Scalar s = (f * 255.5f); @@ -199,7 +203,13 @@ unsigned char aobench_clamp(Scalar f) return (unsigned char) s; } -void render_pixel(Ctx* ctx, int x, int y, int w, int h, int nsubsamples, unsigned char* img) { +EXTERNAL_FN Ctx get_init_context() { + return (Ctx) { + .rng = 0xFEEFDEED, + }; +} + +EXTERNAL_FN void render_pixel(Ctx* ctx, int x, int y, int w, int h, int nsubsamples, TEXEL_T* img) { Scalar pixel[3] = { 0, 0, 0 }; ctx->rng = x * w + y; @@ -210,7 +220,7 @@ void render_pixel(Ctx* ctx, int x, int y, int w, int h, int nsubsamples, unsigne Scalar px = (x + (u / (Scalar)nsubsamples) - (w / 2.0f)) / (w / 2.0f); Scalar py = -(y + (v / (Scalar)nsubsamples) - (h / 2.0f)) / (h / 2.0f); - Ray ray = {}; + Ray ray = { 0 }; ray.org.x = 0.0f; ray.org.y = 0.0f; @@ -221,7 +231,7 @@ void render_pixel(Ctx* ctx, int x, int y, int w, int h, int nsubsamples, unsigne ray.dir.z = -1.0f; vnormalize(&(ray.dir)); - Isect isect = {}; + Isect isect = { 0 }; isect.t = 1.0e+17f; isect.hit = 0; @@ -259,7 +269,7 @@ void render_pixel(Ctx* ctx, int x, int y, int w, int h, int nsubsamples, unsigne img[3 * (y * w + x) + 2] = aobench_clamp(pixel[2]); } -void init_scene(Ctx* ctx) +EXTERNAL_FN void init_scene(Ctx* ctx) { ctx->spheres[0].center.x = -2.0f; ctx->spheres[0].center.y = 0.0f; @@ -285,9 +295,3 @@ void init_scene(Ctx* ctx) ctx->plane.n.z = 0.0f; } - -Ctx get_init_context() { - return (Ctx) { - .rng = 0xFEEFDEED, - }; -} \ No newline at end of file diff --git a/samples/aobench/ao.cl b/samples/aobench/ao.cl deleted file mode 100644 index a93631672..000000000 --- a/samples/aobench/ao.cl +++ /dev/null @@ -1,43 +0,0 @@ -#pragma OPENCL EXTENSION __cl_clang_function_pointers : enable - -#include "ao.c" - -Scalar sqrtf(Scalar s) { return sqrt(s); } -Scalar floorf(Scalar s) { return floor(s); } -Scalar fabsf(Scalar s) { return fabs(s); } -Scalar sinf(Scalar s) { return sin(s); } -Scalar cosf(Scalar s) { return cos(s); } - -global char zero; - -extern "C" { - -void debug_printf_i64(const __constant char*, long int) __asm__("__shady::prim_op::debug_printf::i64"); -void debug_printf_i32_i32(const __constant char*, int, int) __asm__("__shady::prim_op::debug_printf::i32_i32"); - -} - -__attribute__((reqd_work_group_size(16, 16, 1))) -kernel void aobench_kernel(global unsigned char* out) { - int x = get_global_id(0); - int y = get_global_id(1); - - long int ptr = (long int) out; - //debug_printf_i64("ptr: %lu\n", ptr); - //debug_printf_i32_i32("ptr: %d %d\n", (int) (ptr << 32), (int) ptr); - - Ctx ctx = get_init_context(); - init_scene(&ctx); - - render_pixel(&ctx, x, y, WIDTH, HEIGHT, NSUBSAMPLES, out); - /*if (((x / 16) % 2) == ((y / 16) % 2)) { - out[((y * HEIGHT) + x) * 3 + 0] = x; - out[((y * HEIGHT) + x) * 3 + 1] = y; - out[((y * HEIGHT) + x) * 3 + 2] = 0; - } else { - out[((y * HEIGHT) + x) * 3 + 0] = 255; - out[((y * HEIGHT) + x) * 3 + 1] = zero; - out[((y * HEIGHT) + x) * 3 + 2] = zero; - }*/ - //out[index] = add(in1[index], in2[index]); -} \ No newline at end of file diff --git a/samples/aobench/ao.comp.cpp b/samples/aobench/ao.comp.cpp new file mode 100644 index 000000000..40bfc7fd6 --- /dev/null +++ b/samples/aobench/ao.comp.cpp @@ -0,0 +1,52 @@ +#include + +#define compute_shader __attribute__((annotate("shady::entry_point::Compute"))) + +#define location(i) __attribute__((annotate("shady::location::"#i))) + +#define input __attribute__((address_space(389))) +#define output __attribute__((address_space(390))) +#define global __attribute__((address_space(1))) + +typedef uint32_t uvec4 __attribute__((ext_vector_type(4))); +typedef float vec4 __attribute__((ext_vector_type(4))); + +typedef uint32_t uvec3 __attribute__((ext_vector_type(3))); +typedef float vec3 __attribute__((ext_vector_type(3))); + +/*__attribute__((annotate("shady::builtin::FragCoord"))) +input vec4 fragCoord; + +location(0) input vec3 fragColor; +location(0) output vec4 outColor;*/ + +__attribute__((annotate("shady::builtin::WorkgroupId"))) +input uvec3 workgroup_id; + +__attribute__((annotate("shady::builtin::GlobalInvocationId"))) +input uvec3 global_id; + +float sqrtf(float) __asm__("shady::prim_op::sqrt"); +float sinf(float) __asm__("shady::prim_op::sin"); +float cosf(float) __asm__("shady::prim_op::cos"); +float fmodf(float, float) __asm__("shady::prim_op::mod"); +float fabsf(float) __asm__("shady::prim_op::abs"); +float floorf(float) __asm__("shady::prim_op::floor"); + +#define EXTERNAL_FN static +#define FUNCTION static + +#include "ao.c" + +#define xstr(s) str(s) +#define str(s) #s + +extern "C" __attribute__((annotate("shady::workgroup_size::" xstr(BLOCK_SIZE) "::" xstr(BLOCK_SIZE) "::1"))) +compute_shader void aobench_kernel(global TEXEL_T* out) { + Ctx ctx = get_init_context(); + init_scene(&ctx); + + int x = global_id.x; + int y = global_id.y; + render_pixel(&ctx, x, y, WIDTH, HEIGHT, NSUBSAMPLES, (TEXEL_T*) out); +} diff --git a/samples/aobench/ao.cu b/samples/aobench/ao.cu new file mode 100644 index 000000000..8a3168459 --- /dev/null +++ b/samples/aobench/ao.cu @@ -0,0 +1,23 @@ +// #define EXTERNAL_FN static inline __device__ __attribute__((always_inline)) +// #define FUNCTION static inline __device__ __attribute__((always_inline)) + +#define EXTERNAL_FN static __device__ +#define FUNCTION static __device__ + +#include "ao.c" + +extern "C" { + +__global__ void aobench_kernel(TEXEL_T* out) { + int x = threadIdx.x + blockDim.x * blockIdx.x; + int y = threadIdx.y + blockDim.y * blockIdx.y; + + Ctx ctx = get_init_context(); + init_scene(&ctx); + render_pixel(&ctx, x, y, WIDTH, HEIGHT, NSUBSAMPLES, out); + // out[3 * (y * 2048 + x) + 0] = 255; + // out[3 * (y * 2048 + x) + 1] = 255; + // out[3 * (y * 2048 + x) + 2] = 255; +} + +} \ No newline at end of file diff --git a/samples/aobench/ao.h b/samples/aobench/ao.h index dae5dda64..d0b359810 100644 --- a/samples/aobench/ao.h +++ b/samples/aobench/ao.h @@ -2,6 +2,8 @@ #define HEIGHT 2048 #define NSUBSAMPLES 1 #define NAO_SAMPLES 8 +#define BLOCK_SIZE 16 +#define TEXEL_T unsigned char typedef float Scalar; @@ -63,6 +65,6 @@ typedef struct { unsigned int rng; } Ctx; -Ctx get_init_context(); -void init_scene(Ctx*); -void render_pixel(Ctx*, int x, int y, int w, int h, int nsubsamples, unsigned char* img); +EXTERNAL_FN Ctx get_init_context(); +EXTERNAL_FN void init_scene(Ctx*); +EXTERNAL_FN void render_pixel(Ctx*, int x, int y, int w, int h, int nsubsamples, TEXEL_T* img); diff --git a/samples/aobench/ao_host.c b/samples/aobench/ao_host.c index 086cf08d6..5dad91dc6 100644 --- a/samples/aobench/ao_host.c +++ b/samples/aobench/ao_host.c @@ -1,3 +1,4 @@ #define private - +#define EXTERNAL_FN /* not static */ +#define FUNCTION static #include "ao.c" \ No newline at end of file diff --git a/samples/aobench/ao_main.c b/samples/aobench/ao_main.c index 82980a730..aa57a9a5f 100644 --- a/samples/aobench/ao_main.c +++ b/samples/aobench/ao_main.c @@ -1,24 +1,27 @@ -#include "ao.h" +#define EXTERNAL_FN /* not static */ -#include -#include -#include -#include -#include -#include +#include "ao.h" +#include "../runtime/runtime_app_common.h" #include "shady/runtime.h" #include "shady/driver.h" +#include "portability.h" #include "log.h" -#include "list.h" #include "util.h" -static uint64_t timespec_to_nano(struct timespec t) { - return t.tv_sec * 1000000000 + t.tv_nsec; -} +#include +#include +#include +#include -void saveppm(const char *fname, int w, int h, unsigned char *img) { +typedef struct { + CompilerConfig compiler_config; + RuntimeConfig runtime_config; + CommonAppArgs common_app_args; +} Args; + +void saveppm(const char *fname, int w, int h, TEXEL_T* img) { FILE *fp; fp = fopen(fname, "wb"); @@ -27,18 +30,20 @@ void saveppm(const char *fname, int w, int h, unsigned char *img) { fprintf(fp, "P6\n"); fprintf(fp, "%d %d\n", w, h); fprintf(fp, "255\n"); - fwrite(img, w * h * 3, 1, fp); + // fwrite(img, w * h * 3, 1, fp); + for (size_t i = 0; i < w * h * 3; i++) { + unsigned char c = img[i]; + fwrite(&c, 1, 1, fp); + } fclose(fp); } -void render_host(unsigned char *img, int w, int h, int nsubsamples) { +void render_host(TEXEL_T* img, int w, int h, int nsubsamples) { int x, y; Scalar* fimg = (Scalar *)malloc(sizeof(Scalar) * w * h * 3); memset((void *)fimg, 0, sizeof(Scalar) * w * h * 3); - struct timespec ts; - timespec_get(&ts, TIME_UTC); - uint64_t tsn = timespec_to_nano(ts); + uint64_t tsn = shd_get_time_nano(); Ctx ctx = get_init_context(); init_scene(&ctx); @@ -47,10 +52,8 @@ void render_host(unsigned char *img, int w, int h, int nsubsamples) { render_pixel(&ctx, x, y, w, h, nsubsamples, img); } } - struct timespec tp; - timespec_get(&tp, TIME_UTC); - uint64_t tpn = timespec_to_nano(tp); - info_print("reference rendering took %d us\n", (tpn - tsn) / 1000); + uint64_t tpn = shd_get_time_nano(); + shd_info_print("reference rendering took %d us\n", (tpn - tsn) / 1000); } #ifdef ENABLE_ISPC @@ -60,12 +63,10 @@ typedef struct { uint32_t x, y, z; } Vec3u; -Vec3u builtin_NumWorkgroups; +extern Vec3u builtin_NumWorkgroups; -void render_ispc(unsigned char *img, int w, int h, int nsubsamples) { - struct timespec ts; - timespec_get(&ts, TIME_UTC); - uint64_t tsn = timespec_to_nano(ts); +void render_ispc(TEXEL_T* img, int w, int h, int nsubsamples) { + uint64_t tsn = shd_get_time_nano(); Ctx ctx = get_init_context(); init_scene(&ctx); for (size_t i = 0; i < WIDTH; i++) { @@ -81,14 +82,12 @@ void render_ispc(unsigned char *img, int w, int h, int nsubsamples) { builtin_NumWorkgroups.z = 1; aobench_kernel(img); - struct timespec tp; - timespec_get(&tp, TIME_UTC); - uint64_t tpn = timespec_to_nano(tp); - info_print("ispc rendering took %d us\n", (tpn - tsn) / 1000); + uint64_t tpn = shd_get_time_nano(); + shd_info_print("ispc rendering took %d us\n", (tpn - tsn) / 1000); } #endif -void render_device(const CompilerConfig* compiler_config, unsigned char *img, int w, int h, int nsubsamples, String path) { +void render_device(Args* args, TEXEL_T *img, int w, int h, int nsubsamples, String path, bool import_memory) { for (size_t i = 0; i < WIDTH; i++) { for (size_t j = 0; j < HEIGHT; j++) { img[j * WIDTH * 3 + i * 3 + 0] = 255; @@ -97,69 +96,80 @@ void render_device(const CompilerConfig* compiler_config, unsigned char *img, in } } - set_log_level(INFO); - - RuntimeConfig runtime_config = (RuntimeConfig) { - .use_validation = true, - .dump_spv = true, - }; - - info_print("Shady checkerboard test starting...\n"); + shd_info_print("Shady checkerboard test starting...\n"); - Runtime* runtime = initialize_runtime(runtime_config); - Device* device = get_device(runtime, 0); + Runtime* runtime = shd_rt_initialize(args->runtime_config); + Device* device = shd_rt_get_device(runtime, args->common_app_args.device); assert(device); img[0] = 69; - info_print("malloc'd address is: %zu\n", (size_t) img); + shd_info_print("malloc'd address is: %zu\n", (size_t) img); - Buffer* buf = import_buffer_host(device, img, sizeof(uint8_t) * WIDTH * HEIGHT * 3); - uint64_t buf_addr = get_buffer_device_pointer(buf); + Buffer* buf; + if (import_memory) + buf = shd_rt_import_buffer_host(device, img, sizeof(*img) * WIDTH * HEIGHT * 3); + else + buf = shd_rt_allocate_buffer_device(device, sizeof(*img) * WIDTH * HEIGHT * 3); - info_print("Device-side address is: %zu\n", buf_addr); + uint64_t buf_addr = shd_rt_get_buffer_device_pointer(buf); - Program* program = load_program_from_disk(runtime, compiler_config, path); + shd_info_print("Device-side address is: %zu\n", buf_addr); - // run it twice to compile everything and benefit from caches - wait_completion(launch_kernel(program, device, "aobench_kernel", WIDTH / 16, HEIGHT / 16, 1, 1, (void*[]) { &buf_addr })); - struct timespec ts; - timespec_get(&ts, TIME_UTC); - uint64_t tsn = timespec_to_nano(ts); - wait_completion(launch_kernel(program, device, "aobench_kernel", WIDTH / 16, HEIGHT / 16, 1, 1, (void*[]) { &buf_addr })); - struct timespec tp; - timespec_get(&tp, TIME_UTC); - uint64_t tpn = timespec_to_nano(tp); - info_print("device rendering took %d us\n", (tpn - tsn) / 1000); + Module* m; + CHECK(shd_driver_load_source_file_from_filename(&args->compiler_config, path, "aobench", &m) == NoError, return); + Program* program = shd_rt_new_program_from_module(runtime, &args->compiler_config, m); - debug_print("data %d\n", (int) img[0]); + // run it twice to compile everything and benefit from caches + shd_rt_wait_completion(shd_rt_launch_kernel(program, device, "aobench_kernel", WIDTH / BLOCK_SIZE, HEIGHT / BLOCK_SIZE, 1, 1, (void* []) { &buf_addr }, NULL)); + uint64_t tsn = shd_get_time_nano(); + uint64_t profiled_gpu_time = 0; + ExtraKernelOptions extra_kernel_options = { + .profiled_gpu_time = &profiled_gpu_time + }; + shd_rt_wait_completion(shd_rt_launch_kernel(program, device, "aobench_kernel", WIDTH / BLOCK_SIZE, HEIGHT / BLOCK_SIZE, 1, 1, (void* []) { &buf_addr }, &extra_kernel_options)); + uint64_t tpn = shd_get_time_nano(); + shd_info_print("device rendering took %dus (gpu time: %dus)\n", (tpn - tsn) / 1000, profiled_gpu_time / 1000); - destroy_buffer(buf); + if (!import_memory) + shd_rt_copy_from_buffer(buf, 0, img, sizeof(*img) * WIDTH * HEIGHT * 3); + shd_debug_print("data %d\n", (int) img[0]); + shd_rt_destroy_buffer(buf); - shutdown_runtime(runtime); + shd_rt_shutdown(runtime); } int main(int argc, char **argv) { - set_log_level(INFO); - CompilerConfig compiler_config = default_compiler_config(); + shd_log_set_level(INFO); + Args args = { + .compiler_config = shd_default_compiler_config(), + .runtime_config = shd_rt_default_config(), + }; + + args.compiler_config.input_cf.restructure_with_heuristics = true; - cli_parse_common_args(&argc, argv); - cli_parse_compiler_config_args(&compiler_config, &argc, argv); + shd_parse_common_args(&argc, argv); + shd_parse_compiler_config_args(&args.compiler_config, &argc, argv); + shd_rt_cli_parse_runtime_config(&args.runtime_config, &argc, argv); + cli_parse_common_app_arguments(&args.common_app_args, &argc, argv); bool do_host = false, do_ispc = false, do_device = false, do_all = true; - for (size_t i = 0; i < argc; i++) { - if (strcmp(argv[i], "--device") == 0) { + for (size_t i = 1; i < argc; i++) { + if (strcmp(argv[i], "--only-device") == 0) { do_device = true; do_all = false; - } else if (strcmp(argv[i], "--host") == 0) { + } else if (strcmp(argv[i], "--only-host") == 0) { do_host = true; do_all = false; - } else if (strcmp(argv[i], "--ispc") == 0) { + } else if (strcmp(argv[i], "--only-ispc") == 0) { do_ispc = true; do_all = false; + } else { + shd_error_print("Unrecognised argument: %s\n", argv[i]); + shd_error_die(); } } - unsigned char *img = (unsigned char *)malloc(WIDTH * HEIGHT * 3); + void *img = malloc(WIDTH * HEIGHT * 3 * sizeof(TEXEL_T)); if (do_host || do_all) { render_host(img, WIDTH, HEIGHT, NSUBSAMPLES); @@ -174,7 +184,7 @@ int main(int argc, char **argv) { #endif if (do_device || do_all) { - render_device(&compiler_config, img, WIDTH, HEIGHT, NSUBSAMPLES, "./ao.cl.spv"); + render_device(&args, img, WIDTH, HEIGHT, NSUBSAMPLES, "./ao.comp.c.ll", false); saveppm("device.ppm", WIDTH, HEIGHT, img); } diff --git a/samples/checkerboard/CMakeLists.txt b/samples/checkerboard/CMakeLists.txt index b4cfb24ca..e8d3d9010 100644 --- a/samples/checkerboard/CMakeLists.txt +++ b/samples/checkerboard/CMakeLists.txt @@ -1,7 +1,10 @@ -add_executable(checkerboard checkerboard.c) -target_link_libraries(checkerboard PRIVATE m runtime common driver) -target_include_directories(checkerboard PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -message(${CMAKE_CURRENT_BINARY_DIR}) +if(NOT TARGET runtime) + message("Runtime component unavailable. Skipping checkerboard sample.") +else() + add_executable(checkerboard checkerboard.c) + target_link_libraries(checkerboard PRIVATE m runtime common driver) + target_include_directories(checkerboard PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -embed_file(string checkerboard_kernel_src checkerboard_kernel.slim) -target_link_libraries(checkerboard PRIVATE checkerboard_kernel_src) + embed_file(string checkerboard_kernel_src checkerboard_kernel.slim) + target_link_libraries(checkerboard PRIVATE checkerboard_kernel_src) +endif() diff --git a/samples/checkerboard/checkerboard.c b/samples/checkerboard/checkerboard.c index fe4c7ac07..42a9905e8 100644 --- a/samples/checkerboard/checkerboard.c +++ b/samples/checkerboard/checkerboard.c @@ -2,6 +2,7 @@ #include #include +#include "shady/ir/arena.h" #include "shady/runtime.h" #include "shady/driver.h" @@ -40,47 +41,48 @@ int main(int argc, char **argv) } } - set_log_level(INFO); - CompilerConfig compiler_config = default_compiler_config(); + shd_log_set_level(INFO); + CompilerConfig compiler_config = shd_default_compiler_config(); - RuntimeConfig runtime_config = (RuntimeConfig) { - .use_validation = true, - .dump_spv = true, - }; + RuntimeConfig runtime_config = shd_rt_default_config(); - cli_parse_common_args(&argc, argv); - cli_parse_compiler_config_args(&compiler_config, &argc, argv); + shd_parse_common_args(&argc, argv); + shd_parse_compiler_config_args(&compiler_config, &argc, argv); + shd_rt_cli_parse_runtime_config(&runtime_config, &argc, argv); - info_print("Shady checkerboard test starting...\n"); + shd_info_print("Shady checkerboard test starting...\n"); - Runtime* runtime = initialize_runtime(runtime_config); - Device* device = get_device(runtime, 0); + Runtime* runtime = shd_rt_initialize(runtime_config); + Device* device = shd_rt_get_device(runtime, 0); assert(device); img[0] = 69; - info_print("malloc'd address is: %zu\n", (size_t) img); + shd_info_print("malloc'd address is: %zu\n", (size_t) img); int buf_size = sizeof(uint8_t) * WIDTH * HEIGHT * 3; - Buffer* buf = allocate_buffer_device(device, buf_size); - copy_to_buffer(buf, 0, img, buf_size); - uint64_t buf_addr = get_buffer_device_pointer(buf); + Buffer* buf = shd_rt_allocate_buffer_device(device, buf_size); + shd_rt_copy_to_buffer(buf, 0, img, buf_size); + uint64_t buf_addr = shd_rt_get_buffer_device_pointer(buf); - info_print("Device-side address is: %zu\n", buf_addr); + shd_info_print("Device-side address is: %zu\n", buf_addr); - IrArena* a = new_ir_arena(default_arena_config()); - Module* m = new_module(a, "checkerboard"); - driver_load_source_file(SrcSlim, sizeof(checkerboard_kernel_src), checkerboard_kernel_src, m); - Program* program = new_program_from_module(runtime, &compiler_config, m); + ArenaConfig aconfig = shd_default_arena_config(&compiler_config.target); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* m; + if (shd_driver_load_source_file(&compiler_config, SrcSlim, sizeof(checkerboard_kernel_src), checkerboard_kernel_src, + "checkerboard", &m) != NoError) + shd_error("Failed to load checkerboard module"); + Program* program = shd_rt_new_program_from_module(runtime, &compiler_config, m); - wait_completion(launch_kernel(program, device, "main", 16, 16, 1, 1, (void*[]) { &buf_addr })); + shd_rt_wait_completion(shd_rt_launch_kernel(program, device, "checkerboard", 16, 16, 1, 1, (void* []) { &buf_addr }, NULL)); - copy_from_buffer(buf, 0, img, buf_size); - info_print("data %d\n", (int) img[0]); + shd_rt_copy_from_buffer(buf, 0, img, buf_size); + shd_info_print("data %d\n", (int) img[0]); - destroy_buffer(buf); + shd_rt_destroy_buffer(buf); - shutdown_runtime(runtime); - saveppm("ao.ppm", WIDTH, HEIGHT, img); - destroy_ir_arena(a); + shd_rt_shutdown(runtime); + saveppm("checkerboard.ppm", WIDTH, HEIGHT, img); + shd_destroy_ir_arena(a); free(img); } diff --git a/samples/checkerboard/checkerboard_kernel.slim b/samples/checkerboard/checkerboard_kernel.slim index 8f43575de..937ebd742 100644 --- a/samples/checkerboard/checkerboard_kernel.slim +++ b/samples/checkerboard/checkerboard_kernel.slim @@ -2,9 +2,9 @@ const i32 WIDTH = 256; const i32 HEIGHT = 256; @Builtin("GlobalInvocationId") -uniform input pack[u32; 3] global_id; +var uniform input pack[u32; 3] global_id; -@EntryPoint("Compute") @WorkgroupSize(16, 16, 1) fn main(uniform ptr global [u8] p) { +@EntryPoint("Compute") @Exported @WorkgroupSize(16, 16, 1) fn checkerboard(uniform ptr global [u8] p) { val thread_id = global_id; val x = reinterpret[i32](thread_id#0); val y = reinterpret[i32](thread_id#1); diff --git a/samples/fib.slim b/samples/fib.slim index f8d6e4edd..ff21bab31 100644 --- a/samples/fib.slim +++ b/samples/fib.slim @@ -5,12 +5,12 @@ fn fib varying u32(varying u32 n) { } @Builtin("SubgroupLocalInvocationId") -input u32 subgroup_local_id; +var input u32 subgroup_local_id; @Builtin("SubgroupId") -uniform input u32 subgroup_id; +var uniform input u32 subgroup_id; -@EntryPoint("Compute") @WorkgroupSize(32, 1, 1) fn main() { +@EntryPoint("Compute") @Exported @WorkgroupSize(SUBGROUP_SIZE, 1, 1) fn main() { val n = subgroup_local_id % u32 16; debug_printf("fib(%d) = %d from thread %d:%d\n", n, fib(n), subgroup_id, subgroup_local_id); return (); diff --git a/samples/hello_world.slim b/samples/hello_world.slim index 0d71ddfed..952ed3a39 100644 --- a/samples/hello_world.slim +++ b/samples/hello_world.slim @@ -1,4 +1,4 @@ -@EntryPoint("Compute") @WorkgroupSize(64, 1, 1) fn main() { +@Exported @EntryPoint("Compute") @WorkgroupSize(64, 1, 1) fn main() { debug_printf("Hello World\n"); return (); } diff --git a/samples/vcc/hello_world.c b/samples/vcc/hello_world.c new file mode 100644 index 000000000..54ec3fd7f --- /dev/null +++ b/samples/vcc/hello_world.c @@ -0,0 +1,10 @@ +#include + +void debug_printf(const char*) __asm__("shady::prim_op::debug_printf"); +void debug_printfi(const char*, int) __asm__("shady::prim_op::debug_printf::i"); + +compute_shader local_size(1, 1, 1) +void main() { + debug_printf("Hello World from Vcc!\n"); + debug_printfi("I can print numbers too: %d!\n", 42); +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 609a0e0ee..f1480e0d9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,9 +1,6 @@ -add_library(murmur3 STATIC ../murmur3/murmur3.c) -target_include_directories(murmur3 INTERFACE ../murmur3) -set_target_properties(murmur3 PROPERTIES POSITION_INDEPENDENT_CODE ON) - add_subdirectory(common) add_subdirectory(shady) add_subdirectory(runtime) -add_subdirectory(frontends) add_subdirectory(driver) +add_subdirectory(frontend) +add_subdirectory(backend) diff --git a/src/backend/CMakeLists.txt b/src/backend/CMakeLists.txt new file mode 100644 index 000000000..2acd53267 --- /dev/null +++ b/src/backend/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(spirv) +add_subdirectory(c) + diff --git a/src/backend/c/CMakeLists.txt b/src/backend/c/CMakeLists.txt new file mode 100644 index 000000000..e65035f52 --- /dev/null +++ b/src/backend/c/CMakeLists.txt @@ -0,0 +1,28 @@ +add_library(shady_c STATIC + emit_c.c + emit_c_value.c + emit_c_type.c + emit_c_builtin.c + emit_c_control_flow.c +) +set_property(TARGET shady_c PROPERTY POSITION_INDEPENDENT_CODE ON) + +target_include_directories(shady_c PRIVATE $) + +target_link_libraries(shady_c PRIVATE "api") +target_link_libraries(shady_c INTERFACE "$") +target_link_libraries(shady_c PRIVATE "$") +target_link_libraries(shady_c PRIVATE "$") + +embed_file(string shady_cuda_prelude_src prelude.cu) +embed_file(string shady_cuda_runtime_src runtime.cu) +target_link_libraries(shady_c PRIVATE "$") +target_link_libraries(shady_c PRIVATE "$") + +embed_file(string shady_glsl_runtime_120_src runtime_120.glsl) +target_link_libraries(shady_c PRIVATE "$") + +embed_file(string shady_ispc_runtime_src runtime.ispc) +target_link_libraries(shady_c PRIVATE "$") + +target_link_libraries(driver PUBLIC "$") diff --git a/src/backend/c/emit_c.c b/src/backend/c/emit_c.c new file mode 100644 index 000000000..7099c460c --- /dev/null +++ b/src/backend/c/emit_c.c @@ -0,0 +1,537 @@ +#include "emit_c.h" + +#include "../shady/ir_private.h" +#include "../shady/passes/passes.h" +#include "../shady/analysis/cfg.h" +#include "../shady/analysis/scheduler.h" + +#include "shady_cuda_prelude_src.h" +#include "shady_cuda_runtime_src.h" +#include "shady_glsl_runtime_120_src.h" +#include "shady_ispc_runtime_src.h" + +#include "portability.h" +#include "dict.h" +#include "log.h" +#include "util.h" + +#include +#include +#include + +#pragma GCC diagnostic error "-Wswitch" + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +void shd_c_register_emitted(Emitter* emitter, FnEmitter* fn, const Node* node, CTerm as) { + //assert(as.value || as.var); + shd_dict_insert(const Node*, CTerm, fn ? fn->emitted_terms : emitter->emitted_terms, node, as); +} + +CTerm* shd_c_lookup_existing_term(Emitter* emitter, FnEmitter* fn, const Node* node) { + CTerm* found = NULL; + if (fn) + found = shd_dict_find_value(const Node*, CTerm, fn->emitted_terms, node); + if (!found) + found = shd_dict_find_value(const Node*, CTerm, emitter->emitted_terms, node); + return found; +} + +void shd_c_register_emitted_type(Emitter* emitter, const Node* node, String as) { + shd_dict_insert(const Node*, String, emitter->emitted_types, node, as); +} + +CType* shd_c_lookup_existing_type(Emitter* emitter, const Type* node) { + CType* found = shd_dict_find_value(const Node*, CType, emitter->emitted_types, node); + return found; +} + +CValue shd_c_to_ssa(SHADY_UNUSED Emitter* e, CTerm term) { + if (term.value) + return term.value; + if (term.var) + return shd_format_string_arena(e->arena->arena, "(&%s)", term.var); + assert(false); +} + +CAddr shd_c_deref(Emitter* e, CTerm term) { + if (term.value) + return shd_format_string_arena(e->arena->arena, "(*%s)", term.value); + if (term.var) + return term.var; + assert(false); +} + +// TODO: utf8 +static bool is_legal_identifier_char(char c) { + if (c >= '0' && c <= '9') + return true; + if (c >= 'a' && c <= 'z') + return true; + if (c >= 'A' && c <= 'Z') + return true; + if (c == '_') + return true; + return false; +} + +String shd_c_legalize_identifier(Emitter* e, String src) { + if (!src) + return "unnamed"; + size_t len = strlen(src); + LARRAY(char, dst, len + 1); + size_t i; + for (i = 0; i < len; i++) { + char c = src[i]; + if (is_legal_identifier_char(c)) + dst[i] = c; + else + dst[i] = '_'; + } + dst[i] = '\0'; + // TODO: collision handling using a dict + return shd_string(e->arena, dst); +} + +static bool has_forward_declarations(CDialect dialect) { + switch (dialect) { + case CDialect_C11: return true; + case CDialect_CUDA: return true; + case CDialect_GLSL: // no global variable forward declarations in GLSL + case CDialect_ISPC: // ISPC seems to share this quirk + return false; + } +} + +/// hack for ISPC: there is no nice way to get a set of varying pointers (instead of a "pointer to a varying") pointing to a varying global +CTerm shd_ispc_varying_ptr_helper(Emitter* emitter, Printer* block_printer, const Type* ptr_type, CTerm term) { + String interm = shd_make_unique_name(emitter->arena, "intermediary_ptr_value"); + assert(ptr_type->tag == PtrType_TAG); + const Type* ut = shd_as_qualified_type(ptr_type, true); + const Type* vt = shd_as_qualified_type(ptr_type, false); + String lhs = shd_c_emit_type(emitter, vt, interm); + shd_print(block_printer, "\n%s = ((%s) %s) + programIndex;", lhs, shd_c_emit_type(emitter, ut, NULL), shd_c_to_ssa(emitter, term)); + return term_from_cvalue(interm); +} + +void shd_c_emit_variable_declaration(Emitter* emitter, Printer* block_printer, const Type* t, String variable_name, bool mut, const CTerm* initializer) { + assert((mut || initializer != NULL) && "unbound results are only allowed when creating a mutable local variable"); + + String prefix = ""; + String center = variable_name; + + // add extra qualifiers if immutable + if (!mut) switch (emitter->config.dialect) { + case CDialect_ISPC: + center = shd_format_string_arena(emitter->arena->arena, "const %s", center); + break; + case CDialect_C11: + case CDialect_CUDA: + center = shd_format_string_arena(emitter->arena->arena, "const %s", center); + break; + case CDialect_GLSL: + if (emitter->config.glsl_version >= 130) + prefix = "const "; + break; + } + + String decl = shd_c_emit_type(emitter, t, center); + if (initializer) + shd_print(block_printer, "\n%s%s = %s;", prefix, decl, shd_c_to_ssa(emitter, *initializer)); + else + shd_print(block_printer, "\n%s%s;", prefix, decl); +} + +void shd_c_emit_pack_code(Printer* p, Strings src, String dst) { + for (size_t i = 0; i < src.count; i++) { + shd_print(p, "\n%s->_%d = %s", dst, src.strings[i], i); + } +} + +void shd_c_emit_unpack_code(Printer* p, String src, Strings dst) { + for (size_t i = 0; i < dst.count; i++) { + shd_print(p, "\n%s = %s->_%d", dst.strings[i], src, i); + } +} + +void shd_c_emit_global_variable_definition(Emitter* emitter, AddressSpace as, String name, const Type* type, bool constant, String init) { + String prefix = NULL; + + bool is_fs = emitter->compiler_config->specialization.execution_model == EmFragment; + // GLSL wants 'const' to go on the left to start the declaration, but in C const should go on the right (east const convention) + switch (emitter->config.dialect) { + case CDialect_C11: { + if (as != AsGeneric) shd_warn_print_once(c11_non_generic_as, "warning: standard C does not have address spaces\n"); + prefix = ""; + if (constant) + name = shd_format_string_arena(emitter->arena->arena, "const %s", name); + break; + } + case CDialect_ISPC: + // ISPC doesn't really do address space qualifiers. + prefix = ""; + break; + case CDialect_CUDA: + switch (as) { + case AsPrivate: + assert(false); + // Note: this requires many hacks. + prefix = "__device__ "; + name = shd_format_string_arena(emitter->arena->arena, "__shady_private_globals.%s", name); + break; + case AsShared: prefix = "__shared__ "; break; + case AsGlobal: { + if (constant) + prefix = "__constant__ "; + else + prefix = "__device__ __managed__ "; + break; + } + default: { + prefix = shd_format_string_arena(emitter->arena->arena, "/* %s */", shd_get_address_space_name(as)); + shd_warn_print("warning: address space %s not supported in CUDA for global variables\n", shd_get_address_space_name(as)); + break; + } + } + break; + case CDialect_GLSL: + switch (as) { + case AsShared: prefix = "shared "; break; + case AsInput: + case AsUInput: prefix = emitter->config.glsl_version < 130 ? (is_fs ? "varying " : "attribute ") : "in "; break; + case AsOutput: prefix = emitter->config.glsl_version < 130 ? "varying " : "out "; break; + case AsPrivate: prefix = ""; break; + case AsUniformConstant: prefix = "uniform "; break; + case AsGlobal: { + assert(constant && "Only constants are supported"); + prefix = "const "; + break; + } + default: { + prefix = shd_format_string_arena(emitter->arena->arena, "/* %s */", shd_get_address_space_name(as)); + shd_warn_print("warning: address space %s not supported in GLSL for global variables\n", shd_get_address_space_name(as)); + break; + } + } + break; + } + + assert(prefix); + + // ISPC wants uniform/varying annotations + if (emitter->config.dialect == CDialect_ISPC) { + bool uniform = shd_is_addr_space_uniform(emitter->arena, as); + if (uniform) + name = shd_format_string_arena(emitter->arena->arena, "uniform %s", name); + else + name = shd_format_string_arena(emitter->arena->arena, "varying %s", name); + } + + if (init) + shd_print(emitter->fn_decls, "\n%s%s = %s;", prefix, shd_c_emit_type(emitter, type, name), init); + else + shd_print(emitter->fn_decls, "\n%s%s;", prefix, shd_c_emit_type(emitter, type, name)); + + //if (!has_forward_declarations(emitter->config.dialect) || !init) + // return; + // + //String declaration = c_emit_type(emitter, type, decl_center); + //shd_print(emitter->fn_decls, "\n%s;", declaration); +} + +void shd_c_emit_decl(Emitter* emitter, const Node* decl) { + assert(is_declaration(decl)); + + CTerm* found = shd_c_lookup_existing_term(emitter, NULL, decl); + if (found) return; + + CType* found2 = shd_c_lookup_existing_type(emitter, decl); + if (found2) return; + + const char* name = shd_c_legalize_identifier(emitter, get_declaration_name(decl)); + const Type* decl_type = decl->type; + const char* decl_center = name; + CTerm emit_as; + + switch (decl->tag) { + case GlobalVariable_TAG: { + String init = NULL; + if (decl->payload.global_variable.init) + init = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, NULL, decl->payload.global_variable.init)); + AddressSpace ass = decl->payload.global_variable.address_space; + if (ass == AsInput || ass == AsOutput) + init = NULL; + + const GlobalVariable* gvar = &decl->payload.global_variable; + if (shd_is_decl_builtin(decl)) { + Builtin b = shd_get_decl_builtin(decl); + CTerm t = shd_c_emit_builtin(emitter, b); + shd_c_register_emitted(emitter, NULL, decl, t); + return; + } + + if (ass == AsOutput && emitter->compiler_config->specialization.execution_model == EmFragment) { + int location = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_annotation_value(shd_lookup_annotation(decl, "Location"))), false); + CTerm t = term_from_cvar(shd_fmt_string_irarena(emitter->arena, "gl_FragData[%d]", location)); + shd_c_register_emitted(emitter, NULL, decl, t); + return; + } + + decl_type = decl->payload.global_variable.type; + // we emit the global variable as a CVar, so we can refer to it's 'address' without explicit ptrs + emit_as = term_from_cvar(name); + if ((decl->payload.global_variable.address_space == AsPrivate) && emitter->config.dialect == CDialect_CUDA) { + if (emitter->use_private_globals) { + shd_c_register_emitted(emitter, NULL, decl, term_from_cvar(shd_format_string_arena(emitter->arena->arena, "__shady_private_globals->%s", name))); + // HACK + return; + } + emit_as = term_from_cvar(shd_fmt_string_irarena(emitter->arena, "__shady_thread_local_access(%s)", name)); + if (init) + init = shd_fmt_string_irarena(emitter->arena, "__shady_replicate_thread_local(%s)", init); + shd_c_register_emitted(emitter, NULL, decl, emit_as); + } + shd_c_register_emitted(emitter, NULL, decl, emit_as); + + AddressSpace as = decl->payload.global_variable.address_space; + shd_c_emit_global_variable_definition(emitter, as, decl_center, decl_type, false, init); + return; + } + case Function_TAG: { + emit_as = term_from_cvalue(name); + shd_c_register_emitted(emitter, NULL, decl, emit_as); + String head = shd_c_emit_fn_head(emitter, decl->type, name, decl); + const Node* body = decl->payload.fun.body; + if (body) { + FnEmitter fn = { + .cfg = build_fn_cfg(decl), + .emitted_terms = shd_new_dict(Node*, CTerm, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + fn.scheduler = shd_new_scheduler(fn.cfg); + fn.instruction_printers = calloc(sizeof(Printer*), fn.cfg->size); + // for (size_t i = 0; i < fn.cfg->size; i++) + // fn.instruction_printers[i] = open_growy_as_printer(new_growy()); + + for (size_t i = 0; i < decl->payload.fun.params.count; i++) { + String param_name; + String variable_name = shd_get_value_name_unsafe(decl->payload.fun.params.nodes[i]); + param_name = shd_fmt_string_irarena(emitter->arena, "%s_%d", shd_c_legalize_identifier(emitter, variable_name), decl->payload.fun.params.nodes[i]->id); + shd_c_register_emitted(emitter, &fn, decl->payload.fun.params.nodes[i], term_from_cvalue(param_name)); + } + + String fn_body = shd_c_emit_body(emitter, &fn, decl); + if (emitter->config.dialect == CDialect_ISPC) { + // ISPC hack: This compiler (like seemingly all LLVM-based compilers) has broken handling of the execution mask - it fails to generated masked stores for the entry BB of a function that may be called non-uniformingly + // therefore we must tell ISPC to please, pretty please, mask everything by branching on what the mask should be + fn_body = shd_format_string_arena(emitter->arena->arena, "if ((lanemask() >> programIndex) & 1u) { %s}", fn_body); + // I hate everything about this too. + } else if (emitter->config.dialect == CDialect_CUDA) { + if (shd_lookup_annotation(decl, "EntryPoint")) { + // fn_body = format_string_arena(emitter->arena->arena, "\n__shady_entry_point_init();%s", fn_body); + if (emitter->use_private_globals) { + fn_body = shd_format_string_arena(emitter->arena->arena, "\n__shady_PrivateGlobals __shady_private_globals_alloc;\n __shady_PrivateGlobals* __shady_private_globals = &__shady_private_globals_alloc;\n%s", fn_body); + } + fn_body = shd_format_string_arena(emitter->arena->arena, "\n__shady_prepare_builtins();%s", fn_body); + } + } + shd_print(emitter->fn_defs, "\n%s { ", head); + shd_printer_indent(emitter->fn_defs); + shd_print(emitter->fn_defs, " %s", fn_body); + shd_printer_deindent(emitter->fn_defs); + shd_print(emitter->fn_defs, "\n}"); + + shd_destroy_scheduler(fn.scheduler); + shd_destroy_cfg(fn.cfg); + shd_destroy_dict(fn.emitted_terms); + free(fn.instruction_printers); + } + + shd_print(emitter->fn_decls, "\n%s;", head); + return; + } + case Constant_TAG: { + emit_as = term_from_cvalue(name); + shd_c_register_emitted(emitter, NULL, decl, emit_as); + + String init = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, NULL, decl->payload.constant.value)); + shd_c_emit_global_variable_definition(emitter, AsGlobal, decl_center, decl->type, true, init); + return; + } + case NominalType_TAG: { + CType emitted = name; + shd_c_register_emitted_type(emitter, decl, emitted); + switch (emitter->config.dialect) { + case CDialect_ISPC: + default: shd_print(emitter->type_decls, "\ntypedef %s;", shd_c_emit_type(emitter, decl->payload.nom_type.body, emitted)); break; + case CDialect_GLSL: shd_c_emit_nominal_type_body(emitter, shd_format_string_arena(emitter->arena->arena, "struct %s /* nominal */", emitted), decl->payload.nom_type.body); break; + } + return; + } + default: shd_error("not a decl"); + } +} + +static Module* run_backend_specific_passes(const CompilerConfig* config, CEmitterConfig* econfig, Module* initial_mod) { + IrArena* initial_arena = initial_mod->arena; + Module** pmod = &initial_mod; + + // C lacks a nice way to express constants that can be used in type definitions afterwards, so let's just inline them all. + RUN_PASS(shd_pass_eliminate_constants) + if (econfig->dialect == CDialect_ISPC) { + RUN_PASS(shd_pass_lower_workgroups) + RUN_PASS(shd_pass_lower_inclusive_scan) + } + if (econfig->dialect != CDialect_GLSL) { + RUN_PASS(shd_pass_lower_vec_arr) + } + return *pmod; +} + +static String collect_private_globals_in_struct(Emitter* emitter, Module* m) { + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + + shd_print(p, "typedef struct __shady_PrivateGlobals {\n"); + Nodes decls = shd_module_get_declarations(m); + size_t count = 0; + for (size_t i = 0; i < decls.count; i++) { + const Node* decl = decls.nodes[i]; + if (decl->tag != GlobalVariable_TAG) + continue; + AddressSpace as = decl->payload.global_variable.address_space; + if (as != AsPrivate) + continue; + shd_print(p, "%s;\n", shd_c_emit_type(emitter, decl->payload.global_variable.type, decl->payload.global_variable.name)); + count++; + } + shd_print(p, "} __shady_PrivateGlobals;\n"); + + if (count == 0) { + shd_destroy_printer(p); + shd_destroy_growy(g); + return NULL; + } + return shd_printer_growy_unwrap(p); +} + +CEmitterConfig shd_default_c_emitter_config(void) { + return (CEmitterConfig) { + .glsl_version = 420, + }; +} + +void shd_emit_c(const CompilerConfig* compiler_config, CEmitterConfig config, Module* mod, size_t* output_size, char** output, Module** new_mod) { + IrArena* initial_arena = shd_module_get_arena(mod); + mod = run_backend_specific_passes(compiler_config, &config, mod); + IrArena* arena = shd_module_get_arena(mod); + + Growy* type_decls_g = shd_new_growy(); + Growy* fn_decls_g = shd_new_growy(); + Growy* fn_defs_g = shd_new_growy(); + + Emitter emitter = { + .compiler_config = compiler_config, + .config = config, + .arena = arena, + .type_decls = shd_new_printer_from_growy(type_decls_g), + .fn_decls = shd_new_printer_from_growy(fn_decls_g), + .fn_defs = shd_new_printer_from_growy(fn_defs_g), + .emitted_terms = shd_new_dict(Node*, CTerm, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .emitted_types = shd_new_dict(Node*, String, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + + Growy* final = shd_new_growy(); + Printer* finalp = shd_new_printer_from_growy(final); + + shd_print(finalp, "/* file generated by shady */\n"); + + switch (emitter.config.dialect) { + case CDialect_ISPC: { + shd_print(emitter.fn_defs, shady_ispc_runtime_src); + break; + } + case CDialect_C11: + shd_print(finalp, "\n#include "); + shd_print(finalp, "\n#include "); + shd_print(finalp, "\n#include "); + shd_print(finalp, "\n#include "); + shd_print(finalp, "\n#include "); + break; + case CDialect_GLSL: + shd_print(finalp, "#version %d\n", emitter.config.glsl_version); + if (emitter.need_64b_ext) + shd_print(finalp, "#extension GL_ARB_gpu_shader_int64: require\n"); + shd_print(finalp, "#define ubyte uint\n"); + shd_print(finalp, "#define uchar uint\n"); + shd_print(finalp, "#define ulong uint\n"); + if (emitter.config.glsl_version <= 120) + shd_print(finalp, shady_glsl_runtime_120_src); + break; + case CDialect_CUDA: { + size_t total_workgroup_size = emitter.arena->config.specializations.workgroup_size[0]; + total_workgroup_size *= emitter.arena->config.specializations.workgroup_size[1]; + total_workgroup_size *= emitter.arena->config.specializations.workgroup_size[2]; + + shd_print(finalp, "#define __shady_workgroup_size %d\n", total_workgroup_size); + shd_print(finalp, "#define __shady_replicate_thread_local(v) { "); + for (size_t i = 0; i < total_workgroup_size; i++) + shd_print(finalp, "v, "); + shd_print(finalp, "}\n"); + shd_print(finalp, shady_cuda_prelude_src); + + shd_print(emitter.type_decls, "\ntypedef %s;\n", shd_c_emit_type(&emitter, arr_type(arena, (ArrType) { + .size = shd_int32_literal(arena, 3), + .element_type = shd_uint32_type(arena) + }), "uvec3")); + shd_print(emitter.fn_defs, shady_cuda_runtime_src); + + String private_globals = collect_private_globals_in_struct(&emitter, mod); + if (private_globals) { + emitter.use_private_globals = true; + shd_print(emitter.type_decls, private_globals); + free((void*) private_globals); + } + break; + } + default: break; + } + + Nodes decls = shd_module_get_declarations(mod); + for (size_t i = 0; i < decls.count; i++) + shd_c_emit_decl(&emitter, decls.nodes[i]); + + shd_print(finalp, "\n/* types: */\n"); + shd_growy_append_bytes(final, shd_growy_size(type_decls_g), shd_growy_data(type_decls_g)); + + shd_print(finalp, "\n/* declarations: */\n"); + shd_growy_append_bytes(final, shd_growy_size(fn_decls_g), shd_growy_data(fn_decls_g)); + + shd_print(finalp, "\n/* definitions: */\n"); + shd_growy_append_bytes(final, shd_growy_size(fn_defs_g), shd_growy_data(fn_defs_g)); + + shd_print(finalp, "\n"); + shd_print(finalp, "\n"); + shd_print(finalp, "\n"); + shd_growy_append_bytes(final, 1, "\0"); + + shd_destroy_printer(emitter.type_decls); + shd_destroy_printer(emitter.fn_decls); + shd_destroy_printer(emitter.fn_defs); + + shd_destroy_growy(type_decls_g); + shd_destroy_growy(fn_decls_g); + shd_destroy_growy(fn_defs_g); + + shd_destroy_dict(emitter.emitted_types); + shd_destroy_dict(emitter.emitted_terms); + + *output_size = shd_growy_size(final) - 1; + *output = shd_growy_deconstruct(final); + shd_destroy_printer(finalp); + + if (new_mod) + *new_mod = mod; + else if (initial_arena != arena) + shd_destroy_ir_arena(arena); +} diff --git a/src/backend/c/emit_c.h b/src/backend/c/emit_c.h new file mode 100644 index 000000000..6075ae334 --- /dev/null +++ b/src/backend/c/emit_c.h @@ -0,0 +1,95 @@ +#ifndef SHADY_EMIT_C +#define SHADY_EMIT_C + +#include "shady/ir.h" +#include "shady/ir/builtin.h" +#include "shady/be/c.h" + +#include "growy.h" +#include "arena.h" +#include "printer.h" + +typedef struct CFG_ CFG; +typedef struct Scheduler_ Scheduler; + +/// SSA-like things, you can read them +typedef String CValue; +/// non-SSA like things, they represent addresses +typedef String CAddr; + +typedef String CType; + +typedef struct { + CValue value; + CAddr var; +} CTerm; + +#define term_from_cvalue(t) (CTerm) { .value = t } +#define term_from_cvar(t) (CTerm) { .var = t } +#define empty_term() (CTerm) { 0 } +#define is_term_empty(t) (!t.var && !t.value) + +typedef Strings Phis; + +typedef struct CompilerConfig_ CompilerConfig; + +typedef struct { + const CompilerConfig* compiler_config; + CEmitterConfig config; + IrArena* arena; + Printer *type_decls, *fn_decls, *fn_defs; + struct { + Phis selection, loop_continue, loop_break; + } phis; + + struct Dict* emitted_terms; + struct Dict* emitted_types; + + bool use_private_globals; + Printer* entrypoint_prelude; + + bool need_64b_ext; +} Emitter; + +typedef struct { + struct Dict* emitted_terms; + Printer** instruction_printers; + CFG* cfg; + Scheduler* scheduler; +} FnEmitter; + +void shd_c_register_emitted(Emitter* emitter, FnEmitter* fn, const Node* node, CTerm as); +void shd_c_register_emitted_type(Emitter* emitter, const Node* node, String as); + +CTerm* shd_c_lookup_existing_term(Emitter* emitter, FnEmitter* fn, const Node* node); +CType* shd_c_lookup_existing_type(Emitter* emitter, const Type* node); + +String shd_c_legalize_identifier(Emitter* e, String src); +CValue shd_c_to_ssa(Emitter* e, CTerm term); +CAddr shd_c_deref(Emitter* e, CTerm term); +void shd_c_emit_pack_code(Printer* p, Strings src, String dst); +void shd_c_emit_unpack_code(Printer* p, String src, Strings dst); +CTerm shd_c_bind_intermediary_result(Emitter* emitter, Printer* p, const Type* t, CTerm term); +void shd_c_emit_variable_declaration(Emitter* emitter, Printer* block_printer, const Type* t, String variable_name, bool mut, const CTerm* initializer); +CTerm shd_ispc_varying_ptr_helper(Emitter* emitter, Printer* block_printer, const Type* ptr_type, CTerm term); + +void shd_c_emit_decl(Emitter* emitter, const Node* decl); +void shd_c_emit_global_variable_definition(Emitter* emitter, AddressSpace as, String name, const Type* type, bool constant, String init); +CTerm shd_c_emit_builtin(Emitter* emitter, Builtin b); + +CType shd_c_emit_type(Emitter* emitter, const Type* type, const char* center); +String shd_c_get_record_field_name(const Type* t, size_t i); +String shd_c_emit_fn_head(Emitter* emitter, const Node* fn_type, String center, const Node* fn); +void shd_c_emit_nominal_type_body(Emitter* emitter, String name, const Type* type); + +CTerm shd_c_emit_value(Emitter* emitter, FnEmitter* fn_builder, const Node* node); +CTerm shd_c_emit_mem(Emitter* e, FnEmitter* b, const Node* mem); +String shd_c_emit_body(Emitter* emitter, FnEmitter* fn, const Node* abs); + +#define free_tmp_str(s) free((char*) (s)) + +inline static bool is_glsl_scalar_type(const Type* t) { + return t->tag == Bool_TAG || t->tag == Int_TAG || t->tag == Float_TAG; +} + +#endif diff --git a/src/shady/emit/c/emit_c_builtins.c b/src/backend/c/emit_c_builtin.c similarity index 72% rename from src/shady/emit/c/emit_c_builtins.c rename to src/backend/c/emit_c_builtin.c index c94532c87..d1fd395be 100644 --- a/src/shady/emit/c/emit_c_builtins.c +++ b/src/backend/c/emit_c_builtin.c @@ -15,16 +15,17 @@ static String glsl_builtins[BuiltinsCount] = { [BuiltinNumWorkgroups] = "gl_NumWorkGroups", [BuiltinWorkgroupSize] = "gl_WorkGroupSize", [BuiltinGlobalInvocationId] = "gl_GlobalInvocationID", + [BuiltinPosition] = "gl_Position", }; -CTerm emit_c_builtin(Emitter* emitter, Builtin b) { +CTerm shd_c_emit_builtin(Emitter* emitter, Builtin b) { String name = NULL; switch(emitter->config.dialect) { - case ISPC: name = ispc_builtins[b]; break; - case GLSL: name = glsl_builtins[b]; break; + case CDialect_ISPC: name = ispc_builtins[b]; break; + case CDialect_GLSL: name = glsl_builtins[b]; break; default: break; } if (name) return term_from_cvar(name); - return term_from_cvar(get_builtin_name(b)); + return term_from_cvar(shd_get_builtin_name(b)); } diff --git a/src/backend/c/emit_c_control_flow.c b/src/backend/c/emit_c_control_flow.c new file mode 100644 index 000000000..1aa947473 --- /dev/null +++ b/src/backend/c/emit_c_control_flow.c @@ -0,0 +1,247 @@ +#include "emit_c.h" + +#include "../shady/analysis/cfg.h" + +#include "log.h" +#include "portability.h" + +#include + +static void emit_terminator(Emitter* emitter, FnEmitter* fn, Printer* block_printer, const Node* terminator); + +String shd_c_emit_body(Emitter* emitter, FnEmitter* fn, const Node* abs) { + assert(abs && is_abstraction(abs)); + const Node* body = get_abstraction_body(abs); + assert(body && is_terminator(body)); + CFNode* cf_node = shd_cfg_lookup(fn->cfg, abs); + Printer* p = shd_new_printer_from_growy(shd_new_growy()); + fn->instruction_printers[cf_node->rpo_index] = p; + //indent(p); + + emit_terminator(emitter, fn, p, body); + + /*if (bbs && bbs->count > 0) { + assert(emitter->config.dialect != CDialect_GLSL); + error("TODO"); + }*/ + + //deindent(p); + // shd_print(p, "\n"); + + fn->instruction_printers[cf_node->rpo_index] = NULL; + String s2 = shd_printer_growy_unwrap(p); + String s = shd_string(emitter->arena, s2); + free((void*)s2); + return s; +} + +static Strings emit_variable_declarations(Emitter* emitter, FnEmitter* fn, Printer* p, String given_name, Strings* given_names, Nodes types, bool mut, const Nodes* init_values) { + if (given_names) + assert(given_names->count == types.count); + if (init_values) + assert(init_values->count == types.count); + LARRAY(String, names, types.count); + for (size_t i = 0; i < types.count; i++) { + String name = given_names ? given_names->strings[i] : given_name; + assert(name); + names[i] = shd_make_unique_name(emitter->arena, name); + if (init_values) { + CTerm initializer = shd_c_emit_value(emitter, fn, init_values->nodes[i]); + shd_c_emit_variable_declaration(emitter, p, types.nodes[i], names[i], mut, &initializer); + } else + shd_c_emit_variable_declaration(emitter, p, types.nodes[i], names[i], mut, NULL); + } + return shd_strings(emitter->arena, types.count, names); +} + +static void emit_if(Emitter* emitter, FnEmitter* fn, Printer* p, If if_) { + Emitter sub_emiter = *emitter; + Strings ephis = emit_variable_declarations(emitter, fn, p, "if_phi", NULL, if_.yield_types, true, NULL); + sub_emiter.phis.selection = ephis; + + assert(get_abstraction_params(if_.if_true).count == 0); + String true_body = shd_c_emit_body(&sub_emiter, fn, if_.if_true); + String false_body = if_.if_false ? shd_c_emit_body(&sub_emiter, fn, if_.if_false) : NULL; + String tail = shd_c_emit_body(emitter, fn, if_.tail); + CValue condition = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, if_.condition)); + shd_print(p, "\nif (%s) { ", condition); + shd_printer_indent(p); + shd_print(p, "%s", true_body); + shd_printer_deindent(p); + shd_print(p, "\n}"); + if (if_.if_false) { + assert(get_abstraction_params(if_.if_false).count == 0); + shd_print(p, " else {"); + shd_printer_indent(p); + shd_print(p, "%s", false_body); + shd_printer_deindent(p); + shd_print(p, "\n}"); + } + + Nodes results = get_abstraction_params(if_.tail); + for (size_t i = 0; i < ephis.count; i++) { + shd_c_register_emitted(emitter, fn, results.nodes[i], term_from_cvalue(ephis.strings[i])); + } + + shd_print(p, "%s", tail); +} + +static void emit_match(Emitter* emitter, FnEmitter* fn, Printer* p, Match match) { + Emitter sub_emiter = *emitter; + Strings ephis = emit_variable_declarations(emitter, fn, p, "match_phi", NULL, match.yield_types, true, NULL); + sub_emiter.phis.selection = ephis; + + // Of course, the sensible thing to do here would be to emit a switch statement. + // ... + // Except that doesn't work, because C/GLSL have a baffling design wart: the `break` statement is overloaded, + // meaning that if you enter a switch statement, which should be orthogonal to loops, you can't actually break + // out of the outer loop anymore. Brilliant. So we do this terrible if-chain instead. + // + // We could do GOTO for C, but at the cost of arguably even more noise in the output, and two different codepaths. + // I don't think it's quite worth it, just like it's not worth doing some data-flow based solution either. + + CValue inspectee = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, match.inspect)); + bool first = true; + LARRAY(CValue, literals, match.cases.count); + LARRAY(String, bodies, match.cases.count); + String default_case_body = shd_c_emit_body(&sub_emiter, fn, match.default_case); + String tail = shd_c_emit_body(emitter, fn, match.tail); + for (size_t i = 0; i < match.cases.count; i++) { + literals[i] = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, match.literals.nodes[i])); + bodies[i] = shd_c_emit_body(&sub_emiter, fn, match.cases.nodes[i]); + } + for (size_t i = 0; i < match.cases.count; i++) { + shd_print(p, "\n"); + if (!first) + shd_print(p, "else "); + shd_print(p, "if (%s == %s) { ", inspectee, literals[i]); + shd_printer_indent(p); + shd_print(p, "%s", bodies[i]); + shd_printer_deindent(p); + shd_print(p, "\n}"); + first = false; + } + if (match.default_case) { + shd_print(p, "\nelse { "); + shd_printer_indent(p); + shd_print(p, "%s", default_case_body); + shd_printer_deindent(p); + shd_print(p, "\n}"); + } + + Nodes results = get_abstraction_params(match.tail); + for (size_t i = 0; i < ephis.count; i++) { + shd_c_register_emitted(emitter, fn, results.nodes[i], term_from_cvalue(ephis.strings[i])); + } + + shd_print(p, "%s", tail); +} + +static void emit_loop(Emitter* emitter, FnEmitter* fn, Printer* p, Loop loop) { + Emitter sub_emiter = *emitter; + Nodes params = get_abstraction_params(loop.body); + Nodes variables = params; + LARRAY(String, arr, variables.count); + for (size_t i = 0; i < variables.count; i++) { + arr[i] = shd_get_value_name_unsafe(variables.nodes[i]); + if (!arr[i]) + arr[i] = shd_make_unique_name(emitter->arena, "phi"); + } + Strings param_names = shd_strings(emitter->arena, variables.count, arr); + Strings eparams = emit_variable_declarations(emitter, fn, p, NULL, ¶m_names, shd_get_param_types(emitter->arena, params), true, &loop.initial_args); + for (size_t i = 0; i < params.count; i++) + shd_c_register_emitted(&sub_emiter, fn, params.nodes[i], term_from_cvalue(eparams.strings[i])); + + sub_emiter.phis.loop_continue = eparams; + Strings ephis = emit_variable_declarations(emitter, fn, p, "loop_break_phi", NULL, loop.yield_types, true, NULL); + sub_emiter.phis.loop_break = ephis; + + String body = shd_c_emit_body(&sub_emiter, fn, loop.body); + String tail = shd_c_emit_body(emitter, fn, loop.tail); + shd_print(p, "\nwhile(true) { "); + shd_printer_indent(p); + shd_print(p, "%s", body); + shd_printer_deindent(p); + shd_print(p, "\n}"); + + Nodes results = get_abstraction_params(loop.tail); + for (size_t i = 0; i < ephis.count; i++) { + shd_c_register_emitted(emitter, fn, results.nodes[i], term_from_cvalue(ephis.strings[i])); + } + + shd_print(p, "%s", tail); +} + +static void emit_terminator(Emitter* emitter, FnEmitter* fn, Printer* block_printer, const Node* terminator) { + shd_c_emit_mem(emitter, fn, get_terminator_mem(terminator)); + switch (is_terminator(terminator)) { + case NotATerminator: assert(false); + case Join_TAG: shd_error("this must be lowered away!"); + case Jump_TAG: + case Branch_TAG: + case Switch_TAG: + case TailCall_TAG: shd_error("TODO"); + case If_TAG: return emit_if(emitter, fn, block_printer, terminator->payload.if_instr); + case Match_TAG: return emit_match(emitter, fn, block_printer, terminator->payload.match_instr); + case Loop_TAG: return emit_loop(emitter, fn, block_printer, terminator->payload.loop_instr); + case Control_TAG: shd_error("TODO") + case Terminator_Return_TAG: { + Nodes args = terminator->payload.fn_ret.args; + if (args.count == 0) { + shd_print(block_printer, "\nreturn;"); + } else if (args.count == 1) { + shd_print(block_printer, "\nreturn %s;", shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, args.nodes[0]))); + } else { + String packed = shd_make_unique_name(emitter->arena, "pack_return"); + LARRAY(CValue, values, args.count); + for (size_t i = 0; i < args.count; i++) + values[i] = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, args.nodes[i])); + shd_c_emit_pack_code(block_printer, shd_strings(emitter->arena, args.count, values), packed); + shd_print(block_printer, "\nreturn %s;", packed); + } + break; + } + case MergeSelection_TAG: { + Nodes args = terminator->payload.merge_selection.args; + Phis phis = emitter->phis.selection; + assert(phis.count == args.count); + for (size_t i = 0; i < phis.count; i++) + shd_print(block_printer, "\n%s = %s;", phis.strings[i], shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, args.nodes[i]))); + + break; + } + case MergeContinue_TAG: { + Nodes args = terminator->payload.merge_continue.args; + Phis phis = emitter->phis.loop_continue; + assert(phis.count == args.count); + for (size_t i = 0; i < phis.count; i++) + shd_print(block_printer, "\n%s = %s;", phis.strings[i], shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, args.nodes[i]))); + shd_print(block_printer, "\ncontinue;"); + break; + } + case MergeBreak_TAG: { + Nodes args = terminator->payload.merge_break.args; + Phis phis = emitter->phis.loop_break; + assert(phis.count == args.count); + for (size_t i = 0; i < phis.count; i++) + shd_print(block_printer, "\n%s = %s;", phis.strings[i], shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, args.nodes[i]))); + shd_print(block_printer, "\nbreak;"); + break; + } + case Terminator_Unreachable_TAG: { + switch (emitter->config.dialect) { + case CDialect_CUDA: + case CDialect_C11: + shd_print(block_printer, "\n__builtin_unreachable();"); + break; + case CDialect_ISPC: + shd_print(block_printer, "\nassert(false);"); + break; + case CDialect_GLSL: + shd_print(block_printer, "\n//unreachable"); + break; + } + break; + } + } +} diff --git a/src/backend/c/emit_c_type.c b/src/backend/c/emit_c_type.c new file mode 100644 index 000000000..b8cef8b67 --- /dev/null +++ b/src/backend/c/emit_c_type.c @@ -0,0 +1,328 @@ +#include "emit_c.h" + +#include "dict.h" +#include "log.h" +#include "util.h" + +#include "../shady/ir_private.h" + +#include +#include +#include + +#pragma GCC diagnostic error "-Wswitch" + +String shd_c_get_record_field_name(const Type* t, size_t i) { + assert(t->tag == RecordType_TAG); + RecordType r = t->payload.record_type; + assert(i < r.members.count); + if (i >= r.names.count) + return shd_fmt_string_irarena(t->arena, "_%d", i); + else + return r.names.strings[i]; +} + +void shd_c_emit_nominal_type_body(Emitter* emitter, String name, const Type* type) { + assert(type->tag == RecordType_TAG); + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + + shd_print(p, "\n%s {", name); + shd_printer_indent(p); + for (size_t i = 0; i < type->payload.record_type.members.count; i++) { + String member_identifier = shd_c_get_record_field_name(type, i); + shd_print(p, "\n%s;", shd_c_emit_type(emitter, type->payload.record_type.members.nodes[i], member_identifier)); + } + shd_printer_deindent(p); + shd_print(p, "\n};\n"); + shd_growy_append_bytes(g, 1, (char[]) { '\0' }); + + shd_print(emitter->type_decls, shd_growy_data(g)); + shd_destroy_growy(g); + shd_destroy_printer(p); +} + +String shd_c_emit_fn_head(Emitter* emitter, const Node* fn_type, String center, const Node* fn) { + assert(fn_type->tag == FnType_TAG); + assert(!fn || fn->type == fn_type); + Nodes codom = fn_type->payload.fn_type.return_types; + + const Node* entry_point = fn ? shd_lookup_annotation(fn, "EntryPoint") : NULL; + + Growy* paramg = shd_new_growy(); + Printer* paramp = shd_new_printer_from_growy(paramg); + Nodes dom = fn_type->payload.fn_type.param_types; + if (dom.count == 0 && emitter->config.dialect == CDialect_C11) + shd_print(paramp, "void"); + else if (fn) { + Nodes params = fn->payload.fun.params; + assert(params.count == dom.count); + if (emitter->use_private_globals && !entry_point) { + shd_print(paramp, "__shady_PrivateGlobals* __shady_private_globals"); + if (params.count > 0) + shd_print(paramp, ", "); + } + for (size_t i = 0; i < dom.count; i++) { + String param_name; + String variable_name = shd_get_value_name_unsafe(fn->payload.fun.params.nodes[i]); + param_name = shd_fmt_string_irarena(emitter->arena, "%s_%d", shd_c_legalize_identifier(emitter, variable_name), fn->payload.fun.params.nodes[i]->id); + shd_print(paramp, shd_c_emit_type(emitter, params.nodes[i]->type, param_name)); + if (i + 1 < dom.count) { + shd_print(paramp, ", "); + } + } + } else { + if (emitter->use_private_globals) { + shd_print(paramp, "__shady_PrivateGlobals*"); + if (dom.count > 0) + shd_print(paramp, ", "); + } + for (size_t i = 0; i < dom.count; i++) { + shd_print(paramp, shd_c_emit_type(emitter, dom.nodes[i], "")); + if (i + 1 < dom.count) { + shd_print(paramp, ", "); + } + } + } + shd_growy_append_bytes(paramg, 1, (char[]) { 0 }); + const char* parameters = shd_printer_growy_unwrap(paramp); + switch (emitter->config.dialect) { + default: + center = shd_format_string_arena(emitter->arena->arena, "(%s)(%s)", center, parameters); + break; + case CDialect_GLSL: + // GLSL does not accept functions declared like void (foo)(int); + // it also does not support higher-order functions and/or function pointers, so we drop the parentheses + center = shd_format_string_arena(emitter->arena->arena, "%s(%s)", center, parameters); + break; + } + free_tmp_str(parameters); + + String c_decl = shd_c_emit_type(emitter, shd_maybe_multiple_return(emitter->arena, codom), center); + if (entry_point) { + switch (emitter->config.dialect) { + case CDialect_C11: + break; + case CDialect_GLSL: + break; + case CDialect_ISPC: + c_decl = shd_format_string_arena(emitter->arena->arena, "export %s", c_decl); + break; + case CDialect_CUDA: + c_decl = shd_format_string_arena(emitter->arena->arena, "extern \"C\" __global__ %s", c_decl); + break; + } + } else if (emitter->config.dialect == CDialect_CUDA) { + c_decl = shd_format_string_arena(emitter->arena->arena, "__device__ %s", c_decl); + } + + return c_decl; +} + +String shd_c_emit_type(Emitter* emitter, const Type* type, const char* center) { + if (center == NULL) + center = ""; + + String emitted = NULL; + CType* found = shd_c_lookup_existing_type(emitter, type); + if (found) { + emitted = *found; + goto type_goes_on_left; + } + + switch (is_type(type)) { + case NotAType: assert(false); break; + case LamType_TAG: + case BBType_TAG: shd_error("these types do not exist in C"); + case MaskType_TAG: shd_error("should be lowered away"); + case Type_SampledImageType_TAG: + case Type_SamplerType_TAG: + case Type_ImageType_TAG: + case JoinPointType_TAG: shd_error("TODO") + case NoRet_TAG: + case Bool_TAG: emitted = "bool"; break; + case Int_TAG: { + bool sign = type->payload.int_type.is_signed; + switch (emitter->config.dialect) { + case CDialect_ISPC: { + const char* ispc_int_types[4][2] = { + { "uint8" , "int8" }, + { "uint16", "int16" }, + { "uint32", "int32" }, + { "uint64", "int64" }, + }; + emitted = ispc_int_types[type->payload.int_type.width][sign]; + break; + } + case CDialect_CUDA: + case CDialect_C11: { + const char* c_classic_int_types[4][2] = { + { "unsigned char" , "char" }, + { "unsigned short", "short" }, + { "unsigned int" , "int" }, + { "unsigned long long" , "long long" }, + }; + const char* c_explicit_int_sizes[4][2] = { + { "uint8_t" , "int8_t" }, + { "uint16_t", "int16_t" }, + { "uint32_t", "int32_t" }, + { "uint64_t", "int64_t" }, + }; + emitted = (emitter->config.explicitly_sized_types ? c_explicit_int_sizes : c_classic_int_types)[type->payload.int_type.width][sign]; + break; + } + case CDialect_GLSL: + if (emitter->config.glsl_version <= 120) { + emitted = "int"; + break; + } + switch (type->payload.int_type.width) { + case IntTy8: shd_warn_print("vanilla GLSL does not support 8-bit integers\n"); + emitted = sign ? "byte" : "ubyte"; + break; + case IntTy16: shd_warn_print("vanilla GLSL does not support 16-bit integers\n"); + emitted = sign ? "short" : "ushort"; + break; + case IntTy32: emitted = sign ? "int" : "uint"; break; + case IntTy64: + emitter->need_64b_ext = true; + shd_warn_print("vanilla GLSL does not support 64-bit integers\n"); + emitted = sign ? "int64_t" : "uint64_t"; + break; + } + break; + } + break; + } + case Float_TAG: + switch (type->payload.float_type.width) { + case FloatTy16: + assert(false); + break; + case FloatTy32: + emitted = "float"; + break; + case FloatTy64: + emitted = "double"; + break; + } + break; + case Type_RecordType_TAG: { + //if (type->payload.record_type.members.count == 0) { + // emitted = "void"; + // break; + //} + + emitted = shd_make_unique_name(emitter->arena, "Record"); + String prefixed = shd_format_string_arena(emitter->arena->arena, "struct %s", emitted); + shd_c_emit_nominal_type_body(emitter, prefixed, type); + // C puts structs in their own namespace so we always need the prefix + if (emitter->config.dialect == CDialect_C11) + emitted = prefixed; + + break; + } + case Type_QualifiedType_TAG: + if (type->payload.qualified_type.type == unit_type(emitter->arena)) { + emitted = "void"; + break; + } + switch (emitter->config.dialect) { + default: + return shd_c_emit_type(emitter, type->payload.qualified_type.type, center); + case CDialect_ISPC: + if (type->payload.qualified_type.is_uniform) + return shd_c_emit_type(emitter, type->payload.qualified_type.type, shd_format_string_arena(emitter->arena->arena, "uniform %s", center)); + else + return shd_c_emit_type(emitter, type->payload.qualified_type.type, shd_format_string_arena(emitter->arena->arena, "varying %s", center)); + } + case Type_PtrType_TAG: { + CType t = shd_c_emit_type(emitter, type->payload.ptr_type.pointed_type, shd_format_string_arena(emitter->arena->arena, "* %s", center)); + // we always emit pointers to _uniform_ data, no exceptions + if (emitter->config.dialect == CDialect_ISPC) + t = shd_format_string_arena(emitter->arena->arena, "uniform %s", t); + return t; + } + case Type_FnType_TAG: { + return shd_c_emit_fn_head(emitter, type, center, NULL); + } + case Type_ArrType_TAG: { + emitted = shd_make_unique_name(emitter->arena, "Array"); + String prefixed = shd_format_string_arena(emitter->arena->arena, "struct %s", emitted); + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + + const Node* size = type->payload.arr_type.size; + if (!size && emitter->config.decay_unsized_arrays) + return shd_c_emit_type(emitter, type->payload.arr_type.element_type, center); + + shd_print(p, "\n%s {", prefixed); + shd_printer_indent(p); + String inner_decl_rhs; + if (size) + inner_decl_rhs = shd_format_string_arena(emitter->arena->arena, "arr[%zu]", shd_get_int_literal_value(*shd_resolve_to_int_literal(size), false)); + else + inner_decl_rhs = shd_format_string_arena(emitter->arena->arena, "arr[0]"); + shd_print(p, "\n%s;", shd_c_emit_type(emitter, type->payload.arr_type.element_type, inner_decl_rhs)); + shd_printer_deindent(p); + shd_print(p, "\n};\n"); + shd_growy_append_bytes(g, 1, (char[]) { '\0' }); + + String subdecl = shd_printer_growy_unwrap(p); + shd_print(emitter->type_decls, subdecl); + free_tmp_str(subdecl); + + // ditto from RecordType + switch (emitter->config.dialect) { + default: + emitted = prefixed; + break; + case CDialect_GLSL: + break; + } + break; + } + case Type_PackType_TAG: { + int width = type->payload.pack_type.width; + const Type* element_type = type->payload.pack_type.element_type; + switch (emitter->config.dialect) { + case CDialect_CUDA: shd_error("TODO") + case CDialect_GLSL: { + assert(is_glsl_scalar_type(element_type)); + assert(width > 1); + String base; + switch (element_type->tag) { + case Bool_TAG: base = "bvec"; break; + case Int_TAG: base = "uvec"; break; // TODO not every int is 32-bit + case Float_TAG: base = "vec"; break; + default: shd_error("not a valid GLSL vector type"); + } + emitted = shd_format_string_arena(emitter->arena->arena, "%s%d", base, width); + break; + } + case CDialect_ISPC: shd_error("Please lower to something else") + case CDialect_C11: { + emitted = shd_c_emit_type(emitter, element_type, NULL); + emitted = shd_format_string_arena(emitter->arena->arena, "__attribute__ ((vector_size (%d * sizeof(%s) ))) %s", width, emitted, emitted); + break; + } + } + break; + } + case Type_TypeDeclRef_TAG: { + shd_c_emit_decl(emitter, type->payload.type_decl_ref.decl); + emitted = *shd_c_lookup_existing_type(emitter, type->payload.type_decl_ref.decl); + goto type_goes_on_left; + } + } + assert(emitted != NULL); + shd_c_register_emitted_type(emitter, type, emitted); + + type_goes_on_left: + assert(emitted != NULL); + + if (strlen(center) > 0) + emitted = shd_format_string_arena(emitter->arena->arena, "%s %s", emitted, center); + + return emitted; +} diff --git a/src/backend/c/emit_c_value.c b/src/backend/c/emit_c_value.c new file mode 100644 index 000000000..ea44c22c1 --- /dev/null +++ b/src/backend/c/emit_c_value.c @@ -0,0 +1,1131 @@ +#include "emit_c.h" + +#include "portability.h" +#include "log.h" +#include "dict.h" +#include "util.h" + +#include "../shady/ir_private.h" +#include "../shady/analysis/scheduler.h" + +#include + +#include +#include +#include +#include +#include + +#pragma GCC diagnostic error "-Wswitch" + +static CTerm emit_instruction(Emitter* emitter, FnEmitter* fn, Printer* p, const Node* instruction); + +static enum { ObjectsList, StringLit, CharsLit } array_insides_helper(Emitter* e, FnEmitter* fn, Printer* p, Growy* g, const Node* t, Nodes c) { + if (t->tag == Int_TAG && t->payload.int_type.width == 8) { + uint8_t* tmp = malloc(sizeof(uint8_t) * c.count); + bool ends_zero = false; + for (size_t i = 0; i < c.count; i++) { + tmp[i] = shd_get_int_literal_value(*shd_resolve_to_int_literal(c.nodes[i]), false); + if (tmp[i] == 0) { + if (i == c.count - 1) + ends_zero = true; + } + } + bool is_stringy = ends_zero; + for (size_t i = 0; i < c.count; i++) { + // ignore the last char in a string + if (is_stringy && i == c.count - 1) + break; + if (isprint(tmp[i])) + shd_print(p, "%c", tmp[i]); + else + shd_print(p, "\\x%02x", tmp[i]); + } + free(tmp); + return is_stringy ? StringLit : CharsLit; + } else { + for (size_t i = 0; i < c.count; i++) { + shd_print(p, shd_c_to_ssa(e, shd_c_emit_value(e, fn, c.nodes[i]))); + if (i + 1 < c.count) + shd_print(p, ", "); + } + shd_growy_append_bytes(g, 1, "\0"); + return ObjectsList; + } +} + +static CTerm c_emit_value_(Emitter* emitter, FnEmitter* fn, Printer* p, const Node* value) { + if (is_instruction(value)) + return emit_instruction(emitter, fn, p, value); + + String emitted = NULL; + + switch (is_value(value)) { + case NotAValue: assert(false); + case Value_ConstrainedValue_TAG: + case Value_UntypedNumber_TAG: shd_error("lower me"); + case Param_TAG: shd_error("tried to emit a param: all params should be emitted by their binding abstraction !"); + default: { + assert(!is_instruction(value)); + shd_error("Unhandled value for code generation: %s", shd_get_node_tag_string(value->tag)); + } + case Value_IntLiteral_TAG: { + if (value->payload.int_literal.is_signed) + emitted = shd_format_string_arena(emitter->arena->arena, "%" PRIi64, value->payload.int_literal.value); + else + emitted = shd_format_string_arena(emitter->arena->arena, "%" PRIu64, value->payload.int_literal.value); + + bool is_long = value->payload.int_literal.width == IntTy64; + bool is_signed = value->payload.int_literal.is_signed; + if (emitter->config.dialect == CDialect_GLSL && emitter->config.glsl_version >= 130) { + if (!is_signed) + emitted = shd_format_string_arena(emitter->arena->arena, "%sU", emitted); + if (is_long) + emitted = shd_format_string_arena(emitter->arena->arena, "%sL", emitted); + } + + break; + } + case Value_FloatLiteral_TAG: { + uint64_t v = value->payload.float_literal.value; + switch (value->payload.float_literal.width) { + case FloatTy16: + assert(false); + case FloatTy32: { + float f; + memcpy(&f, &v, sizeof(uint32_t)); + double d = (double) f; + emitted = shd_format_string_arena(emitter->arena->arena, "%#.9gf", d); break; + } + case FloatTy64: { + double d; + memcpy(&d, &v, sizeof(uint64_t)); + emitted = shd_format_string_arena(emitter->arena->arena, "%.17g", d); break; + } + } + break; + } + case Value_True_TAG: return term_from_cvalue("true"); + case Value_False_TAG: return term_from_cvalue("false"); + case Value_Undef_TAG: { + if (emitter->config.dialect == CDialect_GLSL) + return shd_c_emit_value(emitter, fn, shd_get_default_value(emitter->arena, value->payload.undef.type)); + String name = shd_make_unique_name(emitter->arena, "undef"); + // c_emit_variable_declaration(emitter, block_printer, value->type, name, true, NULL); + shd_c_emit_global_variable_definition(emitter, AsGlobal, name, value->payload.undef.type, true, NULL); + emitted = name; + break; + } + case Value_NullPtr_TAG: return term_from_cvalue("NULL"); + case Value_Composite_TAG: { + const Type* type = value->payload.composite.type; + Nodes elements = value->payload.composite.contents; + + Growy* g = shd_new_growy(); + Printer* p2 = p; + Printer* p = shd_new_printer_from_growy(g); + + if (type->tag == ArrType_TAG) { + switch (array_insides_helper(emitter, fn, p, g, type, elements)) { + case ObjectsList: + emitted = shd_growy_data(g); + break; + case StringLit: + emitted = shd_format_string_arena(emitter->arena->arena, "\"%s\"", shd_growy_data(g)); + break; + case CharsLit: + emitted = shd_format_string_arena(emitter->arena->arena, "'%s'", shd_growy_data(g)); + break; + } + } else { + for (size_t i = 0; i < elements.count; i++) { + shd_print(p, "%s", shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, elements.nodes[i]))); + if (i + 1 < elements.count) + shd_print(p, ", "); + } + emitted = shd_growy_data(g); + } + shd_growy_append_bytes(g, 1, "\0"); + + switch (emitter->config.dialect) { + no_compound_literals: + case CDialect_ISPC: { + // arrays need double the brackets + if (type->tag == ArrType_TAG) + emitted = shd_format_string_arena(emitter->arena->arena, "{ %s }", emitted); + + if (p2) { + String tmp = shd_make_unique_name(emitter->arena, "composite"); + shd_print(p2, "\n%s = { %s };", shd_c_emit_type(emitter, value->type, tmp), emitted); + emitted = tmp; + } else { + // this requires us to end up in the initialisation side of a declaration + emitted = shd_format_string_arena(emitter->arena->arena, "{ %s }", emitted); + } + break; + } + case CDialect_CUDA: + case CDialect_C11: + // If we're C89 (ew) + if (!emitter->config.allow_compound_literals) + goto no_compound_literals; + emitted = shd_format_string_arena(emitter->arena->arena, "((%s) { %s })", shd_c_emit_type(emitter, value->type, NULL), emitted); + break; + case CDialect_GLSL: + if (type->tag != PackType_TAG) + goto no_compound_literals; + // GLSL doesn't have compound literals, but it does have constructor syntax for vectors + emitted = shd_format_string_arena(emitter->arena->arena, "%s(%s)", shd_c_emit_type(emitter, value->type, NULL), emitted); + break; + } + + shd_destroy_growy(g); + shd_destroy_printer(p); + break; + } + case Value_Fill_TAG: shd_error("lower me") + case Value_StringLiteral_TAG: { + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + + String str = value->payload.string_lit.string; + size_t len = strlen(str); + for (size_t i = 0; i < len; i++) { + char c = str[i]; + switch (c) { + case '\n': shd_print(p, "\\n"); + break; + default: + shd_growy_append_bytes(g, 1, &c); + } + } + shd_growy_append_bytes(g, 1, "\0"); + + emitted = shd_format_string_arena(emitter->arena->arena, "\"%s\"", shd_growy_data(g)); + shd_destroy_growy(g); + shd_destroy_printer(p); + break; + } + case Value_FnAddr_TAG: { + emitted = shd_c_legalize_identifier(emitter, get_declaration_name(value->payload.fn_addr.fn)); + emitted = shd_format_string_arena(emitter->arena->arena, "(&%s)", emitted); + break; + } + case Value_RefDecl_TAG: { + const Node* decl = value->payload.ref_decl.decl; + shd_c_emit_decl(emitter, decl); + + if (emitter->config.dialect == CDialect_ISPC && decl->tag == GlobalVariable_TAG) { + if (!shd_is_addr_space_uniform(emitter->arena, decl->payload.global_variable.address_space) && !shd_is_decl_builtin( + decl)) { + assert(fn && "ISPC backend cannot statically refer to a varying variable"); + return shd_ispc_varying_ptr_helper(emitter, fn->instruction_printers[0], decl->type, *shd_c_lookup_existing_term(emitter, NULL, decl)); + } + } + + return *shd_c_lookup_existing_term(emitter, NULL, decl); + } + } + + assert(emitted); + return term_from_cvalue(emitted); +} + +CTerm shd_c_bind_intermediary_result(Emitter* emitter, Printer* p, const Type* t, CTerm term) { + if (is_term_empty(term)) + return term; + if (t == empty_multiple_return_type(emitter->arena)) { + shd_print(p, "%s;", shd_c_to_ssa(emitter, term)); + return empty_term(); + } + String bind_to = shd_make_unique_name(emitter->arena, ""); + shd_c_emit_variable_declaration(emitter, p, t, bind_to, false, &term); + return term_from_cvalue(bind_to); +} + +static const Type* get_first_op_scalar_type(Nodes ops) { + const Type* t = shd_first(ops)->type; + shd_deconstruct_qualified_type(&t); + shd_deconstruct_maybe_packed_type(&t); + return t; +} + +typedef enum { + OsInfix, OsPrefix, OsCall, +} OpStyle; + +typedef enum { + IsNone, // empty entry + IsMono, + IsPoly +} ISelMechanism; + +typedef struct { + ISelMechanism isel_mechanism; + OpStyle style; + String op; + String u_ops[4]; + String s_ops[4]; + String f_ops[3]; +} ISelTableEntry; + +static const ISelTableEntry isel_dummy = { IsNone }; + +static const ISelTableEntry isel_table[PRIMOPS_COUNT] = { + [add_op] = { IsMono, OsInfix, "+" }, + [sub_op] = { IsMono, OsInfix, "-" }, + [mul_op] = { IsMono, OsInfix, "*" }, + [div_op] = { IsMono, OsInfix, "/" }, + [mod_op] = { IsMono, OsInfix, "%" }, + [neg_op] = { IsMono, OsPrefix, "-" }, + [gt_op] = { IsMono, OsInfix, ">" }, + [gte_op] = { IsMono, OsInfix, ">=" }, + [lt_op] = { IsMono, OsInfix, "<" }, + [lte_op] = { IsMono, OsInfix, "<=" }, + [eq_op] = { IsMono, OsInfix, "==" }, + [neq_op] = { IsMono, OsInfix, "!=" }, + [and_op] = { IsMono, OsInfix, "&" }, + [or_op] = { IsMono, OsInfix, "|" }, + [xor_op] = { IsMono, OsInfix, "^" }, + [not_op] = { IsMono, OsPrefix, "!" }, + /*[rshift_arithm_op] = { IsMono, OsInfix, ">>" }, + [rshift_logical_op] = { IsMono, OsInfix, ">>" }, // TODO achieve desired right shift semantics through unsigned/signed casts + [lshift_op] = { IsMono, OsInfix, "<<" },*/ +}; + +static const ISelTableEntry isel_table_c[PRIMOPS_COUNT] = { + [abs_op] = { IsPoly, OsCall, .s_ops = { "abs", "abs", "abs", "llabs" }, .f_ops = {"fabsf", "fabsf", "fabs"}}, + + [sin_op] = { IsPoly, OsCall, .f_ops = {"sinf", "sinf", "sin"}}, + [cos_op] = { IsPoly, OsCall, .f_ops = {"cosf", "cosf", "cos"}}, + [floor_op] = { IsPoly, OsCall, .f_ops = {"floorf", "floorf", "floor"}}, + [ceil_op] = { IsPoly, OsCall, .f_ops = {"ceilf", "ceilf", "ceil"}}, + [round_op] = { IsPoly, OsCall, .f_ops = {"roundf", "roundf", "round"}}, + + [sqrt_op] = { IsPoly, OsCall, .f_ops = {"sqrtf", "sqrtf", "sqrt"}}, + [exp_op] = { IsPoly, OsCall, .f_ops = {"expf", "expf", "exp"}}, + [pow_op] = { IsPoly, OsCall, .f_ops = {"powf", "powf", "pow"}}, +}; + +static const ISelTableEntry isel_table_glsl[PRIMOPS_COUNT] = { + [abs_op] = { IsMono, OsCall, "abs" }, + + [sin_op] = { IsMono, OsCall, "sin" }, + [cos_op] = { IsMono, OsCall, "cos" }, + [floor_op] = { IsMono, OsCall, "floor" }, + [ceil_op] = { IsMono, OsCall, "ceil" }, + [round_op] = { IsMono, OsCall, "round" }, + + [sqrt_op] = { IsMono, OsCall, "sqrt" }, + [exp_op] = { IsMono, OsCall, "exp" }, + [pow_op] = { IsMono, OsCall, "pow" }, +}; + +static const ISelTableEntry isel_table_glsl_120[PRIMOPS_COUNT] = { + [mod_op] = { IsMono, OsCall, "mod" }, + + [and_op] = { IsMono, OsCall, "and" }, + [ or_op] = { IsMono, OsCall, "or" }, + [xor_op] = { IsMono, OsCall, "xor" }, + [not_op] = { IsMono, OsCall, "not" }, +}; + +static const ISelTableEntry isel_table_ispc[PRIMOPS_COUNT] = { + [abs_op] = { IsMono, OsCall, "abs" }, + + [sin_op] = { IsMono, OsCall, "sin" }, + [cos_op] = { IsMono, OsCall, "cos" }, + [floor_op] = { IsMono, OsCall, "floor" }, + [ceil_op] = { IsMono, OsCall, "ceil" }, + [round_op] = { IsMono, OsCall, "round" }, + + [sqrt_op] = { IsMono, OsCall, "sqrt" }, + [exp_op] = { IsMono, OsCall, "exp" }, + [pow_op] = { IsMono, OsCall, "pow" }, +}; + +static CTerm emit_using_entry(Emitter* emitter, FnEmitter* fn, Printer* p, const ISelTableEntry* entry, Nodes operands) { + String operator_str = NULL; + switch (entry->isel_mechanism) { + case IsNone: return empty_term(); + case IsMono: operator_str = entry->op; break; + case IsPoly: { + const Type* t = get_first_op_scalar_type(operands); + if (t->tag == Float_TAG) + operator_str = entry->f_ops[t->payload.float_type.width]; + else if (t->tag == Int_TAG && t->payload.int_type.is_signed) + operator_str = entry->s_ops[t->payload.int_type.width]; + else if (t->tag == Int_TAG) + operator_str = entry->u_ops[t->payload.int_type.width]; + break; + } + } + + if (!operator_str) { + shd_log_fmt(ERROR, "emit_c: Missing or unsupported operands for this entry"); + return empty_term(); + } + + switch (entry->style) { + case OsInfix: { + CTerm a = shd_c_emit_value(emitter, fn, operands.nodes[0]); + CTerm b = shd_c_emit_value(emitter, fn, operands.nodes[1]); + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "%s %s %s", shd_c_to_ssa(emitter, a), operator_str, shd_c_to_ssa(emitter, b))); + } + case OsPrefix: { + CTerm operand = shd_c_emit_value(emitter, fn, operands.nodes[0]); + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "%s%s", operator_str, shd_c_to_ssa(emitter, operand))); + } + case OsCall: { + LARRAY(CTerm, cops, operands.count); + for (size_t i = 0; i < operands.count; i++) + cops[i] = shd_c_emit_value(emitter, fn, operands.nodes[i]); + if (operands.count == 1) + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "%s(%s)", operator_str, shd_c_to_ssa(emitter, cops[0]))); + else { + Growy* g = shd_new_growy(); + shd_growy_append_string(g, operator_str); + shd_growy_append_string_literal(g, "("); + for (size_t i = 0; i < operands.count; i++) { + shd_growy_append_string(g, shd_c_to_ssa(emitter, cops[i])); + if (i + 1 < operands.count) + shd_growy_append_string_literal(g, ", "); + } + shd_growy_append_string_literal(g, ")"); + return term_from_cvalue(shd_growy_deconstruct(g)); + } + break; + } + } + + SHADY_UNREACHABLE; +} + +static const ISelTableEntry* lookup_entry(Emitter* emitter, Op op) { + const ISelTableEntry* isel_entry = &isel_dummy; + + switch (emitter->config.dialect) { + case CDialect_CUDA: /* TODO: do better than that */ + case CDialect_C11: isel_entry = &isel_table_c[op]; break; + case CDialect_GLSL: isel_entry = &isel_table_glsl[op]; break; + case CDialect_ISPC: isel_entry = &isel_table_ispc[op]; break; + } + + if (emitter->config.dialect == CDialect_GLSL && emitter->config.glsl_version <= 120) + isel_entry = &isel_table_glsl_120[op]; + + if (isel_entry->isel_mechanism == IsNone) + isel_entry = &isel_table[op]; + return isel_entry; +} + +static String index_into_array(Emitter* emitter, const Type* arr_type, CTerm expr, CTerm index) { + IrArena* arena = emitter->arena; + + String index2 = emitter->config.dialect == CDialect_GLSL ? shd_format_string_arena(arena->arena, "int(%s)", shd_c_to_ssa(emitter, index)) : shd_c_to_ssa(emitter, index); + if (emitter->config.decay_unsized_arrays && !arr_type->payload.arr_type.size) + return shd_format_string_arena(arena->arena, "((&%s)[%s])", shd_c_deref(emitter, expr), index2); + else + return shd_format_string_arena(arena->arena, "(%s.arr[%s])", shd_c_deref(emitter, expr), index2); +} + +static CTerm broadcast_first(Emitter* emitter, CValue value, const Type* value_type) { + switch (emitter->config.dialect) { + case CDialect_ISPC: { + const Type* t = shd_get_unqualified_type(value_type); + if (t->tag == PtrType_TAG) + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "extract_ptr(%s, count_trailing_zeros(lanemask()))", value)); + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "extract(%s, count_trailing_zeros(lanemask()))", value)); + } + default: shd_error("TODO"); + } +} + +static CTerm emit_primop(Emitter* emitter, FnEmitter* fn, Printer* p, const Node* node) { + assert(node->tag == PrimOp_TAG); + IrArena* arena = emitter->arena; + const PrimOp* prim_op = &node->payload.prim_op; + CTerm term = term_from_cvalue(shd_fmt_string_irarena(emitter->arena, "/* todo %s */", shd_get_primop_name(prim_op->op))); + const ISelTableEntry* isel_entry = lookup_entry(emitter, prim_op->op); + switch (prim_op->op) { + case add_carry_op: + case sub_borrow_op: + case mul_extended_op: + shd_error("TODO: implement extended arithm ops in C"); + break; + // MATH OPS + case fract_op: { + CTerm floored = emit_using_entry(emitter, fn, p, lookup_entry(emitter, floor_op), prim_op->operands); + term = term_from_cvalue(shd_format_string_arena(arena->arena, "1 - %s", shd_c_to_ssa(emitter, floored))); + break; + } + case inv_sqrt_op: { + CTerm floored = emit_using_entry(emitter, fn, p, lookup_entry(emitter, sqrt_op), prim_op->operands); + term = term_from_cvalue(shd_format_string_arena(arena->arena, "1.0f / %s", shd_c_to_ssa(emitter, floored))); + break; + } + case min_op: { + CValue a = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, shd_first(prim_op->operands))); + CValue b = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[1])); + term = term_from_cvalue(shd_format_string_arena(arena->arena, "(%s > %s ? %s : %s)", a, b, b, a)); + break; + } + case max_op: { + CValue a = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, shd_first(prim_op->operands))); + CValue b = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[1])); + term = term_from_cvalue(shd_format_string_arena(arena->arena, "(%s > %s ? %s : %s)", a, b, a, b)); + break; + } + case sign_op: { + CValue src = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, shd_first(prim_op->operands))); + term = term_from_cvalue(shd_format_string_arena(arena->arena, "(%s > 0 ? 1 : -1)", src)); + break; + } + case fma_op: { + CValue a = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[0])); + CValue b = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[1])); + CValue c = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[2])); + switch (emitter->config.dialect) { + case CDialect_C11: + case CDialect_CUDA: { + term = term_from_cvalue(shd_format_string_arena(arena->arena, "fmaf(%s, %s, %s)", a, b, c)); + break; + } + default: { + term = term_from_cvalue(shd_format_string_arena(arena->arena, "(%s * %s) + %s", a, b, c)); + break; + } + } + break; + } + case lshift_op: + case rshift_arithm_op: + case rshift_logical_op: { + CValue src = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, shd_first(prim_op->operands))); + const Node* offset = prim_op->operands.nodes[1]; + CValue c_offset = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, offset)); + if (emitter->config.dialect == CDialect_GLSL) { + if (shd_get_unqualified_type(offset->type)->payload.int_type.width == IntTy64) + c_offset = shd_format_string_arena(arena->arena, "int(%s)", c_offset); + } + term = term_from_cvalue(shd_format_string_arena(arena->arena, "(%s %s %s)", src, prim_op->op == lshift_op ? "<<" : ">>", c_offset)); + break; + } + case size_of_op: + term = term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "sizeof(%s)", shd_c_emit_type(emitter, shd_first(prim_op->type_arguments), NULL))); + break; + case align_of_op: + term = term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "alignof(%s)", shd_c_emit_type(emitter, shd_first(prim_op->type_arguments), NULL))); + break; + case offset_of_op: { + const Type* t = shd_first(prim_op->type_arguments); + while (t->tag == TypeDeclRef_TAG) { + t = shd_get_nominal_type_body(t); + } + const Node* index = shd_first(prim_op->operands); + uint64_t index_literal = shd_get_int_literal_value(*shd_resolve_to_int_literal(index), false); + String member_name = shd_c_get_record_field_name(t, index_literal); + term = term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "offsetof(%s, %s)", shd_c_emit_type(emitter, t, NULL), member_name)); + break; + } case select_op: { + assert(prim_op->operands.count == 3); + CValue condition = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[0])); + CValue l = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[1])); + CValue r = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[2])); + term = term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "(%s) ? (%s) : (%s)", condition, l, r)); + break; + } + case convert_op: { + CTerm src = shd_c_emit_value(emitter, fn, shd_first(prim_op->operands)); + const Type* src_type = shd_get_unqualified_type(shd_first(prim_op->operands)->type); + const Type* dst_type = shd_first(prim_op->type_arguments); + if (emitter->config.dialect == CDialect_GLSL) { + if (is_glsl_scalar_type(src_type) && is_glsl_scalar_type(dst_type)) { + CType t = shd_c_emit_type(emitter, dst_type, NULL); + term = term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "%s(%s)", t, shd_c_to_ssa(emitter, src))); + } else + assert(false); + } else { + CType t = shd_c_emit_type(emitter, dst_type, NULL); + term = term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "((%s) %s)", t, shd_c_to_ssa(emitter, src))); + } + break; + } + case reinterpret_op: { + CTerm src_value = shd_c_emit_value(emitter, fn, shd_first(prim_op->operands)); + const Type* src_type = shd_get_unqualified_type(shd_first(prim_op->operands)->type); + const Type* dst_type = shd_first(prim_op->type_arguments); + switch (emitter->config.dialect) { + case CDialect_CUDA: + case CDialect_C11: { + String src = shd_make_unique_name(arena, "bitcast_src"); + String dst = shd_make_unique_name(arena, "bitcast_result"); + shd_print(p, "\n%s = %s;", shd_c_emit_type(emitter, src_type, src), shd_c_to_ssa(emitter, src_value)); + shd_print(p, "\n%s;", shd_c_emit_type(emitter, dst_type, dst)); + shd_print(p, "\nmemcpy(&%s, &%s, sizeof(%s));", dst, src, src); + return term_from_cvalue(dst); + } + // GLSL does not feature arbitrary casts, instead we need to run specialized conversion functions... + case CDialect_GLSL: { + String conv_fn = NULL; + if (dst_type->tag == Float_TAG) { + assert(src_type->tag == Int_TAG); + switch (dst_type->payload.float_type.width) { + case FloatTy16: break; + case FloatTy32: conv_fn = src_type->payload.int_type.is_signed ? "intBitsToFloat" : "uintBitsToFloat"; + break; + case FloatTy64: break; + } + } else if (dst_type->tag == Int_TAG) { + if (src_type->tag == Int_TAG) { + return src_value; + } + assert(src_type->tag == Float_TAG); + switch (src_type->payload.float_type.width) { + case FloatTy16: break; + case FloatTy32: conv_fn = dst_type->payload.int_type.is_signed ? "floatBitsToInt" : "floatBitsToUint"; + break; + case FloatTy64: break; + } + } + if (conv_fn) { + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "%s(%s)", conv_fn, shd_c_to_ssa(emitter, src_value))); + } + shd_error_print("glsl: unsupported bit cast from "); + shd_log_node(ERROR, src_type); + shd_error_print(" to "); + shd_log_node(ERROR, dst_type); + shd_error_print(".\n"); + shd_error_die(); + } + case CDialect_ISPC: { + if (dst_type->tag == Float_TAG) { + assert(src_type->tag == Int_TAG); + String n; + switch (dst_type->payload.float_type.width) { + case FloatTy16: n = "float16bits"; + break; + case FloatTy32: n = "floatbits"; + break; + case FloatTy64: n = "doublebits"; + break; + } + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "%s(%s)", n, shd_c_to_ssa(emitter, src_value))); + } else if (src_type->tag == Float_TAG) { + assert(dst_type->tag == Int_TAG); + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "intbits(%s)", shd_c_to_ssa(emitter, src_value))); + } + + CType t = shd_c_emit_type(emitter, dst_type, NULL); + return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "((%s) %s)", t, shd_c_to_ssa(emitter, src_value))); + } + } + SHADY_UNREACHABLE; + } + case insert_op: + case extract_dynamic_op: + case extract_op: { + CValue acc = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, shd_first(prim_op->operands))); + bool insert = prim_op->op == insert_op; + + if (insert) { + String dst = shd_make_unique_name(arena, "modified"); + shd_print(p, "\n%s = %s;", shd_c_emit_type(emitter, node->type, dst), acc); + acc = dst; + term = term_from_cvalue(dst); + } + + const Type* t = shd_get_unqualified_type(shd_first(prim_op->operands)->type); + for (size_t i = (insert ? 2 : 1); i < prim_op->operands.count; i++) { + const Node* index = prim_op->operands.nodes[i]; + const IntLiteral* static_index = shd_resolve_to_int_literal(index); + + switch (is_type(t)) { + case Type_TypeDeclRef_TAG: { + const Node* decl = t->payload.type_decl_ref.decl; + assert(decl && decl->tag == NominalType_TAG); + t = decl->payload.nom_type.body; + SHADY_FALLTHROUGH + } + case Type_RecordType_TAG: { + assert(static_index); + Strings names = t->payload.record_type.names; + if (names.count == 0) + acc = shd_format_string_arena(emitter->arena->arena, "(%s._%d)", acc, static_index->value); + else + acc = shd_format_string_arena(emitter->arena->arena, "(%s.%s)", acc, names.strings[static_index->value]); + break; + } + case Type_PackType_TAG: { + assert(static_index); + assert(static_index->value < 4 && static_index->value < t->payload.pack_type.width); + String suffixes = "xyzw"; + acc = shd_format_string_arena(emitter->arena->arena, "(%s.%c)", acc, suffixes[static_index->value]); + break; + } + case Type_ArrType_TAG: { + acc = index_into_array(emitter, t, term_from_cvar(acc), shd_c_emit_value(emitter, fn, index)); + break; + } + default: + case NotAType: shd_error("Must be a type"); + } + } + + if (insert) { + shd_print(p, "\n%s = %s;", acc, shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[1]))); + break; + } + + term = term_from_cvalue(acc); + break; + } + case shuffle_op: { + String dst = shd_make_unique_name(arena, "shuffled"); + const Node* lhs = prim_op->operands.nodes[0]; + const Node* rhs = prim_op->operands.nodes[1]; + String lhs_e = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[0])); + String rhs_e = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[1])); + const Type* lhs_t = lhs->type; + const Type* rhs_t = rhs->type; + bool lhs_u = shd_deconstruct_qualified_type(&lhs_t); + bool rhs_u = shd_deconstruct_qualified_type(&rhs_t); + size_t left_size = lhs_t->payload.pack_type.width; + // size_t total_size = lhs_t->payload.pack_type.width + rhs_t->payload.pack_type.width; + String suffixes = "xyzw"; + shd_print(p, "\n%s = vec%d(", shd_c_emit_type(emitter, node->type, dst), prim_op->operands.count - 2); + for (size_t i = 2; i < prim_op->operands.count; i++) { + const IntLiteral* selector = shd_resolve_to_int_literal(prim_op->operands.nodes[i]); + if (selector->value < left_size) + shd_print(p, "%s.%c\n", lhs_e, suffixes[selector->value]); + else + shd_print(p, "%s.%c\n", rhs_e, suffixes[selector->value - left_size]); + if (i + 1 < prim_op->operands.count) + shd_print(p, ", "); + } + shd_print(p, ");\n"); + term = term_from_cvalue(dst); + break; + } + case subgroup_assume_uniform_op: { + if (emitter->config.dialect == CDialect_ISPC) { + CValue value = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, prim_op->operands.nodes[0])); + return broadcast_first(emitter, value, prim_op->operands.nodes[0]->type); + } + return shd_c_emit_value(emitter, fn, prim_op->operands.nodes[0]); + } + case empty_mask_op: + case mask_is_thread_active_op: shd_error("lower_me"); + default: break; + case PRIMOPS_COUNT: assert(false); break; + } + + if (isel_entry->isel_mechanism != IsNone) + return emit_using_entry(emitter, fn, p, isel_entry, prim_op->operands); + + return term; +} + +typedef struct { + String set; + SpvOp op; + size_t prefix_len; + uint32_t* prefix; +} ExtISelPattern; + +#define mk_prefix(...) .prefix_len = sizeof((uint32_t[]) {__VA_ARGS__}) / sizeof(uint32_t), .prefix = (uint32_t[]) {__VA_ARGS__} +#define subgroup_reduction mk_prefix(SpvScopeSubgroup, SpvGroupOperationReduce) + +typedef struct { + ExtISelPattern match; + ISelTableEntry payload; +} ExtISelEntry; + +ExtISelEntry ext_isel_ispc_entries[] = { + // reduce add + {{ "spirv.core", SpvOpGroupIAdd, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_add" }}, + {{ "spirv.core", SpvOpGroupFAdd, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_add" }}, + {{ "spirv.core", SpvOpGroupNonUniformIAdd, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_add" }}, + {{ "spirv.core", SpvOpGroupNonUniformFAdd, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_add" }}, + // min + {{ "spirv.core", SpvOpGroupSMin, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_min" }}, + {{ "spirv.core", SpvOpGroupUMin, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_min" }}, + {{ "spirv.core", SpvOpGroupFMin, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_min" }}, + {{ "spirv.core", SpvOpGroupNonUniformSMin, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_min" }}, + {{ "spirv.core", SpvOpGroupNonUniformUMin, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_min" }}, + {{ "spirv.core", SpvOpGroupNonUniformFMin, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_min" }}, + // max + {{ "spirv.core", SpvOpGroupSMax, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_max" }}, + {{ "spirv.core", SpvOpGroupUMax, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_max" }}, + {{ "spirv.core", SpvOpGroupFMax, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_max" }}, + {{ "spirv.core", SpvOpGroupNonUniformSMax, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_max" }}, + {{ "spirv.core", SpvOpGroupNonUniformUMax, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_max" }}, + {{ "spirv.core", SpvOpGroupNonUniformFMax, subgroup_reduction }, { IsMono, OsCall, .op = "reduce_max" }}, + // rest + {{ "spirv.core", SpvOpGroupNonUniformAllEqual, mk_prefix(SpvScopeSubgroup) }, { IsMono, OsCall, .op = "reduce_equal" }}, + {{ "spirv.core", SpvOpGroupNonUniformBallot, mk_prefix(SpvScopeSubgroup) }, { IsMono, OsCall, .op = "packmask" }}, +}; + +ExtISelEntry ext_isel_entries[] = { + {{ "spirv.core", SpvOpGroupNonUniformBroadcastFirst, mk_prefix(SpvScopeSubgroup) }, { IsMono, OsCall, .op = "__shady_broadcast_first" }}, + {{ "spirv.core", SpvOpGroupNonUniformElect, mk_prefix(SpvScopeSubgroup) }, { IsMono, OsCall, .op = "__shady_elect_first" }}, +}; + +static bool check_ext_entry(const ExtISelPattern* entry, ExtInstr instr) { + if (strcmp(entry->set, instr.set) != 0 || entry->op != instr.opcode) + return false; + for (size_t i = 0; i < entry->prefix_len; i++) { + if (i >= instr.operands.count) + return false; + const IntLiteral* lit = shd_resolve_to_int_literal(instr.operands.nodes[i]); + if (!lit) + return false; + if (shd_get_int_literal_value(*lit, false) != entry->prefix[i]) + return false; + } + return true; +} + +static const ExtISelEntry* find_ext_entry_in_list(const ExtISelEntry table[], size_t size, ExtInstr instr) { + for (size_t i = 0; i < size; i++) { + if (check_ext_entry(&table[i].match, instr)) + return &table[i]; + } + return NULL; +} + +#define scan_entries(name) { const ExtISelEntry* f = find_ext_entry_in_list(name, sizeof(name) / sizeof(name[0]), instr); if (f) return f; } + +static const ExtISelEntry* find_ext_entry(Emitter* e, ExtInstr instr) { + switch (e->config.dialect) { + case CDialect_ISPC: scan_entries(ext_isel_ispc_entries); break; + default: break; + } + scan_entries(ext_isel_entries); + return NULL; +} + +static CTerm emit_ext_instruction(Emitter* emitter, FnEmitter* fn, Printer* p, ExtInstr instr) { + shd_c_emit_mem(emitter, fn, instr.mem); + if (strcmp(instr.set, "spirv.core") == 0) { + switch (instr.opcode) { + case SpvOpGroupNonUniformBroadcastFirst: { + assert(instr.operands.count == 2); + CValue value = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, instr.operands.nodes[1])); + return broadcast_first(emitter, value, instr.operands.nodes[1]->type); + } + case SpvOpGroupNonUniformElect: { + assert(instr.operands.count == 1); + const IntLiteral* scope = shd_resolve_to_int_literal(shd_first(instr.operands)); + assert(scope && scope->value == SpvScopeSubgroup); + switch (emitter->config.dialect) { + case CDialect_ISPC: return term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "(programIndex == count_trailing_zeros(lanemask()))")); + default: break; + } + break; + } + default: break; + } + } + + const ExtISelEntry* entry = find_ext_entry(emitter, instr); + if (entry) { + Nodes operands = instr.operands; + if (entry->match.prefix_len > 0) + operands = shd_nodes(emitter->arena, operands.count - entry->match.prefix_len, &operands.nodes[entry->match.prefix_len]); + return emit_using_entry(emitter, fn, p, &entry->payload, operands); + } else { + shd_error("Unsupported extended instruction: (set = %s, opcode = %d )", instr.set, instr.opcode); + } +} + +static CTerm emit_call(Emitter* emitter, FnEmitter* fn, Printer* p, const Node* call) { + Call payload = call->payload.call; + shd_c_emit_mem(emitter, fn, payload.mem); + Nodes args; + if (call->tag == Call_TAG) + args = call->payload.call.args; + else + assert(false); + + Growy* g = shd_new_growy(); + Printer* paramsp = shd_new_printer_from_growy(g); + if (emitter->use_private_globals) { + shd_print(paramsp, "__shady_private_globals"); + if (args.count > 0) + shd_print(paramsp, ", "); + } + for (size_t i = 0; i < args.count; i++) { + shd_print(paramsp, shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, args.nodes[i]))); + if (i + 1 < args.count) + shd_print(paramsp, ", "); + } + + CValue e_callee; + const Node* callee = call->payload.call.callee; + if (callee->tag == FnAddr_TAG) + e_callee = shd_c_legalize_identifier(emitter, get_declaration_name(callee->payload.fn_addr.fn)); + else + e_callee = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, callee)); + + String params = shd_printer_growy_unwrap(paramsp); + + CTerm called = term_from_cvalue(shd_format_string_arena(emitter->arena->arena, "\n%s(%s)", e_callee, params)); + called = shd_c_bind_intermediary_result(emitter, p, call->type, called); + + free_tmp_str(params); + return called; +} + +static CTerm emit_ptr_composite_element(Emitter* emitter, FnEmitter* fn, Printer* p, PtrCompositeElement lea) { + IrArena* arena = emitter->arena; + CTerm acc = shd_c_emit_value(emitter, fn, lea.ptr); + + const Type* src_qtype = lea.ptr->type; + bool uniform = shd_is_qualified_type_uniform(src_qtype); + const Type* curr_ptr_type = shd_get_unqualified_type(src_qtype); + assert(curr_ptr_type->tag == PtrType_TAG); + + const Type* pointee_type = shd_get_pointee_type(arena, curr_ptr_type); + const Node* selector = lea.index; + uniform &= shd_is_qualified_type_uniform(selector->type); + switch (is_type(pointee_type)) { + case ArrType_TAG: { + CTerm index = shd_c_emit_value(emitter, fn, selector); + acc = term_from_cvar(index_into_array(emitter, pointee_type, acc, index)); + curr_ptr_type = ptr_type(arena, (PtrType) { + .pointed_type = pointee_type->payload.arr_type.element_type, + .address_space = curr_ptr_type->payload.ptr_type.address_space + }); + break; + } + case TypeDeclRef_TAG: { + pointee_type = shd_get_nominal_type_body(pointee_type); + SHADY_FALLTHROUGH + } + case RecordType_TAG: { + // yet another ISPC bug and workaround + // ISPC cannot deal with subscripting if you've done pointer arithmetic (!) inside the expression + // so hum we just need to introduce a temporary variable to hold the pointer expression so far, and go again from there + // See https://github.com/ispc/ispc/issues/2496 + if (emitter->config.dialect == CDialect_ISPC) { + String interm = shd_make_unique_name(arena, "lea_intermediary_ptr_value"); + shd_print(p, "\n%s = %s;", shd_c_emit_type(emitter, shd_as_qualified_type(curr_ptr_type, uniform), interm), shd_c_to_ssa(emitter, acc)); + acc = term_from_cvalue(interm); + } + + assert(selector->tag == IntLiteral_TAG && "selectors when indexing into a record need to be constant"); + size_t static_index = shd_get_int_literal_value(*shd_resolve_to_int_literal(selector), false); + String field_name = shd_c_get_record_field_name(pointee_type, static_index); + acc = term_from_cvar(shd_format_string_arena(arena->arena, "(%s.%s)", shd_c_deref(emitter, acc), field_name)); + curr_ptr_type = ptr_type(arena, (PtrType) { + .pointed_type = pointee_type->payload.record_type.members.nodes[static_index], + .address_space = curr_ptr_type->payload.ptr_type.address_space + }); + break; + } + case Type_PackType_TAG: { + size_t static_index = shd_get_int_literal_value(*shd_resolve_to_int_literal(selector), false); + String suffixes = "xyzw"; + acc = term_from_cvar(shd_format_string_arena(emitter->arena->arena, "(%s.%c)", shd_c_deref(emitter, acc), suffixes[static_index])); + curr_ptr_type = ptr_type(arena, (PtrType) { + .pointed_type = pointee_type->payload.pack_type.element_type, + .address_space = curr_ptr_type->payload.ptr_type.address_space + }); + break; + } + default: shd_error("lea can't work on this"); + } + + // if (emitter->config.dialect == CDialect_ISPC) + // acc = c_bind_intermediary_result(emitter, p, curr_ptr_type, acc); + + return acc; +} + +static CTerm emit_ptr_array_element_offset(Emitter* emitter, FnEmitter* fn, Printer* p, PtrArrayElementOffset lea) { + IrArena* arena = emitter->arena; + CTerm acc = shd_c_emit_value(emitter, fn, lea.ptr); + + const Type* src_qtype = lea.ptr->type; + bool uniform = shd_is_qualified_type_uniform(src_qtype); + const Type* curr_ptr_type = shd_get_unqualified_type(src_qtype); + assert(curr_ptr_type->tag == PtrType_TAG); + + const IntLiteral* offset_static_value = shd_resolve_to_int_literal(lea.offset); + if (!offset_static_value || offset_static_value->value != 0) { + CTerm offset = shd_c_emit_value(emitter, fn, lea.offset); + // we sadly need to drop to the value level (aka explicit pointer arithmetic) to do this + // this means such code is never going to be legal in GLSL + // also the cast is to account for our arrays-in-structs hack + const Type* pointee_type = shd_get_pointee_type(arena, curr_ptr_type); + acc = term_from_cvalue(shd_format_string_arena(arena->arena, "((%s) &(%s)[%s])", shd_c_emit_type(emitter, curr_ptr_type, NULL), shd_c_to_ssa(emitter, acc), shd_c_to_ssa(emitter, offset))); + uniform &= shd_is_qualified_type_uniform(lea.offset->type); + } + + if (emitter->config.dialect == CDialect_ISPC) + acc = shd_c_bind_intermediary_result(emitter, p, curr_ptr_type, acc); + + return acc; +} + +static const Type* get_allocated_type(const Node* alloc) { + switch (alloc->tag) { + case Instruction_StackAlloc_TAG: return alloc->payload.stack_alloc.type; + case Instruction_LocalAlloc_TAG: return alloc->payload.local_alloc.type; + default: assert(false); return NULL; + } +} + +static CTerm emit_alloca(Emitter* emitter, Printer* p, const Type* instr) { + String variable_name = shd_make_unique_name(emitter->arena, "alloca"); + CTerm variable = (CTerm) { .value = NULL, .var = variable_name }; + shd_c_emit_variable_declaration(emitter, p, get_allocated_type(instr), variable_name, true, NULL); + const Type* ptr_type = instr->type; + shd_deconstruct_qualified_type(&ptr_type); + assert(ptr_type->tag == PtrType_TAG); + if (emitter->config.dialect == CDialect_ISPC && !ptr_type->payload.ptr_type.is_reference) { + variable = shd_ispc_varying_ptr_helper(emitter, p, shd_get_unqualified_type(instr->type), variable); + } + return variable; +} + +static CTerm emit_instruction(Emitter* emitter, FnEmitter* fn, Printer* p, const Node* instruction) { + assert(is_instruction(instruction)); + IrArena* a = emitter->arena; + + switch (is_instruction(instruction)) { + case NotAnInstruction: assert(false); + case Instruction_PushStack_TAG: + case Instruction_PopStack_TAG: + case Instruction_GetStackSize_TAG: + case Instruction_SetStackSize_TAG: + case Instruction_GetStackBaseAddr_TAG: shd_error("Stack operations need to be lowered."); + case Instruction_ExtInstr_TAG: return emit_ext_instruction(emitter, fn, p, instruction->payload.ext_instr); + case Instruction_PrimOp_TAG: return shd_c_bind_intermediary_result(emitter, p, instruction->type, emit_primop(emitter, fn, p, instruction)); + case Instruction_Call_TAG: return emit_call(emitter, fn, p, instruction); + case Instruction_Comment_TAG: shd_print(p, "/* %s */", instruction->payload.comment.string); return empty_term(); + case Instruction_StackAlloc_TAG: shd_c_emit_mem(emitter, fn, instruction->payload.local_alloc.mem); return emit_alloca(emitter, p, instruction); + case Instruction_LocalAlloc_TAG: shd_c_emit_mem(emitter, fn, instruction->payload.local_alloc.mem); return emit_alloca(emitter, p, instruction); + case Instruction_PtrArrayElementOffset_TAG: return emit_ptr_array_element_offset(emitter, fn, p, instruction->payload.ptr_array_element_offset); + case Instruction_PtrCompositeElement_TAG: return emit_ptr_composite_element(emitter, fn, p, instruction->payload.ptr_composite_element); + case Instruction_Load_TAG: { + Load payload = instruction->payload.load; + shd_c_emit_mem(emitter, fn, payload.mem); + CAddr dereferenced = shd_c_deref(emitter, shd_c_emit_value(emitter, fn, payload.ptr)); + return term_from_cvalue(dereferenced); + } + case Instruction_Store_TAG: { + Store payload = instruction->payload.store; + shd_c_emit_mem(emitter, fn, payload.mem); + const Type* addr_type = payload.ptr->type; + bool addr_uniform = shd_deconstruct_qualified_type(&addr_type); + bool value_uniform = shd_is_qualified_type_uniform(payload.value->type); + assert(addr_type->tag == PtrType_TAG); + CAddr dereferenced = shd_c_deref(emitter, shd_c_emit_value(emitter, fn, payload.ptr)); + CValue cvalue = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, payload.value)); + // ISPC lets you broadcast to a uniform address space iff the address is non-uniform, otherwise we need to do this + if (emitter->config.dialect == CDialect_ISPC && addr_uniform && shd_is_addr_space_uniform(a, addr_type->payload.ptr_type.address_space) && !value_uniform) + cvalue = shd_format_string_arena(emitter->arena->arena, "extract_ptr(%s, count_trailing_zeros(lanemask()))", cvalue); + + shd_print(p, "\n%s = %s;", dereferenced, cvalue); + return empty_term(); + } + case Instruction_CopyBytes_TAG: { + CopyBytes payload = instruction->payload.copy_bytes; + shd_c_emit_mem(emitter, fn, payload.mem); + shd_print(p, "\nmemcpy(%s, %s, %s);", shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, payload.dst)), shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, payload.src)), shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, payload.count))); + return empty_term(); + } + case Instruction_FillBytes_TAG:{ + FillBytes payload = instruction->payload.fill_bytes; + shd_c_emit_mem(emitter, fn, payload.mem); + shd_print(p, "\nmemset(%s, %s, %s);", shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, payload.dst)), shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, payload.src)), shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, payload.count))); + return empty_term(); + } + case Instruction_DebugPrintf_TAG: { + DebugPrintf payload = instruction->payload.debug_printf; + shd_c_emit_mem(emitter, fn, payload.mem); + Printer* args_printer = shd_new_printer_from_growy(shd_new_growy()); + shd_print(args_printer, "\""); + shd_printer_unescape(args_printer, instruction->payload.debug_printf.string); + shd_print(args_printer, "\""); + for (size_t i = 0; i < instruction->payload.debug_printf.args.count; i++) { + CValue str = shd_c_to_ssa(emitter, shd_c_emit_value(emitter, fn, instruction->payload.debug_printf.args.nodes[i])); + + // special casing for the special child + if (emitter->config.dialect == CDialect_ISPC) + shd_print(args_printer, ", extract(%s, printf_thread_index)", str); + else + shd_print(args_printer, ", %s", str); + } + String args_list = shd_printer_growy_unwrap(args_printer); + switch (emitter->config.dialect) { + case CDialect_ISPC:shd_print(p, "\nforeach_active(printf_thread_index) { print(%s); }", args_list); + break; + case CDialect_CUDA: + case CDialect_C11:shd_print(p, "\nprintf(%s);", args_list); + break; + case CDialect_GLSL: shd_warn_print("printf is not supported in GLSL"); + break; + } + free((char*) args_list); + + return empty_term(); + } + } + + SHADY_UNREACHABLE; +} + +static bool can_appear_at_top_level(Emitter* emitter, const Node* node) { + if (is_instruction(node)) + return false; + if (emitter->config.dialect == CDialect_ISPC) { + if (node->tag == RefDecl_TAG) { + const Node* decl = node->payload.ref_decl.decl; + if (decl->tag == GlobalVariable_TAG) + if (!shd_is_addr_space_uniform(emitter->arena, decl->payload.global_variable.address_space) && !shd_is_decl_builtin( + decl)) + //if (is_value(node) && !is_qualified_type_uniform(node->type)) + return false; + } + } + return true; +} + +CTerm shd_c_emit_value(Emitter* emitter, FnEmitter* fn_builder, const Node* node) { + CTerm* found = shd_c_lookup_existing_term(emitter, fn_builder, node); + if (found) return *found; + + CFNode* where = fn_builder ? shd_schedule_instruction(fn_builder->scheduler, node) : NULL; + if (where) { + CTerm emitted = c_emit_value_(emitter, fn_builder, fn_builder->instruction_printers[where->rpo_index], node); + shd_c_register_emitted(emitter, fn_builder, node, emitted); + return emitted; + } else if (!can_appear_at_top_level(emitter, node)) { + if (!fn_builder) { + shd_log_node(ERROR, node); + shd_log_fmt(ERROR, "cannot appear at top-level"); + exit(-1); + } + // Pick the entry block of the current fn + CTerm emitted = c_emit_value_(emitter, fn_builder, fn_builder->instruction_printers[0], node); + shd_c_register_emitted(emitter, fn_builder, node, emitted); + return emitted; + } else { + assert(!is_mem(node)); + CTerm emitted = c_emit_value_(emitter, NULL, NULL, node); + shd_c_register_emitted(emitter, NULL, node, emitted); + return emitted; + } +} + +CTerm shd_c_emit_mem(Emitter* e, FnEmitter* b, const Node* mem) { + assert(is_mem(mem)); + if (mem->tag == AbsMem_TAG) + return empty_term(); + if (is_instruction(mem)) + return shd_c_emit_value(e, b, mem); + shd_error("What sort of mem is this ?"); +} diff --git a/src/backend/c/prelude.cu b/src/backend/c/prelude.cu new file mode 100644 index 000000000..4f5279034 --- /dev/null +++ b/src/backend/c/prelude.cu @@ -0,0 +1,2 @@ +#define __shady_make_thread_local(var) var[__shady_workgroup_size] +#define __shady_thread_local_access(var) (var[((threadIdx.x * blockDim.y + threadIdx.y) * blockDim.z + threadIdx.z)]) diff --git a/src/backend/c/runtime.cu b/src/backend/c/runtime.cu new file mode 100644 index 000000000..2a4c6eea7 --- /dev/null +++ b/src/backend/c/runtime.cu @@ -0,0 +1,31 @@ +__shared__ uvec3 __shady_make_thread_local(RealGlobalInvocationId); +__shared__ uvec3 __shady_make_thread_local(RealLocalInvocationId); + +#define GlobalInvocationId __shady_thread_local_access(RealGlobalInvocationId) +#define LocalInvocationId __shady_thread_local_access(RealLocalInvocationId) + +__device__ void __shady_prepare_builtins() { + LocalInvocationId.arr[0] = threadIdx.x; + LocalInvocationId.arr[1] = threadIdx.y; + LocalInvocationId.arr[2] = threadIdx.z; + GlobalInvocationId.arr[0] = threadIdx.x + blockDim.x * blockIdx.x; + GlobalInvocationId.arr[1] = threadIdx.y + blockDim.y * blockIdx.y; + GlobalInvocationId.arr[2] = threadIdx.z + blockDim.z * blockIdx.z; +} + +__device__ bool __shady_elect_first() { + unsigned int writemask = __activemask(); + // Find the lowest-numbered active lane + int elected_lane = __ffs(writemask) - 1; + return threadIdx.x == __shfl_sync(writemask, threadIdx.x, elected_lane) + && threadIdx.y == __shfl_sync(writemask, threadIdx.y, elected_lane) + && threadIdx.z == __shfl_sync(writemask, threadIdx.z, elected_lane); +} + +template +__device__ T __shady_broadcast_first(T t) { + unsigned int writemask = __activemask(); + // Find the lowest-numbered active lane + int elected_lane = __ffs(writemask) - 1; + return __shfl_sync(writemask, t, elected_lane); +} diff --git a/src/backend/c/runtime.ispc b/src/backend/c/runtime.ispc new file mode 100644 index 000000000..cd6b3cba9 --- /dev/null +++ b/src/backend/c/runtime.ispc @@ -0,0 +1,4 @@ +template +T uniform* uniform extract_ptr(uniform T* varying t, uniform int i) { + return (uniform T* uniform) extract((intptr_t) t, i); +} diff --git a/src/backend/c/runtime_120.glsl b/src/backend/c/runtime_120.glsl new file mode 100644 index 000000000..d794ae47e --- /dev/null +++ b/src/backend/c/runtime_120.glsl @@ -0,0 +1,117 @@ +int mod(int a, int b) { + int q = a / b; + int qb = q * b; + return a - qb; +} + +int rshift1(int word) { + return word / 2; +} + +int lshift1(int word) { + int capped = mod(word, 32768); + return capped * 2; +} + +bool extract_bit(int word, int pos) { + int shifted = word; + for (int i = 0; i < pos; i++) { + shifted = rshift1(shifted); + } + return mod(shifted, 2) == 1; +} + +const int bits[16] = int[16](0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000); + +int set_bit(int word, int pos, bool value) { + int result = 0; + for (int i = 0; i < 16; i++) { + bool set; + if (i == pos) + set = value; + else + set = extract_bit(word, i); + if (set) + result += bits[i]; + } + return result; + //if (value) { + // return mod(unset + bits[pos], 65536); + //} + //return unset; +} + +int and(int a, int b) { + int shifteda = a; + int shiftedb = b; + int result = 0; + for (int i = 0; i < 16; i++) { + bool ba = mod(shifteda, 2) == 1; + bool bb = mod(shiftedb, 2) == 1; + bool br = ba && bb; + + if (br) + result += bits[i]; + + shifteda = rshift1(shifteda); + shiftedb = rshift1(shiftedb); + } + return result; +} + +int or(int a, int b) { + int shifteda = a; + int shiftedb = b; + int result = 0; + for (int i = 0; i < 16; i++) { + bool ba = mod(shifteda, 2) == 1; + bool bb = mod(shiftedb, 2) == 1; + bool br = ba || bb; + + if (br) + result += bits[i]; + + shifteda = rshift1(shifteda); + shiftedb = rshift1(shiftedb); + } + return result; +} + +int xor(int a, int b) { + int shifteda = a; + int shiftedb = b; + int result = 0; + for (int i = 0; i < 16; i++) { + bool ba = mod(shifteda, 2) == 1; + bool bb = mod(shiftedb, 2) == 1; + bool br = ba ^^ bb; + + if (br) + result += bits[i]; + + shifteda = rshift1(shifteda); + shiftedb = rshift1(shiftedb); + } + return result; +} + +int not(int a) { + int shifteda = a; + int result = 0; + for (int i = 0; i < 16; i++) { + bool ba = mod(shifteda, 2) == 1; + bool br = !ba; + + if (br) + result += bits[i]; + + shifteda = rshift1(shifteda); + } + return result; +} + +bool and(bool a, bool b) { return a && b; } +bool or(bool a, bool b) { return a || b; } +bool xor(bool a, bool b) { return a ^^ b; } +bool not(bool a) { return !a; } + diff --git a/src/backend/spirv/CMakeLists.txt b/src/backend/spirv/CMakeLists.txt new file mode 100644 index 000000000..e99e21ff6 --- /dev/null +++ b/src/backend/spirv/CMakeLists.txt @@ -0,0 +1,19 @@ +add_library(shady_spirv STATIC + emit_spv.c + emit_spv_type.c + emit_spv_value.c + emit_spv_control_flow.c + spirv_lift_globals_ssbo.c + spirv_map_entrypoint_args.c + spirv_builder.c +) +set_property(TARGET shady_spirv PROPERTY POSITION_INDEPENDENT_CODE ON) + +target_include_directories(shady_spirv PRIVATE $) + +target_link_libraries(shady_spirv PRIVATE "api") +target_link_libraries(shady_spirv INTERFACE "$") +target_link_libraries(shady_spirv PRIVATE "$") +target_link_libraries(shady_spirv PRIVATE "$") + +target_link_libraries(driver PUBLIC "$") diff --git a/src/backend/spirv/emit_spv.c b/src/backend/spirv/emit_spv.c new file mode 100644 index 000000000..bf3ae9821 --- /dev/null +++ b/src/backend/spirv/emit_spv.c @@ -0,0 +1,373 @@ +#include "emit_spv.h" + +#include "shady/ir/builtin.h" + +#include "../shady/ir_private.h" +#include "../shady/analysis/cfg.h" +#include "../shady/passes/passes.h" +#include "../shady/analysis/scheduler.h" + +#include "list.h" +#include "dict.h" +#include "log.h" +#include "portability.h" +#include "growy.h" + +#include +#include +#include + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +KeyHash shd_hash_string(const char** string); +bool shd_compare_string(const char** a, const char** b); + +#pragma GCC diagnostic error "-Wswitch" + +void spv_register_emitted(Emitter* emitter, FnBuilder* fn_builder, const Node* node, SpvId id) { + if (is_value(node)) { + String name = shd_get_value_name_unsafe(node); + if (name) + spvb_name(emitter->file_builder, id, name); + } + struct Dict* map = fn_builder ? fn_builder->emitted : emitter->global_node_ids; + shd_dict_insert_get_result(struct Node*, SpvId, map, node, id); +} + +SpvId* spv_search_emitted(Emitter* emitter, FnBuilder* fn_builder, const Node* node) { + SpvId* found = NULL; + if (fn_builder) + found = shd_dict_find_value(const Node*, SpvId, fn_builder->emitted, node); + if (!found) + found = shd_dict_find_value(const Node*, SpvId, emitter->global_node_ids, node); + return found; +} + +SpvId spv_find_emitted(Emitter* emitter, FnBuilder* fn_builder, const Node* node) { + SpvId* found = spv_search_emitted(emitter, fn_builder, node); + return *found; +} + +static void emit_basic_block(Emitter* emitter, FnBuilder* fn_builder, const CFNode* cf_node) { + const Node* bb_node = cf_node->node; + assert(is_basic_block(bb_node) || cf_node == fn_builder->cfg->entry); + + const Node* body = get_abstraction_body(bb_node); + + // Find the preassigned ID to this + BBBuilder bb_builder = spv_find_basic_block_builder(emitter, bb_node); + SpvId bb_id = spvb_get_block_builder_id(bb_builder); + spvb_add_bb(fn_builder->base, bb_builder); + + String name = shd_get_abstraction_name_safe(bb_node); + if (name) + spvb_name(emitter->file_builder, bb_id, name); + + spv_emit_terminator(emitter, fn_builder, bb_builder, bb_node, body); + + for (size_t i = 0; i < shd_list_count(cf_node->dominates); i++) { + CFNode* dominated = shd_read_list(CFNode*, cf_node->dominates)[i]; + emit_basic_block(emitter, fn_builder, dominated); + } + + if (fn_builder->per_bb[cf_node->rpo_index].continue_builder) + spvb_add_bb(fn_builder->base, fn_builder->per_bb[cf_node->rpo_index].continue_builder); +} + +static void emit_function(Emitter* emitter, const Node* node) { + assert(node->tag == Function_TAG); + + const Type* fn_type = node->type; + SpvId fn_id = spv_find_emitted(emitter, NULL, node); + FnBuilder fn_builder = { + .base = spvb_begin_fn(emitter->file_builder, fn_id, spv_emit_type(emitter, fn_type), spv_types_to_codom(emitter, node->payload.fun.return_types)), + .emitted = shd_new_dict(Node*, SpvId, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .cfg = build_fn_cfg(node), + }; + fn_builder.scheduler = shd_new_scheduler(fn_builder.cfg); + fn_builder.per_bb = calloc(sizeof(*fn_builder.per_bb), fn_builder.cfg->size); + + Nodes params = node->payload.fun.params; + for (size_t i = 0; i < params.count; i++) { + const Node* param = params.nodes[i]; + const Type* param_type = param->payload.param.type; + SpvId param_id = spvb_parameter(fn_builder.base, spv_emit_type(emitter, param_type)); + spv_register_emitted(emitter, false, param, param_id); + shd_deconstruct_qualified_type(¶m_type); + if (param_type->tag == PtrType_TAG && param_type->payload.ptr_type.address_space == AsGlobal) { + spvb_decorate(emitter->file_builder, param_id, SpvDecorationAliased, 0, NULL); + } + } + + if (node->payload.fun.body) { + // reserve a bunch of identifiers for the basic blocks in the CFG + for (size_t i = 0; i < fn_builder.cfg->size; i++) { + CFNode* cfnode = fn_builder.cfg->rpo[i]; + assert(cfnode); + const Node* bb = cfnode->node; + assert(is_basic_block(bb) || bb == node); + SpvId bb_id = spvb_fresh_id(emitter->file_builder); + BBBuilder basic_block_builder = spvb_begin_bb(fn_builder.base, bb_id); + shd_dict_insert(const Node*, BBBuilder, emitter->bb_builders, bb, basic_block_builder); + // add phis for every non-entry basic block + if (i > 0) { + assert(is_basic_block(bb) && bb != node); + Nodes bb_params = bb->payload.basic_block.params; + for (size_t j = 0; j < bb_params.count; j++) { + const Node* bb_param = bb_params.nodes[j]; + SpvId phi_id = spvb_fresh_id(emitter->file_builder); + spvb_add_phi(basic_block_builder, spv_emit_type(emitter, bb_param->type), phi_id); + spv_register_emitted(emitter, false, bb_param, phi_id); + } + // also make sure to register the label for basic blocks + spv_register_emitted(emitter, false, bb, bb_id); + } + } + emit_basic_block(emitter, &fn_builder, fn_builder.cfg->entry); + + spvb_define_function(emitter->file_builder, fn_builder.base); + } else { + Growy* g = shd_new_growy(); + spvb_literal_name(g, shd_get_abstraction_name(node)); + shd_growy_append_bytes(g, 4, (char*) &(uint32_t) { SpvLinkageTypeImport }); + spvb_decorate(emitter->file_builder, fn_id, SpvDecorationLinkageAttributes, shd_growy_size(g) / 4, (uint32_t*) shd_growy_data(g)); + shd_destroy_growy(g); + spvb_declare_function(emitter->file_builder, fn_builder.base); + } + + free(fn_builder.per_bb); + shd_destroy_scheduler(fn_builder.scheduler); + shd_destroy_cfg(fn_builder.cfg); + shd_destroy_dict(fn_builder.emitted); +} + +SpvId spv_emit_decl(Emitter* emitter, const Node* decl) { + SpvId* existing = shd_dict_find_value(const Node*, SpvId, emitter->global_node_ids, decl); + if (existing) + return *existing; + + switch (is_declaration(decl)) { + case GlobalVariable_TAG: { + const GlobalVariable* gvar = &decl->payload.global_variable; + SpvId given_id = spvb_fresh_id(emitter->file_builder); + spv_register_emitted(emitter, NULL, decl, given_id); + spvb_name(emitter->file_builder, given_id, gvar->name); + SpvId init = 0; + if (gvar->init) + init = spv_emit_value(emitter, NULL, gvar->init); + SpvStorageClass storage_class = spv_emit_addr_space(emitter, gvar->address_space); + spvb_global_variable(emitter->file_builder, given_id, spv_emit_type(emitter, decl->type), storage_class, false, init); + + Builtin b = BuiltinsCount; + for (size_t i = 0; i < gvar->annotations.count; i++) { + const Node* a = gvar->annotations.nodes[i]; + assert(is_annotation(a)); + String name = get_annotation_name(a); + if (strcmp(name, "Builtin") == 0) { + String builtin_name = shd_get_annotation_string_payload(a); + assert(builtin_name); + assert(b == BuiltinsCount && "Only one @Builtin annotation permitted."); + b = shd_get_builtin_by_name(builtin_name); + assert(b != BuiltinsCount); + SpvBuiltIn d = shd_get_builtin_spv_id(b); + uint32_t decoration_payload[] = { d }; + spvb_decorate(emitter->file_builder, given_id, SpvDecorationBuiltIn, 1, decoration_payload); + } else if (strcmp(name, "Location") == 0) { + size_t loc = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_annotation_value(a)), false); + assert(loc >= 0); + spvb_decorate(emitter->file_builder, given_id, SpvDecorationLocation, 1, (uint32_t[]) { loc }); + } else if (strcmp(name, "DescriptorSet") == 0) { + size_t loc = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_annotation_value(a)), false); + assert(loc >= 0); + spvb_decorate(emitter->file_builder, given_id, SpvDecorationDescriptorSet, 1, (uint32_t[]) { loc }); + } else if (strcmp(name, "DescriptorBinding") == 0) { + size_t loc = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_annotation_value(a)), false); + assert(loc >= 0); + spvb_decorate(emitter->file_builder, given_id, SpvDecorationBinding, 1, (uint32_t[]) { loc }); + } + } + + switch (storage_class) { + case SpvStorageClassPushConstant: { + break; + } + case SpvStorageClassStorageBuffer: + case SpvStorageClassUniform: + case SpvStorageClassUniformConstant: { + const Node* descriptor_set = shd_lookup_annotation(decl, "DescriptorSet"); + const Node* descriptor_binding = shd_lookup_annotation(decl, "DescriptorBinding"); + assert(descriptor_set && descriptor_binding && "DescriptorSet and/or DescriptorBinding annotations are missing"); + break; + } + default: break; + } + + return given_id; + } case Function_TAG: { + SpvId given_id = spvb_fresh_id(emitter->file_builder); + spv_register_emitted(emitter, NULL, decl, given_id); + spvb_name(emitter->file_builder, given_id, decl->payload.fun.name); + emit_function(emitter, decl); + return given_id; + } case Constant_TAG: { + // We don't emit constants at all ! + // With RefDecl, we directly grab the underlying value and emit that there and then. + // Emitting constants as their own IDs would be nicer, but it's painful to do because decls need their ID to be reserved in advance, + // but we also desire to cache reused values instead of emitting them multiple times. This means we can't really "force" an ID for a given value. + // The ideal fix would be if SPIR-V offered a way to "alias" an ID under a new one. This would allow applying new debug information to the decl ID, separate from the other instances of that value. + return 0; + } case NominalType_TAG: { + SpvId given_id = spvb_fresh_id(emitter->file_builder); + spv_register_emitted(emitter, NULL, decl, given_id); + spvb_name(emitter->file_builder, given_id, decl->payload.nom_type.name); + spv_emit_nominal_type_body(emitter, decl->payload.nom_type.body, given_id); + return given_id; + } + case NotADeclaration: shd_error(""); + } + shd_error("unreachable"); +} + +static SpvExecutionModel emit_exec_model(ExecutionModel model) { + switch (model) { + case EmCompute: return SpvExecutionModelGLCompute; + case EmVertex: return SpvExecutionModelVertex; + case EmFragment: return SpvExecutionModelFragment; + case EmNone: shd_error("No execution model but we were asked to emit it anyways"); + } +} + +static void emit_entry_points(Emitter* emitter, Nodes declarations) { + // First, collect all the global variables, they're needed for the interface section of OpEntryPoint + // it can be a superset of the ones actually used, so the easiest option is to just grab _all_ global variables and shove them in there + // my gut feeling says it's unlikely any drivers actually care, but validation needs to be happy so here we go... + LARRAY(SpvId, interface_arr, declarations.count); + size_t interface_size = 0; + for (size_t i = 0; i < declarations.count; i++) { + const Node* node = declarations.nodes[i]; + if (node->tag != GlobalVariable_TAG) continue; + // Prior to SPIRV 1.4, _only_ input and output variables should be found here. + if (emitter->configuration->target_spirv_version.major == 1 && + emitter->configuration->target_spirv_version.minor < 4) { + switch (node->payload.global_variable.address_space) { + case AsOutput: + case AsInput: break; + default: continue; + } + } + interface_arr[interface_size++] = spv_find_emitted(emitter, NULL, node); + } + + for (size_t i = 0; i < declarations.count; i++) { + const Node* decl = declarations.nodes[i]; + if (decl->tag != Function_TAG) continue; + SpvId fn_id = spv_find_emitted(emitter, NULL, decl); + + const Node* entry_point = shd_lookup_annotation(decl, "EntryPoint"); + if (entry_point) { + ExecutionModel execution_model = shd_execution_model_from_string(shd_get_string_literal(emitter->arena, shd_get_annotation_value(entry_point))); + assert(execution_model != EmNone); + + spvb_entry_point(emitter->file_builder, emit_exec_model(execution_model), fn_id, decl->payload.fun.name, interface_size, interface_arr); + emitter->num_entry_pts++; + + const Node* workgroup_size = shd_lookup_annotation(decl, "WorkgroupSize"); + if (execution_model == EmCompute) + assert(workgroup_size); + if (workgroup_size) { + Nodes values = shd_get_annotation_values(workgroup_size); + assert(values.count == 3); + uint32_t wg_x_dim = (uint32_t) shd_get_int_literal_value(*shd_resolve_to_int_literal(values.nodes[0]), false); + uint32_t wg_y_dim = (uint32_t) shd_get_int_literal_value(*shd_resolve_to_int_literal(values.nodes[1]), false); + uint32_t wg_z_dim = (uint32_t) shd_get_int_literal_value(*shd_resolve_to_int_literal(values.nodes[2]), false); + + spvb_execution_mode(emitter->file_builder, fn_id, SpvExecutionModeLocalSize, 3, (uint32_t[3]) { wg_x_dim, wg_y_dim, wg_z_dim }); + } + + if (execution_model == EmFragment) { + spvb_execution_mode(emitter->file_builder, fn_id, SpvExecutionModeOriginUpperLeft, 0, NULL); + } + } + } +} + +static void emit_decls(Emitter* emitter, Nodes declarations) { + for (size_t i = 0; i < declarations.count; i++) { + const Node* decl = declarations.nodes[i]; + spv_emit_decl(emitter, decl); + } +} + +SpvId spv_get_extended_instruction_set(Emitter* emitter, const char* name) { + SpvId* found = shd_dict_find_value(const char*, SpvId, emitter->extended_instruction_sets, name); + if (found) + return *found; + + SpvId new = spvb_extended_import(emitter->file_builder, name); + shd_dict_insert(const char*, SpvId, emitter->extended_instruction_sets, name, new); + return new; +} + +RewritePass shd_spvbe_pass_map_entrypoint_args; +RewritePass shd_spvbe_pass_lift_globals_ssbo; + +static Module* run_backend_specific_passes(const CompilerConfig* config, Module* initial_mod) { + IrArena* initial_arena = initial_mod->arena; + Module** pmod = &initial_mod; + + RUN_PASS(shd_pass_lower_entrypoint_args) + RUN_PASS(shd_spvbe_pass_map_entrypoint_args) + RUN_PASS(shd_spvbe_pass_lift_globals_ssbo) + RUN_PASS(shd_pass_eliminate_constants) + RUN_PASS(shd_import) + + return *pmod; +} + +void shd_emit_spirv(const CompilerConfig* config, Module* mod, size_t* output_size, char** output, Module** new_mod) { + IrArena* initial_arena = shd_module_get_arena(mod); + mod = run_backend_specific_passes(config, mod); + IrArena* arena = shd_module_get_arena(mod); + + FileBuilder file_builder = spvb_begin(); + spvb_set_version(file_builder, config->target_spirv_version.major, config->target_spirv_version.minor); + spvb_set_addressing_model(file_builder, SpvAddressingModelLogical); + + Emitter emitter = { + .module = mod, + .arena = arena, + .configuration = config, + .file_builder = file_builder, + .global_node_ids = shd_new_dict(Node*, SpvId, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .bb_builders = shd_new_dict(Node*, BBBuilder, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .num_entry_pts = 0, + }; + + emitter.extended_instruction_sets = shd_new_dict(const char*, SpvId, (HashFn) shd_hash_string, (CmpFn) shd_compare_string); + + emitter.void_t = spvb_void_type(emitter.file_builder); + + spvb_extension(file_builder, "SPV_KHR_non_semantic_info"); + + Nodes decls = shd_module_get_declarations(mod); + emit_decls(&emitter, decls); + emit_entry_points(&emitter, decls); + + if (emitter.num_entry_pts == 0) + spvb_capability(file_builder, SpvCapabilityLinkage); + + spvb_capability(file_builder, SpvCapabilityShader); + + *output_size = spvb_finish(file_builder, output); + + // cleanup the emitter + shd_destroy_dict(emitter.global_node_ids); + shd_destroy_dict(emitter.bb_builders); + shd_destroy_dict(emitter.extended_instruction_sets); + + if (new_mod) + *new_mod = mod; + else if (initial_arena != arena) + shd_destroy_ir_arena(arena); +} diff --git a/src/backend/spirv/emit_spv.h b/src/backend/spirv/emit_spv.h new file mode 100644 index 000000000..2c70f0498 --- /dev/null +++ b/src/backend/spirv/emit_spv.h @@ -0,0 +1,63 @@ +#ifndef SHADY_EMIT_SPIRV_H +#define SHADY_EMIT_SPIRV_H + +#include "shady/ir.h" +#include "shady/be/spirv.h" + +#include "spirv_builder.h" + +typedef struct CFG_ CFG; +typedef struct Scheduler_ Scheduler; + +typedef SpvbFileBuilder* FileBuilder; +typedef SpvbBasicBlockBuilder* BBBuilder; + +typedef struct { + SpvbFnBuilder* base; + CFG* cfg; + Scheduler* scheduler; + struct Dict* emitted; + struct { + SpvId continue_id; + BBBuilder continue_builder; + }* per_bb; +} FnBuilder; + +typedef struct Emitter_ { + Module* module; + IrArena* arena; + const CompilerConfig* configuration; + FileBuilder file_builder; + SpvId void_t; + struct Dict* global_node_ids; + + struct Dict* bb_builders; + + size_t num_entry_pts; + + struct Dict* extended_instruction_sets; +} Emitter; + +typedef SpvbPhi** Phis; + +SpvId spv_emit_decl(Emitter*, const Node*); +SpvId spv_emit_type(Emitter*, const Type*); +SpvId spv_emit_value(Emitter*, FnBuilder*, const Node*); +SpvId spv_emit_mem(Emitter*, FnBuilder*, const Node*); +void spv_emit_terminator(Emitter*, FnBuilder*, BBBuilder, const Node* abs, const Node* terminator); + +void spv_register_emitted(Emitter*, FnBuilder*, const Node*, SpvId id); +SpvId* spv_search_emitted(Emitter* emitter, FnBuilder*, const Node* node); +SpvId spv_find_emitted(Emitter* emitter, FnBuilder*, const Node* node); + +BBBuilder spv_find_basic_block_builder(Emitter* emitter, const Node* bb); + +SpvId spv_get_extended_instruction_set(Emitter*, const char*); + +SpvStorageClass spv_emit_addr_space(Emitter*, AddressSpace address_space); +// SPIR-V doesn't have multiple return types, this bridges the gap... +SpvId spv_types_to_codom(Emitter* emitter, Nodes return_types); +const Type* spv_normalize_type(Emitter* emitter, const Type* type); +void spv_emit_nominal_type_body(Emitter* emitter, const Type* type, SpvId id); + +#endif diff --git a/src/backend/spirv/emit_spv_control_flow.c b/src/backend/spirv/emit_spv_control_flow.c new file mode 100644 index 000000000..9b943ccbd --- /dev/null +++ b/src/backend/spirv/emit_spv_control_flow.c @@ -0,0 +1,270 @@ +#include "emit_spv.h" + +#include "../shady/analysis/cfg.h" + +#include "list.h" +#include "dict.h" +#include "log.h" +#include "portability.h" + +#include + +BBBuilder spv_find_basic_block_builder(Emitter* emitter, const Node* bb) { + BBBuilder* found = shd_dict_find_value(const Node*, BBBuilder, emitter->bb_builders, bb); + assert(found); + return *found; +} + +static void add_phis(Emitter* emitter, FnBuilder* fn_builder, SpvId src, BBBuilder dst_builder, Nodes args) { + struct List* phis = spvb_get_phis(dst_builder); + assert(shd_list_count(phis) == args.count); + for (size_t i = 0; i < args.count; i++) { + SpvbPhi* phi = shd_read_list(SpvbPhi*, phis)[i]; + spvb_add_phi_source(phi, src, spv_emit_value(emitter, fn_builder, args.nodes[i])); + } +} + +static void add_branch_phis(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, const Node* dst, Nodes args) { + // because it's forbidden to jump back into the entry block of a function + // (which is actually a Function in this IR, not a BasicBlock) + // we assert that the destination must be an actual BasicBlock + assert(is_basic_block(dst)); + add_phis(emitter, fn_builder, spvb_get_block_builder_id(bb_builder), spv_find_basic_block_builder(emitter, dst), args); +} + +static void add_branch_phis_from_jump(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, Jump jump) { + return add_branch_phis(emitter, fn_builder, bb_builder, jump.target, jump.args); +} + +static void emit_if(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, If if_instr) { + spv_emit_mem(emitter, fn_builder, if_instr.mem); + SpvId join_bb_id = spv_find_emitted(emitter, fn_builder, if_instr.tail); + + SpvId true_id = spv_find_emitted(emitter, fn_builder, if_instr.if_true); + SpvId false_id = if_instr.if_false ? spv_find_emitted(emitter, fn_builder, if_instr.if_false) : join_bb_id; + + spvb_selection_merge(bb_builder, join_bb_id, 0); + SpvId condition = spv_emit_value(emitter, fn_builder, if_instr.condition); + spvb_branch_conditional(bb_builder, condition, true_id, false_id); +} + +static void emit_match(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, Match match) { + spv_emit_mem(emitter, fn_builder, match.mem); + SpvId join_bb_id = spv_find_emitted(emitter, fn_builder, match.tail); + + assert(shd_get_unqualified_type(match.inspect->type)->tag == Int_TAG); + SpvId inspectee = spv_emit_value(emitter, fn_builder, match.inspect); + + SpvId default_id = spv_find_emitted(emitter, fn_builder, match.default_case); + + const Type* inspectee_t = match.inspect->type; + shd_deconstruct_qualified_type(&inspectee_t); + assert(inspectee_t->tag == Int_TAG); + size_t literal_width = inspectee_t->payload.int_type.width == IntTy64 ? 2 : 1; + size_t literal_case_entry_size = literal_width + 1; + LARRAY(uint32_t, literals_and_cases, match.cases.count * literal_case_entry_size); + for (size_t i = 0; i < match.cases.count; i++) { + uint64_t value = (uint64_t) shd_get_int_literal_value(*shd_resolve_to_int_literal(match.literals.nodes[i]), false); + if (inspectee_t->payload.int_type.width == IntTy64) { + literals_and_cases[i * literal_case_entry_size + 0] = (SpvId) (uint32_t) (value & 0xFFFFFFFF); + literals_and_cases[i * literal_case_entry_size + 1] = (SpvId) (uint32_t) (value >> 32); + } else { + literals_and_cases[i * literal_case_entry_size + 0] = (SpvId) (uint32_t) value; + } + literals_and_cases[i * literal_case_entry_size + literal_width] = spv_find_emitted(emitter, fn_builder, match.cases.nodes[i]); + } + + spvb_selection_merge(bb_builder, join_bb_id, 0); + spvb_switch(bb_builder, inspectee, default_id, match.cases.count * literal_case_entry_size, literals_and_cases); +} + +static void emit_loop(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, const Node* abs, Loop loop_instr) { + spv_emit_mem(emitter, fn_builder, loop_instr.mem); + SpvId body_id = spv_find_emitted(emitter, fn_builder, loop_instr.body); + + SpvId continue_id = spvb_fresh_id(emitter->file_builder); + BBBuilder continue_builder = spvb_begin_bb(fn_builder->base, continue_id); + spvb_name(emitter->file_builder, continue_id, "loop_continue"); + + SpvId header_id = spvb_fresh_id(emitter->file_builder); + BBBuilder header_builder = spvb_begin_bb(fn_builder->base, header_id); + spvb_name(emitter->file_builder, header_id, "loop_header"); + + Nodes body_params = get_abstraction_params(loop_instr.body); + LARRAY(SpvbPhi*, loop_continue_phis, body_params.count); + for (size_t i = 0; i < body_params.count; i++) { + SpvId loop_param_type = spv_emit_type(emitter, shd_get_unqualified_type(body_params.nodes[i]->type)); + + SpvId continue_phi_id = spvb_fresh_id(emitter->file_builder); + SpvbPhi* continue_phi = spvb_add_phi(continue_builder, loop_param_type, continue_phi_id); + loop_continue_phis[i] = continue_phi; + + // To get the actual loop parameter, we make a second phi for the nodes that go into the header + // We already know the two edges into the header so we immediately add the Phi sources for it. + SpvId header_phi_id = spvb_fresh_id(emitter->file_builder); + SpvbPhi* header_phi = spvb_add_phi(header_builder, loop_param_type, header_phi_id); + SpvId param_initial_value = spv_emit_value(emitter, fn_builder, loop_instr.initial_args.nodes[i]); + spvb_add_phi_source(header_phi, spvb_get_block_builder_id(bb_builder), param_initial_value); + spvb_add_phi_source(header_phi, spvb_get_block_builder_id(continue_builder), continue_phi_id); + + BBBuilder body_builder = spv_find_basic_block_builder(emitter, loop_instr.body); + spvb_add_phi_source(shd_read_list(SpvbPhi*, spvb_get_phis(body_builder))[i], spvb_get_block_builder_id(header_builder), header_phi_id); + } + + fn_builder->per_bb[shd_cfg_lookup(fn_builder->cfg, loop_instr.body)->rpo_index].continue_id = continue_id; + fn_builder->per_bb[shd_cfg_lookup(fn_builder->cfg, loop_instr.body)->rpo_index].continue_builder = continue_builder; + + SpvId tail_id = spv_find_emitted(emitter, fn_builder, loop_instr.tail); + + // the header block receives the loop merge annotation + spvb_loop_merge(header_builder, tail_id, continue_id, 0, 0, NULL); + spvb_branch(header_builder, body_id); + + spvb_add_bb(fn_builder->base, header_builder); + + // the continue block just jumps back into the header + spvb_branch(continue_builder, header_id); + + spvb_branch(bb_builder, header_id); +} + +typedef enum { + SelectionConstruct, + LoopConstruct, +} Construct; + +static CFNode* find_surrounding_structured_construct_node(Emitter* emitter, FnBuilder* fn_builder, const Node* abs, Construct construct) { + const Node* oabs = abs; + for (CFNode* n = shd_cfg_lookup(fn_builder->cfg, abs); n; oabs = n->node, n = n->idom) { + const Node* terminator = get_abstraction_body(n->node); + assert(terminator); + if (is_structured_construct(terminator) && get_structured_construct_tail(terminator) == oabs) { + continue; + } + if (construct == LoopConstruct && terminator->tag == Loop_TAG) + return n; + if (construct == SelectionConstruct && terminator->tag == If_TAG) + return n; + if (construct == SelectionConstruct && terminator->tag == Match_TAG) + return n; + + } + return NULL; +} + +static const Node* find_construct(Emitter* emitter, FnBuilder* fn_builder, const Node* abs, Construct construct) { + CFNode* found = find_surrounding_structured_construct_node(emitter, fn_builder, abs, construct); + return found ? get_abstraction_body(found->node) : NULL; +} + +void spv_emit_terminator(Emitter* emitter, FnBuilder* fn_builder, BBBuilder basic_block_builder, const Node* abs, const Node* terminator) { + switch (is_terminator(terminator)) { + case Return_TAG: { + Return payload = terminator->payload.fn_ret; + spv_emit_mem(emitter, fn_builder, payload.mem); + const Nodes* ret_values = &terminator->payload.fn_ret.args; + switch (ret_values->count) { + case 0: spvb_return_void(basic_block_builder); return; + case 1: spvb_return_value(basic_block_builder, spv_emit_value(emitter, fn_builder, ret_values->nodes[0])); return; + default: { + LARRAY(SpvId, arr, ret_values->count); + for (size_t i = 0; i < ret_values->count; i++) + arr[i] = spv_emit_value(emitter, fn_builder, ret_values->nodes[i]); + SpvId return_that = spvb_composite(basic_block_builder, spvb_fn_ret_type_id(fn_builder->base), ret_values->count, arr); + spvb_return_value(basic_block_builder, return_that); + return; + } + } + } + case Unreachable_TAG: { + Unreachable payload = terminator->payload.unreachable; + spv_emit_mem(emitter, fn_builder, payload.mem); + spvb_unreachable(basic_block_builder); + return; + } + case Jump_TAG: { + Jump payload = terminator->payload.jump; + spv_emit_mem(emitter, fn_builder, payload.mem); + add_branch_phis_from_jump(emitter, fn_builder, basic_block_builder, terminator->payload.jump); + spvb_branch(basic_block_builder, spv_find_emitted(emitter, fn_builder, terminator->payload.jump.target)); + return; + } + case Branch_TAG: { + Branch payload = terminator->payload.branch; + spv_emit_mem(emitter, fn_builder, payload.mem); + SpvId condition = spv_emit_value(emitter, fn_builder, terminator->payload.branch.condition); + add_branch_phis_from_jump(emitter, fn_builder, basic_block_builder, terminator->payload.branch.true_jump->payload.jump); + add_branch_phis_from_jump(emitter, fn_builder, basic_block_builder, terminator->payload.branch.false_jump->payload.jump); + spvb_branch_conditional(basic_block_builder, condition, spv_find_emitted(emitter, fn_builder, terminator->payload.branch.true_jump->payload.jump.target), spv_find_emitted(emitter, fn_builder, terminator->payload.branch.false_jump->payload.jump.target)); + return; + } + case Switch_TAG: { + Switch payload = terminator->payload.br_switch; + spv_emit_mem(emitter, fn_builder, payload.mem); + SpvId inspectee = spv_emit_value(emitter, fn_builder, terminator->payload.br_switch.switch_value); + LARRAY(SpvId, targets, terminator->payload.br_switch.case_jumps.count * 2); + for (size_t i = 0; i < terminator->payload.br_switch.case_jumps.count; i++) { + add_branch_phis_from_jump(emitter, fn_builder, basic_block_builder, terminator->payload.br_switch.case_jumps.nodes[i]->payload.jump); + } + add_branch_phis_from_jump(emitter, fn_builder, basic_block_builder, terminator->payload.br_switch.default_jump->payload.jump); + SpvId default_tgt = spv_find_emitted(emitter, fn_builder, terminator->payload.br_switch.default_jump->payload.jump.target); + + spvb_switch(basic_block_builder, inspectee, default_tgt, terminator->payload.br_switch.case_jumps.count, targets); + return; + } + case If_TAG: return emit_if(emitter, fn_builder, basic_block_builder, terminator->payload.if_instr); + case Match_TAG: return emit_match(emitter, fn_builder, basic_block_builder, terminator->payload.match_instr); + case Loop_TAG: return emit_loop(emitter, fn_builder, basic_block_builder, abs, terminator->payload.loop_instr); + case MergeSelection_TAG: { + MergeSelection payload = terminator->payload.merge_selection; + spv_emit_mem(emitter, fn_builder, payload.mem); + const Node* construct = find_construct(emitter, fn_builder, abs, SelectionConstruct); + assert(construct); + const Node* tail = get_structured_construct_tail(construct); + Nodes args = terminator->payload.merge_selection.args; + add_branch_phis(emitter, fn_builder, basic_block_builder, tail, args); + assert(tail != abs); + spvb_branch(basic_block_builder, spv_find_emitted(emitter, fn_builder, tail)); + return; + } + case MergeContinue_TAG: { + MergeContinue payload = terminator->payload.merge_continue; + spv_emit_mem(emitter, fn_builder, payload.mem); + const Node* construct = find_construct(emitter, fn_builder, abs, LoopConstruct); + assert(construct); + Loop loop_payload = construct->payload.loop_instr; + CFNode* loop_body = shd_cfg_lookup(fn_builder->cfg, loop_payload.body); + assert(loop_body); + Nodes args = terminator->payload.merge_continue.args; + add_phis(emitter, fn_builder, spvb_get_block_builder_id(basic_block_builder), fn_builder->per_bb[loop_body->rpo_index].continue_builder, args); + spvb_branch(basic_block_builder, fn_builder->per_bb[loop_body->rpo_index].continue_id); + return; + } + case MergeBreak_TAG: { + MergeBreak payload = terminator->payload.merge_break; + spv_emit_mem(emitter, fn_builder, payload.mem); + const Node* construct = find_construct(emitter, fn_builder, abs, LoopConstruct); + assert(construct); + Loop loop_payload = construct->payload.loop_instr; + Nodes args = terminator->payload.merge_break.args; + add_branch_phis(emitter, fn_builder, basic_block_builder, loop_payload.tail, args); + spvb_branch(basic_block_builder, spv_find_emitted(emitter, fn_builder, loop_payload.tail)); + return; + } + case Terminator_Control_TAG: + case TailCall_TAG: { + TailCall payload = terminator->payload.tail_call; + spv_emit_mem(emitter, fn_builder, payload.mem); + LARRAY(SpvId, args, payload.args.count + 1); + args[0] = spv_emit_value(emitter, fn_builder, payload.callee); + for (size_t i = 0; i < payload.args.count; i++) + args[i + 1] = spv_emit_value(emitter, fn_builder, payload.args.nodes[i]); + spvb_capability(emitter->file_builder, SpvCapabilityIndirectTailCallsSHADY); + spvb_terminator(basic_block_builder, SpvOpIndirectTailCallSHADY, payload.args.count + 1, args); + return; + } + case Join_TAG: shd_error("Lower me"); + case NotATerminator: shd_error("TODO: emit terminator %s", shd_get_node_tag_string(terminator->tag)); + } + SHADY_UNREACHABLE; +} diff --git a/src/shady/emit/spirv/emit_spv_type.c b/src/backend/spirv/emit_spv_type.c similarity index 59% rename from src/shady/emit/spirv/emit_spv_type.c rename to src/backend/spirv/emit_spv_type.c index c782cb2d7..f0a53ff95 100644 --- a/src/shady/emit/spirv/emit_spv_type.c +++ b/src/backend/spirv/emit_spv_type.c @@ -1,27 +1,23 @@ #include "emit_spv.h" -#include "portability.h" -#include "log.h" +#include "shady/ir/memory_layout.h" -#include "../../rewrite.h" -#include "../../transform/memory_layout.h" +#include "shady/rewrite.h" +#include "portability.h" +#include "log.h" #include "dict.h" -#include "assert.h" +#include #pragma GCC diagnostic error "-Wswitch" -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); - -SpvStorageClass emit_addr_space(Emitter* emitter, AddressSpace address_space) { +SpvStorageClass spv_emit_addr_space(Emitter* emitter, AddressSpace address_space) { switch(address_space) { - case AsGlobalLogical: return SpvStorageClassStorageBuffer; - case AsSharedLogical: return SpvStorageClassWorkgroup; - case AsPrivateLogical: return SpvStorageClassPrivate; - case AsFunctionLogical: return SpvStorageClassFunction; - case AsGlobalPhysical: + case AsShared: return SpvStorageClassWorkgroup; + case AsPrivate: return SpvStorageClassPrivate; + case AsFunction: return SpvStorageClassFunction; + case AsGlobal: spvb_set_addressing_model(emitter->file_builder, SpvAddressingModelPhysicalStorageBuffer64); spvb_extension(emitter->file_builder, "SPV_KHR_physical_storage_buffer"); spvb_capability(emitter->file_builder, SpvCapabilityPhysicalStorageBufferAddresses); @@ -36,80 +32,77 @@ SpvStorageClass emit_addr_space(Emitter* emitter, AddressSpace address_space) { case AsUniformConstant: return SpvStorageClassUniformConstant; default: { - error_print("Cannot emit address space %s.\n", get_address_space_name(address_space)); - error_die(); + shd_error_print("Cannot emit address space %s.\n", shd_get_address_space_name(address_space)); + shd_error_die(); SHADY_UNREACHABLE; } } } static const Node* rewrite_normalize(Rewriter* rewriter, const Node* node) { - const Node* found = search_processed(rewriter, node); - if (found) return found; - if (!is_type(node)) { - register_processed(rewriter, node, node); + shd_register_processed(rewriter, node, node); return node; } switch (node->tag) { - case QualifiedType_TAG: return qualified_type(rewriter->dst_arena, (QualifiedType) { .type = rewrite_node(rewriter, node->payload.qualified_type.type), .is_uniform = false }); - default: return recreate_node_identity(rewriter, node); + case QualifiedType_TAG: return qualified_type(rewriter->dst_arena, (QualifiedType) { .type = shd_rewrite_node(rewriter, node->payload.qualified_type.type), .is_uniform = false }); + default: return shd_recreate_node(rewriter, node); } } -const Type* normalize_type(Emitter* emitter, const Type* type) { - Rewriter rewriter = create_rewriter(emitter->module, emitter->module, rewrite_normalize); - const Node* rewritten = rewrite_node(&rewriter, type); - destroy_rewriter(&rewriter); +const Type* spv_normalize_type(Emitter* emitter, const Type* type) { + Rewriter rewriter = shd_create_node_rewriter(emitter->module, emitter->module, rewrite_normalize); + const Node* rewritten = shd_rewrite_node(&rewriter, type); + shd_destroy_rewriter(&rewriter); return rewritten; } -SpvId nodes_to_codom(Emitter* emitter, Nodes return_types) { +SpvId spv_types_to_codom(Emitter* emitter, Nodes return_types) { switch (return_types.count) { case 0: return emitter->void_t; - case 1: return emit_type(emitter, return_types.nodes[0]); + case 1: return spv_emit_type(emitter, return_types.nodes[0]); default: { - const Type* codom_ret_type = record_type(emitter->arena, (RecordType) { .members = return_types, .special = MultipleReturn }); - return emit_type(emitter, codom_ret_type); + const Type* codom_ret_type = record_type(emitter->arena, (RecordType) { .members = return_types, .special = 0 }); + return spv_emit_type(emitter, codom_ret_type); } } } -void emit_nominal_type_body(Emitter* emitter, const Type* type, SpvId id) { +void spv_emit_nominal_type_body(Emitter* emitter, const Type* type, SpvId id) { switch (type->tag) { case RecordType_TAG: { Nodes member_types = type->payload.record_type.members; LARRAY(SpvId, members, member_types.count); for (size_t i = 0; i < member_types.count; i++) - members[i] = emit_type(emitter, member_types.nodes[i]); + members[i] = spv_emit_type(emitter, member_types.nodes[i]); spvb_struct_type(emitter->file_builder, id, member_types.count, members); if (type->payload.record_type.special == DecorateBlock) { spvb_decorate(emitter->file_builder, id, SpvDecorationBlock, 0, NULL); } LARRAY(FieldLayout, fields, member_types.count); - get_record_layout(emitter->arena, type, fields); + shd_get_record_layout(emitter->arena, type, fields); for (size_t i = 0; i < member_types.count; i++) { spvb_decorate_member(emitter->file_builder, id, i, SpvDecorationOffset, 1, (uint32_t[]) { fields[i].offset_in_bytes }); } break; } - default: error("not a suitable nominal type body (tag=%s)", node_tags[type->tag]); + default: shd_error("not a suitable nominal type body (tag=%s)", shd_get_node_tag_string(type->tag)); } } -SpvId emit_type(Emitter* emitter, const Type* type) { +SpvId spv_emit_type(Emitter* emitter, const Type* type) { // Some types in shady lower to the same spir-v type, but spir-v is unhappy with having duplicates of the same types // we could hash the spirv types we generate to find duplicates, but it is easier to normalise our shady types and reuse their infra - type = normalize_type(emitter, type); + type = spv_normalize_type(emitter, type); - SpvId* existing = find_value_dict(struct Node*, SpvId, emitter->node_ids, type); + SpvId* existing = spv_search_emitted(emitter, NULL, type); if (existing) return *existing; SpvId new; switch (is_type(type)) { - case NotAType: error("Not a type"); + case NotAType: shd_error("Not a type"); case Int_TAG: { int width; switch (type->payload.int_type.width) { @@ -144,9 +137,19 @@ SpvId emit_type(Emitter* emitter, const Type* type) { new = spvb_float_type(emitter->file_builder, width); break; } case PtrType_TAG: { - SpvId pointee = emit_type(emitter, type->payload.ptr_type.pointed_type); - SpvStorageClass sc = emit_addr_space(emitter, type->payload.ptr_type.address_space); + SpvStorageClass sc = spv_emit_addr_space(emitter, type->payload.ptr_type.address_space); + const Type* pointed_type = type->payload.ptr_type.pointed_type; + if (shd_get_maybe_nominal_type_decl(pointed_type) && sc == SpvStorageClassPhysicalStorageBuffer) { + new = spvb_forward_ptr_type(emitter->file_builder, sc); + spv_register_emitted(emitter, NULL, type, new); + SpvId pointee = spv_emit_type(emitter, pointed_type); + spvb_ptr_type_define(emitter->file_builder, new, sc, pointee); + return new; + } + + SpvId pointee = spv_emit_type(emitter, pointed_type); new = spvb_ptr_type(emitter->file_builder, sc, pointee); + //if (is_physical_as(type->payload.ptr_type.address_space) && type->payload.ptr_type.pointed_type->tag == ArrType_TAG) { // TypeMemLayout elem_mem_layout = get_mem_layout(emitter->arena, type->payload.ptr_type.pointed_type); // spvb_decorate(emitter->file_builder, new, SpvDecorationArrayStride, 1, (uint32_t[]) {elem_mem_layout.size_in_bytes}); @@ -155,35 +158,35 @@ SpvId emit_type(Emitter* emitter, const Type* type) { } case NoRet_TAG: case LamType_TAG: - case BBType_TAG: error("we can't emit arrow types that aren't those of first-class functions") + case BBType_TAG: shd_error("we can't emit arrow types that aren't those of shd_first-class functions") case FnType_TAG: { const FnType* fnt = &type->payload.fn_type; LARRAY(SpvId, params, fnt->param_types.count); for (size_t i = 0; i < fnt->param_types.count; i++) - params[i] = emit_type(emitter, fnt->param_types.nodes[i]); + params[i] = spv_emit_type(emitter, fnt->param_types.nodes[i]); - new = spvb_fn_type(emitter->file_builder, fnt->param_types.count, params, nodes_to_codom(emitter, fnt->return_types)); + new = spvb_fn_type(emitter->file_builder, fnt->param_types.count, params, spv_types_to_codom(emitter, fnt->return_types)); break; } case QualifiedType_TAG: { // SPIR-V does not care about our type qualifiers. - new = emit_type(emitter, type->payload.qualified_type.type); + new = spv_emit_type(emitter, type->payload.qualified_type.type); break; } case ArrType_TAG: { - SpvId element_type = emit_type(emitter, type->payload.arr_type.element_type); + SpvId element_type = spv_emit_type(emitter, type->payload.arr_type.element_type); if (type->payload.arr_type.size) { - new = spvb_array_type(emitter->file_builder, element_type, emit_value(emitter, NULL, type->payload.arr_type.size)); + new = spvb_array_type(emitter->file_builder, element_type, spv_emit_value(emitter, NULL, type->payload.arr_type.size)); } else { new = spvb_runtime_array_type(emitter->file_builder, element_type); } - TypeMemLayout elem_mem_layout = get_mem_layout(emitter->arena, type->payload.arr_type.element_type); + TypeMemLayout elem_mem_layout = shd_get_mem_layout(emitter->arena, type->payload.arr_type.element_type); spvb_decorate(emitter->file_builder, new, SpvDecorationArrayStride, 1, (uint32_t[]) { elem_mem_layout.size_in_bytes }); break; } case PackType_TAG: { assert(type->payload.pack_type.width >= 2); - SpvId element_type = emit_type(emitter, type->payload.pack_type.element_type); + SpvId element_type = spv_emit_type(emitter, type->payload.pack_type.element_type); new = spvb_vector_type(emitter->file_builder, element_type, type->payload.pack_type.width); break; } @@ -193,24 +196,32 @@ SpvId emit_type(Emitter* emitter, const Type* type) { break; } new = spvb_fresh_id(emitter->file_builder); - emit_nominal_type_body(emitter, type, new); - break; + spv_register_emitted(emitter, NULL, type, new); + spv_emit_nominal_type_body(emitter, type, new); + return new; } case Type_TypeDeclRef_TAG: { - new = emit_decl(emitter, type->payload.type_decl_ref.decl); + new = spv_emit_decl(emitter, type->payload.type_decl_ref.decl); break; } - case Type_CombinedImageSamplerType_TAG: new = spvb_sampled_image_type(emitter->file_builder, emit_type(emitter, type->payload.combined_image_sampler_type.image_type)); break; + case Type_SampledImageType_TAG: new = spvb_sampled_image_type(emitter->file_builder, spv_emit_type(emitter, type->payload.sampled_image_type.image_type)); break; case Type_SamplerType_TAG: new = spvb_sampler_type(emitter->file_builder); break; case Type_ImageType_TAG: { ImageType p = type->payload.image_type; - new = spvb_image_type(emitter->file_builder, emit_type(emitter, p.sampled_type), p.dim, p.depth, p.onion, p.multisample, p.sampled, SpvImageFormatUnknown); + new = spvb_image_type(emitter->file_builder, spv_emit_type(emitter, p.sampled_type), p.dim, p.depth, p.arrayed, p.ms, p.sampled, p.imageformat); break; } case Type_MaskType_TAG: - case Type_JoinPointType_TAG: error("These must be lowered beforehand") + case Type_JoinPointType_TAG: shd_error("These must be lowered beforehand") + } + + if (shd_is_data_type(type)) { + if (type->tag == PtrType_TAG && type->payload.ptr_type.address_space == AsGlobal) { + //TypeMemLayout elem_mem_layout = get_mem_layout(emitter->arena, type->payload.ptr_type.pointed_type); + //spvb_decorate(emitter->file_builder, new, SpvDecorationArrayStride, 1, (uint32_t[]) {elem_mem_layout.size_in_bytes}); + } } - insert_dict_and_get_result(struct Node*, SpvId, emitter->node_ids, type, new); + spv_register_emitted(emitter, NULL, type, new); return new; } diff --git a/src/backend/spirv/emit_spv_value.c b/src/backend/spirv/emit_spv_value.c new file mode 100644 index 000000000..28dabb7c5 --- /dev/null +++ b/src/backend/spirv/emit_spv_value.c @@ -0,0 +1,626 @@ +#include "emit_spv.h" + +#include "shady/ir/memory_layout.h" +#include "shady/ir/builtin.h" + +#include "../shady/analysis/cfg.h" +#include "../shady/analysis/scheduler.h" + +#include "log.h" +#include "dict.h" +#include "portability.h" + +#include "spirv/unified1/NonSemanticDebugPrintf.h" +#include "spirv/unified1/GLSL.std.450.h" + +#include +#include + +typedef enum { + Custom, Plain, +} InstrClass; + +/// What is considered when searching for an instruction +typedef enum { + None, Monomorphic, FirstOp, FirstAndResult +} ISelMechanism; + +typedef enum { + Same, SameTuple, Bool, Void, TyOperand +} ResultClass; + +typedef enum { + Signed, Unsigned, FP, Logical, Ptr, OperandClassCount +} OperandClass; + +static OperandClass classify_operand_type(const Type* type) { + assert(is_type(type) && shd_is_data_type(type)); + + if (type->tag == PackType_TAG) + return classify_operand_type(type->payload.pack_type.element_type); + + switch (type->tag) { + case Int_TAG: return type->payload.int_type.is_signed ? Signed : Unsigned; + case Bool_TAG: return Logical; + case PtrType_TAG: return Ptr; + case Float_TAG: return FP; + default: shd_error("we don't know what to do with this") + } +} + +typedef struct { + InstrClass class; + ISelMechanism isel_mechanism; + ResultClass result_kind; + union { + SpvOp op; + // matches first operand + SpvOp fo[OperandClassCount]; + // matches first operand and return type [first operand][result type] + SpvOp foar[OperandClassCount][OperandClassCount]; + }; + const char* extended_set; +} IselTableEntry; + +#define ISEL_IDENTITY (SpvOpNop /* no-op, should be lowered to nothing beforehand */) +#define ISEL_LOWERME (SpvOpMax /* boolean conversions don't exist as a single instruction, a pass should lower them instead */) +#define ISEL_ILLEGAL (SpvOpMax /* doesn't make sense to support */) +#define ISEL_CUSTOM (SpvOpMax /* doesn't make sense to support */) + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" + +static const IselTableEntry isel_table[] = { + [add_op] = {Plain, FirstOp, Same, .fo = {SpvOpIAdd, SpvOpIAdd, SpvOpFAdd, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [sub_op] = {Plain, FirstOp, Same, .fo = {SpvOpISub, SpvOpISub, SpvOpFSub, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [mul_op] = {Plain, FirstOp, Same, .fo = {SpvOpIMul, SpvOpIMul, SpvOpFMul, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [div_op] = {Plain, FirstOp, Same, .fo = {SpvOpSDiv, SpvOpUDiv, SpvOpFDiv, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [mod_op] = {Plain, FirstOp, Same, .fo = {SpvOpSMod, SpvOpUMod, SpvOpFMod, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + + [add_carry_op] = {Plain, FirstOp, SameTuple, .fo = {SpvOpIAddCarry, SpvOpIAddCarry, ISEL_ILLEGAL }}, + [sub_borrow_op] = {Plain, FirstOp, SameTuple, .fo = {SpvOpISubBorrow, SpvOpISubBorrow, ISEL_ILLEGAL }}, + [mul_extended_op] = {Plain, FirstOp, SameTuple, .fo = {SpvOpSMulExtended, SpvOpUMulExtended, ISEL_ILLEGAL }}, + + [neg_op] = {Plain, FirstOp, Same, .fo = {SpvOpSNegate, SpvOpSNegate, SpvOpFNegate }}, + + [eq_op] = {Plain, FirstOp, Bool, .fo = {SpvOpIEqual, SpvOpIEqual, SpvOpFOrdEqual, SpvOpLogicalEqual, SpvOpPtrEqual }}, + [neq_op] = {Plain, FirstOp, Bool, .fo = {SpvOpINotEqual, SpvOpINotEqual, SpvOpFOrdNotEqual, SpvOpLogicalNotEqual, SpvOpPtrNotEqual }}, + [lt_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSLessThan, SpvOpULessThan, SpvOpFOrdLessThan, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [lte_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSLessThanEqual, SpvOpULessThanEqual, SpvOpFOrdLessThanEqual, ISEL_ILLEGAL, ISEL_ILLEGAL}}, + [gt_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSGreaterThan, SpvOpUGreaterThan, SpvOpFOrdGreaterThan, ISEL_ILLEGAL, ISEL_ILLEGAL}}, + [gte_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSGreaterThanEqual, SpvOpUGreaterThanEqual, SpvOpFOrdGreaterThanEqual, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + + [not_op] = {Plain, FirstOp, Same, .fo = {SpvOpNot, SpvOpNot, ISEL_ILLEGAL, SpvOpLogicalNot }}, + + [and_op] = {Plain, FirstOp, Same, .fo = {SpvOpBitwiseAnd, SpvOpBitwiseAnd, ISEL_ILLEGAL, SpvOpLogicalAnd }}, + [or_op] = {Plain, FirstOp, Same, .fo = {SpvOpBitwiseOr, SpvOpBitwiseOr, ISEL_ILLEGAL, SpvOpLogicalOr }}, + [xor_op] = {Plain, FirstOp, Same, .fo = {SpvOpBitwiseXor, SpvOpBitwiseXor, ISEL_ILLEGAL, SpvOpLogicalNotEqual }}, + + [lshift_op] = {Plain, FirstOp, Same, .fo = {SpvOpShiftLeftLogical, SpvOpShiftLeftLogical, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [rshift_arithm_op] = {Plain, FirstOp, Same, .fo = {SpvOpShiftRightArithmetic, SpvOpShiftRightArithmetic, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [rshift_logical_op] = {Plain, FirstOp, Same, .fo = {SpvOpShiftRightLogical, SpvOpShiftRightLogical, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + + [convert_op] = {Plain, FirstAndResult, TyOperand, .foar = { + { SpvOpSConvert, SpvOpUConvert, SpvOpConvertSToF, ISEL_LOWERME, ISEL_LOWERME }, + { SpvOpSConvert, SpvOpUConvert, SpvOpConvertUToF, ISEL_LOWERME, ISEL_LOWERME }, + { SpvOpConvertFToS, SpvOpConvertFToU, SpvOpFConvert, ISEL_ILLEGAL, ISEL_ILLEGAL }, + { ISEL_LOWERME, ISEL_LOWERME, ISEL_ILLEGAL, ISEL_IDENTITY, ISEL_ILLEGAL }, + { ISEL_LOWERME, ISEL_LOWERME, ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_IDENTITY } + }}, + + [reinterpret_op] = {Plain, FirstAndResult, TyOperand, .foar = { + { ISEL_ILLEGAL, SpvOpBitcast, SpvOpBitcast, ISEL_ILLEGAL, SpvOpConvertUToPtr }, + { SpvOpBitcast, ISEL_ILLEGAL, SpvOpBitcast, ISEL_ILLEGAL, SpvOpConvertUToPtr }, + { SpvOpBitcast, SpvOpBitcast, ISEL_IDENTITY, ISEL_ILLEGAL, ISEL_ILLEGAL /* no fp-ptr casts */ }, + { ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_IDENTITY, ISEL_ILLEGAL /* no bool reinterpret */ }, + { SpvOpConvertPtrToU, SpvOpConvertPtrToU, ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_CUSTOM } + }}, + + [sqrt_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Sqrt }, + [inv_sqrt_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450InverseSqrt}, + [floor_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Floor }, + [ceil_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Ceil }, + [round_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Round }, + [fract_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Fract }, + [sin_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Sin }, + [cos_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Cos }, + + [abs_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = { (SpvOp) GLSLstd450SAbs, ISEL_ILLEGAL, (SpvOp) GLSLstd450FAbs, ISEL_ILLEGAL }}, + [sign_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = { (SpvOp) GLSLstd450SSign, ISEL_ILLEGAL, (SpvOp) GLSLstd450FSign, ISEL_ILLEGAL }}, + + [min_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMin, (SpvOp) GLSLstd450UMin, (SpvOp) GLSLstd450FMin, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [max_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMax, (SpvOp) GLSLstd450UMax, (SpvOp) GLSLstd450FMax, ISEL_ILLEGAL, ISEL_ILLEGAL }}, + [exp_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Exp }, + [pow_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Pow }, + [fma_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Fma }, + + [sample_texture_op] = {Plain, Monomorphic, TyOperand, .op = SpvOpImageSampleImplicitLod }, + + [subgroup_assume_uniform_op] = {Plain, Monomorphic, Same, .op = ISEL_IDENTITY }, + + [PRIMOPS_COUNT] = { Custom } +}; + +#pragma GCC diagnostic pop +#pragma GCC diagnostic error "-Wswitch" + +static const Type* get_result_t(Emitter* emitter, IselTableEntry entry, Nodes args, Nodes type_arguments) { + switch (entry.result_kind) { + case Same: return shd_get_unqualified_type(shd_first(args)->type); + case SameTuple: return record_type(emitter->arena, (RecordType) { .members = mk_nodes(emitter->arena, shd_get_unqualified_type(shd_first(args)->type), shd_get_unqualified_type(shd_first(args)->type)) }); + case Bool: return bool_type(emitter->arena); + case TyOperand: return shd_first(type_arguments); + case Void: return unit_type(emitter->arena); + } +} + +static SpvOp get_opcode(SHADY_UNUSED Emitter* emitter, IselTableEntry entry, Nodes args, Nodes type_arguments) { + switch (entry.isel_mechanism) { + case None: return SpvOpMax; + case Monomorphic: return entry.op; + case FirstOp: { + assert(args.count >= 1); + OperandClass op_class = classify_operand_type(shd_get_unqualified_type(shd_first(args)->type)); + return entry.fo[op_class]; + } + case FirstAndResult: { + assert(args.count >= 1); + assert(type_arguments.count == 1); + OperandClass op_class = classify_operand_type(shd_get_unqualified_type(shd_first(args)->type)); + OperandClass return_t_class = classify_operand_type(shd_first(type_arguments)); + return entry.foar[op_class][return_t_class]; + } + } +} + +static SpvId emit_primop(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, const Node* instr) { + PrimOp the_op = instr->payload.prim_op; + Nodes args = the_op.operands; + Nodes type_arguments = the_op.type_arguments; + + IselTableEntry entry = isel_table[the_op.op]; + if (entry.class != Custom) { + LARRAY(SpvId, emitted_args, args.count); + for (size_t i = 0; i < args.count; i++) + emitted_args[i] = spv_emit_value(emitter, fn_builder, args.nodes[i]); + + switch (entry.class) { + case Plain: { + SpvOp opcode = get_opcode(emitter, entry, args, type_arguments); + if (opcode == SpvOpNop) { + assert(args.count == 1); + return emitted_args[0]; + } + + if (opcode == SpvOpMax) + goto custom; + + SpvId result_t = instr->type == empty_multiple_return_type(emitter->arena) ? emitter->void_t : spv_emit_type(emitter, instr->type); + if (entry.extended_set) { + SpvId set_id = spv_get_extended_instruction_set(emitter, entry.extended_set); + return spvb_ext_instruction(bb_builder, result_t, set_id, opcode, args.count, emitted_args); + } else { + return spvb_op(bb_builder, opcode, result_t, args.count, emitted_args); + } + } + case Custom: SHADY_UNREACHABLE; + } + SHADY_UNREACHABLE; + } + + custom: + switch (the_op.op) { + case reinterpret_op: { + const Type* dst = shd_first(the_op.type_arguments); + const Type* src = shd_get_unqualified_type(shd_first(the_op.operands)->type); + assert(dst->tag == PtrType_TAG && src->tag == PtrType_TAG); + assert(src != dst); + return spvb_op(bb_builder, SpvOpBitcast, spv_emit_type(emitter, dst), 1, (SpvId[]) {spv_emit_value(emitter, fn_builder, shd_first(the_op.operands)) }); + } + case insert_op: + case extract_dynamic_op: + case extract_op: { + bool insert = the_op.op == insert_op; + + const Node* src_value = shd_first(args); + const Type* result_t = instr->type; + size_t indices_start = insert ? 2 : 1; + size_t indices_count = args.count - indices_start; + assert(args.count > indices_start); + + bool dynamic = the_op.op == extract_dynamic_op; + + if (dynamic) { + LARRAY(SpvId, indices, indices_count); + for (size_t i = 0; i < indices_count; i++) { + indices[i] = spv_emit_value(emitter, fn_builder, args.nodes[i + indices_start]); + } + assert(indices_count == 1); + return spvb_vector_extract_dynamic(bb_builder, spv_emit_type(emitter, result_t), spv_emit_value(emitter, fn_builder, src_value), indices[0]); + } + LARRAY(uint32_t, indices, indices_count); + for (size_t i = 0; i < indices_count; i++) { + // TODO: fallback to Dynamic variants transparently + indices[i] = shd_get_int_literal_value(*shd_resolve_to_int_literal(args.nodes[i + indices_start]), false); + } + + if (insert) + return spvb_insert(bb_builder, spv_emit_type(emitter, result_t), spv_emit_value(emitter, fn_builder, args.nodes[1]), spv_emit_value(emitter, fn_builder, src_value), indices_count, indices); + else + return spvb_extract(bb_builder, spv_emit_type(emitter, result_t), spv_emit_value(emitter, fn_builder, src_value), indices_count, indices); + } + case shuffle_op: { + const Type* result_t = instr->type; + SpvId a = spv_emit_value(emitter, fn_builder, args.nodes[0]); + SpvId b = spv_emit_value(emitter, fn_builder, args.nodes[1]); + LARRAY(uint32_t, indices, args.count - 2); + for (size_t i = 0; i < args.count - 2; i++) { + int64_t indice = shd_get_int_literal_value(*shd_resolve_to_int_literal(args.nodes[i + 2]), true); + if (indice == -1) + indices[i] = 0xFFFFFFFF; + else + indices[i] = indice; + } + return spvb_vecshuffle(bb_builder, spv_emit_type(emitter, result_t), a, b, args.count - 2, indices); + } + case select_op: { + SpvId cond = spv_emit_value(emitter, fn_builder, shd_first(args)); + SpvId truv = spv_emit_value(emitter, fn_builder, args.nodes[1]); + SpvId flsv = spv_emit_value(emitter, fn_builder, args.nodes[2]); + + return spvb_select(bb_builder, spv_emit_type(emitter, args.nodes[1]->type), cond, truv, flsv); + } + default: shd_error("TODO: unhandled op"); + } + shd_error("unreachable"); +} + +static SpvId emit_ext_instr(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, ExtInstr instr) { + spv_emit_mem(emitter, fn_builder, instr.mem); + if (strcmp("spirv.core", instr.set) == 0) { + switch (instr.opcode) { + case SpvOpGroupNonUniformBroadcastFirst: { + spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformBallot); + SpvId scope_subgroup = spv_emit_value(emitter, fn_builder, shd_int32_literal(emitter->arena, SpvScopeSubgroup)); + if (emitter->configuration->hacks.spv_shuffle_instead_of_broadcast_first) { + spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformShuffle); + const Node* b = ref_decl_helper(emitter->arena, shd_get_or_create_builtin(emitter->module, BuiltinSubgroupLocalInvocationId, NULL)); + SpvId local_id = spvb_op(bb_builder, SpvOpLoad, spv_emit_type(emitter, shd_uint32_type(emitter->arena)), 1, (SpvId []) { spv_emit_value(emitter, fn_builder, b) }); + return spvb_group_shuffle(bb_builder, spv_emit_type(emitter, instr.result_t), scope_subgroup, spv_emit_value(emitter, fn_builder, shd_first(instr.operands)), local_id); + } + break; + } + case SpvOpGroupNonUniformBallot: { + spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformBallot); + assert(instr.operands.count == 2); + // SpvId scope_subgroup = spv_emit_value(emitter, fn_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); + // ad-hoc extension for my sanity + if (shd_get_unqualified_type(instr.result_t) == shd_get_actual_mask_type(emitter->arena)) { + const Type* i32x4 = pack_type(emitter->arena, (PackType) { .width = 4, .element_type = shd_uint32_type(emitter->arena) }); + SpvId raw_result = spvb_group_ballot(bb_builder, spv_emit_type(emitter, i32x4), spv_emit_value(emitter, fn_builder, instr.operands.nodes[1]), spv_emit_value(emitter, fn_builder, shd_first(instr.operands))); + // TODO: why are we doing this in SPIR-V and not the IR ? + SpvId low32 = spvb_extract(bb_builder, spv_emit_type(emitter, shd_uint32_type(emitter->arena)), raw_result, 1, (uint32_t[]) { 0 }); + SpvId hi32 = spvb_extract(bb_builder, spv_emit_type(emitter, shd_uint32_type(emitter->arena)), raw_result, 1, (uint32_t[]) { 1 }); + SpvId low64 = spvb_op(bb_builder, SpvOpUConvert, spv_emit_type(emitter, shd_uint64_type(emitter->arena)), 1, &low32); + SpvId hi64 = spvb_op(bb_builder, SpvOpUConvert, spv_emit_type(emitter, shd_uint64_type(emitter->arena)), 1, &hi32); + hi64 = spvb_op(bb_builder, SpvOpShiftLeftLogical, spv_emit_type(emitter, shd_uint64_type(emitter->arena)), 2, (SpvId []) { hi64, spv_emit_value(emitter, fn_builder, shd_int64_literal(emitter->arena, 32)) }); + SpvId final_result = spvb_op(bb_builder, SpvOpBitwiseOr, spv_emit_type(emitter, shd_uint64_type(emitter->arena)), 2, (SpvId []) { low64, hi64 }); + return final_result; + } + break; + } + case SpvOpGroupNonUniformIAdd: { + spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformArithmetic); + SpvId scope = spv_emit_value(emitter, fn_builder, shd_first(instr.operands)); + SpvGroupOperation group_op = shd_get_int_literal_value(*shd_resolve_to_int_literal(instr.operands.nodes[2]), false); + return spvb_group_non_uniform_group_op(bb_builder, spv_emit_type(emitter, instr.result_t), instr.opcode, scope, group_op, spv_emit_value(emitter, fn_builder, instr.operands.nodes[1]), NULL); + } + case SpvOpGroupNonUniformElect: { + spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniform); + assert(instr.operands.count == 1); + break; + // SpvId result_t = spv_emit_type(emitter, bool_type(emitter->arena)); + // SpvId scope_subgroup = spv_emit_value(emitter, fn_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); + // return spvb_group_elect(bb_builder, result_t, scope_subgroup); + } + default: break; + } + LARRAY(SpvId, ops, instr.operands.count); + for (size_t i = 0; i < instr.operands.count; i++) + ops[i] = spv_emit_value(emitter, fn_builder, instr.operands.nodes[i]); + return spvb_op(bb_builder, instr.opcode, spv_emit_type(emitter, instr.result_t), instr.operands.count, ops); + } + LARRAY(SpvId, ops, instr.operands.count); + for (size_t i = 0; i < instr.operands.count; i++) + ops[i] = spv_emit_value(emitter, fn_builder, instr.operands.nodes[i]); + SpvId set_id = spv_get_extended_instruction_set(emitter, instr.set); + return spvb_ext_instruction(bb_builder, spv_emit_type(emitter, instr.result_t), set_id, instr.opcode, instr.operands.count, ops); +} + +static SpvId emit_fn_call(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, Call call) { + spv_emit_mem(emitter, fn_builder, call.mem); + + const Node* fn = call.callee; + const Type* callee_type = fn->type; + assert(callee_type->tag == FnType_TAG); + Nodes return_types = callee_type->payload.fn_type.return_types; + SpvId return_type = spv_types_to_codom(emitter, return_types); + + LARRAY(SpvId, args, call.args.count); + for (size_t i = 0; i < call.args.count; i++) + args[i] = spv_emit_value(emitter, fn_builder, call.args.nodes[i]); + + if (fn->tag == FnAddr_TAG) { + fn = fn->payload.fn_addr.fn; + SpvId callee = spv_emit_decl(emitter, fn); + return spvb_call(bb_builder, return_type, callee, call.args.count, args); + } else { + spvb_capability(emitter->file_builder, SpvCapabilityFunctionPointersINTEL); + return spvb_op(bb_builder, SpvOpFunctionPointerCallINTEL, return_type, call.args.count, args); + } +} + +static SpvId spv_emit_instruction(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, const Node* instruction) { + assert(is_instruction(instruction)); + + switch (is_instruction(instruction)) { + case NotAnInstruction: shd_error(""); + case Instruction_PushStack_TAG: + case Instruction_PopStack_TAG: + case Instruction_GetStackSize_TAG: + case Instruction_SetStackSize_TAG: + case Instruction_GetStackBaseAddr_TAG: shd_error("Stack operations need to be lowered."); + case Instruction_CopyBytes_TAG: + case Instruction_FillBytes_TAG: + case Instruction_StackAlloc_TAG: shd_error("Should be lowered elsewhere") + case Instruction_ExtInstr_TAG: return emit_ext_instr(emitter, fn_builder, bb_builder, instruction->payload.ext_instr); + case Instruction_Call_TAG: return emit_fn_call(emitter, fn_builder, bb_builder, instruction->payload.call); + case PrimOp_TAG: return emit_primop(emitter, fn_builder, bb_builder, instruction); + case Comment_TAG: { + spv_emit_mem(emitter, fn_builder, instruction->payload.comment.mem); + return 0; + } + case Instruction_LocalAlloc_TAG: { + LocalAlloc payload = instruction->payload.local_alloc; + spv_emit_mem(emitter, fn_builder, payload.mem); + assert(bb_builder); + return spvb_local_variable(spvb_get_fn_builder(bb_builder), spv_emit_type(emitter, ptr_type(emitter->arena, (PtrType) { + .address_space = AsFunction, + .pointed_type = instruction->payload.local_alloc.type + })), SpvStorageClassFunction); + } + case Instruction_Load_TAG: { + Load payload = instruction->payload.load; + spv_emit_mem(emitter, fn_builder, payload.mem); + const Type* ptr_type = payload.ptr->type; + shd_deconstruct_qualified_type(&ptr_type); + assert(ptr_type->tag == PtrType_TAG); + const Type* elem_type = ptr_type->payload.ptr_type.pointed_type; + + size_t operands_count = 0; + uint32_t operands[2]; + if (ptr_type->payload.ptr_type.address_space == AsGlobal) { + // TODO only do this in VK mode ? + TypeMemLayout layout = shd_get_mem_layout(emitter->arena, elem_type); + operands[operands_count + 0] = SpvMemoryAccessAlignedMask; + operands[operands_count + 1] = (uint32_t) layout.alignment_in_bytes; + operands_count += 2; + } + + SpvId eptr = spv_emit_value(emitter, fn_builder, payload.ptr); + return spvb_load(bb_builder, spv_emit_type(emitter, elem_type), eptr, operands_count, operands); + } + case Instruction_Store_TAG: { + Store payload = instruction->payload.store; + spv_emit_mem(emitter, fn_builder, payload.mem); + const Type* ptr_type = payload.ptr->type; + shd_deconstruct_qualified_type(&ptr_type); + assert(ptr_type->tag == PtrType_TAG); + const Type* elem_type = ptr_type->payload.ptr_type.pointed_type; + + size_t operands_count = 0; + uint32_t operands[2]; + if (ptr_type->payload.ptr_type.address_space == AsGlobal) { + // TODO only do this in VK mode ? + TypeMemLayout layout = shd_get_mem_layout(emitter->arena, elem_type); + operands[operands_count + 0] = SpvMemoryAccessAlignedMask; + operands[operands_count + 1] = (uint32_t) layout.alignment_in_bytes; + operands_count += 2; + } + + SpvId eptr = spv_emit_value(emitter, fn_builder, payload.ptr); + SpvId eval = spv_emit_value(emitter, fn_builder, payload.value); + spvb_store(bb_builder, eval, eptr, operands_count, operands); + return 0; + } + case Instruction_PtrCompositeElement_TAG: { + PtrCompositeElement payload = instruction->payload.ptr_composite_element; + SpvId base = spv_emit_value(emitter, fn_builder, payload.ptr); + const Type* target_type = instruction->type; + SpvId index = spv_emit_value(emitter, fn_builder, payload.index); + return spvb_access_chain(bb_builder, spv_emit_type(emitter, target_type), base, 1, &index); + } + case Instruction_PtrArrayElementOffset_TAG: { + PtrArrayElementOffset payload = instruction->payload.ptr_array_element_offset; + SpvId base = spv_emit_value(emitter, fn_builder, payload.ptr); + const Type* target_type = instruction->type; + SpvId offset = spv_emit_value(emitter, fn_builder, payload.offset); + return spvb_ptr_access_chain(bb_builder, spv_emit_type(emitter, target_type), base, offset, 0, NULL); + } + case Instruction_DebugPrintf_TAG: { + DebugPrintf payload = instruction->payload.debug_printf; + spv_emit_mem(emitter, fn_builder, payload.mem); + SpvId set_id = spv_get_extended_instruction_set(emitter, "NonSemantic.DebugPrintf"); + LARRAY(SpvId, args, instruction->payload.debug_printf.args.count + 1); + args[0] = spv_emit_value(emitter, fn_builder, string_lit_helper(emitter->arena, instruction->payload.debug_printf.string)); + for (size_t i = 0; i < instruction->payload.debug_printf.args.count; i++) + args[i + 1] = spv_emit_value(emitter, fn_builder, instruction->payload.debug_printf.args.nodes[i]); + spvb_ext_instruction(bb_builder, spv_emit_type(emitter, instruction->type), set_id, (SpvOp) NonSemanticDebugPrintfDebugPrintf, instruction->payload.debug_printf.args.count + 1, args); + return 0; + } + } + SHADY_UNREACHABLE; +} + +static SpvId spv_emit_value_(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_builder, const Node* node) { + if (is_instruction(node)) + return spv_emit_instruction(emitter, fn_builder, bb_builder, node); + + SpvId new; + switch (is_value(node)) { + case NotAValue: shd_error(""); + case Param_TAG: shd_error("tried to emit a param: all params should be emitted by their binding abstraction !"); + case Value_ConstrainedValue_TAG: + case Value_UntypedNumber_TAG: + case Value_FnAddr_TAG: { + spvb_capability(emitter->file_builder, SpvCapabilityInModuleFunctionAddressSHADY); + SpvId fn = spv_emit_decl(emitter, node->payload.fn_addr.fn); + return spvb_constant_op(emitter->file_builder, spv_emit_type(emitter, node->type), SpvOpConstantFunctionPointerINTEL, 1, &fn); + } + case IntLiteral_TAG: { + new = spvb_fresh_id(emitter->file_builder); + SpvId ty = spv_emit_type(emitter, node->type); + // 64-bit constants take two spirv words, anything else fits in one + if (node->payload.int_literal.width == IntTy64) { + uint32_t arr[] = { node->payload.int_literal.value & 0xFFFFFFFF, node->payload.int_literal.value >> 32 }; + spvb_constant(emitter->file_builder, new, ty, 2, arr); + } else { + uint32_t arr[] = { node->payload.int_literal.value }; + spvb_constant(emitter->file_builder, new, ty, 1, arr); + } + break; + } + case FloatLiteral_TAG: { + new = spvb_fresh_id(emitter->file_builder); + SpvId ty = spv_emit_type(emitter, node->type); + switch (node->payload.float_literal.width) { + case FloatTy16: { + uint32_t arr[] = { node->payload.float_literal.value & 0xFFFF }; + spvb_constant(emitter->file_builder, new, ty, 1, arr); + break; + } + case FloatTy32: { + uint32_t arr[] = { node->payload.float_literal.value }; + spvb_constant(emitter->file_builder, new, ty, 1, arr); + break; + } + case FloatTy64: { + uint32_t arr[] = { node->payload.float_literal.value & 0xFFFFFFFF, node->payload.float_literal.value >> 32 }; + spvb_constant(emitter->file_builder, new, ty, 2, arr); + break; + } + } + break; + } + case True_TAG: { + new = spvb_fresh_id(emitter->file_builder); + spvb_bool_constant(emitter->file_builder, new, spv_emit_type(emitter, bool_type(emitter->arena)), true); + break; + } + case False_TAG: { + new = spvb_fresh_id(emitter->file_builder); + spvb_bool_constant(emitter->file_builder, new, spv_emit_type(emitter, bool_type(emitter->arena)), false); + break; + } + case Value_StringLiteral_TAG: { + new = spvb_debug_string(emitter->file_builder, node->payload.string_lit.string); + break; + } + case Value_NullPtr_TAG: { + new = spvb_constant_null(emitter->file_builder, spv_emit_type(emitter, node->payload.null_ptr.ptr_type)); + break; + } + case Composite_TAG: { + Nodes contents = node->payload.composite.contents; + LARRAY(SpvId, ids, contents.count); + for (size_t i = 0; i < contents.count; i++) { + ids[i] = spv_emit_value(emitter, fn_builder, contents.nodes[i]); + } + if (bb_builder) { + new = spvb_composite(bb_builder, spv_emit_type(emitter, node->type), contents.count, ids); + return new; + } else { + new = spvb_constant_composite(emitter->file_builder, spv_emit_type(emitter, node->type), contents.count, ids); + break; + } + } + case Value_Undef_TAG: { + new = spvb_undef(emitter->file_builder, spv_emit_type(emitter, node->payload.undef.type)); + break; + } + case Value_Fill_TAG: shd_error("lower me") + case RefDecl_TAG: { + const Node* decl = node->payload.ref_decl.decl; + switch (decl->tag) { + case GlobalVariable_TAG: { + new = spv_emit_decl(emitter, decl); + break; + } + case Constant_TAG: { + new = spv_emit_value(emitter, fn_builder, decl->payload.constant.value); + break; + } + default: shd_error("RefDecl must reference a constant or global"); + } + break; + } + default: { + shd_error("Unhandled value for code generation: %s", shd_get_node_tag_string(node->tag)); + } + } + + return new; +} + +static bool can_appear_at_top_level(const Node* node) { + switch (node->tag) { + case Undef_TAG: + case Composite_TAG: + case FloatLiteral_TAG: + case IntLiteral_TAG: + case True_TAG: + case False_TAG: + return true; + default: break; + } + return false; +} + +SpvId spv_emit_value(Emitter* emitter, FnBuilder* fn_builder, const Node* node) { + SpvId* existing = spv_search_emitted(emitter, fn_builder, node); + if (existing) + return *existing; + + CFNode* where = fn_builder ? shd_schedule_instruction(fn_builder->scheduler, node) : NULL; + if (where) { + BBBuilder bb_builder = spv_find_basic_block_builder(emitter, where->node); + SpvId emitted = spv_emit_value_(emitter, fn_builder, bb_builder, node); + spv_register_emitted(emitter, fn_builder, node, emitted); + return emitted; + } else if (!can_appear_at_top_level(node)) { + if (!fn_builder) { + shd_log_node(ERROR, node); + shd_log_fmt(ERROR, "cannot appear at top-level"); + exit(-1); + } + // Pick the entry block of the current fn + BBBuilder bb_builder = spv_find_basic_block_builder(emitter, fn_builder->cfg->entry->node); + SpvId emitted = spv_emit_value_(emitter, fn_builder, bb_builder, node); + spv_register_emitted(emitter, fn_builder, node, emitted); + return emitted; + } else { + assert(!is_mem(node)); + SpvId emitted = spv_emit_value_(emitter, NULL, NULL, node); + spv_register_emitted(emitter, NULL, node, emitted); + return emitted; + } +} + +SpvId spv_emit_mem(Emitter* e, FnBuilder* b, const Node* mem) { + assert(is_mem(mem)); + if (mem->tag == AbsMem_TAG) + return 0; + if (is_instruction(mem)) + return spv_emit_value(e, b, mem); + shd_error("What sort of mem is this ?"); +} diff --git a/src/shady/emit/spirv/spirv_builder.c b/src/backend/spirv/spirv_builder.c similarity index 82% rename from src/shady/emit/spirv/spirv_builder.c rename to src/backend/spirv/spirv_builder.c index 419482c15..d1ff6a051 100644 --- a/src/shady/emit/spirv/spirv_builder.c +++ b/src/backend/spirv/spirv_builder.c @@ -20,7 +20,7 @@ inline static int div_roundup(int a, int b) { } inline static void output_word(SpvbSectionBuilder data, uint32_t word) { - growy_append_bytes(data, sizeof(uint32_t), (char*) &word); + shd_growy_append_bytes(data, sizeof(uint32_t), (char*) &word); } #define op(opcode, size) op_(target_data, opcode, size) @@ -61,7 +61,7 @@ inline static void literal_int_(SpvbSectionBuilder data, uint32_t i) { #define copy_section(section) copy_section_(target_data, section) inline static void copy_section_(SpvbSectionBuilder target, SpvbSectionBuilder source) { - growy_append_bytes(target, growy_size(source), growy_data(source)); + shd_growy_append_bytes(target, shd_growy_size(source), shd_growy_data(source)); } struct SpvbFileBuilder_ { @@ -93,31 +93,31 @@ struct SpvbFileBuilder_ { struct Dict* extensions_set; }; -static KeyHash hash_u32(uint32_t* p) { return hash_murmur(p, sizeof(uint32_t)); } +static KeyHash hash_u32(uint32_t* p) { return shd_hash(p, sizeof(uint32_t)); } static bool compare_u32s(uint32_t* a, uint32_t* b) { return *a == *b; } -KeyHash hash_string(const char** string); -bool compare_string(const char** a, const char** b); +KeyHash shd_hash_string(const char** string); +bool shd_compare_string(const char** a, const char** b); SpvbFileBuilder* spvb_begin() { SpvbFileBuilder* file_builder = (SpvbFileBuilder*) malloc(sizeof(SpvbFileBuilder)); *file_builder = (SpvbFileBuilder) { .bound = 1, - .capabilities = new_growy(), - .extensions = new_growy(), - .ext_inst_import = new_growy(), - .entry_points = new_growy(), - .execution_modes = new_growy(), - .debug_string_source = new_growy(), - .debug_names = new_growy(), - .debug_module_processed = new_growy(), - .annotations = new_growy(), - .types_constants = new_growy(), - .fn_decls = new_growy(), - .fn_defs = new_growy(), - - .capabilities_set = new_set(SpvCapability, (HashFn) hash_u32, (CmpFn) compare_u32s), - .extensions_set = new_set(const char*, (HashFn) hash_string, (CmpFn) compare_string), + .capabilities = shd_new_growy(), + .extensions = shd_new_growy(), + .ext_inst_import = shd_new_growy(), + .entry_points = shd_new_growy(), + .execution_modes = shd_new_growy(), + .debug_string_source = shd_new_growy(), + .debug_names = shd_new_growy(), + .debug_module_processed = shd_new_growy(), + .annotations = shd_new_growy(), + .types_constants = shd_new_growy(), + .fn_decls = shd_new_growy(), + .fn_defs = shd_new_growy(), + + .capabilities_set = shd_new_set(SpvCapability, (HashFn) hash_u32, (CmpFn) compare_u32s), + .extensions_set = shd_new_set(const char*, (HashFn) shd_hash_string, (CmpFn) shd_compare_string), .memory_model = SpvMemoryModelGLSL450, }; @@ -175,30 +175,30 @@ static uint32_t byteswap(uint32_t v) { } size_t spvb_finish(SpvbFileBuilder* file_builder, char** output) { - Growy* g = new_growy(); + Growy* g = shd_new_growy(); merge_sections(file_builder, g); - destroy_growy(file_builder->fn_defs); - destroy_growy(file_builder->fn_decls); - destroy_growy(file_builder->types_constants); - destroy_growy(file_builder->annotations); - destroy_growy(file_builder->debug_module_processed); - destroy_growy(file_builder->debug_names); - destroy_growy(file_builder->debug_string_source); - destroy_growy(file_builder->execution_modes); - destroy_growy(file_builder->entry_points); - destroy_growy(file_builder->ext_inst_import); - destroy_growy(file_builder->extensions); - destroy_growy(file_builder->capabilities); - - destroy_dict(file_builder->capabilities_set); - destroy_dict(file_builder->extensions_set); + shd_destroy_growy(file_builder->fn_defs); + shd_destroy_growy(file_builder->fn_decls); + shd_destroy_growy(file_builder->types_constants); + shd_destroy_growy(file_builder->annotations); + shd_destroy_growy(file_builder->debug_module_processed); + shd_destroy_growy(file_builder->debug_names); + shd_destroy_growy(file_builder->debug_string_source); + shd_destroy_growy(file_builder->execution_modes); + shd_destroy_growy(file_builder->entry_points); + shd_destroy_growy(file_builder->ext_inst_import); + shd_destroy_growy(file_builder->extensions); + shd_destroy_growy(file_builder->capabilities); + + shd_destroy_dict(file_builder->capabilities_set); + shd_destroy_dict(file_builder->extensions_set); free(file_builder); - size_t s = growy_size(g); + size_t s = shd_growy_size(g); assert(s % 4 == 0); - *output = growy_deconstruct(g); + *output = shd_growy_deconstruct(g); if (is_big_endian()) for (size_t i = 0; i < s / 4; i++) { ((uint32_t*)*output)[i] = byteswap(((uint32_t*)(*output))[i]); @@ -223,7 +223,7 @@ void spvb_set_addressing_model(SpvbFileBuilder* file_builder, SpvAddressingModel #define target_data file_builder->capabilities void spvb_capability(SpvbFileBuilder* file_builder, SpvCapability cap) { - if (insert_set_get_result(SpvCapability, file_builder->capabilities_set, cap)) { + if (shd_set_insert_get_result(SpvCapability, file_builder->capabilities_set, cap)) { op(SpvOpCapability, 2); literal_int(cap); } @@ -232,7 +232,7 @@ void spvb_capability(SpvbFileBuilder* file_builder, SpvCapability cap) { #define target_data file_builder->extensions void spvb_extension(SpvbFileBuilder* file_builder, const char* name) { - if (insert_set_get_result(char*, file_builder->extensions_set, name)) { + if (shd_set_insert_get_result(char*, file_builder->extensions_set, name)) { op(SpvOpExtension, 1 + div_roundup(strlen(name) + 1, 4)); literal_name(name); } @@ -349,6 +349,21 @@ SpvId spvb_ptr_type(SpvbFileBuilder* file_builder, SpvStorageClass storage_class return id; } +SpvId spvb_forward_ptr_type(SpvbFileBuilder* file_builder, SpvStorageClass storage_class) { + op(SpvOpTypeForwardPointer, 3); + SpvId id = spvb_fresh_id(file_builder); + ref_id(id); + literal_int(storage_class); + return id; +} + +void spvb_ptr_type_define(SpvbFileBuilder* file_builder, SpvId id, SpvStorageClass storage_class, SpvId element_type) { + op(SpvOpTypePointer, 4); + ref_id(id); + literal_int(storage_class); + ref_id(element_type); +} + SpvId spvb_array_type(SpvbFileBuilder* file_builder, SpvId element_type, SpvId dim) { op(SpvOpTypeArray, 4); SpvId id = spvb_fresh_id(file_builder); @@ -464,6 +479,16 @@ SpvId spvb_global_variable(SpvbFileBuilder* file_builder, SpvId id, SpvId type, return id; } +SpvId spvb_constant_op(SpvbFileBuilder* file_builder, SpvId type, SpvOp op, size_t operands_count, SpvId operands[]) { + op(op, 3 + operands_count); + SpvId id = spvb_fresh_id(file_builder); + ref_id(type); + ref_id(id); + for (size_t i = 0; i < operands_count; i++) + literal_int(operands[i]); + return id; +} + SpvId spvb_undef(SpvbFileBuilder* file_builder, SpvId type) { op(SpvOpUndef, 3); ref_id(type); @@ -493,14 +518,14 @@ SpvbFnBuilder* spvb_begin_fn(SpvbFileBuilder* file_builder, SpvId fn_id, SpvId f .fn_type = fn_type, .fn_ret_type = fn_ret_type, .file_builder = file_builder, - .bbs = new_list(SpvbBasicBlockBuilder*), - .variables = new_growy(), - .header = new_growy(), + .bbs = shd_new_list(SpvbBasicBlockBuilder*), + .variables = shd_new_growy(), + .header = shd_new_growy(), }; return fnb; } -SpvId fn_ret_type_id(SpvbFnBuilder* fnb){ +SpvId spvb_fn_ret_type_id(SpvbFnBuilder* fnb){ return fnb->fn_ret_type; } @@ -536,20 +561,21 @@ void spvb_declare_function(SpvbFileBuilder* file_builder, SpvbFnBuilder* fn_buil // Includes stuff like OpFunctionParameters copy_section(fn_builder->header); - assert(entries_count_list(fn_builder->bbs) == 0 && "declared functions must be empty"); + assert(shd_list_count(fn_builder->bbs) == 0 && "declared functions must be shd_empty"); op(SpvOpFunctionEnd, 1); - destroy_list(fn_builder->bbs); - destroy_growy(fn_builder->header); - destroy_growy(fn_builder->variables); + shd_destroy_list(fn_builder->bbs); + shd_destroy_growy(fn_builder->header); + shd_destroy_growy(fn_builder->variables); free(fn_builder); } #undef target_data struct SpvbBasicBlockBuilder_ { SpvbFnBuilder* fn_builder; - SpvbSectionBuilder section_data; + SpvbSectionBuilder instructions_section; + SpvbSectionBuilder terminator_section; struct List* phis; SpvId label; @@ -580,7 +606,7 @@ void spvb_define_function(SpvbFileBuilder* file_builder, SpvbFnBuilder* fn_build bool first = true; for (size_t i = 0; i < fn_builder->bbs->elements_count; i++) { op(SpvOpLabel, 2); - SpvbBasicBlockBuilder* bb = read_list(SpvbBasicBlockBuilder*, fn_builder->bbs)[i]; + SpvbBasicBlockBuilder* bb = shd_read_list(SpvbBasicBlockBuilder*, fn_builder->bbs)[i]; ref_id(bb->label); if (first) { @@ -590,34 +616,36 @@ void spvb_define_function(SpvbFileBuilder* file_builder, SpvbFnBuilder* fn_build } for (size_t j = 0; j < bb->phis->elements_count; j++) { - SpvbPhi* phi = read_list(SpvbPhi*, bb->phis)[j]; + SpvbPhi* phi = shd_read_list(SpvbPhi*, bb->phis)[j]; op(SpvOpPhi, 3 + 2 * phi->preds->elements_count); ref_id(phi->type); ref_id(phi->value); assert(phi->preds->elements_count != 0); for (size_t k = 0; k < phi->preds->elements_count; k++) { - SpvbPhiSrc* pred = &read_list(SpvbPhiSrc, phi->preds)[k]; + SpvbPhiSrc* pred = &shd_read_list(SpvbPhiSrc, phi->preds)[k]; ref_id(pred->value); ref_id(pred->basic_block); } - destroy_list(phi->preds); + shd_destroy_list(phi->preds); free(phi); } - copy_section(bb->section_data); + copy_section(bb->instructions_section); + copy_section(bb->terminator_section); - destroy_list(bb->phis); - destroy_growy(bb->section_data); + shd_destroy_list(bb->phis); + shd_destroy_growy(bb->instructions_section); + shd_destroy_growy(bb->terminator_section); free(bb); } op(SpvOpFunctionEnd, 1); - destroy_list(fn_builder->bbs); - destroy_growy(fn_builder->header); - destroy_growy(fn_builder->variables); + shd_destroy_list(fn_builder->bbs); + shd_destroy_growy(fn_builder->header); + shd_destroy_growy(fn_builder->variables); free(fn_builder); } #undef target_data @@ -625,42 +653,47 @@ void spvb_define_function(SpvbFileBuilder* file_builder, SpvbFnBuilder* fn_build SpvbBasicBlockBuilder* spvb_begin_bb(SpvbFnBuilder* fn_builder, SpvId label) { SpvbBasicBlockBuilder* bbb = (SpvbBasicBlockBuilder*) malloc(sizeof(SpvbBasicBlockBuilder)); *bbb = (SpvbBasicBlockBuilder) { - .fn_builder = fn_builder, - .label = label, - .phis = new_list(SpvbPhi*), - .section_data = new_growy() + .fn_builder = fn_builder, + .label = label, + .phis = shd_new_list(SpvbPhi*), + .instructions_section = shd_new_growy(), + .terminator_section = shd_new_growy(), }; return bbb; } +SpvbFnBuilder* spvb_get_fn_builder(SpvbBasicBlockBuilder* bb_builder) { + return bb_builder->fn_builder; +} + void spvb_add_bb(SpvbFnBuilder* fn_builder, SpvbBasicBlockBuilder* bb_builder) { - append_list(SpvbBasicBlockBuilder*, fn_builder->bbs, bb_builder); + shd_list_append(SpvbBasicBlockBuilder*, fn_builder->bbs, bb_builder); } -SpvId get_block_builder_id(SpvbBasicBlockBuilder* basic_block_builder) { +SpvId spvb_get_block_builder_id(SpvbBasicBlockBuilder* basic_block_builder) { return basic_block_builder->label; } SpvbPhi* spvb_add_phi(SpvbBasicBlockBuilder* bb_builder, SpvId type, SpvId id) { SpvbPhi* phi = malloc(sizeof(SpvbPhi)); - phi->preds = new_list(SpvbPhiSrc); + phi->preds = shd_new_list(SpvbPhiSrc); phi->value = id; phi->type = type; - append_list(SpvbPhi*, bb_builder->phis, phi); + shd_list_append(SpvbPhi*, bb_builder->phis, phi); return phi; } void spvb_add_phi_source(SpvbPhi* phi, SpvId source_block, SpvId value) { SpvbPhiSrc op = { .value = value, .basic_block = source_block }; - append_list(SpvbPhi, phi->preds, op); + shd_list_append(SpvbPhi, phi->preds, op); } -struct List* spbv_get_phis(SpvbBasicBlockBuilder* bb_builder) { +struct List* spvb_get_phis(SpvbBasicBlockBuilder* bb_builder) { return bb_builder->phis; } // It is tiresome to pass the context over and over again. Let's not ! // We use this macro to save us some typing -#define target_data bb_builder->section_data +#define target_data bb_builder->instructions_section SpvId spvb_composite(SpvbBasicBlockBuilder* bb_builder, SpvId aggregate_t, size_t elements_count, SpvId elements[]) { op(SpvOpCompositeConstruct, 3u + elements_count); ref_id(aggregate_t); @@ -760,7 +793,7 @@ SpvId spvb_load(SpvbBasicBlockBuilder* bb_builder, SpvId target_type, SpvId poin return id; } -SpvId spvb_vecshuffle(SpvbBasicBlockBuilder* bb_builder, SpvId result_type, SpvId a, SpvId b, size_t operands_count, uint32_t operands[]) { +SpvId spvb_vecshuffle(SpvbBasicBlockBuilder* bb_builder, SpvId result_type, SpvId a, SpvId b, size_t operands_count, uint32_t operands[]) { op(SpvOpVectorShuffle, 5 + operands_count); SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder); ref_id(result_type); @@ -830,8 +863,8 @@ SpvId spvb_group_shuffle(SpvbBasicBlockBuilder* bb_builder, SpvId result_type, S return rid; } -SpvId spvb_group_non_uniform_iadd(SpvbBasicBlockBuilder* bb_builder, SpvId result_type, SpvId value, SpvId scope, SpvGroupOperation group_op, SpvId* cluster_size) { - op(SpvOpGroupNonUniformIAdd, cluster_size ? 7 : 6); +SpvId spvb_group_non_uniform_group_op(SpvbBasicBlockBuilder* bb_builder, SpvId result_type, SpvOp op, SpvId scope, SpvGroupOperation group_op, SpvId value, SpvId* cluster_size) { + op(op, cluster_size ? 7 : 6); SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder); ref_id(result_type); ref_id(id); @@ -843,6 +876,32 @@ SpvId spvb_group_non_uniform_iadd(SpvbBasicBlockBuilder* bb_builder, SpvId resul return id; } +SpvId spvb_call(SpvbBasicBlockBuilder* bb_builder, SpvId return_type, SpvId callee, size_t arguments_count, SpvId arguments[]) { + op(SpvOpFunctionCall, 4u + arguments_count); + SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder); + ref_id(return_type); + ref_id(id); + ref_id(callee); + for (size_t i = 0; i < arguments_count; i++) + ref_id(arguments[i]); + return id; +} + +SpvId spvb_ext_instruction(SpvbBasicBlockBuilder* bb_builder, SpvId return_type, SpvId set, uint32_t instruction, size_t arguments_count, SpvId arguments[]) { + op(SpvOpExtInst, 5 + arguments_count); + SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder); + ref_id(return_type); + ref_id(id); + ref_id(set); + literal_int(instruction); + for (size_t i = 0; i < arguments_count; i++) + ref_id(arguments[i]); + return id; +} + +#undef target_data +#define target_data bb_builder->terminator_section + void spvb_branch(SpvbBasicBlockBuilder* bb_builder, SpvId target) { op(SpvOpBranch, 2); ref_id(target); @@ -880,29 +939,6 @@ void spvb_loop_merge(SpvbBasicBlockBuilder* bb_builder, SpvId merge_bb, SpvId co literal_int(loop_control_ops[i]); } -SpvId spvb_call(SpvbBasicBlockBuilder* bb_builder, SpvId return_type, SpvId callee, size_t arguments_count, SpvId arguments[]) { - op(SpvOpFunctionCall, 4u + arguments_count); - SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder); - ref_id(return_type); - ref_id(id); - ref_id(callee); - for (size_t i = 0; i < arguments_count; i++) - ref_id(arguments[i]); - return id; -} - -SpvId spvb_ext_instruction(SpvbBasicBlockBuilder* bb_builder, SpvId return_type, SpvId set, uint32_t instruction, size_t arguments_count, SpvId arguments[]) { - op(SpvOpExtInst, 5 + arguments_count); - SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder); - ref_id(return_type); - ref_id(id); - ref_id(set); - literal_int(instruction); - for (size_t i = 0; i < arguments_count; i++) - ref_id(arguments[i]); - return id; -} - void spvb_return_void(SpvbBasicBlockBuilder* bb_builder) { op(SpvOpReturn, 1); } @@ -915,4 +951,11 @@ void spvb_return_value(SpvbBasicBlockBuilder* bb_builder, SpvId value) { void spvb_unreachable(SpvbBasicBlockBuilder* bb_builder) { op(SpvOpUnreachable, 1); } + +void spvb_terminator(SpvbBasicBlockBuilder* bb_builder, SpvOp op, size_t operands_count, SpvId operands[]) { + op(op, 1 + operands_count); + for (size_t i = 0; i < operands_count; i++) + ref_id(operands[i]); +} + #undef target_data diff --git a/src/shady/emit/spirv/spirv_builder.h b/src/backend/spirv/spirv_builder.h similarity index 89% rename from src/shady/emit/spirv/spirv_builder.h rename to src/backend/spirv/spirv_builder.h index 6e8b74421..7d8bee7c3 100644 --- a/src/shady/emit/spirv/spirv_builder.h +++ b/src/backend/spirv/spirv_builder.h @@ -44,6 +44,8 @@ SpvId spvb_void_type(SpvbFileBuilder*); SpvId spvb_bool_type(SpvbFileBuilder*); SpvId spvb_int_type(SpvbFileBuilder*, int width, bool signed_); SpvId spvb_float_type(SpvbFileBuilder*, int width); +SpvId spvb_forward_ptr_type(SpvbFileBuilder*, SpvStorageClass storage_class); +void spvb_ptr_type_define(SpvbFileBuilder*, SpvId id, SpvStorageClass storage_class, SpvId element_type); SpvId spvb_ptr_type(SpvbFileBuilder*, SpvStorageClass storage_class, SpvId element_type); SpvId spvb_array_type(SpvbFileBuilder*, SpvId element_type, SpvId dim); SpvId spvb_runtime_array_type(SpvbFileBuilder*, SpvId element_type); @@ -61,24 +63,26 @@ void spvb_constant(SpvbFileBuilder*, SpvId result, SpvId type, size_t bit_patter SpvId spvb_constant_composite(SpvbFileBuilder*, SpvId type, size_t ops_count, SpvId ops[]); SpvId spvb_constant_null(SpvbFileBuilder*, SpvId type); SpvId spvb_global_variable(SpvbFileBuilder*, SpvId id, SpvId type, SpvStorageClass storage_class, bool has_initializer, SpvId initializer); +SpvId spvb_constant_op(SpvbFileBuilder*, SpvId type, SpvOp, size_t operands_count, SpvId operands[]); // Function building stuff SpvbFnBuilder* spvb_begin_fn(SpvbFileBuilder*, SpvId fn_id, SpvId fn_type, SpvId fn_ret_type); -SpvId fn_ret_type_id(SpvbFnBuilder*); +SpvId spvb_fn_ret_type_id(SpvbFnBuilder* fnb); SpvId spvb_parameter(SpvbFnBuilder* fn_builder, SpvId param_type); SpvId spvb_local_variable(SpvbFnBuilder* fn_builder, SpvId type, SpvStorageClass storage_class); void spvb_declare_function(SpvbFileBuilder*, SpvbFnBuilder* fn_builder); void spvb_define_function(SpvbFileBuilder*, SpvbFnBuilder* fn_builder); SpvbBasicBlockBuilder* spvb_begin_bb(SpvbFnBuilder*, SpvId label); +SpvbFnBuilder* spvb_get_fn_builder(SpvbBasicBlockBuilder*); /// Actually adds the basic block to the function /// This is a separate action from begin_bb because the ordering in which the basic blocks are written matters... void spvb_add_bb(SpvbFnBuilder*, SpvbBasicBlockBuilder*); -SpvId get_block_builder_id(SpvbBasicBlockBuilder*); +SpvId spvb_get_block_builder_id(SpvbBasicBlockBuilder* basic_block_builder); SpvbPhi* spvb_add_phi(SpvbBasicBlockBuilder*, SpvId type, SpvId id); void spvb_add_phi_source(SpvbPhi*, SpvId source_block, SpvId value); -struct List* spbv_get_phis(SpvbBasicBlockBuilder*); +struct List* spvb_get_phis(SpvbBasicBlockBuilder* bb_builder); // Normal instructions SpvId spvb_op(SpvbBasicBlockBuilder*, SpvOp op, SpvId result_type, size_t operands_count, SpvId operands[]); @@ -97,7 +101,7 @@ SpvId spvb_group_elect(SpvbBasicBlockBuilder*, SpvId result_type, SpvId scope); SpvId spvb_group_ballot(SpvbBasicBlockBuilder*, SpvId result_t, SpvId predicate, SpvId scope); SpvId spvb_group_shuffle(SpvbBasicBlockBuilder*, SpvId result_type, SpvId scope, SpvId value, SpvId id); SpvId spvb_group_broadcast_first(SpvbBasicBlockBuilder*, SpvId result_t, SpvId value, SpvId scope); -SpvId spvb_group_non_uniform_iadd(SpvbBasicBlockBuilder*, SpvId result_t, SpvId value, SpvId scope, SpvGroupOperation group_op, SpvId* cluster_size); +SpvId spvb_group_non_uniform_group_op(SpvbBasicBlockBuilder*, SpvId result_t, SpvOp op, SpvId scope, SpvGroupOperation group_op, SpvId value, SpvId* cluster_size); // Terminators void spvb_branch(SpvbBasicBlockBuilder*, SpvId target); @@ -110,5 +114,6 @@ SpvId spvb_ext_instruction(SpvbBasicBlockBuilder*, SpvId return_type, SpvId set, void spvb_return_void(SpvbBasicBlockBuilder*) ; void spvb_return_value(SpvbBasicBlockBuilder*, SpvId value); void spvb_unreachable(SpvbBasicBlockBuilder*); +void spvb_terminator(SpvbBasicBlockBuilder*, SpvOp, size_t operands_count, SpvId operands[]); #endif diff --git a/src/backend/spirv/spirv_lift_globals_ssbo.c b/src/backend/spirv/spirv_lift_globals_ssbo.c new file mode 100644 index 000000000..ed8c32390 --- /dev/null +++ b/src/backend/spirv/spirv_lift_globals_ssbo.c @@ -0,0 +1,126 @@ +#include "shady/pass.h" +#include "shady/ir/memory_layout.h" +#include "shady/ir/function.h" +#include "shady/ir/mem.h" +#include "shady/ir/decl.h" + +#include "dict.h" +#include "portability.h" +#include "log.h" + +typedef struct { + Rewriter rewriter; + const CompilerConfig* config; + BodyBuilder* bb; + Node* lifted_globals_decl; +} Context; + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + + switch (node->tag) { + case Function_TAG: { + Node* newfun = shd_recreate_node_head(r, node); + if (get_abstraction_body(node)) { + Context functx = *ctx; + functx.rewriter.map = shd_clone_dict(functx.rewriter.map); + shd_dict_clear(functx.rewriter.map); + shd_register_processed_list(&functx.rewriter, get_abstraction_params(node), get_abstraction_params(newfun)); + functx.bb = shd_bld_begin(a, shd_get_abstraction_mem(newfun)); + Node* post_prelude = basic_block(a, shd_empty(a), "post-prelude"); + shd_register_processed(&functx.rewriter, shd_get_abstraction_mem(node), shd_get_abstraction_mem(post_prelude)); + shd_set_abstraction_body(post_prelude, shd_rewrite_node(&functx.rewriter, get_abstraction_body(node))); + shd_set_abstraction_body(newfun, shd_bld_finish(functx.bb, jump_helper(a, shd_bb_mem(functx.bb), post_prelude, + shd_empty(a)))); + shd_destroy_dict(functx.rewriter.map); + } + return newfun; + } + case RefDecl_TAG: { + const Node* odecl = node->payload.ref_decl.decl; + if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsGlobal) + break; + assert(ctx->bb && "this RefDecl node isn't appearing in an abstraction - we cannot replace it with a load!"); + const Node* ptr_addr = lea_helper(a, ref_decl_helper(a, ctx->lifted_globals_decl), shd_int32_literal(a, 0), shd_singleton(shd_rewrite_node(&ctx->rewriter, odecl))); + const Node* ptr = shd_bld_load(ctx->bb, ptr_addr); + return ptr; + } + case GlobalVariable_TAG: + if (node->payload.global_variable.address_space != AsGlobal) + break; + assert(false); + default: break; + } + + if (is_declaration(node)) { + Context declctx = *ctx; + declctx.bb = NULL; + return shd_recreate_node(&declctx.rewriter, node); + } + + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* shd_spvbe_pass_lift_globals_ssbo(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config + }; + + Nodes old_decls = shd_module_get_declarations(src); + LARRAY(const Type*, member_tys, old_decls.count); + LARRAY(String, member_names, old_decls.count); + + Nodes annotations = mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" })); + annotations = shd_empty(a); + + annotations = shd_nodes_append(a, annotations, annotation_value(a, (AnnotationValue) { .name = "DescriptorSet", .value = shd_int32_literal(a, 0) })); + annotations = shd_nodes_append(a, annotations, annotation_value(a, (AnnotationValue) { .name = "DescriptorBinding", .value = shd_int32_literal(a, 0) })); + annotations = shd_nodes_append(a, annotations, annotation(a, (Annotation) { .name = "Constants" })); + + size_t lifted_globals_count = 0; + for (size_t i = 0; i < old_decls.count; i++) { + const Node* odecl = old_decls.nodes[i]; + if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsGlobal) + continue; + + member_tys[lifted_globals_count] = shd_rewrite_node(&ctx.rewriter, odecl->type); + member_names[lifted_globals_count] = get_declaration_name(odecl); + + shd_register_processed(&ctx.rewriter, odecl, shd_int32_literal(a, lifted_globals_count)); + lifted_globals_count++; + } + + if (lifted_globals_count > 0) { + const Type* lifted_globals_struct_t = record_type(a, (RecordType) { + .members = shd_nodes(a, lifted_globals_count, member_tys), + .names = shd_strings(a, lifted_globals_count, member_names), + .special = DecorateBlock + }); + ctx.lifted_globals_decl = global_var(dst, annotations, lifted_globals_struct_t, "lifted_globals", AsShaderStorageBufferObject); + } + + shd_rewrite_module(&ctx.rewriter); + + lifted_globals_count = 0; + for (size_t i = 0; i < old_decls.count; i++) { + const Node* odecl = old_decls.nodes[i]; + if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsGlobal) + continue; + if (odecl->payload.global_variable.init) + ctx.lifted_globals_decl->payload.global_variable.annotations = shd_nodes_append(a, ctx.lifted_globals_decl->payload.global_variable.annotations, annotation_values(a, (AnnotationValues) { + .name = "InitialValue", + .values = mk_nodes(a, shd_int32_literal(a, lifted_globals_count), shd_rewrite_node(&ctx.rewriter, odecl->payload.global_variable.init)) + })); + + lifted_globals_count++; + } + + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/backend/spirv/spirv_map_entrypoint_args.c b/src/backend/spirv/spirv_map_entrypoint_args.c new file mode 100644 index 000000000..b08e27451 --- /dev/null +++ b/src/backend/spirv/spirv_map_entrypoint_args.c @@ -0,0 +1,70 @@ +#include "shady/pass.h" +#include "shady/ir/memory_layout.h" +#include "shady/ir/decl.h" +#include "shady/ir/annotation.h" + +#include "portability.h" +#include "log.h" + +typedef struct { + Rewriter rewriter; + const CompilerConfig* config; +} Context; + +static const Node* rewrite_args_type(Rewriter* rewriter, const Node* old_type) { + IrArena* a = rewriter->dst_arena; + + if (old_type->tag != RecordType_TAG || old_type->payload.record_type.special != NotSpecial) + shd_error("EntryPointArgs type must be a plain record type"); + + const Node* new_type = record_type(a, (RecordType) { + .members = shd_rewrite_nodes(rewriter, old_type->payload.record_type.members), + .names = old_type->payload.record_type.names, + .special = DecorateBlock + }); + + shd_register_processed(rewriter, old_type, new_type); + + return new_type; +} + +static const Node* process(Context* ctx, const Node* node) { + switch (node->tag) { + case GlobalVariable_TAG: + if (shd_lookup_annotation(node, "EntryPointArgs")) { + if (node->payload.global_variable.address_space != AsExternal) + shd_error("EntryPointArgs address space must be extern"); + + Nodes annotations = shd_rewrite_nodes(&ctx->rewriter, node->payload.global_variable.annotations); + const Node* type = rewrite_args_type(&ctx->rewriter, node->payload.global_variable.type); + + const Node* new_var = global_var(ctx->rewriter.dst_module, + annotations, + type, + node->payload.global_variable.name, + AsPushConstant + ); + + shd_register_processed(&ctx->rewriter, node, new_var); + + return new_var; + } + break; + default: break; + } + + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* shd_spvbe_pass_map_entrypoint_args(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index e941ac2df..552960de1 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -1,8 +1,15 @@ -add_library(common STATIC list.c dict.c log.c portability.c util.c growy.c arena.c printer.c) -target_include_directories(common INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) -target_link_libraries(common PRIVATE "$") +add_library(common list.c dict.c log.c portability.c util.c growy.c arena.c printer.c) set_property(TARGET common PROPERTY POSITION_INDEPENDENT_CODE ON) +# We need to export 'common' because otherwise when using static libraries we will not be able to resolve those symbols +install(TARGETS common EXPORT shady_export_set) + +# But we don't want projects outside this to be able to see these APIs and call into them +# (Also we couldn't since the header files live with the source anyways) +add_library(common_api INTERFACE) +target_include_directories(common_api INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(common INTERFACE "$") + add_executable(embedder embed.c) function(embed_file TYPE NAME SRC) @@ -14,3 +21,13 @@ function(embed_file TYPE NAME SRC) add_dependencies(${NAME} "${NAME}_h") target_include_directories(${NAME} INTERFACE ${CMAKE_CURRENT_BINARY_DIR}) endfunction() + +if (BUILD_TESTING) + add_executable(test_dict test_dict.c) + target_link_libraries(test_dict PRIVATE common) + add_test(NAME test_dict COMMAND test_dict) + + add_executable(test_util test_util.c) + target_link_libraries(test_util PRIVATE common) + add_test(NAME test_util COMMAND test_util) +endif () \ No newline at end of file diff --git a/src/common/arena.c b/src/common/arena.c index 2c8a37990..3b4525c64 100644 --- a/src/common/arena.c +++ b/src/common/arena.c @@ -19,12 +19,12 @@ inline static size_t round_up(size_t a, size_t b) { return divided * b; } -Arena* new_arena() { +Arena* shd_new_arena(void) { Arena* arena = malloc(sizeof(Arena)); *arena = (Arena) { .nblocks = 0, .maxblocks = 256, - .blocks = malloc(256 * sizeof(size_t)), + .blocks = malloc(256 * sizeof(void*)), .available = 0, }; for (int i = 0; i < arena->maxblocks; i++) @@ -32,7 +32,7 @@ Arena* new_arena() { return arena; } -void destroy_arena(Arena* arena) { +void shd_destroy_arena(Arena* arena) { for (int i = 0; i < arena->nblocks; i++) { free(arena->blocks[i]); } @@ -40,20 +40,39 @@ void destroy_arena(Arena* arena) { free(arena); } -void* arena_alloc(Arena* arena, size_t size) { +static void* new_block(Arena* arena, size_t size) { + assert(arena->nblocks <= arena->maxblocks); + // we need more storage for the block pointers themselves ! + if (arena->nblocks == arena->maxblocks) { + arena->maxblocks *= 2; + arena->blocks = realloc(arena->blocks, arena->maxblocks * sizeof(void*)); + } + + void* allocated = malloc(size); + assert(allocated); + arena->blocks[arena->nblocks++] = allocated; + return allocated; +} + +void* shd_arena_alloc(Arena* arena, size_t size) { size = round_up(size, (size_t) sizeof(max_align_t)); if (size == 0) return NULL; - // arena is full - if (size > arena->available) { - assert(arena->nblocks <= arena->maxblocks); - // we need more storage for the block pointers themselves ! - if (arena->nblocks == arena->maxblocks) { - arena->maxblocks *= 2; - arena->blocks = realloc(arena->blocks, arena->maxblocks); + if (size > alloc_size) { + void* allocated = new_block(arena, size); + memset(allocated, 0, size); + // swap the last two blocks + if (arena->nblocks >= 2) { + void* swap = arena->blocks[arena->nblocks - 1]; + arena->blocks[arena->nblocks - 1] = arena->blocks[arena->nblocks - 2]; + arena->blocks[arena->nblocks - 2] = swap; } + return allocated; + } - arena->blocks[arena->nblocks++] = malloc(alloc_size); + // arena is full + if (size > arena->available) { + new_block(arena, alloc_size); arena->available = alloc_size; } diff --git a/src/common/arena.h b/src/common/arena.h index baa16f5ae..45c3426cf 100644 --- a/src/common/arena.h +++ b/src/common/arena.h @@ -5,8 +5,8 @@ typedef struct Arena_ Arena; -Arena* new_arena(); -void destroy_arena(Arena* arena); -void* arena_alloc(Arena* arena, size_t size); +Arena* shd_new_arena(void); +void shd_destroy_arena(Arena* arena); +void* shd_arena_alloc(Arena* arena, size_t size); #endif diff --git a/src/common/dict.c b/src/common/dict.c index 2fbaf1af0..afa0294f9 100644 --- a/src/common/dict.c +++ b/src/common/dict.c @@ -1,12 +1,13 @@ #include "dict.h" +#include "log.h" + #include #include #include #include inline static size_t div_roundup(size_t a, size_t b) { - //return (a + b - 1) / b; if (a % b == 0) return a / b; else @@ -27,6 +28,9 @@ static size_t init_size = 32; struct BucketTag { bool is_present; bool is_thombstone; +#ifdef GOBLIB_DICT_DEBUG + KeyHash cached_hash; +#endif }; struct Dict { @@ -46,7 +50,46 @@ struct Dict { void* alloc; }; -struct Dict* new_dict_impl(size_t key_size, size_t value_size, size_t key_align, size_t value_align, KeyHash (*hash_fn)(void*), bool (*cmp_fn) (void*, void*)) { +#ifdef GOBLIB_DICT_DEBUG +static size_t dict_count_sanity(struct Dict* dict) { + size_t i = 0; + size_t count = 0; + while (dict_iter(dict, &i, NULL, NULL)) { + count++; + } + return count; +} + +static void validate_hashmap_integrity(const struct Dict* dict) { + const size_t alloc_base = (size_t) dict->alloc; + for (size_t i = 0; i < dict->size; i++) { + size_t bucket = alloc_base + i * dict->bucket_entry_size; + void* in_dict_key = (void*) bucket; + struct BucketTag tag = *(struct BucketTag*) (void*) (bucket + dict->tag_offset); + if (tag.is_present) { + KeyHash fresh_hash = dict->hash_fn(in_dict_key); + if (fresh_hash != tag.cached_hash) { + error("hash changed under our noses"); + } + } + } +} + +static void dump_dict_keys(struct Dict* dict) { + const size_t alloc_base = (size_t) dict->alloc; + for (size_t i = 0; i < dict->size; i++) { + size_t bucket = alloc_base + i * dict->bucket_entry_size; + void* in_dict_key = (void*) bucket; + struct BucketTag tag = *(struct BucketTag*) (void*) (bucket + dict->tag_offset); + if (tag.is_present) { + KeyHash hash = dict->hash_fn(in_dict_key); + printf("@i = %zu, hash = %d\n", i, hash); + } + } +} +#endif + +struct Dict* shd_new_dict_impl(size_t key_size, size_t value_size, size_t key_align, size_t value_align, KeyHash (*hash_fn)(void*), bool (*cmp_fn) (void*, void*)) { // offset of key is obviously zero size_t value_offset = align_offset(key_size, value_align); size_t tag_offset = align_offset(value_offset + value_size, alignof(struct BucketTag)); @@ -80,7 +123,7 @@ struct Dict* new_dict_impl(size_t key_size, size_t value_size, size_t key_align, return dict; } -struct Dict* clone_dict(struct Dict* source) { +struct Dict* shd_clone_dict(struct Dict* source) { struct Dict* dict = (struct Dict*) malloc(sizeof(struct Dict)); *dict = (struct Dict) { .entries_count = source->entries_count, @@ -100,25 +143,32 @@ struct Dict* clone_dict(struct Dict* source) { .alloc = malloc(source->bucket_entry_size * source->size) }; memcpy(dict->alloc, source->alloc, source->bucket_entry_size * source->size); +#ifdef GOBLIB_DICT_DEBUG + validate_hashmap_integrity(dict); + validate_hashmap_integrity(source); +#endif return dict; } -void destroy_dict(struct Dict* dict) { +void shd_destroy_dict(struct Dict* dict) { free(dict->alloc); free(dict); } -void clear_dict(struct Dict* dict) { +void shd_dict_clear(struct Dict* dict) { dict->entries_count = 0; dict->thombstones_count = 0; memset(dict->alloc, 0, dict->bucket_entry_size * dict->size); } -size_t entries_count_dict(struct Dict* dict) { +size_t shd_dict_count(struct Dict* dict) { return dict->entries_count; } -void* find_key_dict_impl(struct Dict* dict, void* key) { +void* shd_dict_find_impl(struct Dict* dict, void* key) { +#ifdef GOBLIB_DICT_DEBUG_PARANOID + validate_hashmap_integrity(dict); +#endif KeyHash hash = dict->hash_fn(key); size_t pos = hash % dict->size; const size_t init_pos = pos; @@ -146,15 +196,15 @@ void* find_key_dict_impl(struct Dict* dict, void* key) { return NULL; } -void* find_value_dict_impl(struct Dict* dict, void* key) { - void* found = find_key_dict_impl(dict, key); +void* shd_dict_find_value_impl(struct Dict* dict, void* key) { + void* found = shd_dict_find_impl(dict, key); if (found) return (void*) ((size_t)found + dict->value_offset); return NULL; } -bool remove_dict_impl(struct Dict* dict, void* key) { - void* found = find_key_dict_impl(dict, key); +bool shd_dict_remove_impl(struct Dict* dict, void* key) { + void* found = shd_dict_find_impl(dict, key); if (found) { struct BucketTag* tag = (void *) (((size_t) found) + dict->tag_offset); assert(tag->is_present && !tag->is_thombstone); @@ -167,42 +217,43 @@ bool remove_dict_impl(struct Dict* dict, void* key) { return false; } -bool insert_dict_impl(struct Dict* dict, void* key, void* value, void** out_ptr); -bool insert_dict_and_get_result_impl(struct Dict* dict, void* key, void* value) { +static bool dict_insert(struct Dict* dict, void* key, void* value, void** out_ptr); + +bool shd_dict_insert_impl(struct Dict* dict, void* key, void* value) { void* dont_care; - return insert_dict_impl(dict, key, value, &dont_care); + return dict_insert(dict, key, value, &dont_care); } -void* insert_dict_and_get_key_impl(struct Dict* dict, void* key, void* value) { +void* shd_dict_insert_get_key_impl(struct Dict* dict, void* key, void* value) { void* do_care; - insert_dict_impl(dict, key, value, &do_care); + dict_insert(dict, key, value, &do_care); return do_care; } -void* insert_dict_and_get_value_impl(struct Dict* dict, void* key, void* value) { +void* shd_dict_insert_get_value_impl(struct Dict* dict, void* key, void* value) { void* do_care; - insert_dict_impl(dict, key, value, &do_care); + dict_insert(dict, key, value, &do_care); return (void*) ((size_t)do_care + dict->value_offset); } static void rehash(struct Dict* dict, void* old_alloc, size_t old_size) { const size_t alloc_base = (size_t) old_alloc; // Go over all the old entries and add them back - for(size_t pos = 0; pos < old_size; pos++) { + for (size_t pos = 0; pos < old_size; pos++) { size_t bucket = alloc_base + pos * dict->bucket_entry_size; struct BucketTag* tag = (struct BucketTag*) (void*) (bucket + dict->tag_offset); if (tag->is_present) { void* key = (void*) bucket; void* value = (void*) (bucket + dict->value_offset); - insert_dict_and_get_result_impl(dict, key, value); + bool fresh = shd_dict_insert_impl(dict, key, value); + assert(fresh); } } } static void grow_and_rehash(struct Dict* dict) { - //printf("grow_rehash\n"); - size_t old_entries_count = entries_count_dict(dict); + size_t old_entries_count = shd_dict_count(dict); void* old_alloc = dict->alloc; size_t old_size = dict->size; @@ -215,12 +266,15 @@ static void grow_and_rehash(struct Dict* dict) { memset(dict->alloc, 0, dict->size * dict->bucket_entry_size); rehash(dict, old_alloc, old_size); - assert(old_entries_count == entries_count_dict(dict)); +#ifdef GOBLIB_DICT_DEBUG + assert(dict_count_sanity(dict) == entries_count_dict(dict)); +#endif + assert(old_entries_count == shd_dict_count(dict)); free(old_alloc); } -bool insert_dict_impl(struct Dict* dict, void* key, void* value, void** out_ptr) { +static bool dict_insert(struct Dict* dict, void* key, void* value, void** out_ptr) { float load_factor = (float) (dict->entries_count + dict->thombstones_count) / (float) dict->size; if (load_factor > 0.6) grow_and_rehash(dict); @@ -290,15 +344,22 @@ bool insert_dict_impl(struct Dict* dict, void* key, void* value, void** out_ptr) dst_tag->is_present = true; dst_tag->is_thombstone = false; +#ifdef GOBLIB_DICT_DEBUG + dst_tag->cached_hash = hash; +#endif memcpy(in_dict_key, key, dict->key_size); if (dict->value_size) memcpy(in_dict_value, value, dict->value_size); *out_ptr = in_dict_key; +#ifdef GOBLIB_DICT_DEBUG + validate_hashmap_integrity(dict); +#endif + return mode == Inserting; } -bool dict_iter(struct Dict* dict, size_t* iterator_state, void* key, void* value) { +bool shd_dict_iter(struct Dict* dict, size_t* iterator_state, void* key, void* value) { bool found_something = false; while (!found_something) { if (*iterator_state >= dict->size) { @@ -321,16 +382,25 @@ bool dict_iter(struct Dict* dict, size_t* iterator_state, void* key, void* value return true; } -#include "murmur3.h" +KeyHash shd_hash(const void* data, size_t size) { + const char* data_chars = (const char*) data; + const unsigned int fnv_prime = 0x811C9DC5; + unsigned int hash = 0; + unsigned int i = 0; + + for (i = 0; i < size; data_chars++, i++) + { + hash *= fnv_prime; + hash ^= (*data_chars); + } + + return hash; +} -KeyHash hash_murmur(const void* data, size_t size) { - int32_t out[4]; - MurmurHash3_x64_128(data, (int) size, 0x1234567, &out); +KeyHash shd_hash_ptr(void** p) { + return shd_hash(p, sizeof(void*)); +} - uint32_t final = 0; - final ^= out[0]; - final ^= out[1]; - final ^= out[2]; - final ^= out[3]; - return final; +bool shd_compare_ptrs(void** a, void** b) { + return *a == *b; } diff --git a/src/common/dict.h b/src/common/dict.h index 74fcdb5a4..5e83b12d8 100644 --- a/src/common/dict.h +++ b/src/common/dict.h @@ -6,44 +6,49 @@ #include #include +#include "portability.h" + typedef uint32_t KeyHash; typedef KeyHash (*HashFn)(void*); typedef bool (*CmpFn)(void*, void*); struct Dict; -#define new_dict(K, T, hash, cmp) new_dict_impl(sizeof(K), sizeof(T), alignof(K), alignof(T), hash, cmp) -#define new_set(K, hash, cmp) new_dict_impl(sizeof(K), 0, alignof(K), 0, hash, cmp) -struct Dict* new_dict_impl(size_t key_size, size_t value_size, size_t key_align, size_t value_align, KeyHash (*)(void*), bool (*)(void*, void*)); +#define shd_new_dict(K, T, hash, cmp) shd_new_dict_impl(sizeof(K), sizeof(T), alignof(K), alignof(T), hash, cmp) +#define shd_new_set(K, hash, cmp) shd_new_dict_impl(sizeof(K), 0, alignof(K), 0, hash, cmp) +struct Dict* shd_new_dict_impl(size_t key_size, size_t value_size, size_t key_align, size_t value_align, KeyHash (* hash_fn)(void*), bool (* cmp_fn)(void*, void*)); + +struct Dict* shd_clone_dict(struct Dict* source); +void shd_destroy_dict(struct Dict* dict); +void shd_dict_clear(struct Dict* dict); -struct Dict* clone_dict(struct Dict*); -void destroy_dict(struct Dict*); -void clear_dict(struct Dict*); +bool shd_dict_iter(struct Dict* dict, size_t* iterator_state, void* key, void* value); -bool dict_iter(struct Dict*, size_t* iterator_state, void* key, void* value); +size_t shd_dict_count(struct Dict* dict); -size_t entries_count_dict(struct Dict*); +#define shd_dict_find_value(K, T, dict, key) (T*) shd_dict_find_value_impl(dict, (void*) (&(key))) +#define shd_dict_find_key(K, dict, key) (K*) shd_dict_find_impl(dict, (void*) (&(key))) +void* shd_dict_find_impl(struct Dict*, void*); +void* shd_dict_find_value_impl(struct Dict*, void*); -#define find_value_dict(K, T, dict, key) (T*) find_value_dict_impl(dict, (void*) (&(key))) -#define find_key_dict(K, dict, key) (K*) find_key_dict_impl(dict, (void*) (&(key))) -void* find_key_dict_impl(struct Dict*, void*); -void* find_value_dict_impl(struct Dict*, void*); +#define shd_dict_remove(K, dict, key) shd_dict_remove_impl(dict, (void*) (&(key))) +bool shd_dict_remove_impl(struct Dict* dict, void* key); -#define remove_dict(K, dict, key) remove_dict_impl(dict, (void*) (&(key))) -bool remove_dict_impl(struct Dict* dict, void* key); +#define shd_dict_insert_get_value(K, V, dict, key, value) *(V*) shd_dict_insert_get_value_impl(dict, (void*) (&(key)), (void*) (&(value))) +#define shd_dict_insert(K, V, dict, key, value) shd_dict_insert_get_value_impl(dict, (void*) (&(key)), (void*) (&(value))) +void* shd_dict_insert_get_value_impl(struct Dict*, void* key, void* value); -#define insert_dict_and_get_value(K, V, dict, key, value) *(V*) insert_dict_and_get_value_impl(dict, (void*) (&(key)), (void*) (&(value))) -#define insert_dict(K, V, dict, key, value) insert_dict_and_get_value_impl(dict, (void*) (&(key)), (void*) (&(value))) -void* insert_dict_and_get_value_impl(struct Dict*, void* key, void* value); +#define shd_dict_insert_get_key(K, V, dict, key, value) *(K*) shd_dict_insert_get_key_impl(dict, (void*) (&(key)), (void*) (&(value))) +#define shd_set_insert_get_key(K, dict, key) *(K*) shd_dict_insert_get_key_impl(dict, (void*) (&(key)), NULL) +void* shd_dict_insert_get_key_impl(struct Dict*, void* key, void* value); -#define insert_dict_and_get_key(K, V, dict, key, value) *(K*) insert_dict_and_get_key_impl(dict, (void*) (&(key)), (void*) (&(value))) -#define insert_set_get_key(K, dict, key) *(K*) insert_dict_and_get_key_impl(dict, (void*) (&(key)), NULL) -void* insert_dict_and_get_key_impl(struct Dict*, void* key, void* value); +#define shd_dict_insert_get_result(K, V, dict, key, value) shd_dict_insert_impl(dict, (void*) (&(key)), (void*) (&(value))) +#define shd_set_insert_get_result(K, dict, key) shd_dict_insert_impl(dict, (void*) (&(key)), NULL) +bool shd_dict_insert_impl(struct Dict*, void* key, void* value); -#define insert_dict_and_get_result(K, V, dict, key, value) insert_dict_and_get_result_impl(dict, (void*) (&(key)), (void*) (&(value))) -#define insert_set_get_result(K, dict, key) insert_dict_and_get_result_impl(dict, (void*) (&(key)), NULL) -bool insert_dict_and_get_result_impl(struct Dict*, void* key, void* value); +KeyHash shd_hash(const void* data, size_t size); -KeyHash hash_murmur(const void* data, size_t size); +KeyHash shd_hash_ptr(void**); +bool shd_compare_ptrs(void**, void**); #endif diff --git a/src/common/growy.c b/src/common/growy.c index 9cc43b39c..95ae760e2 100644 --- a/src/common/growy.c +++ b/src/common/growy.c @@ -11,7 +11,7 @@ struct Growy_ { size_t used, size; }; -Growy* new_growy() { +Growy* shd_new_growy(void) { Growy* g = calloc(1, sizeof(Growy)); *g = (Growy) { .buffer = calloc(1, init_size), @@ -21,7 +21,7 @@ Growy* new_growy() { return g; } -void growy_append_bytes(Growy* g, size_t s, const char* bytes) { +void shd_growy_append_bytes(Growy* g, size_t s, const char* bytes) { size_t old_used = g->used; g->used += s; while (g->used >= g->size) { @@ -31,30 +31,30 @@ void growy_append_bytes(Growy* g, size_t s, const char* bytes) { memcpy(g->buffer + old_used, bytes, s); } -void growy_append_string(Growy* g, const char* str) { +void shd_growy_append_string(Growy* g, const char* str) { size_t len = strlen(str); - growy_append_bytes(g, len, str); + shd_growy_append_bytes(g, len, str); } -void format_string_internal(const char* str, va_list args, void* uptr, void callback(void*, size_t, char*)); +void shd_format_string_internal(const char* str, va_list args, void* uptr, void callback(void*, size_t, char*)); -void growy_append_formatted(Growy* g, const char* str, ...) { +void shd_growy_append_formatted(Growy* g, const char* str, ...) { va_list args; va_start(args, str); - format_string_internal(str, args, g, (void(*)(void*, size_t, char*)) growy_append_bytes); + shd_format_string_internal(str, args, g, (void (*)(void*, size_t, char*)) shd_growy_append_bytes); va_end(args); } -void destroy_growy(Growy* g) { +void shd_destroy_growy(Growy* g) { free(g->buffer); free(g); } -char* growy_deconstruct(Growy* g) { +char* shd_growy_deconstruct(Growy* g) { char* buf = g->buffer; free(g); return buf; } -size_t growy_size(const Growy* g) { return g->used; } -char* growy_data(const Growy* g) { return g->buffer; } +size_t shd_growy_size(const Growy* g) { return g->used; } +char* shd_growy_data(const Growy* g) { return g->buffer; } diff --git a/src/common/growy.h b/src/common/growy.h index e35b6a199..de0af8f55 100644 --- a/src/common/growy.h +++ b/src/common/growy.h @@ -7,16 +7,16 @@ /// Addresses not guaranteed to be stable. typedef struct Growy_ Growy; -Growy* new_growy(); -void growy_append_bytes(Growy*, size_t, const char*); -void growy_append_string(Growy*, const char*); -void growy_append_formatted(Growy* g, const char* str, ...); -#define growy_append_string_literal(a, v) growy_append_bytes(a, sizeof(v) - 1, (char*) &v) -#define growy_append_object(a, v) growy_append_bytes(a, sizeof(v), (char*) &v) -size_t growy_size(const Growy*); -char* growy_data(const Growy*); -void destroy_growy(Growy*g); +Growy* shd_new_growy(void); +void shd_growy_append_bytes(Growy*, size_t, const char*); +void shd_growy_append_string(Growy* g, const char* str); +void shd_growy_append_formatted(Growy* g, const char* str, ...); +#define shd_growy_append_string_literal(a, v) shd_growy_append_bytes(a, sizeof(v) - 1, (char*) &v) +#define shd_growy_append_object(a, v) shd_growy_append_bytes(a, sizeof(v), (char*) &v) +size_t shd_growy_size(const Growy* g); +char* shd_growy_data(const Growy* g); +void shd_destroy_growy(Growy*g); // Like destroy, but we scavenge the internal allocation for later use. -char* growy_deconstruct(Growy*); +char* shd_growy_deconstruct(Growy* g); #endif diff --git a/src/common/list.c b/src/common/list.c index da2f67fc0..5c344073b 100644 --- a/src/common/list.c +++ b/src/common/list.c @@ -6,7 +6,7 @@ #include "portability.h" -struct List* new_list_impl(size_t elem_size) { +struct List* shd_new_list_impl(size_t elem_size) { struct List* list = (struct List*) malloc(sizeof (struct List)); *list = (struct List) { .elements_count = 0, @@ -17,21 +17,21 @@ struct List* new_list_impl(size_t elem_size) { return list; } -void destroy_list(struct List* list) { +void shd_destroy_list(struct List* list) { free(list->alloc); free(list); } -size_t entries_count_list(struct List* list) { +size_t shd_list_count(struct List* list) { return list->elements_count; } -void grow_list(struct List* list) { +static void grow_list(struct List* list) { list->space = list->space * 2; list->alloc = realloc(list->alloc, list->space * list->element_size); } -void append_list_impl(struct List* list, void* element) { +void shd_list_append_impl(struct List* list, void* element) { if (list->elements_count == list->space) grow_list(list); size_t element_size = list->element_size; @@ -39,14 +39,14 @@ void append_list_impl(struct List* list, void* element) { list->elements_count++; } -void* pop_list_impl(struct List* list) { +void* shd_list_pop_impl(struct List* list) { assert(list->elements_count > 0); list->elements_count--; void* last = (void*) ((size_t)(list->alloc) + list->elements_count * list->element_size); return last; } -void add_list_impl(struct List* list, size_t index, void* element) { +void shd_list_insert_impl(struct List* list, size_t index, void* element) { size_t old_elements_count = list->elements_count; if (list->elements_count == list->space) grow_list(list); @@ -61,20 +61,7 @@ void add_list_impl(struct List* list, size_t index, void* element) { list->elements_count++; } -void delete_list_impl(struct List* list, size_t index) { - size_t old_elements_count = list->elements_count; - - size_t element_size = list->element_size; - - void* hole_at = (void*) ((size_t) list->alloc + element_size * index); - void* fill_with = (void*) ((size_t) list->alloc + element_size * (index + 1)); - size_t amount = old_elements_count - index; - memmove(hole_at, fill_with, element_size * amount); - - list->elements_count--; -} - -void* remove_list_impl(struct List* list, size_t index) { +void* shd_list_remove_impl(struct List* list, size_t index, bool move_to_end) { size_t old_elements_count = list->elements_count; size_t element_size = list->element_size; @@ -83,16 +70,19 @@ void* remove_list_impl(struct List* list, size_t index) { void* hole_at = (void*) ((size_t) list->alloc + element_size * index); void* fill_with = (void*) ((size_t) list->alloc + element_size * (index + 1)); size_t amount = old_elements_count - index; - memcpy(&temp, hole_at, element_size); + if (move_to_end) + memcpy(&temp, hole_at, element_size); memmove(hole_at, fill_with, element_size * amount); list->elements_count--; + if (!move_to_end) + return NULL; void* end = (void*) ((size_t) list->alloc + element_size * list->elements_count); memcpy(end, &temp, element_size); return end; } -void clear_list(struct List* list) { +void shd_clear_list(struct List* list) { list->elements_count = 0; } diff --git a/src/common/list.h b/src/common/list.h index 3196ae000..bd837f6ab 100644 --- a/src/common/list.h +++ b/src/common/list.h @@ -2,6 +2,7 @@ #define SHADY_LIST_H #include +#include struct List { size_t elements_count; @@ -10,31 +11,28 @@ struct List { void* alloc; }; -#define new_list(T) new_list_impl(sizeof(T)) -struct List* new_list_impl(size_t elem_size); +#define shd_new_list(T) shd_new_list_impl(sizeof(T)) +struct List* shd_new_list_impl(size_t elem_size); -void destroy_list(struct List* list); +void shd_destroy_list(struct List* list); -size_t entries_count_list(struct List* list); +size_t shd_list_count(struct List* list); -#define append_list(T, list, element) append_list_impl(list, (void*) &(element)) -void append_list_impl(struct List* list, void* element); +#define shd_list_append(T, list, element) shd_list_append_impl(list, (void*) &(element)) +void shd_list_append_impl(struct List* list, void* element); -#define pop_last_list(T, list) * ((T*) pop_list_impl(list)) -#define remove_last_list(T, list) (pop_list_impl(list)) -void* pop_list_impl(struct List* list); +#define shd_list_pop(T, list) * ((T*) shd_list_pop_impl(list)) +void* shd_list_pop_impl(struct List* list); -void clear_list(struct List* list); +void shd_clear_list(struct List* list); -#define add_list(T, list, i, e) add_list_impl(list, i, (void*) &(e)) -void add_list_impl(struct List* list, size_t index, void* element); +#define shd_list_insert(T, list, i, e) shd_list_insert_impl(list, i, (void*) &(e)) +void shd_list_insert_impl(struct List* list, size_t index, void* element); -#define delete_list(T, list, i) delete_list_impl(list, index) -void delete_list_impl(struct List* list, size_t index); +#define shd_list_delete(T, list, i) shd_list_remove_impl(list, index, false) +#define shd_list_remove(T, list, i) *((T*) shd_list_remove_impl(list, i, true)) +void* shd_list_remove_impl(struct List* list, size_t index, bool); -#define remove_list(T, list, i) *(T*) remove_list_impl(list, index) -void* remove_list_impl(struct List* list, size_t index); - -#define read_list(T, list) ((T*) (list)->alloc) +#define shd_read_list(T, list) ((T*) (list)->alloc) #endif diff --git a/src/common/log.c b/src/common/log.c index 66e27dbf6..58b4a7c7d 100644 --- a/src/common/log.c +++ b/src/common/log.c @@ -1,22 +1,25 @@ #include "log.h" #include -#include LogLevel shady_log_level = INFO; -LogLevel get_log_level() { +LogLevel shd_log_get_level(void) { return shady_log_level; } -void set_log_level(LogLevel l) { +void shd_log_set_level(LogLevel l) { shady_log_level = l; } -void log_string(LogLevel level, const char* format, ...) { +void shd_log_fmt_va_list(LogLevel level, const char* format, va_list args) { + if (level <= shady_log_level) + vfprintf(stderr, format, args); +} + +void shd_log_fmt(LogLevel level, const char* format, ...) { va_list args; va_start(args, format); - if (level >= shady_log_level) - vfprintf(stderr, format, args); + shd_log_fmt_va_list(level, format, args); va_end(args); } diff --git a/src/common/log.h b/src/common/log.h index dc18de3d9..637f24f73 100644 --- a/src/common/log.h +++ b/src/common/log.h @@ -2,32 +2,41 @@ #define SHADY_LOG_H #include +#include typedef struct Node_ Node; typedef struct Module_ Module; typedef enum LogLevel_ { - DEBUGVV, - DEBUGV, - DEBUG, - INFO, + ERROR, WARN, - ERROR + INFO, + DEBUG, + DEBUGV, + DEBUGVV, } LogLevel; -LogLevel get_log_level(); -void set_log_level(LogLevel); -void log_string(LogLevel level, const char* format, ...); -void log_node(LogLevel level, const Node* node); +LogLevel shd_log_get_level(void); +void shd_log_set_level(LogLevel l); +void shd_log_fmt_va_list(LogLevel level, const char* format, va_list args); +void shd_log_fmt(LogLevel level, const char* format, ...); +void shd_log_node(LogLevel level, const Node* node); typedef struct CompilerConfig_ CompilerConfig; -void log_module(LogLevel level, CompilerConfig*, Module*); - -#define debugvv_print(...) log_string(DEBUGVV, __VA_ARGS__) -#define debugv_print(...) log_string(DEBUGV, __VA_ARGS__) -#define debug_print(...) log_string(DEBUG, __VA_ARGS__) -#define info_print(...) log_string(INFO, __VA_ARGS__) -#define warn_print(...) log_string(WARN, __VA_ARGS__) -#define error_print(...) log_string(ERROR, __VA_ARGS__) +void shd_log_module(LogLevel level, const CompilerConfig* compiler_cfg, Module* mod); + +#define shd_debugvv_print(...) shd_log_fmt(DEBUGVV, __VA_ARGS__) +#define shd_debugv_print(...) shd_log_fmt(DEBUGV, __VA_ARGS__) +#define shd_debug_print(...) shd_log_fmt(DEBUG, __VA_ARGS__) +#define shd_info_print(...) shd_log_fmt(INFO, __VA_ARGS__) +#define shd_warn_print(...) shd_log_fmt(WARN, __VA_ARGS__) +#define shd_error_print(...) shd_log_fmt(ERROR, __VA_ARGS__) + +#define shd_debugvv_print_once(flag, ...) { static bool flag = false; if (!flag) { flag = true; shd_debugvv_print(__VA_ARGS__ ); } } +#define shd_debugv_print_once(flag, ...) { static bool flag = false; if (!flag) { flag = true; shd_debugv_print(__VA_ARGS__ ); } } +#define shd_debug_print_once(flag, ...) { static bool flag = false; if (!flag) { flag = true; shd_debug_print(__VA_ARGS__ ); } } +#define shd_info_print_once(flag, ...) { static bool flag = false; if (!flag) { flag = true; shd_info_print(__VA_ARGS__ ); } } +#define shd_warn_print_once(flag, ...) { static bool flag = false; if (!flag) { flag = true; shd_warn_print(__VA_ARGS__ ); } } +#define shd_error_print_once(flag, ...) { static bool flag = false; if (!flag) { flag = true; shd_error_print(__VA_ARGS__ ); } } #ifdef _MSC_VER #define SHADY_UNREACHABLE __assume(0) @@ -40,14 +49,14 @@ void log_module(LogLevel level, CompilerConfig*, Module*); SHADY_UNREACHABLE; \ } -#define error(...) { \ +#define shd_error(...) { \ fprintf (stderr, "Error at %s:%d: ", __FILE__, __LINE__); \ fprintf (stderr, __VA_ARGS__); \ fprintf (stderr, "\n"); \ - error_die(); \ - SHADY_UNREACHABLE; \ + shd_error_die(); \ } -void error_die(); +#include +noreturn void shd_error_die(void); #endif diff --git a/src/common/portability.c b/src/common/portability.c index 057a98a4c..7b392be0c 100644 --- a/src/common/portability.c +++ b/src/common/portability.c @@ -15,7 +15,7 @@ #endif -void platform_specific_terminal_init_extras() { +void shd_platform_specific_terminal_init_extras(void) { #ifdef NEED_COLOR_FIX HANDLE handle = GetStdHandle(STD_OUTPUT_HANDLE); if (handle != INVALID_HANDLE_VALUE) { @@ -28,6 +28,23 @@ void platform_specific_terminal_init_extras() { #endif } +#include +#if defined(__MINGW64__) | defined(__MINGW32__) +#include +uint64_t get_time_nano() { + struct timespec t; + clock_gettime(CLOCK_REALTIME, &t); + return t.tv_sec * 1000000000 + t.tv_nsec; +} +#else +#include +uint64_t shd_get_time_nano(void) { + struct timespec t; + timespec_get(&t, TIME_UTC); + return t.tv_sec * 1000000000 + t.tv_nsec; +} +#endif + #ifdef WIN32 #include #elif __APPLE__ @@ -37,7 +54,7 @@ void platform_specific_terminal_init_extras() { #include #include #endif -const char* get_executable_location(void) { +const char* shd_get_executable_location(void) { size_t len = 256; char* buf = calloc(len + 1, 1); #ifdef WIN32 diff --git a/src/common/portability.h b/src/common/portability.h index 722413a4c..c1e797fb4 100644 --- a/src/common/portability.h +++ b/src/common/portability.h @@ -15,10 +15,12 @@ static_assert(__STDC_VERSION__ >= 201112L, "C11 support is required to build sha #define SHADY_UNUSED #define LARRAY(T, name, size) T* name = alloca(sizeof(T) * (size)) #define alloca _alloca + #define popen _popen + #define pclose _pclose #define SHADY_FALLTHROUGH // It's mid 2022, and this typedef is missing from // MSVC is not a real C11 compiler. - typedef long long max_align_t; + typedef double max_align_t; #else #ifdef USE_VLAS #define LARRAY(T, name, size) T name[size] @@ -29,7 +31,7 @@ static_assert(__STDC_VERSION__ >= 201112L, "C11 support is required to build sha #define SHADY_FALLTHROUGH __attribute__((fallthrough)); #endif -static inline void* alloc_aligned(size_t size, size_t alignment) { +static inline void* shd_alloc_aligned(size_t size, size_t alignment) { #ifdef _WIN32 return _aligned_malloc(size, alignment); #else @@ -37,7 +39,7 @@ static inline void* alloc_aligned(size_t size, size_t alignment) { #endif } -static inline void free_aligned(void* ptr) { +static inline void shd_free_aligned(void* ptr) { #ifdef _WIN32 _aligned_free(ptr); #else @@ -45,8 +47,10 @@ static inline void free_aligned(void* ptr) { #endif } -const char* get_executable_location(void); +#include +uint64_t shd_get_time_nano(void); +const char* shd_get_executable_location(void); -void platform_specific_terminal_init_extras(); +void shd_platform_specific_terminal_init_extras(void); #endif diff --git a/src/common/printer.c b/src/common/printer.c index a73bb57ff..2b9988191 100644 --- a/src/common/printer.c +++ b/src/common/printer.c @@ -23,56 +23,56 @@ struct Printer_ { int indent; }; -Printer* open_file_as_printer(void* f) { +Printer* shd_new_printer_from_file(void* f) { Printer* p = calloc(1, sizeof(Printer)); p->output = PoFile; p->file = (FILE*) f; return p; } -Printer* open_growy_as_printer(Growy* g) { +Printer* shd_new_printer_from_growy(Growy* g) { Printer* p = calloc(1, sizeof(Printer)); p->output = PoGrowy; p->growy = g; return p; } -void destroy_printer(Printer* p) { +void shd_destroy_printer(Printer* p) { free(p); } -static void print_bare(Printer* p, size_t len, const char* str) { +static void shd_printer_print_raw(Printer* p, size_t len, const char* str) { assert(strlen(str) >= len); switch(p->output) { case PoFile: fwrite(str, sizeof(char), len, p->file); break; - case PoGrowy: growy_append_bytes(p->growy, len, str); + case PoGrowy: shd_growy_append_bytes(p->growy, len, str); } } -void flush(Printer* p) { +void shd_printer_flush(Printer* p) { switch(p->output) { case PoFile: fflush(p->file); break; case PoGrowy: break; } } -void indent(Printer* p) { +void shd_printer_indent(Printer* p) { p->indent++; } -void deindent(Printer* p) { +void shd_printer_deindent(Printer* p) { p->indent--; } -void newline(Printer* p) { - print_bare(p, 1, "\n"); +void shd_newline(Printer* p) { + shd_printer_print_raw(p, 1, "\n"); for (int i = 0; i < p->indent; i++) - print_bare(p, 4, " "); + shd_printer_print_raw(p, 4, " "); } #define LOCAL_BUFFER_SIZE 32 -Printer* print(Printer* p, const char* f, ...) { +Printer* shd_print(Printer* p, const char* f, ...) { size_t len = strlen(f) + 1; if (len == 1) return p; @@ -116,41 +116,24 @@ Printer* print(Printer* p, const char* f, ...) { size_t i = 0; while(i < written) { if (tmp[i] == '\n') { - print_bare(p, i - start, &tmp[start]); - newline(p); + shd_printer_print_raw(p, i - start, &tmp[start]); + shd_newline(p); start = i + 1; } i++; } if (start < i) - print_bare(p, i - start, &tmp[start]); + shd_printer_print_raw(p, i - start, &tmp[start]); free(alloc); return p; } -const char* printer_growy_unwrap(Printer* p) { +const char* shd_printer_growy_unwrap(Printer* p) { assert(p->output == PoGrowy); - const char* insides = growy_deconstruct(p->growy); + shd_growy_append_bytes(p->growy, 1, "\0"); + const char* insides = shd_growy_deconstruct(p->growy); free(p); return insides; } - -const char* replace_string(const char* source, const char* match, const char* replace_with) { - Growy* g = new_growy(); - size_t match_len = strlen(match); - size_t replace_len = strlen(replace_with); - const char* next_match = strstr(source, match); - while (next_match != NULL) { - size_t diff = next_match - source; - growy_append_bytes(g, diff, (char*) source); - growy_append_bytes(g, replace_len, (char*) replace_with); - source = next_match + match_len; - next_match = strstr(source, match); - } - growy_append_bytes(g, strlen(source), (char*) source); - char zero = '\0'; - growy_append_bytes(g, 1, &zero); - return growy_deconstruct(g); -} diff --git a/src/common/printer.h b/src/common/printer.h index 4a6a809d8..aff7b9d54 100644 --- a/src/common/printer.h +++ b/src/common/printer.h @@ -6,20 +6,20 @@ typedef struct Printer_ Printer; typedef struct Growy_ Growy; -Printer* open_file_as_printer(void* FILE); -Printer* open_growy_as_printer(Growy*); -void destroy_printer(Printer*); +Printer* shd_new_printer_from_file(void* FILE); +Printer* shd_new_printer_from_growy(Growy* g); +void shd_destroy_printer(Printer* p); -Printer* print(Printer*, const char*, ...); -void newline(Printer* p); -void indent(Printer* p); -void deindent(Printer* p); -void flush(Printer*); +Printer* shd_print(Printer*, const char*, ...); +void shd_newline(Printer* p); +void shd_printer_indent(Printer* p); +void shd_printer_deindent(Printer* p); +void shd_printer_flush(Printer* p); +void shd_printer_escape(Printer* p, const char*); +void shd_printer_unescape(Printer* p, const char*); -const char* printer_growy_unwrap(Printer* p); -Growy* new_growy(); -#define helper_format_string(f, ...) printer_growy_unwrap(cunk_print(cunk_open_growy_as_printer(cunk_new_growy()), (f), __VA_ARGS__)) - -const char* replace_string(const char* source, const char* match, const char* replace_with); +const char* shd_printer_growy_unwrap(Printer* p); +Growy* shd_new_growy(void); +#define shd_helper_format_string(f, ...) printer_growy_unwrap(cunk_print(cunk_open_growy_as_printer(cunk_new_growy()), (f), __VA_ARGS__)) #endif diff --git a/src/common/test_dict.c b/src/common/test_dict.c new file mode 100644 index 000000000..35f287b1c --- /dev/null +++ b/src/common/test_dict.c @@ -0,0 +1,91 @@ +#include "dict.h" +#include "log.h" + +#include +#include +#include + +// purposefully bad hash to make sure the collision handling is solid +KeyHash bad_hash_i32(int* i) { + return *i; +} + +bool compare_i32(int* pa, int* pb) { + return *pa == *pb; +} + +#define TEST_ENTRIES 10000 + +void shuffle(int arr[]) { + for (int i = 0; i < TEST_ENTRIES; i++) { + int a = rand() % TEST_ENTRIES; + int b = rand() % TEST_ENTRIES; + int tmp = arr[a]; + arr[a] = arr[b]; + arr[b] = tmp; + } +} + +int main(int argc, char** argv) { + srand((int) shd_get_time_nano()); + struct Dict* d = shd_new_set(int, (HashFn) bad_hash_i32, (CmpFn) compare_i32); + + int arr[TEST_ENTRIES]; + for (int i = 0; i < TEST_ENTRIES; i++) { + arr[i] = i; + } + + shuffle(arr); + + bool contained[TEST_ENTRIES]; + memset(contained, 0, sizeof(contained)); + + for (int i = 0; i < TEST_ENTRIES; i++) { + bool unique = shd_set_insert_get_result(int, d, arr[i]); + if (!unique) { + shd_error("Entry %d was thought to be already in the dict", arr[i]); + } + contained[arr[i]] = true; + } + + shuffle(arr); + for (int i = 0; i < TEST_ENTRIES; i++) { + assert(contained[arr[i]]); + assert(shd_dict_find_key(int, d, arr[i])); + } + + shuffle(arr); + for (int i = 0; i < rand() % TEST_ENTRIES; i++) { + assert(contained[arr[i]]); + bool removed = shd_dict_remove(int, d, arr[i]); + assert(removed); + contained[arr[i]] = false; + } + + shuffle(arr); + for (int i = 0; i < TEST_ENTRIES; i++) { + assert(!!shd_dict_find_key(int, d, arr[i]) == contained[arr[i]]); + } + + shuffle(arr); + for (int i = 0; i < TEST_ENTRIES; i++) { + assert(!!shd_dict_find_key(int, d, arr[i]) == contained[arr[i]]); + if (!contained[arr[i]]) { + bool unique = shd_set_insert_get_result(int, d, arr[i]); + if (!unique) { + shd_error("Entry %d was thought to be already in the dict", arr[i]); + } + contained[arr[i]] = true; + } + assert(contained[arr[i]]); + } + + shuffle(arr); + for (int i = 0; i < TEST_ENTRIES; i++) { + assert(contained[arr[i]]); + assert(shd_dict_find_key(int, d, arr[i])); + } + + shd_destroy_dict(d); + return 0; +} diff --git a/src/common/test_util.c b/src/common/test_util.c new file mode 100644 index 000000000..49373d880 --- /dev/null +++ b/src/common/test_util.c @@ -0,0 +1,62 @@ +#include "util.h" +#include "printer.h" + +#undef NDEBUG +#include +#include +#include +#include + +const char escaped[] = "hi\nthis is a backslash\\, \tthis is a tab and this backspace character ends it all\b"; +const char double_escaped[] = "hi\\nthis is a backslash\\\\, \\tthis is a tab and this backspace character ends it all\\b"; + +enum { + Len = sizeof(escaped), + DoubleLen = sizeof(double_escaped), + MaxLen = DoubleLen +}; + +void test_escape_unescape_basic(void) { + char output[MaxLen] = { 0 }; + + printf("escaped: %s\n---------------------\n", escaped); + printf("double_escaped: %s\n---------------------\n", double_escaped); + shd_apply_escape_codes(double_escaped, DoubleLen, output); + printf("shd_apply_escape_codes(double_escaped): %s\n---------------------\n", output); + assert(strcmp(output, escaped) == 0); + memset(output, 0, MaxLen); + shd_unapply_escape_codes(escaped, Len, output); + printf("shd_apply_escape_codes(escaped): %s\n---------------------\n", output); + assert(strcmp(output, double_escaped) == 0); +} + +void test_escape_printer(void) { + Printer* p = shd_new_printer_from_growy(shd_new_growy()); + shd_printer_escape(p, double_escaped); + const char* output = shd_printer_growy_unwrap(p); + printf("shd_printer_escape(escaped): %s\n---------------------\n", output); + assert(strlen(output) == Len - 1); + assert(strcmp(output, escaped) == 0); + free((char*) output); +} + +void test_unescape_printer(void) { + Printer* p = shd_new_printer_from_growy(shd_new_growy()); + shd_printer_unescape(p, escaped); + const char* output = shd_printer_growy_unwrap(p); + printf("shd_printer_unescape(escaped): %s\n---------------------\n", output); + assert(strlen(output) == DoubleLen - 1); + assert(strcmp(output, double_escaped) == 0); + free((char*) output); +} + +int main(int argc, char** argv) { + assert(strlen(double_escaped) == DoubleLen - 1); + assert(strlen(escaped) == Len - 1); + + test_escape_unescape_basic(); + test_escape_printer(); + test_unescape_printer(); + + return 0; +} \ No newline at end of file diff --git a/src/common/util.c b/src/common/util.c index 637834f8f..8ed07ed7b 100644 --- a/src/common/util.c +++ b/src/common/util.c @@ -1,4 +1,5 @@ #include "util.h" +#include "printer.h" #include "arena.h" #include @@ -19,25 +20,81 @@ X( 'f', '\f') \ X( 'a', '\a') \ X( 'v', '\v') \ -size_t apply_escape_codes(const char* src, size_t size, char* dst) { - char p, c = '\0'; +size_t shd_apply_escape_codes(const char* src, size_t size, char* dst) { + char prev, c = '\0'; size_t j = 0; for (size_t i = 0; i < size; i++) { - p = c; + prev = c; c = src[i]; -#define ESCAPE_CASE(m, s) if (p == '\\' && c == m) { \ +#define ESCAPE_CASE(m, s) if (prev == '\\' && c == m) { \ dst[j - 1] = s; \ continue; \ } \ ESCAPE_SEQS(ESCAPE_CASE) +#undef ESCAPE_CASE dst[j++] = c; } return j; } +size_t shd_unapply_escape_codes(const char* src, size_t size, char* dst) { + char c = '\0'; + size_t j = 0; + for (size_t i = 0; i < size; i++) { + c = src[i]; + +#define ESCAPE_CASE(m, s) if (c == s) { \ + dst[j++] = '\\'; \ + dst[j++] = m; \ + continue; \ + } \ + + ESCAPE_SEQS(ESCAPE_CASE) +#undef ESCAPE_CASE + + dst[j++] = c; + } + return j; +} + +void shd_printer_escape(Printer* p, const char* src) { + size_t size = strlen(src); + for (size_t i = 0; i < size; i++) { + char c = src[i]; + char next = i + 1 < size ? src[i + 1] : '\0'; + +#define ESCAPE_CASE(m, s) if (c == '\\' && next == m) { \ + char code = s; \ + shd_print(p, "%c", code); \ + i++; \ + continue; \ + } \ + + ESCAPE_SEQS(ESCAPE_CASE) +#undef ESCAPE_CASE + shd_print(p, "%c", c); + } +} + +void shd_printer_unescape(Printer* p, const char* src) { + size_t size = strlen(src); + for (size_t i = 0; i < size; i++) { + char c = src[i]; + +#define ESCAPE_CASE(m, s) if (c == s) { \ + shd_print(p, "\\%c", m); \ + continue; \ + } \ + + ESCAPE_SEQS(ESCAPE_CASE) +#undef ESCAPE_CASE + shd_print(p, "%c", c); + } +} + static long get_file_size(FILE* f) { if (fseek(f, 0, SEEK_END) != 0) return -1; @@ -53,7 +110,7 @@ static long get_file_size(FILE* f) { return fsize; } -bool read_file(const char* filename, size_t* size, char** output) { +bool shd_read_file(const char* filename, size_t* size, char** output) { FILE *f = fopen(filename, "rb"); if (f == NULL) return false; @@ -86,7 +143,7 @@ bool read_file(const char* filename, size_t* size, char** output) { return false; } -bool write_file(const char* filename, size_t size, const char* data) { +bool shd_write_file(const char* filename, size_t size, const char* data) { FILE* f = fopen(filename, "wb"); if (f == NULL) return false; @@ -109,7 +166,7 @@ enum { static char static_buffer[ThreadLocalStaticBufferSize]; -void format_string_internal(const char* str, va_list args, void* uptr, void callback(void*, size_t, char*)) { +void shd_format_string_internal(const char* str, va_list args, void* uptr, void callback(void*, size_t, char*)) { size_t buffer_size = ThreadLocalStaticBufferSize; int len; char* tmp; @@ -139,18 +196,18 @@ void format_string_internal(const char* str, va_list args, void* uptr, void call typedef struct { Arena* a; char** result; } InternInArenaPayload; static void intern_in_arena(InternInArenaPayload* uptr, size_t len, char* tmp) { - char* interned = (char*) arena_alloc(uptr->a, len + 1); + char* interned = (char*) shd_arena_alloc(uptr->a, len + 1); strncpy(interned, tmp, len); interned[len] = '\0'; *uptr->result = interned; } -char* format_string_arena(Arena* arena, const char* str, ...) { +char* shd_format_string_arena(Arena* arena, const char* str, ...) { char* result = NULL; InternInArenaPayload p = { .a = arena, .result = &result }; va_list args; va_start(args, str); - format_string_internal(str, args, &p, (void(*)(void*, size_t, char*)) intern_in_arena); + shd_format_string_internal(str, args, &p, (void (*)(void*, size_t, char*)) intern_in_arena); va_end(args); return result; } @@ -164,17 +221,17 @@ static void put_in_new(PutNewPayload* uptr, size_t len, char* tmp) { *uptr->result = allocated; } -char* format_string_new(const char* str, ...) { +char* shd_format_string_new(const char* str, ...) { char* result = NULL; PutNewPayload p = { .result = &result }; va_list args; va_start(args, str); - format_string_internal(str, args, &p, (void(*)(void*, size_t, char*)) put_in_new); + shd_format_string_internal(str, args, &p, (void (*)(void*, size_t, char*)) put_in_new); va_end(args); return result; } -bool string_starts_with(const char* string, const char* prefix) { +bool shd_string_starts_with(const char* string, const char* prefix) { size_t len = strlen(string); size_t slen = strlen(prefix); if (len < slen) @@ -182,7 +239,7 @@ bool string_starts_with(const char* string, const char* prefix) { return memcmp(string, prefix, slen) == 0; } -bool string_ends_with(const char* string, const char* suffix) { +bool shd_string_ends_with(const char* string, const char* suffix) { size_t len = strlen(string); size_t slen = strlen(suffix); if (len < slen) @@ -194,7 +251,7 @@ bool string_ends_with(const char* string, const char* suffix) { return true; } -char* strip_path(const char* path) { +char* shd_strip_path(const char* path) { char separator = strchr(path, '\\') == NULL ? '/' : '\\'; char* end = strrchr(path, separator); if (!end) { @@ -212,6 +269,6 @@ char* strip_path(const char* path) { return new; } -void error_die() { +void shd_error_die(void) { abort(); } diff --git a/src/common/util.h b/src/common/util.h index d4d93accb..f00e0634f 100644 --- a/src/common/util.h +++ b/src/common/util.h @@ -4,17 +4,18 @@ #include #include -size_t apply_escape_codes(const char* src, size_t og_len, char* dst); +size_t shd_apply_escape_codes(const char* src, size_t size, char* dst); +size_t shd_unapply_escape_codes(const char* src, size_t size, char* dst); -bool read_file(const char* filename, size_t* size, char** output); -bool write_file(const char* filename, size_t size, const char* data); +bool shd_read_file(const char* filename, size_t* size, char** output); +bool shd_write_file(const char* filename, size_t size, const char* data); typedef struct Arena_ Arena; -char* format_string_arena(Arena*, const char* str, ...); -char* format_string_new(const char* str, ...); -bool string_starts_with(const char* string, const char* prefix); -bool string_ends_with(const char* string, const char* suffix); +char* shd_format_string_arena(Arena* arena, const char* str, ...); +char* shd_format_string_new(const char* str, ...); +bool shd_string_starts_with(const char* string, const char* prefix); +bool shd_string_ends_with(const char* string, const char* suffix); -char* strip_path(const char*); +char* shd_strip_path(const char*); #endif diff --git a/src/driver/CMakeLists.txt b/src/driver/CMakeLists.txt index 8539478f8..16a872764 100644 --- a/src/driver/CMakeLists.txt +++ b/src/driver/CMakeLists.txt @@ -1,21 +1,10 @@ -add_library(driver STATIC driver.c cli.c) -target_link_libraries(driver PUBLIC "$") -set_property(TARGET driver PROPERTY POSITION_INDEPENDENT_CODE ON) +add_library(driver driver.c cli.c) +target_link_libraries(driver PUBLIC "api" common) +target_link_libraries(driver PRIVATE "$") +set_target_properties(driver PROPERTIES OUTPUT_NAME "shady_driver") +install(TARGETS driver EXPORT shady_export_set) add_executable(slim slim.c) target_link_libraries(slim PRIVATE driver) install(TARGETS slim EXPORT shady_export_set) -if (TARGET shady_s2s) - target_compile_definitions(driver PUBLIC SPV_PARSER_PRESENT) - target_link_libraries(driver PRIVATE shady_s2s) -endif() - -if (TARGET shady_fe_llvm) - target_link_libraries(driver PRIVATE shady_fe_llvm) - target_compile_definitions(driver PUBLIC LLVM_PARSER_PRESENT) - - add_executable(vcc vcc.c) - target_link_libraries(vcc PRIVATE driver api common) - install(TARGETS vcc EXPORT shady_export_set) -endif () diff --git a/src/driver/cli.c b/src/driver/cli.c index 9ad332cd7..c136a5b56 100644 --- a/src/driver/cli.c +++ b/src/driver/cli.c @@ -1,29 +1,30 @@ +#include "cli.h" + #include "shady/driver.h" #include "shady/ir.h" #include #include -#include #include "log.h" #include "portability.h" #include "list.h" #include "util.h" -CodegenTarget guess_target(const char* filename) { - if (string_ends_with(filename, ".c")) +CodegenTarget shd_guess_target(const char* filename) { + if (shd_string_ends_with(filename, ".c")) return TgtC; - else if (string_ends_with(filename, "glsl")) + else if (shd_string_ends_with(filename, "glsl")) return TgtGLSL; - else if (string_ends_with(filename, "spirv") || string_ends_with(filename, "spv")) + else if (shd_string_ends_with(filename, "spirv") || shd_string_ends_with(filename, "spv")) return TgtSPV; - else if (string_ends_with(filename, "ispc")) + else if (shd_string_ends_with(filename, "ispc")) return TgtISPC; - error_print("No target has been specified, and output filename '%s' did not allow guessing the right one\n"); + shd_error_print("No target has been specified, and output filename '%s' did not allow guessing the right one\n"); exit(InvalidTarget); } -void cli_pack_remaining_args(int* pargc, char** argv) { +void shd_pack_remaining_args(int* pargc, char** argv) { LARRAY(char*, nargv, *pargc); int nargc = 0; for (size_t i = 0; i < *pargc; i++) { @@ -34,7 +35,7 @@ void cli_pack_remaining_args(int* pargc, char** argv) { *pargc = nargc; } -void cli_parse_common_args(int* pargc, char** argv) { +void shd_parse_common_args(int* pargc, char** argv) { int argc = *pargc; bool help = false; @@ -47,22 +48,22 @@ void cli_parse_common_args(int* pargc, char** argv) { if (i == argc) goto incorrect_log_level; if (strcmp(argv[i], "debugvv") == 0) - set_log_level(DEBUGVV); + shd_log_set_level(DEBUGVV); else if (strcmp(argv[i], "debugv") == 0) - set_log_level(DEBUGV); + shd_log_set_level(DEBUGV); else if (strcmp(argv[i], "debug") == 0) - set_log_level(DEBUG); + shd_log_set_level(DEBUG); else if (strcmp(argv[i], "info") == 0) - set_log_level(INFO); + shd_log_set_level(INFO); else if (strcmp(argv[i], "warn") == 0) - set_log_level(WARN); + shd_log_set_level(WARN); else if (strcmp(argv[i], "error") == 0) - set_log_level(ERROR); + shd_log_set_level(ERROR); else { incorrect_log_level: - error_print("--log-level argument takes one of: "); - error_print("debug, info, warn, error"); - error_print("\n"); + shd_error_print("--log-level argument takes one of: "); + shd_error_print("debug, info, warn, error"); + shd_error_print("\n"); exit(IncorrectLogLevel); } } else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { @@ -75,55 +76,92 @@ void cli_parse_common_args(int* pargc, char** argv) { } if (help) { - error_print(" --log-level debug[v[v]], info, warn, error]\n"); + shd_error_print(" --log-level debug[v[v]], info, warn, error]\n"); } - cli_pack_remaining_args(pargc, argv); + shd_pack_remaining_args(pargc, argv); } -void cli_parse_compiler_config_args(CompilerConfig* config, int* pargc, char** argv) { +#define COMPILER_CONFIG_TOGGLE_OPTIONS(F) \ +F(config->lower.emulate_physical_memory, emulate-physical-memory) \ +F(config->lower.emulate_generic_ptrs, emulate-generic-pointers) \ +F(config->dynamic_scheduling, dynamic-scheduling) \ +F(config->hacks.force_join_point_lifting, lift-join-points) \ +F(config->logging.print_internal, print-internal) \ +F(config->logging.print_generated, print-builtin) \ +F(config->logging.print_generated, print-generated) \ +F(config->optimisations.inline_everything, inline-everything) \ +F(config->input_cf.restructure_with_heuristics, restructure-everything) \ +F(config->input_cf.add_scope_annotations, add-scope-annotations) \ +F(config->input_cf.has_scope_annotations, has-scope-annotations) \ + +static IntSizes parse_int_size(String argv) { + if (strcmp(argv, "8") == 0) + return IntTy8; + if (strcmp(argv, "16") == 0) + return IntTy16; + if (strcmp(argv, "32") == 0) + return IntTy32; + if (strcmp(argv, "64") == 0) + return IntTy64; + shd_error("Valid pointer sizes are 8, 16, 32 or 64."); +} + +void shd_parse_compiler_config_args(CompilerConfig* config, int* pargc, char** argv) { int argc = *pargc; bool help = false; for (int i = 1; i < argc; i++) { if (argv[i] == NULL) continue; - if (strcmp(argv[i], "--no-dynamic-scheduling") == 0) { - config->dynamic_scheduling = false; - } else if (strcmp(argv[i], "--lift-join-points") == 0) { - config->hacks.force_join_point_lifting = true; - } else if (strcmp(argv[i], "--entry-point") == 0) { + + COMPILER_CONFIG_TOGGLE_OPTIONS(PARSE_TOGGLE_OPTION) + + if (strcmp(argv[i], "--entry-point") == 0) { argv[i] = NULL; i++; if (i == argc) - error("Missing entry point name"); + shd_error("Missing entry point name"); config->specialization.entry_point = argv[i]; } else if (strcmp(argv[i], "--subgroup-size") == 0) { argv[i] = NULL; i++; if (i == argc) - error("Missing subgroup size name"); + shd_error("Missing subgroup size"); config->specialization.subgroup_size = atoi(argv[i]); + } else if (strcmp(argv[i], "--stack-size") == 0) { + argv[i] = NULL; + i++; + if (i == argc) + shd_error("Missing stack size"); + config->per_thread_stack_size = atoi(argv[i]); } else if (strcmp(argv[i], "--execution-model") == 0) { argv[i] = NULL; i++; if (i == argc) - error("Missing execution model name"); + shd_error("Missing execution model name"); ExecutionModel em = EmNone; #define EM(n, _) if (strcmp(argv[i], #n) == 0) em = Em##n; EXECUTION_MODELS(EM) #undef EM if (em == EmNone) - error("Unknown execution model: %s", argv[i]); + shd_error("Unknown execution model: %s", argv[i]); + switch (em) { + case EmFragment: + case EmVertex: + config->dynamic_scheduling = false; + break; + default: break; + } config->specialization.execution_model = em; - } else if (strcmp(argv[i], "--simt2d") == 0) { - config->lower.simt_to_explicit_simd = true; - } else if (strcmp(argv[i], "--print-internal") == 0) { - config->logging.skip_internal = false; - } else if (strcmp(argv[i], "--print-generated") == 0) { - config->logging.skip_generated = false; - } else if (strcmp(argv[i], "--no-physical-global-ptrs") == 0) { - config->hacks.no_physical_global_ptrs = true; + } else if (strcmp(argv[i], "--word-size") == 0) { + argv[i] = NULL; + i++; + config->target.memory.word_size = parse_int_size(argv[i]); + } else if (strcmp(argv[i], "--pointer-size") == 0) { + argv[i] = NULL; + i++; + config->target.memory.ptr_size = parse_int_size(argv[i]); } else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { help = true; continue; @@ -134,51 +172,54 @@ void cli_parse_compiler_config_args(CompilerConfig* config, int* pargc, char** a } if (help) { - error_print(" --print-internal Includes internal functions in the debug output\n"); - error_print(" --print-generated Includes generated functions in the debug output\n"); - error_print(" --no-dynamic-scheduling Disable the built-in dynamic scheduler, restricts code to only leaf functions\n"); - error_print(" --simt2d Emits SIMD code instead of SIMT, only effective with the C backend.\n"); - error_print(" --entry-point Selects an entry point for the program to be specialized on.\n"); + shd_error_print(" --shd_print-internal Includes internal functions in the debug output\n"); + shd_error_print(" --shd_print-generated Includes generated functions in the debug output\n"); + shd_error_print(" --no-dynamic-scheduling Disable the built-in dynamic scheduler, restricts code to only leaf functions\n"); + shd_error_print(" --simt2d Emits SIMD code instead of SIMT, only effective with the C backend.\n"); + shd_error_print(" --entry-point Selects an entry point for the program to be specialized on.\n"); + shd_error_print(" --word-size <8|16|32|64> Sets the word size for physical memory emulation (default=32)\n"); + shd_error_print(" --pointer-size <8|16|32|64> Sets the pointer size for physical pointers (default=64)\n"); #define EM(name, _) #name", " - error_print(" --execution-model Selects an entry point for the program to be specialized on.\nPossible values: " EXECUTION_MODELS(EM)); + shd_error_print(" --execution-model Selects an entry point for the program to be specialized on.\nPossible values: " EXECUTION_MODELS(EM)); #undef EM - error_print(" --subgroup-size N Sets the subgroup size the program will be specialized for.\n"); - error_print(" --lift-join-points Forcefully lambda-lifts all join points. Can help with reconvergence issues.\n"); + shd_error_print(" --subgroup-size N Sets the subgroup size the program will be specialized for.\n"); + shd_error_print(" --lift-join-points Forcefully lambda-lifts all join points. Can help with reconvergence issues.\n"); } - cli_pack_remaining_args(pargc, argv); + shd_pack_remaining_args(pargc, argv); } -void cli_parse_input_files(struct List* list, int* pargc, char** argv) { +void shd_driver_parse_input_files(struct List* list, int* pargc, char** argv) { int argc = *pargc; for (int i = 1; i < argc; i++) { if (argv[i] == NULL) continue; - append_list(const char*, list, argv[i]); + shd_list_append(const char*, list, argv[i]); argv[i] = NULL; } - cli_pack_remaining_args(pargc, argv); + shd_pack_remaining_args(pargc, argv); assert(*pargc == 1); } -DriverConfig default_driver_config() { +DriverConfig shd_default_driver_config(void) { return (DriverConfig) { - .config = default_compiler_config(), + .config = shd_default_compiler_config(), .target = TgtAuto, - .input_filenames = new_list(const char*), + .input_filenames = shd_new_list(const char*), .output_filename = NULL, .cfg_output_filename = NULL, .shd_output_filename = NULL, + .c_emitter_config = shd_default_c_emitter_config(), }; } -void destroy_driver_config(DriverConfig* config) { - destroy_list(config->input_filenames); +void shd_destroy_driver_config(DriverConfig* config) { + shd_destroy_list(config->input_filenames); } -void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv) { +void shd_parse_driver_args(DriverConfig* args, int* pargc, char** argv) { int argc = *pargc; bool help = false; @@ -187,7 +228,7 @@ void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv) { argv[i] = NULL; i++; if (i == argc) { - error_print("--output must be followed with a filename"); + shd_error_print("--output must be followed with a filename"); exit(MissingOutputArg); } args->output_filename = argv[i]; @@ -195,7 +236,7 @@ void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv) { argv[i] = NULL; i++; if (i == argc) { - error_print("--dump-cfg must be followed with a filename"); + shd_error_print("--dump-cfg must be followed with a filename"); exit(MissingDumpCfgArg); } args->cfg_output_filename = argv[i]; @@ -203,7 +244,7 @@ void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv) { argv[i] = NULL; i++; if (i == argc) { - error_print("--dump-loop-tree must be followed with a filename"); + shd_error_print("--dump-loop-tree must be followed with a filename"); exit(MissingDumpCfgArg); } args->loop_tree_output_filename = argv[i]; @@ -211,10 +252,14 @@ void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv) { argv[i] = NULL; i++; if (i == argc) { - error_print("--dump-ir must be followed with a filename"); + shd_error_print("--dump-ir must be followed with a filename"); exit(MissingDumpIrArg); } args->shd_output_filename = argv[i]; + } else if (strcmp(argv[i], "--glsl-version") == 0) { + argv[i] = NULL; + i++; + args->c_emitter_config.glsl_version = strtol(argv[i], NULL, 10); } else if (strcmp(argv[i], "--target") == 0) { argv[i] = NULL; i++; @@ -233,7 +278,7 @@ void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv) { argv[i] = NULL; continue; invalid_target: - error_print("--target must be followed with a valid target (see help for list of targets)"); + shd_error_print("--target must be followed with a valid target (see help for list of targets)"); exit(InvalidTarget); } else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { help = true; @@ -245,14 +290,14 @@ void cli_parse_driver_arguments(DriverConfig* args, int* pargc, char** argv) { } if (help) { - // error_print("Usage: slim source.slim\n"); - // error_print("Available arguments: \n"); - error_print(" --target \n"); - error_print(" --output , -o \n"); - error_print(" --dump-cfg Dumps the control flow graph of the final IR\n"); - error_print(" --dump-loop-tree \n"); - error_print(" --dump-ir Dumps the final IR\n"); + // shd_error_print("Usage: slim source.slim\n"); + // shd_error_print("Available arguments: \n"); + shd_error_print(" --target \n"); + shd_error_print(" --output , -o \n"); + shd_error_print(" --dump-cfg Dumps the control flow graph of the final IR\n"); + shd_error_print(" --dump-loop-tree \n"); + shd_error_print(" --dump-ir Dumps the final IR\n"); } - cli_pack_remaining_args(pargc, argv); -} \ No newline at end of file + shd_pack_remaining_args(pargc, argv); +} diff --git a/src/driver/cli.h b/src/driver/cli.h new file mode 100644 index 000000000..f7e1995f1 --- /dev/null +++ b/src/driver/cli.h @@ -0,0 +1,15 @@ +#ifndef CLI_H +#define CLI_H + +#include + +#define PARSE_TOGGLE_OPTION(f, name) \ +if (strcmp(argv[i], "--no-"#name) == 0) { \ + f = false; argv[i] = NULL; continue; \ +} else if (strcmp(argv[i], "--"#name) == 0) { \ + f = true; argv[i] = NULL; continue; \ +} + +void shd_pack_remaining_args(int* pargc, char** argv); + +#endif diff --git a/src/driver/driver.c b/src/driver/driver.c index 7b09567c6..9edf18124 100644 --- a/src/driver/driver.c +++ b/src/driver/driver.c @@ -1,11 +1,14 @@ #include "shady/ir.h" #include "shady/driver.h" +#include "shady/print.h" +#include "shady/be/c.h" +#include "shady/be/spirv.h" +#include "shady/be/dump.h" -#include "frontends/slim/parser.h" +#include "../frontend/slim/parser.h" #include "list.h" #include "util.h" - #include "log.h" #include @@ -13,34 +16,35 @@ #include #ifdef LLVM_PARSER_PRESENT -#include "../frontends/llvm/l2s.h" +#include "../frontend/llvm/l2s.h" #endif #ifdef SPV_PARSER_PRESENT -#include "../frontends/spirv/s2s.h" +#include "../frontend/spirv/s2s.h" #endif #pragma GCC diagnostic error "-Wswitch" -SourceLanguage guess_source_language(const char* filename) { - if (string_ends_with(filename, ".ll") || string_ends_with(filename, ".bc")) +SourceLanguage shd_driver_guess_source_language(const char* filename) { + if (shd_string_ends_with(filename, ".ll") || shd_string_ends_with(filename, ".bc")) return SrcLLVM; - else if (string_ends_with(filename, ".spv")) + else if (shd_string_ends_with(filename, ".spv")) return SrcSPIRV; - else if (string_ends_with(filename, ".slim")) + else if (shd_string_ends_with(filename, ".slim")) return SrcSlim; - else if (string_ends_with(filename, ".slim")) + else if (shd_string_ends_with(filename, ".slim")) return SrcShadyIR; - warn_print("unknown filename extension '%s', interpreting as Slim sourcecode by default."); + shd_warn_print("unknown filename extension '%s', interpreting as Slim sourcecode by default.", filename); return SrcSlim; } -ShadyErrorCodes driver_load_source_file(SourceLanguage lang, size_t len, const char* file_contents, Module* mod) { +ShadyErrorCodes shd_driver_load_source_file(const CompilerConfig* config, SourceLanguage lang, size_t len, const char* file_contents, String name, Module** mod) { switch (lang) { case SrcLLVM: { #ifdef LLVM_PARSER_PRESENT - parse_llvm_into_shady(mod, len, file_contents); + bool ok = shd_parse_llvm(config, len, file_contents, name, mod); + assert(ok); #else assert(false && "LLVM front-end missing in this version"); #endif @@ -48,7 +52,7 @@ ShadyErrorCodes driver_load_source_file(SourceLanguage lang, size_t len, const c } case SrcSPIRV: { #ifdef SPV_PARSER_PRESENT - parse_spirv_into_shady(mod, len, file_contents); + shd_parse_spirv(config, len, file_contents, name, mod); #else assert(false && "SPIR-V front-end missing in this version"); #endif @@ -56,81 +60,86 @@ ShadyErrorCodes driver_load_source_file(SourceLanguage lang, size_t len, const c } case SrcShadyIR: case SrcSlim: { - ParserConfig pconfig = { - .front_end = lang == SrcSlim, + SlimParserConfig pconfig = { + .front_end = lang == SrcSlim, }; - debugv_print("Parsing: \n%s\n", file_contents); - parse_shady_ir(pconfig, (const char*) file_contents, mod); + shd_debugvv_print("Parsing: \n%s\n", file_contents); + *mod = shd_parse_slim_module(config, &pconfig, (const char*) file_contents, name); } } return NoError; } -ShadyErrorCodes driver_load_source_file_from_filename(const char* filename, Module* mod) { +ShadyErrorCodes shd_driver_load_source_file_from_filename(const CompilerConfig* config, const char* filename, String name, Module** mod) { ShadyErrorCodes err; - SourceLanguage lang = guess_source_language(filename); + SourceLanguage lang = shd_driver_guess_source_language(filename); size_t len; char* contents; assert(filename); - bool ok = read_file(filename, &len, &contents); + bool ok = shd_read_file(filename, &len, &contents); if (!ok) { - error_print("Failed to read file '%s'\n", filename); + shd_error_print("Failed to read file '%s'\n", filename); err = InputFileIOError; goto exit; } if (contents == NULL) { - error_print("file does not exist\n"); + shd_error_print("file does not exist\n"); err = InputFileDoesNotExist; goto exit; } - err = driver_load_source_file(lang, len, contents, mod); - exit: + err = shd_driver_load_source_file(config, lang, len, contents, name, mod); free((void*) contents); + exit: return err; } -ShadyErrorCodes driver_load_source_files(DriverConfig* args, Module* mod) { - if (entries_count_list(args->input_filenames) == 0) { - error_print("Missing input file. See --help for proper usage"); +ShadyErrorCodes shd_driver_load_source_files(DriverConfig* args, Module* mod) { + if (shd_list_count(args->input_filenames) == 0) { + shd_error_print("Missing input file. See --help for proper usage"); return MissingInputArg; } - size_t num_source_files = entries_count_list(args->input_filenames); + size_t num_source_files = shd_list_count(args->input_filenames); for (size_t i = 0; i < num_source_files; i++) { - int err = driver_load_source_file_from_filename(read_list(const char*, args->input_filenames)[i], mod); + Module* m; + int err = shd_driver_load_source_file_from_filename(&args->config, + shd_read_list(const char*, args->input_filenames)[i], + shd_read_list(const char*, args->input_filenames)[i], &m); if (err) return err; + shd_module_link(mod, m); + shd_destroy_ir_arena(shd_module_get_arena(m)); } return NoError; } -ShadyErrorCodes driver_compile(DriverConfig* args, Module* mod) { - debugv_print("Parsed program successfully: \n"); - log_module(DEBUGV, &args->config, mod); +ShadyErrorCodes shd_driver_compile(DriverConfig* args, Module* mod) { + shd_debugv_print("Parsed program successfully: \n"); + shd_log_module(DEBUGV, &args->config, mod); - CompilationResult result = run_compiler_passes(&args->config, &mod); + CompilationResult result = shd_run_compiler_passes(&args->config, &mod); if (result != CompilationNoError) { - error_print("Compilation pipeline failed, errcode=%d\n", (int) result); + shd_error_print("Compilation pipeline failed, errcode=%d\n", (int) result); exit(result); } - debug_print("Ran all passes successfully\n"); - log_module(DEBUG, &args->config, mod); + shd_debug_print("Ran all passes successfully\n"); + shd_log_module(DEBUG, &args->config, mod); if (args->cfg_output_filename) { FILE* f = fopen(args->cfg_output_filename, "wb"); assert(f); - dump_cfg(f, mod); + shd_dump_cfgs(f, mod); fclose(f); - debug_print("CFG dumped\n"); + shd_debug_print("CFG dumped\n"); } if (args->loop_tree_output_filename) { FILE* f = fopen(args->loop_tree_output_filename, "wb"); assert(f); - dump_loop_trees(f, mod); + shd_dump_loop_trees(f, mod); fclose(f); - debug_print("Loop tree dumped\n"); + shd_debug_print("Loop tree dumped\n"); } if (args->shd_output_filename) { @@ -138,40 +147,40 @@ ShadyErrorCodes driver_compile(DriverConfig* args, Module* mod) { assert(f); size_t output_size; char* output_buffer; - print_module_into_str(mod, &output_buffer, &output_size); + shd_print_module_into_str(mod, &output_buffer, &output_size); fwrite(output_buffer, output_size, 1, f); free((void*) output_buffer); fclose(f); - debug_print("IR dumped\n"); + shd_debug_print("IR dumped\n"); } if (args->output_filename) { if (args->target == TgtAuto) - args->target = guess_target(args->output_filename); + args->target = shd_guess_target(args->output_filename); FILE* f = fopen(args->output_filename, "wb"); size_t output_size; char* output_buffer; switch (args->target) { case TgtAuto: SHADY_UNREACHABLE; - case TgtSPV: emit_spirv(&args->config, mod, &output_size, &output_buffer, NULL); break; + case TgtSPV: shd_emit_spirv(&args->config, mod, &output_size, &output_buffer, NULL); break; case TgtC: - args->c_emitter_config.dialect = C; - emit_c(args->config, args->c_emitter_config, mod, &output_size, &output_buffer, NULL); + args->c_emitter_config.dialect = CDialect_C11; + shd_emit_c(&args->config, args->c_emitter_config, mod, &output_size, &output_buffer, NULL); break; case TgtGLSL: - args->c_emitter_config.dialect = GLSL; - emit_c(args->config, args->c_emitter_config, mod, &output_size, &output_buffer, NULL); + args->c_emitter_config.dialect = CDialect_GLSL; + shd_emit_c(&args->config, args->c_emitter_config, mod, &output_size, &output_buffer, NULL); break; case TgtISPC: - args->c_emitter_config.dialect = ISPC; - emit_c(args->config, args->c_emitter_config, mod, &output_size, &output_buffer, NULL); + args->c_emitter_config.dialect = CDialect_ISPC; + shd_emit_c(&args->config, args->c_emitter_config, mod, &output_size, &output_buffer, NULL); break; } - debug_print("Wrote result to %s\n", args->output_filename); + shd_debug_print("Wrote result to %s\n", args->output_filename); fwrite(output_buffer, output_size, 1, f); free((void*) output_buffer); fclose(f); } - destroy_ir_arena(get_module_arena(mod)); + shd_destroy_ir_arena(shd_module_get_arena(mod)); return NoError; } diff --git a/src/driver/slim.c b/src/driver/slim.c index d860e6e61..ae6c8fa97 100644 --- a/src/driver/slim.c +++ b/src/driver/slim.c @@ -6,30 +6,28 @@ #include "util.h" #include "portability.h" -#include - -#ifndef HOOK_STUFF -#define HOOK_STUFF -#endif - int main(int argc, char** argv) { - platform_specific_terminal_init_extras(); + shd_platform_specific_terminal_init_extras(); - DriverConfig args = default_driver_config(); - HOOK_STUFF - cli_parse_driver_arguments(&args, &argc, argv); - cli_parse_common_args(&argc, argv); - cli_parse_compiler_config_args(&args.config, &argc, argv); - cli_parse_input_files(args.input_filenames, &argc, argv); + DriverConfig args = shd_default_driver_config(); + shd_parse_driver_args(&args, &argc, argv); + shd_parse_common_args(&argc, argv); + shd_parse_compiler_config_args(&args.config, &argc, argv); + shd_driver_parse_input_files(args.input_filenames, &argc, argv); - IrArena* arena = new_ir_arena(default_arena_config()); - Module* mod = new_module(arena, "my_module"); // TODO name module after first filename, or perhaps the last one + ArenaConfig aconfig = shd_default_arena_config(&args.config.target); + IrArena* arena = shd_new_ir_arena(&aconfig); + Module* mod = shd_new_module(arena, "my_module"); // TODO name module after first filename, or perhaps the last one - driver_load_source_files(&args, mod); + ShadyErrorCodes err = shd_driver_load_source_files(&args, mod); + if (err) + exit(err); - driver_compile(&args, mod); - info_print("Done\n"); + err = shd_driver_compile(&args, mod); + if (err) + exit(err); + shd_info_print("Compilation successful\n"); - destroy_ir_arena(arena); - destroy_driver_config(&args); + shd_destroy_ir_arena(arena); + shd_destroy_driver_config(&args); } diff --git a/src/driver/vcc.c b/src/driver/vcc.c deleted file mode 100644 index 9fe8522ee..000000000 --- a/src/driver/vcc.c +++ /dev/null @@ -1,136 +0,0 @@ -#include "shady/ir.h" -#include "shady/driver.h" - -#include "log.h" -#include "list.h" -#include "util.h" -#include "growy.h" -#include "portability.h" - -#include -#include - -typedef struct { - char* tmp_filename; - bool delete_tmp_file; -} VccOptions; - -static void cli_parse_vcc_args(VccOptions* options, int* pargc, char** argv) { - int argc = *pargc; - - for (int i = 1; i < argc; i++) { - if (argv[i] == NULL) - continue; - else if (strcmp(argv[i], "--vcc-keep-tmp-file") == 0) { - argv[i] = NULL; - options->delete_tmp_file = false; - options->tmp_filename = "vcc_tmp.ll"; - continue; - } - } - - cli_pack_remaining_args(pargc, argv); -} - -uint32_t hash_murmur(const void* data, size_t size); - -int main(int argc, char** argv) { - platform_specific_terminal_init_extras(); - - DriverConfig args = default_driver_config(); - VccOptions vcc_options = { - .tmp_filename = NULL, - .delete_tmp_file = true - }; - cli_parse_driver_arguments(&args, &argc, argv); - cli_parse_common_args(&argc, argv); - cli_parse_compiler_config_args(&args.config, &argc, argv); - cli_parse_vcc_args(&vcc_options, &argc, argv); - cli_parse_input_files(args.input_filenames, &argc, argv); - - if (entries_count_list(args.input_filenames) == 0) { - error_print("Missing input file. See --help for proper usage"); - exit(MissingInputArg); - } - - ArenaConfig aconfig = default_arena_config(); - aconfig.untyped_ptrs = true; // tolerate untyped ptrs... - IrArena* arena = new_ir_arena(aconfig); - Module* mod = new_module(arena, "my_module"); // TODO name module after first filename, or perhaps the last one - - int clang_retval = system("clang --version"); - if (clang_retval != 0) - error("clang not present in path or otherwise broken (retval=%d)", clang_retval); - - size_t num_source_files = entries_count_list(args.input_filenames); - if (!vcc_options.tmp_filename) { - vcc_options.tmp_filename = alloca(33); - vcc_options.tmp_filename[32] = '\0'; - uint32_t hash = 0; - for (size_t i = 0; i < num_source_files; i++) { - String filename = read_list(const char*, args.input_filenames)[i]; - hash ^= hash_murmur(filename, strlen(filename)); - } - srand(hash); - for (size_t i = 0; i < 32; i++) { - vcc_options.tmp_filename[i] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"[rand() % (10 + 26 * 2)]; - } - } - - Growy* g = new_growy(); - growy_append_string(g, "clang"); - char* self_path = get_executable_location(); - char* working_dir = strip_path(self_path); - growy_append_formatted(g, " -c -emit-llvm -S -g -O0 -ffreestanding -Wno-main-return-type -Xclang -fpreserve-vec3-type --target=spir64-unknown-unknown -isystem\"%s/../share/vcc/include/\" -D__SHADY__=1", working_dir); - free(working_dir); - free(self_path); - growy_append_formatted(g, " -o %s", vcc_options.tmp_filename); - - for (size_t i = 0; i < num_source_files; i++) { - String filename = read_list(const char*, args.input_filenames)[i]; - - growy_append_string(g, " \""); - growy_append_bytes(g, strlen(filename), filename); - growy_append_string(g, "\""); - } - - growy_append_bytes(g, 1, "\0"); - char* arg_string = growy_deconstruct(g); - - info_print("built command: %s\n", arg_string); - - FILE* stream = popen(arg_string, "r"); - free(arg_string); - - Growy* json_bytes = new_growy(); - while (true) { - char buf[4096]; - int read = fread(buf, 1, sizeof(buf), stream); - if (read == 0) - break; - growy_append_bytes(json_bytes, read, buf); - } - growy_append_string(json_bytes, "\0"); - char* llvm_result = growy_deconstruct(json_bytes); - int clang_returned = pclose(stream); - info_print("Clang returned %d and replied: \n%s", clang_returned, llvm_result); - free(llvm_result); - if (clang_returned) - exit(ClangInvocationFailed); - - size_t len; - char* llvm_ir; - if (!read_file(vcc_options.tmp_filename, &len, &llvm_ir)) - exit(InputFileIOError); - driver_load_source_file(SrcLLVM, len, llvm_ir, mod); - free(llvm_ir); - - if (vcc_options.delete_tmp_file) - remove(vcc_options.tmp_filename); - - driver_compile(&args, mod); - info_print("Done\n"); - - destroy_ir_arena(arena); - destroy_driver_config(&args); -} diff --git a/src/frontend/CMakeLists.txt b/src/frontend/CMakeLists.txt new file mode 100644 index 000000000..3510e80a2 --- /dev/null +++ b/src/frontend/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(slim) +add_subdirectory(spirv) +add_subdirectory(llvm) diff --git a/src/frontend/llvm/CMakeLists.txt b/src/frontend/llvm/CMakeLists.txt new file mode 100644 index 000000000..f4f8af10a --- /dev/null +++ b/src/frontend/llvm/CMakeLists.txt @@ -0,0 +1,32 @@ +if (NOT LLVM_FOUND) + message("LLVM not found. Skipping LLVM front-end.") +else () + option (SHADY_ENABLE_LLVM_FRONTEND "Uses LLVM-C to parse and then convert LLVM IR into Shady IR" ON) +endif () + +if (LLVM_FOUND AND SHADY_ENABLE_LLVM_FRONTEND) + add_generated_file(FILE_NAME l2s_generated.c SOURCES generator_l2s.c) + + add_library(shady_fe_llvm STATIC l2s.c l2s_type.c l2s_value.c l2s_instr.c l2s_meta.c l2s_postprocess.c l2s_annotations.c ${CMAKE_CURRENT_BINARY_DIR}/l2s_generated.c) + + target_include_directories(shady_fe_llvm PRIVATE ${LLVM_INCLUDE_DIRS}) + target_include_directories(shady_fe_llvm PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) # for l2s_generated.c + separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) + add_definitions(${LLVM_DEFINITIONS_LIST}) + target_compile_definitions(shady_fe_llvm PRIVATE "LLVM_VERSION_MAJOR=${LLVM_VERSION_MAJOR}") + + if (TARGET LLVM-C) + message("LLVM-C shared library target exists, major version = ${LLVM_VERSION_MAJOR}") + target_link_libraries(shady_fe_llvm PRIVATE LLVM-C) + elseif (TARGET LLVM) + message("LLVM shared library target exists, major version = ${LLVM_VERSION_MAJOR}") + target_link_libraries(shady_fe_llvm PRIVATE LLVM) + else () + message(FATAL_ERROR "Failed to find LLVM-C target, but found LLVM module earlier") + endif() + + target_link_libraries(shady_fe_llvm PRIVATE api common shady) + + target_compile_definitions(driver PUBLIC LLVM_PARSER_PRESENT) + target_link_libraries(driver PUBLIC "$") +endif () diff --git a/src/frontend/llvm/generator_l2s.c b/src/frontend/llvm/generator_l2s.c new file mode 100644 index 000000000..49b8f1fe0 --- /dev/null +++ b/src/frontend/llvm/generator_l2s.c @@ -0,0 +1,31 @@ +#include "generator.h" + +void generate_llvm_shady_address_space_conversion(Growy* g, json_object* address_spaces) { + shd_growy_append_formatted(g, "AddressSpace l2s_convert_llvm_address_space(unsigned as) {\n"); + shd_growy_append_formatted(g, "\tstatic bool warned = false;\n"); + shd_growy_append_formatted(g, "\tswitch (as) {\n"); + for (size_t i = 0; i < json_object_array_length(address_spaces); i++) { + json_object* as = json_object_array_get_idx(address_spaces, i); + String name = json_object_get_string(json_object_object_get(as, "name")); + json_object* llvm_id = json_object_object_get(as, "llvm-id"); + if (!llvm_id || json_object_get_type(llvm_id) != json_type_int) + continue; + shd_growy_append_formatted(g, "\t\t case %d: return As%s;\n", json_object_get_int(llvm_id), name); + } + shd_growy_append_formatted(g, "\t\tdefault:\n"); + shd_growy_append_formatted(g, "\t\t\tif (!warned)\n"); + shd_growy_append_string(g, "\t\t\t\tshd_warn_print(\"Warning: unrecognised address space %d\", as);\n"); + shd_growy_append_formatted(g, "\t\t\twarned = true;\n"); + shd_growy_append_formatted(g, "\t\t\treturn AsGeneric;\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n"); +} + +void generate(Growy* g, json_object* src) { + generate_header(g, src); + shd_growy_append_formatted(g, "#include \"l2s_private.h\"\n"); + shd_growy_append_formatted(g, "#include \"log.h\"\n"); + shd_growy_append_formatted(g, "#include \n"); + + generate_llvm_shady_address_space_conversion(g, json_object_object_get(src, "address-spaces")); +} diff --git a/src/frontend/llvm/l2s.c b/src/frontend/llvm/l2s.c new file mode 100644 index 000000000..6c7209e38 --- /dev/null +++ b/src/frontend/llvm/l2s.c @@ -0,0 +1,349 @@ +#include "l2s_private.h" + +#include "ir_private.h" +#include "analysis/verify.h" + +#include "log.h" +#include "dict.h" +#include "list.h" + +#include "llvm-c/IRReader.h" + +#include +#include +#include + +typedef struct OpaqueRef* OpaqueRef; + +static KeyHash hash_opaque_ptr(OpaqueRef* pvalue) { + if (!pvalue) + return 0; + size_t ptr = *(size_t*) pvalue; + return shd_hash(&ptr, sizeof(size_t)); +} + +static bool cmp_opaque_ptr(OpaqueRef* a, OpaqueRef* b) { + if (a == b) + return true; + if (!a ^ !b) + return false; + return *a == *b; +} + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +#ifdef LLVM_VERSION_MAJOR +int vcc_get_linked_major_llvm_version() { + return LLVM_VERSION_MAJOR; +} +#else +#error "wat" +#endif + +static void write_bb_body(Parser* p, FnParseCtx* fn_ctx, BBParseCtx* bb_ctx) { + bb_ctx->builder = shd_bld_begin(bb_ctx->nbb->arena, shd_get_abstraction_mem(bb_ctx->nbb)); + LLVMValueRef instr; + LLVMBasicBlockRef bb = bb_ctx->bb; + for (instr = bb_ctx->instr; instr; instr = LLVMGetNextInstruction(instr)) { + bool last = instr == LLVMGetLastInstruction(bb); + if (last) + assert(LLVMGetBasicBlockTerminator(bb) == instr); + // LLVMDumpValue(instr); + // printf("\n"); + if (LLVMIsATerminatorInst(instr)) + return; + const Node* emitted = l2s_convert_instruction(p, fn_ctx, bb_ctx->nbb, bb_ctx->builder, instr); + if (!emitted) + continue; + shd_dict_insert(LLVMValueRef, const Node*, p->map, instr, emitted); + } + shd_log_fmt(ERROR, "Reached end of LLVM basic block without encountering a terminator!"); + SHADY_UNREACHABLE; +} + +static void write_bb_tail(Parser* p, FnParseCtx* fn_ctx, BBParseCtx* bb_ctx) { + LLVMBasicBlockRef bb = bb_ctx->bb; + LLVMValueRef instr = LLVMGetLastInstruction(bb); + shd_set_abstraction_body(bb_ctx->nbb, shd_bld_finish(bb_ctx->builder, l2s_convert_instruction(p, fn_ctx, bb_ctx->nbb, bb_ctx->builder, instr))); +} + +static void prepare_bb(Parser* p, FnParseCtx* fn_ctx, BBParseCtx* ctx, LLVMBasicBlockRef bb) { + IrArena* a = shd_module_get_arena(p->dst); + shd_debug_print("l2s: preparing BB %s %d\n", LLVMGetBasicBlockName(bb), bb); + if (shd_log_get_level() >= DEBUG) + LLVMDumpValue((LLVMValueRef)bb); + + struct List* phis = shd_new_list(LLVMValueRef); + Nodes params = shd_empty(a); + LLVMValueRef instr = LLVMGetFirstInstruction(bb); + while (instr) { + switch (LLVMGetInstructionOpcode(instr)) { + case LLVMPHI: { + const Node* nparam = param(a, shd_as_qualified_type(l2s_convert_type(p, LLVMTypeOf(instr)), false), "phi"); + shd_dict_insert(LLVMValueRef, const Node*, p->map, instr, nparam); + shd_list_append(LLVMValueRef, phis, instr); + params = shd_nodes_append(a, params, nparam); + break; + } + default: goto after_phis; + } + instr = LLVMGetNextInstruction(instr); + } + after_phis: + { + String name = LLVMGetBasicBlockName(bb); + if (strlen(name) == 0) + name = NULL; + Node* nbb = basic_block(a, params, name); + shd_dict_insert(LLVMValueRef, const Node*, p->map, bb, nbb); + shd_dict_insert(const Node*, struct List*, fn_ctx->phis, nbb, phis); + *ctx = (BBParseCtx) { + .bb = bb, + .instr = instr, + .nbb = nbb, + }; + } +} + +static BBParseCtx* get_bb_ctx(Parser* p, FnParseCtx* fn_ctx, LLVMBasicBlockRef bb) { + BBParseCtx** found = shd_dict_find_value(LLVMValueRef, BBParseCtx*, fn_ctx->bbs, bb); + if (found) return *found; + + BBParseCtx* ctx = shd_arena_alloc(p->annotations_arena, sizeof(BBParseCtx)); + prepare_bb(p, fn_ctx, ctx, bb); + shd_dict_insert(LLVMBasicBlockRef, BBParseCtx*, fn_ctx->bbs, bb, ctx); + + return ctx; +} + +const Node* l2s_convert_basic_block_header(Parser* p, FnParseCtx* fn_ctx, LLVMBasicBlockRef bb) { + const Node** found = shd_dict_find_value(LLVMValueRef, const Node*, p->map, bb); + if (found) return *found; + + BBParseCtx* ctx = get_bb_ctx(p, fn_ctx, bb); + return ctx->nbb; +} + +const Node* l2s_convert_basic_block_body(Parser* p, FnParseCtx* fn_ctx, LLVMBasicBlockRef bb) { + BBParseCtx* ctx = get_bb_ctx(p, fn_ctx, bb); + if (ctx->translated) + return ctx->nbb; + + ctx->translated = true; + write_bb_body(p, fn_ctx, ctx); + write_bb_tail(p, fn_ctx, ctx); + return ctx->nbb; +} + +const Node* l2s_convert_function(Parser* p, LLVMValueRef fn) { + if (is_llvm_intrinsic(fn)) { + shd_warn_print("Skipping unknown LLVM intrinsic function: %s\n", LLVMGetValueName(fn)); + return NULL; + } + if (is_shady_intrinsic(fn)) { + shd_warn_print("Skipping shady intrinsic function: %s\n", LLVMGetValueName(fn)); + return NULL; + } + + const Node** found = shd_dict_find_value(LLVMValueRef, const Node*, p->map, fn); + if (found) return *found; + IrArena* a = shd_module_get_arena(p->dst); + shd_debug_print("Converting function: %s\n", LLVMGetValueName(fn)); + + Nodes params = shd_empty(a); + for (LLVMValueRef oparam = LLVMGetFirstParam(fn); oparam; oparam = LLVMGetNextParam(oparam)) { + LLVMTypeRef ot = LLVMTypeOf(oparam); + const Type* t = l2s_convert_type(p, ot); + const Node* nparam = param(a, shd_as_qualified_type(t, false), LLVMGetValueName(oparam)); + shd_dict_insert(LLVMValueRef, const Node*, p->map, oparam, nparam); + params = shd_nodes_append(a, params, nparam); + if (oparam == LLVMGetLastParam(fn)) + break; + } + const Type* fn_type = l2s_convert_type(p, LLVMGlobalGetValueType(fn)); + assert(fn_type->tag == FnType_TAG); + assert(fn_type->payload.fn_type.param_types.count == params.count); + Nodes annotations = shd_empty(a); + switch (LLVMGetLinkage(fn)) { + case LLVMExternalLinkage: + case LLVMExternalWeakLinkage: { + annotations = shd_nodes_append(a, annotations, annotation(a, (Annotation) { .name = "Exported" })); + break; + } + default: + break; + } + Node* f = function(p->dst, params, LLVMGetValueName(fn), annotations, fn_type->payload.fn_type.return_types); + FnParseCtx fn_parse_ctx = { + .fn = f, + .phis = shd_new_dict(const Node*, struct List*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .bbs = shd_new_dict(LLVMBasicBlockRef, BBParseCtx*, (HashFn) shd_hash_ptr, (CmpFn) shd_compare_ptrs), + .jumps_todo = shd_new_list(JumpTodo), + }; + const Node* r = fn_addr_helper(a, f); + r = prim_op_helper(a, reinterpret_op, shd_singleton(ptr_type(a, (PtrType) { .address_space = AsGeneric, .pointed_type = unit_type(a) })), shd_singleton(r)); + //r = prim_op_helper(a, convert_op, singleton(ptr_type(a, (PtrType) { .address_space = AsGeneric, .pointed_type = unit_type(a) })), singleton(r)); + shd_dict_insert(LLVMValueRef, const Node*, p->map, fn, r); + + size_t bb_count = LLVMCountBasicBlocks(fn); + if (bb_count > 0) { + LLVMBasicBlockRef first_bb = LLVMGetEntryBasicBlock(fn); + shd_dict_insert(LLVMValueRef, const Node*, p->map, first_bb, f); + + //LLVMBasicBlockRef bb = LLVMGetNextBasicBlock(first_bb); + //LARRAY(BBParseCtx, bbs, bb_count); + //bbs[0] = (BBParseCtx) { + BBParseCtx bb0 = { + .nbb = f, + .bb = first_bb, + .instr = LLVMGetFirstInstruction(first_bb), + }; + //BBParseCtx* bb0p = &bbs[0]; + BBParseCtx* bb0p = &bb0; + shd_dict_insert(LLVMBasicBlockRef, BBParseCtx*, fn_parse_ctx.bbs, first_bb, bb0p); + + write_bb_body(p, &fn_parse_ctx, &bb0); + write_bb_tail(p, &fn_parse_ctx, &bb0); + + /*for (size_t i = 1;bb; bb = LLVMGetNextBasicBlock(bb)) { + assert(i < bb_count); + prepare_bb(p, &fn_parse_ctx, &bbs[i++], bb); + } + + for (size_t i = 0; i < bb_count; i++) { + write_bb_body(p, &fn_parse_ctx, &bbs[i]); + } + + for (size_t i = 0; i < bb_count; i++) { + write_bb_tail(p, &fn_parse_ctx, &bbs[i]); + }*/ + } + + { + size_t i = 0; + struct List* phis_list; + while (shd_dict_iter(fn_parse_ctx.phis, &i, NULL, &phis_list)) { + shd_destroy_list(phis_list); + } + } + shd_destroy_dict(fn_parse_ctx.phis); + shd_destroy_dict(fn_parse_ctx.bbs); + shd_destroy_list(fn_parse_ctx.jumps_todo); + + return r; +} + +const Node* l2s_convert_global(Parser* p, LLVMValueRef global) { + const Node** found = shd_dict_find_value(LLVMValueRef, const Node*, p->map, global); + if (found) return *found; + IrArena* a = shd_module_get_arena(p->dst); + + String name = LLVMGetValueName(global); + String intrinsic = is_llvm_intrinsic(global); + if (intrinsic) { + if (strcmp(intrinsic, "llvm.global.annotations") == 0) { + return NULL; + } + shd_warn_print("Skipping unknown LLVM intrinsic function: %s\n", name); + return NULL; + } + shd_debug_print("Converting global: %s\n", name); + + Node* decl = NULL; + + if (LLVMIsAGlobalVariable(global)) { + LLVMValueRef value = LLVMGetInitializer(global); + const Type* type = l2s_convert_type(p, LLVMGlobalGetValueType(global)); + // nb: even if we have untyped pointers, they still carry useful address space info + const Type* ptr_t = l2s_convert_type(p, LLVMTypeOf(global)); + assert(ptr_t->tag == PtrType_TAG); + AddressSpace as = ptr_t->payload.ptr_type.address_space; + decl = global_var(p->dst, shd_empty(a), type, name, as); + if (value && as != AsUniformConstant) + decl->payload.global_variable.init = l2s_convert_value(p, value); + + if (UNTYPED_POINTERS) { + Node* untyped_wrapper = constant(p->dst, shd_singleton(annotation(a, (Annotation) { .name = "Inline" })), ptr_t, shd_fmt_string_irarena(a, "%s_untyped", name)); + untyped_wrapper->payload.constant.value = ref_decl_helper(a, decl); + untyped_wrapper->payload.constant.value = prim_op_helper(a, reinterpret_op, shd_singleton(ptr_t), shd_singleton(ref_decl_helper(a, decl))); + decl = untyped_wrapper; + } + } else { + const Type* type = l2s_convert_type(p, LLVMTypeOf(global)); + decl = constant(p->dst, shd_empty(a), type, name); + decl->payload.constant.value = l2s_convert_value(p, global); + } + + assert(decl && is_declaration(decl)); + const Node* r = ref_decl_helper(a, decl); + + shd_dict_insert(LLVMValueRef, const Node*, p->map, global, r); + return r; +} + +bool shd_parse_llvm(const CompilerConfig* config, size_t len, const char* data, String name, Module** dst) { + LLVMContextRef context = LLVMContextCreate(); + LLVMModuleRef src; + LLVMMemoryBufferRef mem = LLVMCreateMemoryBufferWithMemoryRange(data, len, "my_great_buffer", false); + char* parsing_diagnostic = ""; + if (LLVMParseIRInContext(context, mem, &src, &parsing_diagnostic)) { + shd_error_print("Failed to parse LLVM IR\n"); + shd_error_print(parsing_diagnostic); + shd_error_die(); + } + shd_info_print("LLVM IR parsed successfully\n"); + + ArenaConfig aconfig = shd_default_arena_config(&config->target); + aconfig.check_types = false; + aconfig.allow_fold = false; + aconfig.optimisations.inline_single_use_bbs = false; + + IrArena* arena = shd_new_ir_arena(&aconfig); + Module* dirty = shd_new_module(arena, "dirty"); + Parser p = { + .ctx = context, + .config = config, + .map = shd_new_dict(LLVMValueRef, const Node*, (HashFn) hash_opaque_ptr, (CmpFn) cmp_opaque_ptr), + .annotations = shd_new_dict(LLVMValueRef, ParsedAnnotation, (HashFn) hash_opaque_ptr, (CmpFn) cmp_opaque_ptr), + .annotations_arena = shd_new_arena(), + .src = src, + .dst = dirty, + }; + + LLVMValueRef global_annotations = LLVMGetNamedGlobal(src, "llvm.global.annotations"); + if (global_annotations) + l2s_process_llvm_annotations(&p, global_annotations); + + for (LLVMValueRef fn = LLVMGetFirstFunction(src); fn; fn = LLVMGetNextFunction(fn)) { + l2s_convert_function(&p, fn); + } + + LLVMValueRef global = LLVMGetFirstGlobal(src); + while (global) { + l2s_convert_global(&p, global); + if (global == LLVMGetLastGlobal(src)) + break; + global = LLVMGetNextGlobal(global); + } + shd_log_fmt(DEBUGVV, "Shady module parsed from LLVM:"); + shd_log_module(DEBUGVV, config, dirty); + + aconfig.check_types = true; + aconfig.allow_fold = true; + IrArena* arena2 = shd_new_ir_arena(&aconfig); + *dst = shd_new_module(arena2, name); + l2s_postprocess(&p, dirty, *dst); + shd_log_fmt(DEBUGVV, "Shady module parsed from LLVM, after cleanup:"); + shd_log_module(DEBUGVV, config, *dst); + shd_verify_module(config, *dst); + shd_destroy_ir_arena(arena); + + shd_destroy_dict(p.map); + shd_destroy_dict(p.annotations); + shd_destroy_arena(p.annotations_arena); + + LLVMContextDispose(context); + + return true; +} diff --git a/src/frontend/llvm/l2s.h b/src/frontend/llvm/l2s.h new file mode 100644 index 000000000..ec9b8e59b --- /dev/null +++ b/src/frontend/llvm/l2s.h @@ -0,0 +1,10 @@ +#ifndef SHADY_FE_LLVM_H +#define SHADY_FE_LLVM_H + +#include "shady/ir.h" +#include + +typedef struct CompilerConfig_ CompilerConfig; +bool shd_parse_llvm(const CompilerConfig* config, size_t len, const char* data, String name, Module** dst); + +#endif diff --git a/src/frontends/llvm/l2s_annotations.c b/src/frontend/llvm/l2s_annotations.c similarity index 56% rename from src/frontends/llvm/l2s_annotations.c rename to src/frontend/llvm/l2s_annotations.c index 44e2b364c..40573f98e 100644 --- a/src/frontends/llvm/l2s_annotations.c +++ b/src/frontend/llvm/l2s_annotations.c @@ -5,20 +5,20 @@ #include -ParsedAnnotation* find_annotation(Parser* p, const Node* n) { - return find_value_dict(const Node*, ParsedAnnotation, p->annotations, n); +ParsedAnnotation* l2s_find_annotation(Parser* p, const Node* n) { + return shd_dict_find_value(const Node*, ParsedAnnotation, p->annotations, n); } -void add_annotation(Parser* p, const Node* n, ParsedAnnotation a) { - ParsedAnnotation* found = find_value_dict(const Node*, ParsedAnnotation, p->annotations, n); +static void add_annotation(Parser* p, const Node* n, ParsedAnnotation a) { + ParsedAnnotation* found = shd_dict_find_value(const Node*, ParsedAnnotation, p->annotations, n); if (found) { - ParsedAnnotation* data = arena_alloc(p->annotations_arena, sizeof(a)); + ParsedAnnotation* data = shd_arena_alloc(p->annotations_arena, sizeof(a)); *data = a; while (found->next) found = found->next; found->next = data; } else { - insert_dict(const Node*, ParsedAnnotation, p->annotations, n, a); + shd_dict_insert(const Node*, ParsedAnnotation, p->annotations, n, a); } } @@ -29,29 +29,52 @@ static const Node* assert_and_strip_fn_addr(const Node* fn) { return fn; } -void process_llvm_annotations(Parser* p, LLVMValueRef global) { - IrArena* a = get_module_arena(p->dst); - const Type* t = convert_type(p, LLVMGlobalGetValueType(global)); +static const Node* look_past_stuff(const Node* thing) { + if (thing->tag == Constant_TAG) { + const Node* instr = thing->payload.constant.value; + assert(instr->tag == PrimOp_TAG); + thing = instr; + } + if (thing->tag == PrimOp_TAG) { + switch (thing->payload.prim_op.op) { + case reinterpret_op: + case convert_op: thing = shd_first(thing->payload.prim_op.operands); break; + default: assert(false); + } + } + if (thing->tag == PtrCompositeElement_TAG) { + thing = thing->payload.ptr_composite_element.ptr; + } + return thing; +} + +static bool is_io_as(AddressSpace as) { + switch (as) { + case AsInput: + case AsUInput: + case AsOutput: + case AsUniform: + case AsUniformConstant: return true; + default: break; + } + return false; +} + +void l2s_process_llvm_annotations(Parser* p, LLVMValueRef global) { + IrArena* a = shd_module_get_arena(p->dst); + const Type* t = l2s_convert_type(p, LLVMGlobalGetValueType(global)); assert(t->tag == ArrType_TAG); - size_t arr_size = get_int_literal_value(*resolve_to_int_literal(t->payload.arr_type.size), false); + size_t arr_size = shd_get_int_literal_value(*shd_resolve_to_int_literal(t->payload.arr_type.size), false); assert(arr_size > 0); - const Node* value = convert_value(p, LLVMGetInitializer(global)); + const Node* value = l2s_convert_value(p, LLVMGetInitializer(global)); assert(value->tag == Composite_TAG && value->payload.composite.contents.count == arr_size); for (size_t i = 0; i < arr_size; i++) { const Node* entry = value->payload.composite.contents.nodes[i]; + entry = look_past_stuff(entry); assert(entry->tag == Composite_TAG); const Node* annotation_payload = entry->payload.composite.contents.nodes[1]; // eliminate dummy reinterpret cast - if (annotation_payload->tag == Constant_TAG) { - const Node* instr = annotation_payload->payload.constant.instruction; - assert(instr->tag == PrimOp_TAG); - switch (instr->payload.prim_op.op) { - case reinterpret_op: - case convert_op: - case lea_op: annotation_payload = first(instr->payload.prim_op.operands); break; - default: assert(false); - } - } + annotation_payload = look_past_stuff(annotation_payload); if (annotation_payload->tag == RefDecl_TAG) { annotation_payload = annotation_payload->payload.ref_decl.decl; } @@ -59,16 +82,16 @@ void process_llvm_annotations(Parser* p, LLVMValueRef global) { annotation_payload = annotation_payload->payload.global_variable.init; } - NodeResolveConfig resolve_config = default_node_resolve_config(); + NodeResolveConfig resolve_config = shd_default_node_resolve_config(); // both of those assumptions are hacky but this front-end is a hacky deal anyways. resolve_config.assume_globals_immutability = true; resolve_config.allow_incompatible_types = true; - const char* ostr = get_string_literal(a, resolve_node_to_definition(annotation_payload, resolve_config)); + const char* ostr = shd_get_string_literal(a, shd_chase_ptr_to_source(annotation_payload, resolve_config)); char* str = calloc(strlen(ostr) + 1, 1); memcpy(str, ostr, strlen(ostr) + 1); if (strcmp(strtok(str, "::"), "shady") == 0) { const Node* target = entry->payload.composite.contents.nodes[0]; - target = resolve_node_to_definition(target, resolve_config); + target = shd_resolve_node_to_definition(target, resolve_config); char* keyword = strtok(NULL, "::"); if (strcmp(keyword, "entry_point") == 0) { @@ -84,7 +107,7 @@ void process_llvm_annotations(Parser* p, LLVMValueRef global) { add_annotation(p, target, (ParsedAnnotation) { .payload = annotation_values(a, (AnnotationValues) { .name = "WorkgroupSize", - .values = mk_nodes(a, int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)), int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)), int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10))) + .values = mk_nodes(a, shd_int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)), shd_int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)), shd_int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10))) }) }); } else if (strcmp(keyword, "builtin") == 0) { @@ -100,7 +123,7 @@ void process_llvm_annotations(Parser* p, LLVMValueRef global) { add_annotation(p, target, (ParsedAnnotation) { .payload = annotation_value(a, (AnnotationValue) { .name = "Location", - .value = int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)) + .value = shd_int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)) }) }); } else if (strcmp(keyword, "descriptor_set") == 0) { @@ -108,7 +131,7 @@ void process_llvm_annotations(Parser* p, LLVMValueRef global) { add_annotation(p, target, (ParsedAnnotation) { .payload = annotation_value(a, (AnnotationValue) { .name = "DescriptorSet", - .value = int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)) + .value = shd_int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)) }) }); } else if (strcmp(keyword, "descriptor_binding") == 0) { @@ -116,24 +139,28 @@ void process_llvm_annotations(Parser* p, LLVMValueRef global) { add_annotation(p, target, (ParsedAnnotation) { .payload = annotation_value(a, (AnnotationValue) { .name = "DescriptorBinding", - .value = int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)) + .value = shd_int32_literal(a, strtol(strtok(NULL, "::"), NULL, 10)) }) }); - } else if (strcmp(keyword, "uniform") == 0) { + } else if (strcmp(keyword, "extern") == 0) { assert(target->tag == GlobalVariable_TAG); + AddressSpace as = l2s_convert_llvm_address_space(strtol(strtok(NULL, "::"), NULL, 10)); + if (is_io_as(as)) + ((Node*) target)->payload.global_variable.init = NULL; add_annotation(p, target, (ParsedAnnotation) { - .payload = annotation(a, (Annotation) { - .name = "UniformConstant" + .payload = annotation_value(a, (AnnotationValue) { + .name = "AddressSpace", + .value = shd_int32_literal(a, as) }) }); } else { - error_print("Unrecognised shady annotation '%s'\n", keyword); - error_die(); + shd_error_print("Unrecognised shady annotation '%s'\n", keyword); + shd_error_die(); } } else { - warn_print("Ignoring annotation '%s'\n", ostr); + shd_warn_print("Ignoring annotation '%s'\n", ostr); } free(str); //dump_node(annotation_payload); } -} \ No newline at end of file +} diff --git a/src/frontend/llvm/l2s_instr.c b/src/frontend/llvm/l2s_instr.c new file mode 100644 index 000000000..5f03484c0 --- /dev/null +++ b/src/frontend/llvm/l2s_instr.c @@ -0,0 +1,666 @@ +#include "l2s_private.h" + +#include "shady/ir/memory_layout.h" + +#include "portability.h" +#include "log.h" +#include "dict.h" +#include "list.h" + +#include "llvm-c/DebugInfo.h" + +static Nodes convert_operands(Parser* p, size_t num_ops, LLVMValueRef v) { + IrArena* a = shd_module_get_arena(p->dst); + LARRAY(const Node*, ops, num_ops); + for (size_t i = 0; i < num_ops; i++) { + LLVMValueRef op = LLVMGetOperand(v, i); + if (LLVMIsAFunction(op) && (is_llvm_intrinsic(op) || is_shady_intrinsic(op))) + ops[i] = NULL; + else + ops[i] = l2s_convert_value(p, op); + } + Nodes operands = shd_nodes(a, num_ops, ops); + return operands; +} + +static const Type* change_int_t_sign(const Type* t, bool as_signed) { + assert(t); + assert(t->tag == Int_TAG); + return int_type(t->arena, (Int) { + .width = t->payload.int_type.width, + .is_signed = as_signed + }); +} + +static Nodes reinterpret_operands(BodyBuilder* b, Nodes ops, const Type* dst_t) { + assert(ops.count > 0); + IrArena* a = dst_t->arena; + LARRAY(const Node*, nops, ops.count); + for (size_t i = 0; i < ops.count; i++) + nops[i] = shd_first(shd_bld_add_instruction_extract_count(b, prim_op_helper(a, reinterpret_op, shd_singleton(dst_t), shd_singleton(ops.nodes[i])), 1)); + return shd_nodes(a, ops.count, nops); +} + +static LLVMValueRef remove_ptr_bitcasts(Parser* p, LLVMValueRef v) { + while (true) { + if (LLVMIsAInstruction(v) || LLVMIsAConstantExpr(v)) { + if (LLVMGetInstructionOpcode(v) == LLVMBitCast) { + LLVMTypeRef t = LLVMTypeOf(v); + if (LLVMGetTypeKind(t) == LLVMPointerTypeKind) + v = LLVMGetOperand(v, 0); + } + } + break; + } + return v; +} + +static const Node* convert_jump(Parser* p, FnParseCtx* fn_ctx, const Node* src, LLVMBasicBlockRef dst, const Node* mem) { + IrArena* a = fn_ctx->fn->arena; + const Node* dst_bb = l2s_convert_basic_block_body(p, fn_ctx, dst); + struct List* phis = *shd_dict_find_value(const Node*, struct List*, fn_ctx->phis, dst_bb); + assert(phis); + size_t params_count = shd_list_count(phis); + LARRAY(const Node*, params, params_count); + for (size_t i = 0; i < params_count; i++) { + LLVMValueRef phi = shd_read_list(LLVMValueRef, phis)[i]; + for (size_t j = 0; j < LLVMCountIncoming(phi); j++) { + if (l2s_convert_basic_block_header(p, fn_ctx, LLVMGetIncomingBlock(phi, j)) == src) { + params[i] = l2s_convert_value(p, LLVMGetIncomingValue(phi, j)); + goto next; + } + } + assert(false && "failed to find the appropriate source"); + next: continue; + } + return jump_helper(a, mem, dst_bb, shd_nodes(a, params_count, params)); +} + +static const Type* type_untyped_ptr(const Type* untyped_ptr_t, const Type* element_type) { + IrArena* a = untyped_ptr_t->arena; + assert(element_type); + assert(untyped_ptr_t->tag == PtrType_TAG); + assert(!untyped_ptr_t->payload.ptr_type.is_reference); + const Type* typed_ptr_t = ptr_type(a, (PtrType) { .pointed_type = element_type, .address_space = untyped_ptr_t->payload.ptr_type.address_space }); + return typed_ptr_t; +} + +/// instr may be an instruction or a constantexpr +const Node* l2s_convert_instruction(Parser* p, FnParseCtx* fn_ctx, Node* fn_or_bb, BodyBuilder* b, LLVMValueRef instr) { + Node* fn = fn_ctx ? fn_ctx->fn : NULL; + + IrArena* a = shd_module_get_arena(p->dst); + int num_ops = LLVMGetNumOperands(instr); + size_t num_results = 1; + Nodes result_types = shd_empty(a); + // const Node* r = NULL; + + LLVMOpcode opcode; + if (LLVMIsAInstruction(instr)) + opcode = LLVMGetInstructionOpcode(instr); + else if (LLVMIsAConstantExpr(instr)) + opcode = LLVMGetConstOpcode(instr); + else + assert(false); + + const Type* t = l2s_convert_type(p, LLVMTypeOf(instr)); + +#define BIND_PREV_R(t) shd_bld_add_instruction_extract_count(b, r, 1) + + //if (LLVMIsATerminatorInst(instr)) { + //if (LLVMIsAInstruction(instr) && !LLVMIsATerminatorInst(instr)) { + if (LLVMIsAInstruction(instr) && p->config->input_cf.add_scope_annotations) { + assert(fn && fn_or_bb); + LLVMMetadataRef dbgloc = LLVMInstructionGetDebugLoc(instr); + if (dbgloc) { + //Nodes* found = find_value_dict(const Node*, Nodes, p->scopes, fn_or_bb); + //if (!found) { + Nodes str = l2s_scope_to_string(p, dbgloc); + //insert_dict(const Node*, Nodes, p->scopes, fn_or_bb, str); + shd_debugv_print("Found a debug location for "); + shd_log_node(DEBUGV, fn_or_bb); + shd_debugv_print(" "); + for (size_t i = 0; i < str.count; i++) { + shd_log_node(DEBUGV, str.nodes[i]); + shd_debugv_print(" -> "); + } + shd_debugv_print(" (depth= %zu)\n", str.count); + shd_bld_add_instruction(b, ext_instr(a, (ExtInstr) { + .set = "shady.scope", + .opcode = 0, + .result_t = unit_type(a), + .mem = shd_bb_mem(b), + .operands = str, + })); + //} + } + } + + switch (opcode) { + case LLVMRet: return fn_ret(a, (Return) { + .args = num_ops == 0 ? shd_empty(a) : convert_operands(p, num_ops, instr), + .mem = shd_bb_mem(b), + }); + case LLVMBr: { + unsigned n_targets = LLVMGetNumSuccessors(instr); + LARRAY(LLVMBasicBlockRef, targets, n_targets); + for (size_t i = 0; i < n_targets; i++) + targets[i] = LLVMGetSuccessor(instr, i); + if (LLVMIsConditional(instr)) { + assert(n_targets == 2); + const Node* condition = l2s_convert_value(p, LLVMGetCondition(instr)); + return branch(a, (Branch) { + .condition = condition, + .true_jump = convert_jump(p, fn_ctx, fn_or_bb, targets[0], shd_bb_mem(b)), + .false_jump = convert_jump(p, fn_ctx, fn_or_bb, targets[1], shd_bb_mem(b)), + .mem = shd_bb_mem(b), + }); + } else { + assert(n_targets == 1); + return convert_jump(p, fn_ctx, fn_or_bb, targets[0], shd_bb_mem(b)); + } + } + case LLVMSwitch: { + const Node* inspectee = l2s_convert_value(p, LLVMGetOperand(instr, 0)); + const Node* default_jump = convert_jump(p, fn_ctx, fn_or_bb, (LLVMBasicBlockRef) LLVMGetOperand(instr, 1), shd_bb_mem(b)); + int n_targets = LLVMGetNumOperands(instr) / 2 - 1; + LARRAY(const Node*, targets, n_targets); + LARRAY(const Node*, literals, n_targets); + for (size_t i = 0; i < n_targets; i++) { + literals[i] = l2s_convert_value(p, LLVMGetOperand(instr, i * 2 + 2)); + targets[i] = convert_jump(p, fn_ctx, fn_or_bb, (LLVMBasicBlockRef) LLVMGetOperand(instr, i * 2 + 3), shd_bb_mem(b)); + } + return br_switch(a, (Switch) { + .switch_value = inspectee, + .default_jump = default_jump, + .case_values = shd_nodes(a, n_targets, literals), + .case_jumps = shd_nodes(a, n_targets, targets), + .mem = shd_bb_mem(b), + }); + } + case LLVMIndirectBr: + goto unimplemented; + case LLVMInvoke: + goto unimplemented; + case LLVMUnreachable: return unreachable(a, (Unreachable) { .mem = shd_bb_mem(b) }); + case LLVMCallBr: + goto unimplemented; + case LLVMFNeg: + return prim_op_helper(a, neg_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMFAdd: + case LLVMAdd: + return prim_op_helper(a, add_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMSub: + case LLVMFSub: + return prim_op_helper(a, sub_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMMul: + case LLVMFMul: + return prim_op_helper(a, mul_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMUDiv: + case LLVMFDiv: + return prim_op_helper(a, div_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMSDiv: { + const Type* int_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + const Type* signed_t = change_int_t_sign(int_t, true); + return prim_op_helper(a, reinterpret_op, shd_singleton(int_t), shd_singleton(prim_op_helper(a, div_op, shd_empty(a), reinterpret_operands(b, convert_operands(p, num_ops, instr), signed_t)))); + } case LLVMURem: + case LLVMFRem: + return prim_op_helper(a, mod_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMSRem: { + const Type* int_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + const Type* signed_t = change_int_t_sign(int_t, true); + return prim_op_helper(a, reinterpret_op, shd_singleton(int_t), shd_singleton(prim_op_helper(a, mod_op, shd_empty(a), reinterpret_operands(b, convert_operands(p, num_ops, instr), signed_t)))); + } case LLVMShl: + return prim_op_helper(a, lshift_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMLShr: + return prim_op_helper(a, rshift_logical_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMAShr: + return prim_op_helper(a, rshift_arithm_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMAnd: + return prim_op_helper(a, and_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMOr: + return prim_op_helper(a, or_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMXor: + return prim_op_helper(a, xor_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMAlloca: { + assert(t->tag == PtrType_TAG); + const Type* allocated_t = l2s_convert_type(p, LLVMGetAllocatedType(instr)); + const Type* allocated_ptr_t = ptr_type(a, (PtrType) { .pointed_type = allocated_t, .address_space = AsPrivate }); + const Node* r = shd_first(shd_bld_add_instruction_extract_count(b, stack_alloc(a, (StackAlloc) { .type = allocated_t, .mem = shd_bb_mem(b) }), 1)); + if (UNTYPED_POINTERS) { + const Type* untyped_ptr_t = ptr_type(a, (PtrType) { .pointed_type = unit_type(a), .address_space = AsPrivate }); + r = shd_first(shd_bld_add_instruction_extract_count(b, prim_op_helper(a, reinterpret_op, shd_singleton(untyped_ptr_t), shd_singleton(r)), 1)); + } + return prim_op_helper(a, convert_op, shd_singleton(t), shd_singleton(r)); + } + case LLVMLoad: { + Nodes ops = convert_operands(p, num_ops, instr); + assert(ops.count == 1); + const Node* ptr = shd_first(ops); + if (UNTYPED_POINTERS) { + const Type* element_t = t; + const Type* untyped_ptr_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t); + ptr = shd_first(shd_bld_add_instruction_extract_count(b, prim_op_helper(a, reinterpret_op, shd_singleton(typed_ptr), shd_singleton(ptr)), 1)); + } + return shd_bld_add_instruction(b, load(a, (Load) { .ptr = ptr, .mem = shd_bb_mem(b) })); + } + case LLVMStore: { + num_results = 0; + Nodes ops = convert_operands(p, num_ops, instr); + assert(ops.count == 2); + const Node* ptr = ops.nodes[1]; + if (UNTYPED_POINTERS) { + const Type* element_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + const Type* untyped_ptr_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 1))); + const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t); + ptr = shd_first(shd_bld_add_instruction_extract_count(b, prim_op_helper(a, reinterpret_op, shd_singleton(typed_ptr), shd_singleton(ptr)), 1)); + } + return shd_bld_add_instruction(b, store(a, (Store) { .ptr = ptr, .value = ops.nodes[0], .mem = shd_bb_mem(b) })); + } + case LLVMGetElementPtr: { + Nodes ops = convert_operands(p, num_ops, instr); + const Node* ptr = shd_first(ops); + if (UNTYPED_POINTERS) { + const Type* element_t = l2s_convert_type(p, LLVMGetGEPSourceElementType(instr)); + const Type* untyped_ptr_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t); + ptr = shd_first(shd_bld_add_instruction_extract_count(b, prim_op_helper(a, reinterpret_op, shd_singleton(typed_ptr), shd_singleton(ptr)), 1)); + } + ops = shd_change_node_at_index(a, ops, 0, ptr); + const Node* r = lea_helper(a, ops.nodes[0], ops.nodes[1], shd_nodes(a, ops.count - 2, &ops.nodes[2])); + if (UNTYPED_POINTERS) { + const Type* element_t = l2s_convert_type(p, LLVMGetGEPSourceElementType(instr)); + const Type* untyped_ptr_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + bool idk; + //element_t = shd_as_qualified_type(element_t, false); + shd_enter_composite_type_indices(&element_t, &idk, shd_nodes(a, ops.count - 2, &ops.nodes[2]), true); + const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t); + r = prim_op_helper(a, reinterpret_op, shd_singleton(untyped_ptr_t), BIND_PREV_R(typed_ptr)); + } + return r; + } + case LLVMTrunc: + case LLVMZExt: { + const Type* src_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + Nodes ops = convert_operands(p, num_ops, instr); + const Node* r; + if (src_t->tag == Bool_TAG) { + assert(t->tag == Int_TAG); + const Node* zero = int_literal(a, (IntLiteral) { .value = 0, .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed }); + const Node* one = int_literal(a, (IntLiteral) { .value = 1, .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed }); + r = prim_op_helper(a, select_op, shd_empty(a), mk_nodes(a, shd_first(ops), one, zero)); + } else if (t->tag == Bool_TAG) { + assert(src_t->tag == Int_TAG); + const Node* one = int_literal(a, (IntLiteral) { .value = 1, .width = src_t->payload.int_type.width, .is_signed = false }); + r = prim_op_helper(a, and_op, shd_empty(a), mk_nodes(a, shd_first(ops), one)); + r = prim_op_helper(a, eq_op, shd_empty(a), mk_nodes(a, shd_first(BIND_PREV_R(int_type(a, (Int) { .width = src_t->payload.int_type.width, .is_signed = false }))), one)); + } else { + // reinterpret as unsigned, convert to change size, reinterpret back to target T + const Type* unsigned_src_t = change_int_t_sign(src_t, false); + const Type* unsigned_dst_t = change_int_t_sign(t, false); + r = prim_op_helper(a, convert_op, shd_singleton(unsigned_dst_t), reinterpret_operands(b, ops, unsigned_src_t)); + r = prim_op_helper(a, reinterpret_op, shd_singleton(t), BIND_PREV_R(unsigned_dst_t)); + } + return r; + } case LLVMSExt: { + const Node* r; + // reinterpret as signed, convert to change size, reinterpret back to target T + const Type* src_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + Nodes ops = convert_operands(p, num_ops, instr); + if (src_t->tag == Bool_TAG) { + assert(t->tag == Int_TAG); + const Node* zero = int_literal(a, (IntLiteral) { .value = 0, .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed }); + uint64_t i = UINT64_MAX >> (64 - int_size_in_bytes(t->payload.int_type.width) * 8); + const Node* ones = int_literal(a, (IntLiteral) { .value = i, .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed }); + r = prim_op_helper(a, select_op, shd_empty(a), mk_nodes(a, shd_first(ops), ones, zero)); + } else { + const Type* signed_src_t = change_int_t_sign(src_t, true); + const Type* signed_dst_t = change_int_t_sign(t, true); + r = prim_op_helper(a, convert_op, shd_singleton(signed_dst_t), reinterpret_operands(b, ops, signed_src_t)); + r = prim_op_helper(a, reinterpret_op, shd_singleton(t), shd_singleton(r)); + } + return r; + } case LLVMFPToUI: + case LLVMFPToSI: + case LLVMUIToFP: + case LLVMSIToFP: + return prim_op_helper(a, convert_op, shd_singleton(t), convert_operands(p, num_ops, instr)); + case LLVMFPTrunc: + goto unimplemented; + case LLVMFPExt: + goto unimplemented; + case LLVMPtrToInt: + case LLVMIntToPtr: + case LLVMBitCast: + case LLVMAddrSpaceCast: { + // when constructing or deconstructing generic pointers, we need to emit a convert_op instead + assert(num_ops == 1); + const Node* src = shd_first(convert_operands(p, num_ops, instr)); + Op op = reinterpret_op; + const Type* src_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + if (src_t->tag == PtrType_TAG && t->tag == PtrType_TAG) { + if ((t->payload.ptr_type.address_space == AsGeneric)) { + switch (src_t->payload.ptr_type.address_space) { + case AsGeneric: // generic-to-generic isn't a conversion. + break; + default: { + op = convert_op; + break; + } + } + } + } else { + assert(opcode != LLVMAddrSpaceCast); + } + return prim_op_helper(a, op, shd_singleton(t), shd_singleton(src)); + } + case LLVMICmp: { + Op op; + bool cast_to_signed = false; + switch(LLVMGetICmpPredicate(instr)) { + case LLVMIntEQ: + op = eq_op; + break; + case LLVMIntNE: + op = neq_op; + break; + case LLVMIntUGT: + op = gt_op; + break; + case LLVMIntUGE: + op = gte_op; + break; + case LLVMIntULT: + op = lt_op; + break; + case LLVMIntULE: + op = lte_op; + break; + case LLVMIntSGT: + op = gt_op; + cast_to_signed = true; + break; + case LLVMIntSGE: + op = gte_op; + cast_to_signed = true; + break; + case LLVMIntSLT: + op = lt_op; + cast_to_signed = true; + break; + case LLVMIntSLE: + op = lte_op; + cast_to_signed = true; + break; + } + Nodes ops = convert_operands(p, num_ops, instr); + if (cast_to_signed) { + const Type* unsigned_t = l2s_convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); + assert(unsigned_t->tag == Int_TAG); + const Type* signed_t = change_int_t_sign(unsigned_t, true); + ops = reinterpret_operands(b, ops, signed_t); + } + return prim_op_helper(a, op, shd_empty(a), ops); + } + case LLVMFCmp: { + Op op; + bool cast_to_signed = false; + switch(LLVMGetFCmpPredicate(instr)) { + case LLVMRealUEQ: + case LLVMRealOEQ: + op = eq_op; + break; + case LLVMRealUNE: + case LLVMRealONE: + op = neq_op; + break; + case LLVMRealUGT: + case LLVMRealOGT: + op = gt_op; + break; + case LLVMRealUGE: + case LLVMRealOGE: + op = gte_op; + break; + case LLVMRealULT: + case LLVMRealOLT: + op = lt_op; + break; + case LLVMRealULE: + case LLVMRealOLE: + op = lte_op; + break; + default: goto unimplemented; + } + Nodes ops = convert_operands(p, num_ops, instr); + return prim_op_helper(a, op, shd_empty(a), ops); + break; + } + case LLVMPHI: + assert(false && "We deal with phi nodes before, there shouldn't be one here"); + break; + case LLVMCall: { + const Node* r = NULL; + unsigned num_args = LLVMGetNumArgOperands(instr); + LLVMValueRef callee = LLVMGetCalledValue(instr); + LLVMTypeRef callee_type = LLVMGetCalledFunctionType(instr); + callee = remove_ptr_bitcasts(p, callee); + assert(num_args + 1 == num_ops); + String intrinsic = NULL; + if (LLVMIsAFunction(callee) || LLVMIsAConstant(callee)) { + intrinsic = is_llvm_intrinsic(callee); + if (!intrinsic) + intrinsic = is_shady_intrinsic(callee); + } + if (intrinsic) { + assert(LLVMIsAFunction(callee)); + if (strcmp(intrinsic, "llvm.dbg.declare") == 0) { + const Node* target = l2s_convert_value(p, LLVMGetOperand(instr, 0)); + const Node* meta = l2s_convert_value(p, LLVMGetOperand(instr, 1)); + assert(meta->tag == RefDecl_TAG); + meta = meta->payload.ref_decl.decl; + assert(meta->tag == GlobalVariable_TAG); + meta = meta->payload.global_variable.init; + assert(meta && meta->tag == Composite_TAG); + const Node* name_node = meta->payload.composite.contents.nodes[2]; + String name = shd_get_string_literal(target->arena, name_node); + assert(name); + shd_set_value_name((Node*) target, name); + return NULL; + } + if (strcmp(intrinsic, "llvm.dbg.label") == 0) { + // TODO + return NULL; + } + if (strcmp(intrinsic, "llvm.dbg.value") == 0) { + // TODO + return NULL; + } + if (shd_string_starts_with(intrinsic, "llvm.lifetime")) { + // don't care + return NULL; + } + if (shd_string_starts_with(intrinsic, "llvm.experimental.noalias.scope.decl")) { + // don't care + return NULL; + } + if (shd_string_starts_with(intrinsic, "llvm.var.annotation")) { + // don't care + return NULL; + } + if (shd_string_starts_with(intrinsic, "llvm.memcpy")) { + Nodes ops = convert_operands(p, num_ops, instr); + return shd_bld_add_instruction(b, copy_bytes(a, (CopyBytes) { .dst = ops.nodes[0], .src = ops.nodes[1], .count = ops.nodes[2], .mem = shd_bb_mem(b) })); + } else if (shd_string_starts_with(intrinsic, "llvm.memset")) { + Nodes ops = convert_operands(p, num_ops, instr); + return shd_bld_add_instruction(b, fill_bytes(a, (FillBytes) { .dst = ops.nodes[0], .src = ops.nodes[1], .count = ops.nodes[2], .mem = shd_bb_mem(b) })); + } else if (shd_string_starts_with(intrinsic, "llvm.fmuladd")) { + Nodes ops = convert_operands(p, num_ops, instr); + return prim_op_helper(a, fma_op, shd_empty(a), shd_nodes(a, 3, ops.nodes)); + } else if (shd_string_starts_with(intrinsic, "llvm.fabs")) { + Nodes ops = convert_operands(p, num_ops, instr); + return prim_op_helper(a, abs_op, shd_empty(a), shd_nodes(a, 1, ops.nodes)); + } else if (shd_string_starts_with(intrinsic, "llvm.floor")) { + Nodes ops = convert_operands(p, num_ops, instr); + return prim_op_helper(a, floor_op, shd_empty(a), shd_nodes(a, 1, ops.nodes)); + } + + typedef struct { + bool is_byval; + } DecodedParamAttr; + + size_t params_count = 0; + for (LLVMValueRef oparam = LLVMGetFirstParam(callee); oparam && oparam <= LLVMGetLastParam(callee); oparam = LLVMGetNextParam(oparam)) { + params_count++; + } + LARRAY(DecodedParamAttr, decoded, params_count); + memset(decoded, 0, sizeof(DecodedParamAttr) * params_count); + size_t param_index = 0; + for (LLVMValueRef oparam = LLVMGetFirstParam(callee); oparam && oparam <= LLVMGetLastParam(callee); oparam = LLVMGetNextParam(oparam)) { + size_t num_attrs = LLVMGetAttributeCountAtIndex(callee, param_index + 1); + LARRAY(LLVMAttributeRef, attrs, num_attrs); + LLVMGetAttributesAtIndex(callee, param_index + 1, attrs); + bool is_byval = false; + for (size_t i = 0; i < num_attrs; i++) { + LLVMAttributeRef attr = attrs[i]; + size_t k = LLVMGetEnumAttributeKind(attr); + size_t e = LLVMGetEnumAttributeKindForName("byval", 5); + uint64_t value = LLVMGetEnumAttributeValue(attr); + // printf("p = %zu, i = %zu, k = %zu, e = %zu\n", param_index, i, k, e); + if (k == e) + decoded[param_index].is_byval = true; + } + param_index++; + } + + String ostr = intrinsic; + char* str = calloc(strlen(ostr) + 1, 1); + memcpy(str, ostr, strlen(ostr) + 1); + + if (strcmp(strtok(str, "::"), "shady") == 0) { + char* keyword = strtok(NULL, "::"); + if (strcmp(keyword, "prim_op") == 0) { + char* opname = strtok(NULL, "::"); + Op op; + size_t i; + for (i = 0; i < PRIMOPS_COUNT; i++) { + if (strcmp(shd_get_primop_name(i), opname) == 0) { + op = (Op) i; + break; + } + } + assert(i != PRIMOPS_COUNT); + Nodes ops = convert_operands(p, num_args, instr); + LARRAY(const Node*, processed_ops, ops.count); + for (i = 0; i < num_args; i++) { + if (decoded[i].is_byval) + processed_ops[i] = shd_first(shd_bld_add_instruction_extract_count(b, load(a, (Load) { .ptr = ops.nodes[i], .mem = shd_bb_mem(b) }), 1)); + else + processed_ops[i] = ops.nodes[i]; + } + r = prim_op_helper(a, op, shd_empty(a), shd_nodes(a, num_args, processed_ops)); + free(str); + goto finish; + } else if (strcmp(keyword, "instruction") == 0) { + char* instructionname = strtok(NULL, "::"); + Nodes ops = convert_operands(p, num_args, instr); + if (strcmp(instructionname, "DebugPrintf") == 0) { + if (ops.count == 0) + shd_error("DebugPrintf called without arguments"); + size_t whocares; + shd_bld_debug_printf(b, LLVMGetAsString(LLVMGetInitializer(LLVMGetOperand(instr, 0)), &whocares), shd_nodes(a, ops.count - 1, &ops.nodes[1])); + return shd_tuple_helper(a, shd_empty(a)); + } + + shd_error_print("Unrecognised shady instruction '%s'\n", instructionname); + shd_error_die(); + } else { + shd_error_print("Unrecognised shady intrinsic '%s'\n", keyword); + shd_error_die(); + } + } + + shd_error_print("Unhandled intrinsic '%s'\n", intrinsic); + shd_error_die(); + } + finish: + + if (!r) { + Nodes ops = convert_operands(p, num_ops, instr); + r = shd_bld_add_instruction(b, call(a, (Call) { + .mem = shd_bb_mem(b), + .callee = prim_op_helper(a, reinterpret_op, shd_singleton(ptr_type(a, (PtrType) { + .address_space = AsGeneric, + .pointed_type = l2s_convert_type(p, callee_type) + })), shd_singleton(ops.nodes[num_args])), + .args = shd_nodes(a, num_args, ops.nodes), + })); + } + return r; + } + case LLVMSelect: + return prim_op_helper(a, select_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMUserOp1: + goto unimplemented; + case LLVMUserOp2: + goto unimplemented; + case LLVMVAArg: + goto unimplemented; + case LLVMExtractElement: + return prim_op_helper(a, extract_dynamic_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMInsertElement: + return prim_op_helper(a, insert_op, shd_empty(a), convert_operands(p, num_ops, instr)); + case LLVMShuffleVector: { + Nodes ops = convert_operands(p, num_ops, instr); + unsigned num_indices = LLVMGetNumMaskElements(instr); + LARRAY(const Node*, cindices, num_indices); + for (size_t i = 0; i < num_indices; i++) + cindices[i] = shd_uint32_literal(a, LLVMGetMaskValue(instr, i)); + ops = shd_concat_nodes(a, ops, shd_nodes(a, num_indices, cindices)); + return prim_op_helper(a, shuffle_op, shd_empty(a), ops); + } + case LLVMExtractValue: + goto unimplemented; + case LLVMInsertValue: + goto unimplemented; + case LLVMFreeze: + goto unimplemented; + case LLVMFence: + goto unimplemented; + case LLVMAtomicCmpXchg: + goto unimplemented; + case LLVMAtomicRMW: + goto unimplemented; + case LLVMResume: + goto unimplemented; + case LLVMLandingPad: + goto unimplemented; + case LLVMCleanupRet: + goto unimplemented; + case LLVMCatchRet: + goto unimplemented; + case LLVMCatchPad: + goto unimplemented; + case LLVMCleanupPad: + goto unimplemented; + case LLVMCatchSwitch: + goto unimplemented; + } + /*shortcut: + if (r) { + if (num_results == 1) + result_types = singleton(convert_type(p, LLVMTypeOf(instr))); + assert(result_types.count == num_results); + return (EmittedInstr) { + .instruction = r, + .result_types = result_types, + }; + }*/ + + unimplemented: + shd_error_print("Shady: unimplemented LLVM instruction "); + LLVMDumpValue(instr); + shd_error_print(" (opcode=%d)\n", opcode); + shd_error_die(); +} diff --git a/src/frontends/llvm/l2s_meta.c b/src/frontend/llvm/l2s_meta.c similarity index 72% rename from src/frontends/llvm/l2s_meta.c rename to src/frontend/llvm/l2s_meta.c index eb760f744..914672d82 100644 --- a/src/frontends/llvm/l2s_meta.c +++ b/src/frontend/llvm/l2s_meta.c @@ -7,7 +7,7 @@ #include "llvm-c/DebugInfo.h" static Nodes convert_mdnode_operands(Parser* p, LLVMValueRef mdnode) { - IrArena* a = get_module_arena(p->dst); + IrArena* a = shd_module_get_arena(p->dst); assert(LLVMIsAMDNode(mdnode)); unsigned count = LLVMGetMDNodeNumOperands(mdnode); @@ -16,24 +16,24 @@ static Nodes convert_mdnode_operands(Parser* p, LLVMValueRef mdnode) { LARRAY(const Node*, cops, count); for (size_t i = 0; i < count; i++) - cops[i] = ops[i] ? convert_value(p, ops[i]) : string_lit_helper(a, "null"); - Nodes args = nodes(a, count, cops); + cops[i] = ops[i] ? l2s_convert_value(p, ops[i]) : string_lit_helper(a, "null"); + Nodes args = shd_nodes(a, count, cops); return args; } static const Node* convert_named_tuple_metadata(Parser* p, LLVMValueRef v, String node_name) { // printf("%s\n", name); - IrArena* a = get_module_arena(p->dst); + IrArena* a = shd_module_get_arena(p->dst); String name = LLVMGetValueName(v); if (!name || strlen(name) == 0) - name = unique_name(a, node_name); - Node* g = global_var(p->dst, singleton(annotation(a, (Annotation) { .name = "SkipOnInfer" })), NULL, name, AsDebugInfo); + name = shd_make_unique_name(a, node_name); + Node* g = global_var(p->dst, shd_singleton(annotation(a, (Annotation) { .name = "LLVMMetaData" })), unit_type(a), name, AsDebugInfo); const Node* r = ref_decl_helper(a, g); - insert_dict(LLVMValueRef, const Type*, p->map, v, r); + shd_dict_insert(LLVMValueRef, const Type*, p->map, v, r); Nodes args = convert_mdnode_operands(p, v); - args = prepend_nodes(a, args, string_lit_helper(a, node_name)); - g->payload.global_variable.init = tuple_helper(a, args); + args = shd_nodes_prepend(a, args, string_lit_helper(a, node_name)); + g->payload.global_variable.init = shd_tuple_helper(a, args); return r; } @@ -101,15 +101,15 @@ LLVM_DI_WITH_PARENT_SCOPES(N) return ops[1]; } -Nodes scope_to_string(Parser* p, LLVMMetadataRef dbgloc) { - IrArena* a = get_module_arena(p->dst); - Nodes str = empty(a); +Nodes l2s_scope_to_string(Parser* p, LLVMMetadataRef dbgloc) { + IrArena* a = shd_module_get_arena(p->dst); + Nodes str = shd_empty(a); LLVMMetadataRef scope = LLVMDILocationGetScope(dbgloc); while (true) { if (!scope) break; - str = prepend_nodes(a, str, convert_metadata(p, scope)); + str = shd_nodes_prepend(a, str, shd_uint32_literal(a, l2s_convert_metadata(p, scope)->id)); // LLVMDumpValue(LLVMMetadataAsValue(p->ctx, scope)); // printf("\n"); @@ -122,18 +122,18 @@ Nodes scope_to_string(Parser* p, LLVMMetadataRef dbgloc) { return str; } -const Node* convert_metadata(Parser* p, LLVMMetadataRef meta) { - IrArena* a = get_module_arena(p->dst); +const Node* l2s_convert_metadata(Parser* p, LLVMMetadataRef meta) { + IrArena* a = shd_module_get_arena(p->dst); LLVMMetadataKind kind = LLVMGetMetadataKind(meta); LLVMValueRef v = LLVMMetadataAsValue(p->ctx, meta); if (v) { - const Type** found = find_value_dict(LLVMTypeRef, const Type*, p->map, v); + const Type** found = shd_dict_find_value(LLVMTypeRef, const Type*, p->map, v); if (found) return *found; } switch (kind) { - case LLVMMDTupleMetadataKind: return tuple_helper(a, convert_mdnode_operands(p, v)); + case LLVMMDTupleMetadataKind: return shd_tuple_helper(a, convert_mdnode_operands(p, v)); case LLVMDICompileUnitMetadataKind: return string_lit_helper(a, "CompileUnit"); } @@ -147,7 +147,7 @@ const Node* convert_metadata(Parser* p, LLVMMetadataRef meta) { case LLVMLocalAsMetadataMetadataKind: { Nodes ops = convert_mdnode_operands(p, v); assert(ops.count == 1); - return first(ops); + return shd_first(ops); } case LLVMDistinctMDOperandPlaceholderMetadataKind: goto default_; @@ -155,9 +155,9 @@ const Node* convert_metadata(Parser* p, LLVMMetadataRef meta) { LLVM_DI_METADATA_NODES(N) #undef N default: default_: - error_print("Unknown metadata kind %d for ", kind); + shd_error_print("Unknown metadata kind %d for ", kind); LLVMDumpValue(v); - error_print(".\n"); - error_die(); + shd_error_print(".\n"); + shd_error_die(); } } diff --git a/src/frontend/llvm/l2s_postprocess.c b/src/frontend/llvm/l2s_postprocess.c new file mode 100644 index 000000000..85bde99c5 --- /dev/null +++ b/src/frontend/llvm/l2s_postprocess.c @@ -0,0 +1,148 @@ +#include "l2s_private.h" + +#include "shady/rewrite.h" + +#include "portability.h" +#include "dict.h" +#include "list.h" +#include "log.h" +#include "arena.h" + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +typedef struct { + Rewriter rewriter; + const CompilerConfig* config; + Parser* p; + Arena* arena; +} Context; + +static Nodes remake_params(Context* ctx, Nodes old) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + LARRAY(const Node*, nvars, old.count); + for (size_t i = 0; i < old.count; i++) { + const Node* node = old.nodes[i]; + const Type* t = NULL; + if (node->payload.param.type) { + if (node->payload.param.type->tag == QualifiedType_TAG) + t = shd_rewrite_node(r, node->payload.param.type); + else + t = shd_as_qualified_type(shd_rewrite_node(r, node->payload.param.type), false); + } + nvars[i] = param(a, t, node->payload.param.name); + assert(nvars[i]->tag == Param_TAG); + } + return shd_nodes(a, old.count, nvars); +} + +static const Node* process_node(Context* ctx, const Node* node) { + IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + switch (node->tag) { + case Param_TAG: { + assert(false); + } + case Constant_TAG: { + Node* new = (Node*) shd_recreate_node(r, node); + BodyBuilder* bb = shd_bld_begin_pure(a); + const Node* value = new->payload.constant.value; + value = prim_op_helper(a, subgroup_assume_uniform_op, shd_empty(a), shd_singleton(value)); + new->payload.constant.value = shd_bld_to_instr_pure_with_values(bb, shd_singleton(value)); + return new; + } + case Function_TAG: { + Nodes new_params = remake_params(ctx, node->payload.fun.params); + Nodes old_annotations = node->payload.fun.annotations; + ParsedAnnotation* an = l2s_find_annotation(ctx->p, node); + Op primop_intrinsic = PRIMOPS_COUNT; + while (an) { + if (strcmp(get_annotation_name(an->payload), "PrimOpIntrinsic") == 0) { + assert(!node->payload.fun.body); + Op op; + size_t i; + for (i = 0; i < PRIMOPS_COUNT; i++) { + if (strcmp(shd_get_primop_name(i), shd_get_annotation_string_payload(an->payload)) == 0) { + op = (Op) i; + break; + } + } + assert(i != PRIMOPS_COUNT); + primop_intrinsic = op; + } else if (strcmp(get_annotation_name(an->payload), "EntryPoint") == 0) { + for (size_t i = 0; i < new_params.count; i++) + new_params = shd_change_node_at_index(a, new_params, i, param(a, shd_as_qualified_type( + shd_get_unqualified_type(new_params.nodes[i]->payload.param.type), true), new_params.nodes[i]->payload.param.name)); + } + old_annotations = shd_nodes_append(a, old_annotations, an->payload); + an = an->next; + } + shd_register_processed_list(r, node->payload.fun.params, new_params); + Nodes new_annotations = shd_rewrite_nodes(r, old_annotations); + Node* decl = function(ctx->rewriter.dst_module, new_params, shd_get_abstraction_name(node), new_annotations, shd_rewrite_nodes(&ctx->rewriter, node->payload.fun.return_types)); + shd_register_processed(&ctx->rewriter, node, decl); + if (primop_intrinsic != PRIMOPS_COUNT) { + shd_set_abstraction_body(decl, fn_ret(a, (Return) { + .args = shd_singleton(prim_op_helper(a, primop_intrinsic, shd_empty(a), get_abstraction_params(decl))), + .mem = shd_get_abstraction_mem(decl), + })); + } else if (get_abstraction_body(node)) + shd_set_abstraction_body(decl, shd_rewrite_node(r, get_abstraction_body(node))); + return decl; + } + case GlobalVariable_TAG: { + // if (lookup_annotation(node, "LLVMMetaData")) + // return NULL; + AddressSpace as = node->payload.global_variable.address_space; + const Node* old_init = node->payload.global_variable.init; + Nodes annotations = shd_rewrite_nodes(r, node->payload.global_variable.annotations); + const Type* type = shd_rewrite_node(r, node->payload.global_variable.type); + ParsedAnnotation* an = l2s_find_annotation(ctx->p, node); + AddressSpace old_as = as; + while (an) { + annotations = shd_nodes_append(a, annotations, shd_rewrite_node(r, an->payload)); + if (strcmp(get_annotation_name(an->payload), "Builtin") == 0) + old_init = NULL; + if (strcmp(get_annotation_name(an->payload), "AddressSpace") == 0) + as = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_annotation_value(an->payload)), false); + an = an->next; + } + Node* decl = global_var(ctx->rewriter.dst_module, annotations, type, get_declaration_name(node), as); + Node* result = decl; + if (old_as != as) { + const Type* pt = ptr_type(a, (PtrType) { .address_space = old_as, .pointed_type = type }); + Node* c = constant(ctx->rewriter.dst_module, shd_singleton(annotation(a, (Annotation) { + .name = "Inline" + })), pt, shd_fmt_string_irarena(a, "%s_proxy", get_declaration_name(decl))); + c->payload.constant.value = prim_op_helper(a, convert_op, shd_singleton(pt), shd_singleton( + ref_decl_helper(a, decl))); + result = c; + } + + shd_register_processed(r, node, result); + if (old_init) + decl->payload.global_variable.init = shd_rewrite_node(r, old_init); + return result; + } + default: break; + } + + return shd_recreate_node(&ctx->rewriter, node); +} + +void l2s_postprocess(Parser* p, Module* src, Module* dst) { + assert(src != dst); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), + .config = p->config, + .p = p, + .arena = shd_new_arena(), + }; + + ctx.rewriter.rewrite_fn = (RewriteNodeFn) process_node; + + shd_rewrite_module(&ctx.rewriter); + shd_destroy_arena(ctx.arena); + shd_destroy_rewriter(&ctx.rewriter); +} diff --git a/src/frontend/llvm/l2s_private.h b/src/frontend/llvm/l2s_private.h new file mode 100644 index 000000000..68ce8cf0f --- /dev/null +++ b/src/frontend/llvm/l2s_private.h @@ -0,0 +1,100 @@ +#ifndef SHADY_L2S_PRIVATE_H +#define SHADY_L2S_PRIVATE_H + +#include "l2s.h" + +#include "shady/config.h" + +#include "arena.h" +#include "util.h" + +#include "llvm-c/Core.h" + +#include +#include + +typedef struct { + const CompilerConfig* config; + LLVMContextRef ctx; + struct Dict* map; + struct Dict* annotations; + Arena* annotations_arena; + LLVMModuleRef src; + Module* dst; +} Parser; + +typedef struct { + LLVMBasicBlockRef bb; + LLVMValueRef instr; + Node* nbb; + BodyBuilder* builder; + bool translated; +} BBParseCtx; + +typedef struct { + Node* fn; + struct Dict* phis; + struct Dict* bbs; + struct List* jumps_todo; +} FnParseCtx; + +#ifndef LLVM_VERSION_MAJOR +#error "Missing LLVM_VERSION_MAJOR" +#else +#define UNTYPED_POINTERS (LLVM_VERSION_MAJOR >= 15) +#endif + +typedef struct ParsedAnnotationContents_ { + const Node* payload; + struct ParsedAnnotationContents_* next; +} ParsedAnnotation; + +ParsedAnnotation* l2s_find_annotation(Parser* p, const Node* n); +ParsedAnnotation* next_annotation(ParsedAnnotation*); + +void l2s_process_llvm_annotations(Parser* p, LLVMValueRef global); + +AddressSpace l2s_convert_llvm_address_space(unsigned); +const Node* l2s_convert_value(Parser* p, LLVMValueRef v); +const Node* l2s_convert_function(Parser* p, LLVMValueRef fn); +const Type* l2s_convert_type(Parser* p, LLVMTypeRef t); +const Node* l2s_convert_metadata(Parser* p, LLVMMetadataRef meta); +const Node* l2s_convert_global(Parser* p, LLVMValueRef global); +const Node* l2s_convert_function(Parser* p, LLVMValueRef fn); +const Node* l2s_convert_basic_block_header(Parser* p, FnParseCtx* fn_ctx, LLVMBasicBlockRef bb); +const Node* l2s_convert_basic_block_body(Parser* p, FnParseCtx* fn_ctx, LLVMBasicBlockRef bb); + +typedef struct { + struct List* list; +} BBPhis; + +typedef struct { + Node* wrapper; + Node* src; + LLVMBasicBlockRef dst; +} JumpTodo; + +void convert_jump_finish(Parser* p, FnParseCtx*, JumpTodo todo); +const Node* l2s_convert_instruction(Parser* p, FnParseCtx* fn_ctx, Node* fn_or_bb, BodyBuilder* b, LLVMValueRef instr); + +Nodes l2s_scope_to_string(Parser* p, LLVMMetadataRef dbgloc); + +void l2s_postprocess(Parser* p, Module* src, Module* dst); + +inline static String is_llvm_intrinsic(LLVMValueRef fn) { + assert(LLVMIsAFunction(fn) || LLVMIsConstant(fn)); + String name = LLVMGetValueName(fn); + if (shd_string_starts_with(name, "llvm.")) + return name; + return NULL; +} + +inline static String is_shady_intrinsic(LLVMValueRef fn) { + assert(LLVMIsAFunction(fn) || LLVMIsConstant(fn)); + String name = LLVMGetValueName(fn); + if (shd_string_starts_with(name, "shady::")) + return name; + return NULL; +} + +#endif diff --git a/src/frontends/llvm/l2s_type.c b/src/frontend/llvm/l2s_type.c similarity index 53% rename from src/frontends/llvm/l2s_type.c rename to src/frontend/llvm/l2s_type.c index 4da900def..bebaf087c 100644 --- a/src/frontends/llvm/l2s_type.c +++ b/src/frontend/llvm/l2s_type.c @@ -5,16 +5,16 @@ #include "dict.h" #include "util.h" -const Type* convert_type(Parser* p, LLVMTypeRef t) { - const Type** found = find_value_dict(LLVMTypeRef, const Type*, p->map, t); +const Type* l2s_convert_type(Parser* p, LLVMTypeRef t) { + const Type** found = shd_dict_find_value(LLVMTypeRef, const Type*, p->map, t); if (found) return *found; - IrArena* a = get_module_arena(p->dst); + IrArena* a = shd_module_get_arena(p->dst); switch (LLVMGetTypeKind(t)) { case LLVMVoidTypeKind: return unit_type(a); - case LLVMHalfTypeKind: return fp16_type(a); - case LLVMFloatTypeKind: return fp32_type(a); - case LLVMDoubleTypeKind: return fp64_type(a); + case LLVMHalfTypeKind: return shd_fp16_type(a); + case LLVMFloatTypeKind: return shd_fp32_type(a); + case LLVMDoubleTypeKind: return shd_fp64_type(a); case LLVMX86_FP80TypeKind: case LLVMFP128TypeKind: break; @@ -23,11 +23,11 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) { case LLVMIntegerTypeKind: switch(LLVMGetIntTypeWidth(t)) { case 1: return bool_type(a); - case 8: return uint8_type(a); - case 16: return uint16_type(a); - case 32: return uint32_type(a); - case 64: return uint64_type(a); - default: error("Unsupported integer width: %d\n", LLVMGetIntTypeWidth(t)); break; + case 8: return shd_uint8_type(a); + case 16: return shd_uint16_type(a); + case 32: return shd_uint32_type(a); + case 64: return shd_uint64_type(a); + default: shd_error("Unsupported integer width: %d\n", LLVMGetIntTypeWidth(t)); break; } case LLVMFunctionTypeKind: { unsigned num_params = LLVMCountParamTypes(t); @@ -35,13 +35,15 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) { LLVMGetParamTypes(t, param_types); LARRAY(const Type*, cparam_types, num_params); for (size_t i = 0; i < num_params; i++) - cparam_types[i] = convert_type(p, param_types[i]); - const Type* ret_type = convert_type(p, LLVMGetReturnType(t)); + cparam_types[i] = shd_as_qualified_type(l2s_convert_type(p, param_types[i]), false); + const Type* ret_type = l2s_convert_type(p, LLVMGetReturnType(t)); if (LLVMGetTypeKind(LLVMGetReturnType(t)) == LLVMVoidTypeKind) ret_type = empty_multiple_return_type(a); + else + ret_type = shd_as_qualified_type(ret_type, false); return fn_type(a, (FnType) { - .param_types = nodes(a, num_params, cparam_types), - .return_types = ret_type == empty_multiple_return_type(a) ? empty(a) : singleton(ret_type) + .param_types = shd_nodes(a, num_params, cparam_types), + .return_types = ret_type == empty_multiple_return_type(a) ? shd_empty(a) : shd_singleton(ret_type) }); } case LLVMStructTypeKind: { @@ -49,20 +51,9 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) { Node* decl = NULL; const Node* result = NULL; if (name) { - if (strcmp(name, "struct.__shady_builtin_sampler2D") == 0) - return combined_image_sampler_type(a, (CombinedImageSamplerType) { .image_type = image_type(a, (ImageType) { - //.sampled_type = pack_type(a, (PackType) { .element_type = float_type(a, (Float) { .width = FloatTy32 }), .width = 4 }), - .sampled_type = float_type(a, (Float) { .width = FloatTy32 }), - .dim = 1, - .depth = 0, - .onion = 0, - .multisample = 0, - .sampled = 1, - } ) }); - - decl = nominal_type(p->dst, empty(a), name); + decl = nominal_type(p->dst, shd_empty(a), name); result = type_decl_ref_helper(a, decl); - insert_dict(LLVMTypeRef, const Type*, p->map, t, result); + shd_dict_insert(LLVMTypeRef, const Type*, p->map, t, result); } unsigned size = LLVMCountStructElementTypes(t); @@ -70,11 +61,11 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) { LLVMGetStructElementTypes(t, elements); LARRAY(const Type*, celements, size); for (size_t i = 0; i < size; i++) { - celements[i] = convert_type(p, elements[i]); + celements[i] = l2s_convert_type(p, elements[i]); } const Node* product = record_type(a, (RecordType) { - .members = nodes(a, size, celements) + .members = shd_nodes(a, size, celements) }); if (decl) decl->payload.nom_type.body = product; @@ -84,15 +75,42 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) { } case LLVMArrayTypeKind: { unsigned length = LLVMGetArrayLength(t); - const Type* elem_t = convert_type(p, LLVMGetElementType(t)); - return arr_type(a, (ArrType) { .element_type = elem_t, .size = uint32_literal(a, length)}); + const Type* elem_t = l2s_convert_type(p, LLVMGetElementType(t)); + return arr_type(a, (ArrType) { .element_type = elem_t, .size = shd_uint32_literal(a, length)}); } case LLVMPointerTypeKind: { - AddressSpace as = convert_llvm_address_space(LLVMGetPointerAddressSpace(t)); + unsigned int llvm_as = LLVMGetPointerAddressSpace(t); + if (llvm_as >= 0x1000 && llvm_as <= 0x2000) { + unsigned offset = llvm_as - 0x1000; + unsigned dim = offset & 0xF; + unsigned type_id = (offset >> 4) & 0x3; + const Type* sampled_type = NULL; + switch (type_id) { + case 0x0: sampled_type = float_type(a, (Float) {.width = FloatTy32}); break; + case 0x1: sampled_type = shd_int32_type(a); break; + case 0x2: sampled_type = shd_uint32_type(a); break; + default: assert(false); + } + bool arrayed = (offset >> 6) & 1; + + return sampled_image_type(a, (SampledImageType) {.image_type = image_type(a, (ImageType) { + //.sampled_type = pack_type(a, (PackType) { .element_type = float_type(a, (Float) { .width = FloatTy32 }), .width = 4 }), + .sampled_type = sampled_type, + .dim = dim, + .depth = 0, + .arrayed = arrayed, + .ms = 0, + .sampled = 1, + .imageformat = 0 + })}); + } + AddressSpace as = l2s_convert_llvm_address_space(llvm_as); const Type* pointee = NULL; #if !UNTYPED_POINTERS LLVMTypeRef element_type = LLVMGetElementType(t); pointee = convert_type(p, element_type); +#else + pointee = unit_type(a); #endif return ptr_type(a, (PtrType) { .address_space = as, @@ -101,7 +119,7 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) { } case LLVMVectorTypeKind: { unsigned width = LLVMGetVectorSize(t); - const Type* elem_t = convert_type(p, LLVMGetElementType(t)); + const Type* elem_t = l2s_convert_type(p, LLVMGetElementType(t)); return pack_type(a, (PackType) { .element_type = elem_t, .width = (size_t) width }); } case LLVMMetadataTypeKind: @@ -119,7 +137,7 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) { break; } - error_print("Unsupported type: "); + shd_error_print("Unsupported type: "); LLVMDumpType(t); - error_die(); + shd_error_die(); } diff --git a/src/frontends/llvm/l2s_value.c b/src/frontend/llvm/l2s_value.c similarity index 63% rename from src/frontends/llvm/l2s_value.c rename to src/frontend/llvm/l2s_value.c index 5fa779f47..bdf228e7a 100644 --- a/src/frontends/llvm/l2s_value.c +++ b/src/frontend/llvm/l2s_value.c @@ -3,8 +3,6 @@ #include "portability.h" #include "log.h" #include "dict.h" -#include "../../shady/transform/ir_gen_helpers.h" -#include "../../shady/type.h" static const Node* data_composite(const Type* t, size_t size, LLVMValueRef v) { IrArena* a = t->arena; @@ -12,14 +10,14 @@ static const Node* data_composite(const Type* t, size_t size, LLVMValueRef v) { size_t idc; const char* raw_bytes = LLVMGetAsString(v, &idc); for (size_t i = 0; i < size; i++) { - const Type* et = get_fill_type_element_type(t); + const Type* et = shd_get_fill_type_element_type(t); switch (et->tag) { case Int_TAG: { switch (et->payload.int_type.width) { - case IntTy8: elements[i] = uint8_literal(a, ((uint8_t*) raw_bytes)[i]); break; - case IntTy16: elements[i] = uint16_literal(a, ((uint16_t*) raw_bytes)[i]); break; - case IntTy32: elements[i] = uint32_literal(a, ((uint32_t*) raw_bytes)[i]); break; - case IntTy64: elements[i] = uint64_literal(a, ((uint64_t*) raw_bytes)[i]); break; + case IntTy8: elements[i] = shd_uint8_literal(a, ((uint8_t*) raw_bytes)[i]); break; + case IntTy16: elements[i] = shd_uint16_literal(a, ((uint16_t*) raw_bytes)[i]); break; + case IntTy32: elements[i] = shd_uint32_literal(a, ((uint32_t*) raw_bytes)[i]); break; + case IntTy64: elements[i] = shd_uint64_literal(a, ((uint64_t*) raw_bytes)[i]); break; } break; } @@ -40,16 +38,16 @@ static const Node* data_composite(const Type* t, size_t size, LLVMValueRef v) { default: assert(false); } } - return composite_helper(a, t, nodes(a, size, elements)); + return composite_helper(a, t, shd_nodes(a, size, elements)); } -const Node* convert_value(Parser* p, LLVMValueRef v) { - const Type** found = find_value_dict(LLVMTypeRef, const Type*, p->map, v); +const Node* l2s_convert_value(Parser* p, LLVMValueRef v) { + const Type** found = shd_dict_find_value(LLVMTypeRef, const Type*, p->map, v); if (found) return *found; - IrArena* a = get_module_arena(p->dst); + IrArena* a = shd_module_get_arena(p->dst); const Node* r = NULL; - const Type* t = LLVMGetValueKind(v) != LLVMMetadataAsValueValueKind ? convert_type(p, LLVMTypeOf(v)) : NULL; + const Type* t = LLVMGetValueKind(v) != LLVMMetadataAsValueValueKind ? l2s_convert_type(p, LLVMTypeOf(v)) : NULL; switch (LLVMGetValueKind(v)) { case LLVMArgumentValueKind: @@ -63,35 +61,33 @@ const Node* convert_value(Parser* p, LLVMValueRef v) { case LLVMMemoryPhiValueKind: break; case LLVMFunctionValueKind: - r = convert_function(p, v); + r = l2s_convert_function(p, v); break; case LLVMGlobalAliasValueKind: break; case LLVMGlobalIFuncValueKind: break; case LLVMGlobalVariableValueKind: - r = convert_global(p, v); + r = l2s_convert_global(p, v); break; case LLVMBlockAddressValueKind: break; case LLVMConstantExprValueKind: { String name = LLVMGetValueName(v); if (!name || strlen(name) == 0) - name = unique_name(a, "constant_expr"); - Nodes annotations = singleton(annotation(a, (Annotation) { .name = "SkipOnInfer" })); - annotations = empty(a); - Node* decl = constant(p->dst, annotations, NULL, name); + name = shd_make_unique_name(a, "constant_expr"); + Nodes annotations = shd_singleton(annotation(a, (Annotation) { .name = "Inline" })); + assert(t); + Node* decl = constant(p->dst, annotations, t, name); r = ref_decl_helper(a, decl); - insert_dict(LLVMTypeRef, const Type*, p->map, v, r); - BodyBuilder* bb = begin_body(a); - EmittedInstr emitted = convert_instruction(p, NULL, bb, v); - Nodes types = singleton(convert_type(p, LLVMTypeOf(v))); - decl->payload.constant.instruction = bind_last_instruction_and_wrap_in_block_explicit_return_types(bb, emitted.instruction, &types); + shd_dict_insert(LLVMTypeRef, const Type*, p->map, v, r); + BodyBuilder* bb = shd_bld_begin_pure(a); + decl->payload.constant.value = shd_bld_to_instr_yield_value(bb, l2s_convert_instruction(p, NULL, NULL, bb, v)); return r; } case LLVMConstantDataArrayValueKind: { assert(t->tag == ArrType_TAG); - size_t arr_size = get_int_literal_value(*resolve_to_int_literal(t->payload.arr_type.size), false); + size_t arr_size = shd_get_int_literal_value(*shd_resolve_to_int_literal(t->payload.arr_type.size), false); assert(arr_size >= 0 && arr_size < INT32_MAX && "sanity check"); return data_composite(t, arr_size, v); } @@ -102,15 +98,16 @@ const Node* convert_value(Parser* p, LLVMValueRef v) { return data_composite(t, width, v); } case LLVMConstantStructValueKind: { - assert(t->tag == RecordType_TAG); - size_t size = t->payload.record_type.members.count; + const Node* actual_t = shd_get_maybe_nominal_type_body(t); + assert(actual_t->tag == RecordType_TAG); + size_t size = actual_t->payload.record_type.members.count; LARRAY(const Node*, elements, size); for (size_t i = 0; i < size; i++) { LLVMValueRef value = LLVMGetOperand(v, i); assert(value); - elements[i] = convert_value(p, value); + elements[i] = l2s_convert_value(p, value); } - return composite_helper(a, t, nodes(a, size, elements)); + return composite_helper(a, t, shd_nodes(a, size, elements)); } case LLVMConstantVectorValueKind: { assert(t->tag == PackType_TAG); @@ -119,30 +116,25 @@ const Node* convert_value(Parser* p, LLVMValueRef v) { for (size_t i = 0; i < size; i++) { LLVMValueRef value = LLVMGetOperand(v, i); assert(value); - elements[i] = convert_value(p, value); + elements[i] = l2s_convert_value(p, value); } - return composite_helper(a, t, nodes(a, size, elements)); + return composite_helper(a, t, shd_nodes(a, size, elements)); } case LLVMUndefValueValueKind: - return undef(a, (Undef) { .type = convert_type(p, LLVMTypeOf(v)) }); + return undef(a, (Undef) { .type = l2s_convert_type(p, LLVMTypeOf(v)) }); case LLVMConstantAggregateZeroValueKind: - return get_default_zero_value(a, convert_type(p, LLVMTypeOf(v))); + return shd_get_default_value(a, l2s_convert_type(p, LLVMTypeOf(v))); case LLVMConstantArrayValueKind: { assert(t->tag == ArrType_TAG); - if (LLVMIsConstantString(v)) { - size_t idc; - r = string_lit_helper(a, LLVMGetAsString(v, &idc)); - break; - } - size_t arr_size = get_int_literal_value(*resolve_to_int_literal(t->payload.arr_type.size), false); + size_t arr_size = shd_get_int_literal_value(*shd_resolve_to_int_literal(t->payload.arr_type.size), false); assert(arr_size >= 0 && arr_size < INT32_MAX && "sanity check"); LARRAY(const Node*, elements, arr_size); for (size_t i = 0; i < arr_size; i++) { LLVMValueRef value = LLVMGetOperand(v, i); assert(value); - elements[i] = convert_value(p, value); + elements[i] = l2s_convert_value(p, value); } - return composite_helper(a, t, nodes(a, arr_size, elements)); + return composite_helper(a, t, shd_nodes(a, arr_size, elements)); } case LLVMConstantIntValueKind: { if (t->tag == Bool_TAG) { @@ -152,10 +144,10 @@ const Node* convert_value(Parser* p, LLVMValueRef v) { assert(t->tag == Int_TAG); unsigned long long value = LLVMConstIntGetZExtValue(v); switch (t->payload.int_type.width) { - case IntTy8: return uint8_literal(a, value); - case IntTy16: return uint16_literal(a, value); - case IntTy32: return uint32_literal(a, value); - case IntTy64: return uint64_literal(a, value); + case IntTy8: return shd_uint8_literal(a, value); + case IntTy16: return shd_uint16_literal(a, value); + case IntTy32: return shd_uint32_literal(a, value); + case IntTy64: return shd_uint64_literal(a, value); } } case LLVMConstantFPValueKind: { @@ -165,7 +157,7 @@ const Node* convert_value(Parser* p, LLVMValueRef v) { uint64_t u = 0; static_assert(sizeof(u) == sizeof(d), ""); switch (t->payload.float_type.width) { - case FloatTy16: error("todo") + case FloatTy16: shd_error("todo") case FloatTy32: { float f = (float) d; static_assert(sizeof(f) == sizeof(uint32_t), ""); @@ -185,23 +177,23 @@ const Node* convert_value(Parser* p, LLVMValueRef v) { break; case LLVMMetadataAsValueValueKind: { LLVMMetadataRef meta = LLVMValueAsMetadata(v); - r = convert_metadata(p, meta); + r = l2s_convert_metadata(p, meta); } case LLVMInlineAsmValueKind: break; case LLVMInstructionValueKind: break; case LLVMPoisonValueValueKind: - return undef(a, (Undef) { .type = convert_type(p, LLVMTypeOf(v)) }); + return undef(a, (Undef) { .type = l2s_convert_type(p, LLVMTypeOf(v)) }); } if (r) { - insert_dict(LLVMTypeRef, const Type*, p->map, v, r); + shd_dict_insert(LLVMTypeRef, const Type*, p->map, v, r); return r; } - error_print("Failed to find value "); + shd_error_print("Failed to find value "); LLVMDumpValue(v); - error_print(" in the already emitted map (kind=%d)\n", LLVMGetValueKind(v)); - error_die(); + shd_error_print(" in the already emitted map (kind=%d)\n", LLVMGetValueKind(v)); + shd_error_die(); } diff --git a/src/frontend/slim/CMakeLists.txt b/src/frontend/slim/CMakeLists.txt new file mode 100644 index 000000000..1ca8a23a1 --- /dev/null +++ b/src/frontend/slim/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library(slim_parser STATIC slim_driver.c parser.c token.c bind.c normalize.c infer.c) +target_link_libraries(slim_parser PUBLIC common api) +target_link_libraries(slim_parser PRIVATE shady) +target_include_directories(slim_parser PUBLIC $) +target_link_libraries(shady PRIVATE "$") + +generate_extinst_headers(SlimFrontendOps extinst.spv-shady-slim-frontend.grammar.json) +target_link_libraries(slim_parser PRIVATE SlimFrontendOps) diff --git a/src/frontend/slim/bind.c b/src/frontend/slim/bind.c new file mode 100644 index 000000000..c95557656 --- /dev/null +++ b/src/frontend/slim/bind.c @@ -0,0 +1,389 @@ +#include "SlimFrontendOps.h" + +#include "shady/pass.h" +#include "shady/fe/slim.h" + +#include "../shady/ir_private.h" +#include "../shady/analysis/uses.h" + +#include "list.h" +#include "log.h" +#include "portability.h" + +#include +#include + +typedef struct NamedBindEntry_ NamedBindEntry; +struct NamedBindEntry_ { + const char* name; + bool is_var; + Node* node; + NamedBindEntry* next; +}; + +typedef struct { + Rewriter rewriter; + const UsesMap* uses; + + const Node* current_function; + NamedBindEntry* local_variables; +} Context; + +typedef struct { + bool is_var; + const Node* node; +} Resolved; + +static Resolved resolve_using_name(Context* ctx, const char* name) { + for (NamedBindEntry* entry = ctx->local_variables; entry != NULL; entry = entry->next) { + if (strcmp(entry->name, name) == 0) { + return (Resolved) { + .is_var = entry->is_var, + .node = entry->node + }; + } + } + + Nodes new_decls = shd_module_get_declarations(ctx->rewriter.dst_module); + for (size_t i = 0; i < new_decls.count; i++) { + const Node* decl = new_decls.nodes[i]; + if (strcmp(get_declaration_name(decl), name) == 0) { + return (Resolved) { + .is_var = decl->tag == GlobalVariable_TAG, + .node = decl + }; + } + } + + Nodes old_decls = shd_module_get_declarations(ctx->rewriter.src_module); + for (size_t i = 0; i < old_decls.count; i++) { + const Node* old_decl = old_decls.nodes[i]; + if (strcmp(get_declaration_name(old_decl), name) == 0) { + Context top_ctx = *ctx; + top_ctx.current_function = NULL; + top_ctx.local_variables = NULL; + const Node* decl = shd_rewrite_node(&top_ctx.rewriter, old_decl); + return (Resolved) { + .is_var = decl->tag == GlobalVariable_TAG, + .node = decl + }; + } + } + + shd_error("could not resolve node %s", name) +} + +static void add_binding(Context* ctx, bool is_var, String name, const Node* node) { + assert(name); + NamedBindEntry* entry = shd_arena_alloc(ctx->rewriter.dst_arena->arena, sizeof(NamedBindEntry)); + *entry = (NamedBindEntry) { + .name = shd_string(ctx->rewriter.dst_arena, name), + .is_var = is_var, + .node = (Node*) node, + .next = NULL + }; + entry->next = ctx->local_variables; + ctx->local_variables = entry; +} + +static const Node* get_node_address(Context* ctx, const Node* node); + +static const Node* get_node_address_maybe(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (node->tag) { + case ExtInstr_TAG: { + ExtInstr payload = node->payload.ext_instr; + if (strcmp(payload.set, "shady.frontend") == 0) { + if (payload.opcode == SlimFrontendOpsSlimSubscriptSHADY) { + assert(payload.operands.count == 2); + const Node* src_ptr = get_node_address_maybe(ctx, shd_first(payload.operands)); + if (src_ptr == NULL) + return NULL; + const Node* index = shd_rewrite_node(&ctx->rewriter, payload.operands.nodes[1]); + return mem_and_value(a, (MemAndValue) { + .mem = shd_rewrite_node(r, payload.mem), + .value = ptr_composite_element(a, (PtrCompositeElement) { .ptr = src_ptr, .index = index }), + }); + } else if (payload.opcode == SlimFrontendOpsSlimDereferenceSHADY) { + assert(payload.operands.count == 1); + return mem_and_value(a, (MemAndValue) { + .mem = shd_rewrite_node(r, payload.mem), + .value = shd_rewrite_node(&ctx->rewriter, shd_first(payload.operands)), + }); + } else if (payload.opcode == SlimFrontendOpsSlimUnboundSHADY) { + if (payload.mem) + shd_rewrite_node(&ctx->rewriter, payload.mem); + Resolved entry = resolve_using_name(ctx, shd_get_string_literal(a, shd_first(payload.operands))); + // can't take the address if it's not a var! + if (!entry.is_var) + return NULL; + return entry.node; + } + } + break; + } + default: break; + } + return NULL; +} + +static const Node* get_node_address(Context* ctx, const Node* node) { + const Node* got = get_node_address_maybe(ctx, node); + if (!got) shd_error("This doesn't really look like a place expression...") + return got; +} + +static const Node* desugar_bind_identifiers(Context* ctx, ExtInstr instr) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + BodyBuilder* bb = instr.mem ? shd_bld_begin(a, shd_rewrite_node(r, instr.mem)) : shd_bld_begin_pure(a); + + switch (instr.opcode) { + case SlimFrontendOpsSlimBindValSHADY: { + size_t names_count = instr.operands.count - 1; + const Node** names = &instr.operands.nodes[1]; + const Node* value = shd_rewrite_node(r, shd_first(instr.operands)); + Nodes results = shd_deconstruct_composite(a, value, names_count); + for (size_t i = 0; i < names_count; i++) { + String name = shd_get_string_literal(a, names[i]); + shd_log_fmt(DEBUGV, "Bound immutable variable '%s'\n", name); + add_binding(ctx, false, name, results.nodes[i]); + } + break; + } + case SlimFrontendOpsSlimBindVarSHADY: { + size_t names_count = (instr.operands.count - 1) / 2; + const Node** names = &instr.operands.nodes[1]; + const Node** types = &instr.operands.nodes[1 + names_count]; + const Node* value = shd_rewrite_node(r, shd_first(instr.operands)); + Nodes results = shd_deconstruct_composite(a, value, names_count); + for (size_t i = 0; i < names_count; i++) { + String name = shd_get_string_literal(a, names[i]); + const Type* type_annotation = types[i]; + assert(type_annotation); + const Node* alloca = stack_alloc(a, (StackAlloc) { .type = shd_rewrite_node(&ctx->rewriter, type_annotation), .mem = shd_bb_mem(bb) }); + const Node* ptr = shd_bld_add_instruction_extract_count(bb, alloca, 1).nodes[0]; + shd_set_value_name(ptr, name); + shd_bld_add_instruction_extract_count(bb, store(a, (Store) { .ptr = ptr, .value = results.nodes[0], .mem = shd_bb_mem(bb) }), 0); + + add_binding(ctx, true, name, ptr); + shd_log_fmt(DEBUGV, "Bound mutable variable '%s'\n", name); + } + break; + } + case SlimFrontendOpsSlimBindContinuationsSHADY: { + size_t names_count = (instr.operands.count ) / 2; + const Node** names = &instr.operands.nodes[0]; + const Node** conts = &instr.operands.nodes[0 + names_count]; + LARRAY(Node*, bbs, names_count); + for (size_t i = 0; i < names_count; i++) { + String name = shd_get_string_literal(a, names[i]); + Nodes nparams = shd_recreate_params(r, get_abstraction_params(conts[i])); + bbs[i] = basic_block(a, nparams, shd_get_abstraction_name_unsafe(conts[i])); + shd_register_processed(r, conts[i], bbs[i]); + add_binding(ctx, false, name, bbs[i]); + shd_log_fmt(DEBUGV, "Bound continuation '%s'\n", name); + } + for (size_t i = 0; i < names_count; i++) { + Context cont_ctx = *ctx; + Nodes bb_params = get_abstraction_params(bbs[i]); + for (size_t j = 0; j < bb_params.count; j++) { + const Node* bb_param = bb_params.nodes[j]; + assert(bb_param->tag == Param_TAG); + String param_name = bb_param->payload.param.name; + if (param_name) + add_binding(&cont_ctx, false, param_name, bb_param); + } + shd_set_abstraction_body(bbs[i], shd_rewrite_node(&cont_ctx.rewriter, get_abstraction_body(conts[i]))); + } + } + } + + return shd_bld_to_instr_yield_values(bb, shd_empty(a)); +} + +static const Node* rewrite_decl(Context* ctx, const Node* decl) { + assert(is_declaration(decl)); + switch (decl->tag) { + case GlobalVariable_TAG: { + const GlobalVariable* ogvar = &decl->payload.global_variable; + Node* bound = global_var(ctx->rewriter.dst_module, shd_rewrite_nodes(&ctx->rewriter, ogvar->annotations), shd_rewrite_node(&ctx->rewriter, ogvar->type), ogvar->name, ogvar->address_space); + shd_register_processed(&ctx->rewriter, decl, bound); + bound->payload.global_variable.init = shd_rewrite_node(&ctx->rewriter, decl->payload.global_variable.init); + return bound; + } + case Constant_TAG: { + const Constant* cnst = &decl->payload.constant; + Node* bound = constant(ctx->rewriter.dst_module, shd_rewrite_nodes(&ctx->rewriter, cnst->annotations), shd_rewrite_node(&ctx->rewriter, decl->payload.constant.type_hint), cnst->name); + shd_register_processed(&ctx->rewriter, decl, bound); + bound->payload.constant.value = shd_rewrite_node(&ctx->rewriter, decl->payload.constant.value); + return bound; + } + case Function_TAG: { + Nodes new_fn_params = shd_recreate_params(&ctx->rewriter, decl->payload.fun.params); + Node* bound = function(ctx->rewriter.dst_module, new_fn_params, decl->payload.fun.name, shd_rewrite_nodes(&ctx->rewriter, decl->payload.fun.annotations), shd_rewrite_nodes(&ctx->rewriter, decl->payload.fun.return_types)); + shd_register_processed(&ctx->rewriter, decl, bound); + Context fn_ctx = *ctx; + for (size_t i = 0; i < new_fn_params.count; i++) { + add_binding(&fn_ctx, false, decl->payload.fun.params.nodes[i]->payload.param.name, new_fn_params.nodes[i]); + } + shd_register_processed_list(&ctx->rewriter, decl->payload.fun.params, new_fn_params); + + if (decl->payload.fun.body) { + fn_ctx.current_function = bound; + shd_set_abstraction_body(bound, shd_rewrite_node(&fn_ctx.rewriter, decl->payload.fun.body)); + } + return bound; + } + case NominalType_TAG: { + Node* bound = nominal_type(ctx->rewriter.dst_module, shd_rewrite_nodes(&ctx->rewriter, decl->payload.nom_type.annotations), decl->payload.nom_type.name); + shd_register_processed(&ctx->rewriter, decl, bound); + bound->payload.nom_type.body = shd_rewrite_node(&ctx->rewriter, decl->payload.nom_type.body); + return bound; + } + default: shd_error("unknown declaration kind"); + } + + shd_error("unreachable") + //register_processed(&ctx->rewriter, decl, bound); + //return bound; +} + +static bool is_used_as_value(Context* ctx, const Node* node) { + const Use* use = shd_get_first_use(ctx->uses, node); + for (;use;use = use->next_use) { + if (use->operand_class != NcMem) { + if (use->user->tag == ExtInstr_TAG && strcmp(use->user->payload.ext_instr.set, "shady.frontend") == 0) { + if (use->user->payload.ext_instr.opcode == SlimFrontendOpsSlimAssignSHADY && use->operand_index == 0) + continue; + if (use->user->payload.ext_instr.opcode == SlimFrontendOpsSlimSubscriptSHADY && use->operand_index == 0) { + const Node* ptr = get_node_address_maybe(ctx, node); + if (ptr) + continue; + } + } + return true; + } + } + return false; +} + +static const Node* bind_node(Context* ctx, const Node* node) { + IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + + // in case the node is an l-value, we load it + // const Node* lhs = get_node_address_safe(ctx, node); + // if (lhs) { + // return load(a, (Load) { lhs, .mem = rewrite_node() }); + // } + + switch (node->tag) { + case Function_TAG: + case Constant_TAG: + case GlobalVariable_TAG: + case NominalType_TAG: { + assert(is_declaration(node)); + return rewrite_decl(ctx, node); + } + case Param_TAG: shd_error("the binders should be handled such that this node is never reached"); + case BasicBlock_TAG: { + assert(is_basic_block(node)); + Nodes new_params = shd_recreate_params(&ctx->rewriter, node->payload.basic_block.params); + String name = node->payload.basic_block.name; + Node* new_bb = basic_block(a, new_params, name); + Context bb_ctx = *ctx; + ctx = &bb_ctx; + if (name) + add_binding(ctx, false, name, new_bb); + for (size_t i = 0; i < new_params.count; i++) { + String param_name = new_params.nodes[i]->payload.param.name; + if (param_name) + add_binding(ctx, false, param_name, new_params.nodes[i]); + } + shd_register_processed(&ctx->rewriter, node, new_bb); + shd_register_processed_list(&ctx->rewriter, node->payload.basic_block.params, new_params); + shd_set_abstraction_body(new_bb, shd_rewrite_node(&ctx->rewriter, node->payload.basic_block.body)); + return new_bb; + } + case ExtInstr_TAG: { + ExtInstr payload = node->payload.ext_instr; + if (strcmp("shady.frontend", payload.set) == 0) { + switch ((enum SlimFrontendOpsInstructions) payload.opcode) { + case SlimFrontendOpsSlimDereferenceSHADY: + if (!is_used_as_value(ctx, node)) + return shd_rewrite_node(r, payload.mem); + return load(a, (Load) { + .ptr = shd_rewrite_node(r, shd_first(payload.operands)), + .mem = shd_rewrite_node(r, payload.mem), + }); + case SlimFrontendOpsSlimAssignSHADY: { + const Node* target_ptr = get_node_address(ctx, payload.operands.nodes[0]); + assert(target_ptr); + const Node* value = shd_rewrite_node(r, payload.operands.nodes[1]); + return store(a, (Store) { .ptr = target_ptr, .value = value, .mem = shd_rewrite_node(r, payload.mem) }); + } + case SlimFrontendOpsSlimAddrOfSHADY: { + const Node* target_ptr = get_node_address(ctx, payload.operands.nodes[0]); + return mem_and_value(a, (MemAndValue) { .value = target_ptr, .mem = shd_rewrite_node(r, payload.mem) }); + } + case SlimFrontendOpsSlimSubscriptSHADY: { + const Node* ptr = get_node_address_maybe(ctx, node); + if (ptr) + return load(a, (Load) { + .ptr = ptr, + .mem = shd_rewrite_node(r, payload.mem) + }); + return mem_and_value(a, (MemAndValue) { + .value = prim_op(a, (PrimOp) { + .op = extract_op, + .operands = mk_nodes(a, shd_rewrite_node(r, payload.operands.nodes[0]), shd_rewrite_node(r, payload.operands.nodes[1])) + }), + .mem = shd_rewrite_node(r, payload.mem) } + ); + } + case SlimFrontendOpsSlimUnboundSHADY: { + const Node* mem = NULL; + if (payload.mem) { + if (!is_used_as_value(ctx, node)) + return shd_rewrite_node(r, payload.mem); + mem = shd_rewrite_node(r, payload.mem); + } + Resolved entry = resolve_using_name(ctx, shd_get_string_literal(a, shd_first(payload.operands))); + if (entry.is_var) { + return load(a, (Load) { .ptr = entry.node, .mem = mem }); + } else if (mem) { + return mem_and_value(a, (MemAndValue) { .value = entry.node, .mem = mem }); + } + return entry.node; + } + default: return desugar_bind_identifiers(ctx, payload); + } + } + break; + } + default: break; + } + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* slim_pass_bind(SHADY_UNUSED const CompilerConfig* compiler_config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + assert(!src->arena->config.name_bound); + aconfig.name_bound = true; + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) bind_node), + .local_variables = NULL, + .current_function = NULL, + .uses = shd_new_uses_map_module(src, 0), + }; + + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_uses_map(ctx.uses); + return dst; +} diff --git a/src/frontend/slim/extinst.spv-shady-slim-frontend.grammar.json b/src/frontend/slim/extinst.spv-shady-slim-frontend.grammar.json new file mode 100644 index 000000000..bca5b1323 --- /dev/null +++ b/src/frontend/slim/extinst.spv-shady-slim-frontend.grammar.json @@ -0,0 +1,67 @@ +{ + "revision" : 1, + "instructions" : [ + { + "opname" : "SlimDereferenceSHADY", + "opcode" : 1, + "operands" : [ + { "kind" : "IdRef", "name" : "'Target'" } + ] + }, + { + "opname" : "SlimAssignSHADY", + "opcode" : 2, + "operands" : [ + { "kind" : "IdRef", "name" : "'Target'" }, + { "kind" : "IdRef", "name" : "'Value'" } + ] + }, + { + "opname" : "SlimAddrOfSHADY", + "opcode" : 3, + "operands" : [ + { "kind" : "IdRef", "name" : "'Target'" } + ] + }, + { + "opname" : "SlimSubscriptSHADY", + "opcode" : 4, + "operands" : [ + { "kind" : "IdRef", "name" : "'Value'" }, + { "kind" : "IdRef", "name" : "'Index'" } + ] + }, + { + "opname" : "SlimBindValSHADY", + "opcode" : 5, + "operands" : [ + { "kind" : "IdRef", "name" : "'Value'" }, + { "kind" : "IdRef", "name" : "'Names'", "quantifier" : "*" } + ] + }, + { + "opname" : "SlimBindVarSHADY", + "opcode" : 6, + "operands" : [ + { "kind" : "IdRef", "name" : "'Value'" }, + { "kind" : "IdRef", "name" : "'Names'", "quantifier" : "*" }, + { "kind" : "IdRef", "name" : "'Types'", "quantifier" : "*" } + ] + }, + { + "opname" : "SlimBindContinuationsSHADY", + "opcode" : 7, + "operands" : [ + { "kind" : "IdRef", "name" : "'Name'" }, + { "kind" : "IdRef", "name" : "'Continuations'", "quantifier" : "*" } + ] + }, + { + "opname" : "SlimUnboundSHADY", + "opcode" : 8, + "operands" : [ + { "kind" : "IdRef", "name" : "'Identifier'" } + ] + } + ] +} diff --git a/src/frontend/slim/infer.c b/src/frontend/slim/infer.c new file mode 100644 index 000000000..5f22073a7 --- /dev/null +++ b/src/frontend/slim/infer.c @@ -0,0 +1,622 @@ +#include "shady/pass.h" + +#include "../shady/check.h" + +#include "log.h" +#include "portability.h" + +#include +#include + +#pragma GCC diagnostic error "-Wswitch" + +static Nodes annotate_all_types(IrArena* a, Nodes types, bool uniform_by_default) { + LARRAY(const Type*, ntypes, types.count); + for (size_t i = 0; i < types.count; i++) { + if (shd_is_data_type(types.nodes[i])) + ntypes[i] = qualified_type(a, (QualifiedType) { + .type = types.nodes[i], + .is_uniform = uniform_by_default, + }); + else + ntypes[i] = types.nodes[i]; + } + return shd_nodes(a, types.count, ntypes); +} + +typedef struct { + Rewriter rewriter; + + const Node* current_fn; + const Type* expected_type; +} Context; + +static const Node* infer_value(Context* ctx, const Node* node, const Type* expected_type); +static const Node* infer_instruction(Context* ctx, const Node* node, const Node* expected_type); + +static const Node* infer(Context* ctx, const Node* node, const Type* expect) { + Context ctx2 = *ctx; + ctx2.expected_type = expect; + return shd_rewrite_node(&ctx2.rewriter, node); +} + +static Nodes infer_nodes(Context* ctx, Nodes nodes) { + Context ctx2 = *ctx; + ctx2.expected_type = NULL; + return shd_rewrite_nodes(&ctx->rewriter, nodes); +} + +#define rewrite_node shd_error("don't use this directly, use the 'infer' and 'infer_node' helpers") +#define rewrite_nodes rewrite_node + +static const Node* infer_annotation(Context* ctx, const Node* node) { + IrArena* a = ctx->rewriter.dst_arena; + assert(is_annotation(node)); + switch (node->tag) { + case Annotation_TAG: return annotation(a, (Annotation) { .name = node->payload.annotation.name }); + case AnnotationValue_TAG: return annotation_value(a, (AnnotationValue) { .name = node->payload.annotation_value.name, .value = infer(ctx, node->payload.annotation_value.value, NULL) }); + case AnnotationValues_TAG: return annotation_values(a, (AnnotationValues) { .name = node->payload.annotation_values.name, .values = infer_nodes(ctx, node->payload.annotation_values.values) }); + case AnnotationCompound_TAG: return annotation_compound(a, (AnnotationCompound) { .name = node->payload.annotation_compound.name, .entries = infer_nodes(ctx, node->payload.annotation_compound.entries) }); + default: shd_error("Not an annotation"); + } +} + +static const Node* infer_type(Context* ctx, const Type* type) { + IrArena* a = ctx->rewriter.dst_arena; + switch (type->tag) { + case ArrType_TAG: { + const Node* size = infer(ctx, type->payload.arr_type.size, NULL); + return arr_type(a, (ArrType) { + .size = size, + .element_type = infer(ctx, type->payload.arr_type.element_type, NULL) + }); + } + case PtrType_TAG: { + const Node* element_type = infer(ctx, type->payload.ptr_type.pointed_type, NULL); + assert(element_type); + //if (!element_type) + // element_type = unit_type(a); + return ptr_type(a, (PtrType) { .pointed_type = element_type, .address_space = type->payload.ptr_type.address_space }); + } + default: return shd_recreate_node(&ctx->rewriter, type); + } +} + +static const Node* infer_decl(Context* ctx, const Node* node) { + assert(is_declaration(node)); + if (shd_lookup_annotation(node, "SkipOnInfer")) + return NULL; + + IrArena* a = ctx->rewriter.dst_arena; + switch (is_declaration(node)) { + case Function_TAG: { + Context body_context = *ctx; + + LARRAY(const Node*, nparams, node->payload.fun.params.count); + for (size_t i = 0; i < node->payload.fun.params.count; i++) { + const Param* old_param = &node->payload.fun.params.nodes[i]->payload.param; + const Type* imported_param_type = infer(ctx, old_param->type, NULL); + nparams[i] = param(a, imported_param_type, old_param->name); + shd_register_processed(&body_context.rewriter, node->payload.fun.params.nodes[i], nparams[i]); + } + + Nodes nret_types = annotate_all_types(a, infer_nodes(ctx, node->payload.fun.return_types), false); + Node* fun = function(ctx->rewriter.dst_module, shd_nodes(a, node->payload.fun.params.count, nparams), shd_string(a, node->payload.fun.name), infer_nodes(ctx, node->payload.fun.annotations), nret_types); + shd_register_processed(&ctx->rewriter, node, fun); + body_context.current_fn = fun; + shd_set_abstraction_body(fun, infer(&body_context, node->payload.fun.body, NULL)); + return fun; + } + case Constant_TAG: { + const Constant* oconstant = &node->payload.constant; + const Type* imported_hint = infer(ctx, oconstant->type_hint, NULL); + const Node* instruction = NULL; + if (imported_hint) { + assert(shd_is_data_type(imported_hint)); + const Node* s = shd_as_qualified_type(imported_hint, true); + if (oconstant->value) + instruction = infer(ctx, oconstant->value, s); + } else if (oconstant->value) { + instruction = infer(ctx, oconstant->value, NULL); + } + if (instruction) + imported_hint = shd_get_unqualified_type(instruction->type); + assert(imported_hint); + + Node* nconstant = constant(ctx->rewriter.dst_module, infer_nodes(ctx, oconstant->annotations), imported_hint, oconstant->name); + shd_register_processed(&ctx->rewriter, node, nconstant); + nconstant->payload.constant.value = instruction; + + return nconstant; + } + case GlobalVariable_TAG: { + const GlobalVariable* old_gvar = &node->payload.global_variable; + const Type* imported_ty = infer(ctx, old_gvar->type, NULL); + Node* ngvar = global_var(ctx->rewriter.dst_module, infer_nodes(ctx, old_gvar->annotations), imported_ty, old_gvar->name, old_gvar->address_space); + shd_register_processed(&ctx->rewriter, node, ngvar); + + ngvar->payload.global_variable.init = infer(ctx, old_gvar->init, shd_as_qualified_type(imported_ty, true)); + return ngvar; + } + case NominalType_TAG: { + const NominalType* onom_type = &node->payload.nom_type; + Node* nnominal_type = nominal_type(ctx->rewriter.dst_module, infer_nodes(ctx, onom_type->annotations), onom_type->name); + shd_register_processed(&ctx->rewriter, node, nnominal_type); + nnominal_type->payload.nom_type.body = infer(ctx, onom_type->body, NULL); + return nnominal_type; + } + case NotADeclaration: shd_error("not a decl"); + } +} + +/// Like get_unqualified_type but won't error out if type wasn't qualified to begin with +static const Type* remove_uniformity_qualifier(const Node* type) { + if (shd_is_value_type(type)) + return shd_get_unqualified_type(type); + return type; +} + +static const Node* infer_value(Context* ctx, const Node* node, const Type* expected_type) { + if (!node) return NULL; + + IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + switch (is_value(node)) { + case NotAValue: shd_error(""); + case Param_TAG: + case Value_ConstrainedValue_TAG: { + const Type* type = infer(ctx, node->payload.constrained.type, NULL); + bool expect_uniform = false; + if (expected_type) { + expect_uniform = shd_deconstruct_qualified_type(&expected_type); + assert(shd_is_subtype(expected_type, type)); + } + return infer(ctx, node->payload.constrained.value, shd_as_qualified_type(type, expect_uniform)); + } + case IntLiteral_TAG: { + if (expected_type) { + expected_type = remove_uniformity_qualifier(expected_type); + assert(expected_type->tag == Int_TAG); + assert(expected_type->payload.int_type.width == node->payload.int_literal.width); + } + return int_literal(a, (IntLiteral) { + .width = node->payload.int_literal.width, + .is_signed = node->payload.int_literal.is_signed, + .value = node->payload.int_literal.value}); + } + case UntypedNumber_TAG: { + char* endptr; + int64_t i = strtoll(node->payload.untyped_number.plaintext, &endptr, 10); + if (!expected_type) { + bool valid_int = *endptr == '\0'; + expected_type = valid_int ? shd_int32_type(a) : shd_fp32_type(a); + } + expected_type = remove_uniformity_qualifier(expected_type); + if (expected_type->tag == Int_TAG) { + // TODO chop off extra bits based on width ? + return int_literal(a, (IntLiteral) { + .width = expected_type->payload.int_type.width, + .is_signed = expected_type->payload.int_literal.is_signed, + .value = i + }); + } else if (expected_type->tag == Float_TAG) { + uint64_t v; + switch (expected_type->payload.float_type.width) { + case FloatTy16: + shd_error("TODO: implement fp16 parsing"); + case FloatTy32: + assert(sizeof(float) == sizeof(uint32_t)); + float f = strtof(node->payload.untyped_number.plaintext, NULL); + memcpy(&v, &f, sizeof(uint32_t)); + break; + case FloatTy64: + assert(sizeof(double) == sizeof(uint64_t)); + double d = strtod(node->payload.untyped_number.plaintext, NULL); + memcpy(&v, &d, sizeof(uint64_t)); + break; + } + return float_literal(a, (FloatLiteral) {.value = v, .width = expected_type->payload.float_type.width}); + } + } + case FloatLiteral_TAG: { + if (expected_type) { + expected_type = remove_uniformity_qualifier(expected_type); + assert(expected_type->tag == Float_TAG); + assert(expected_type->payload.float_type.width == node->payload.float_literal.width); + } + return float_literal(a, (FloatLiteral) { .width = node->payload.float_literal.width, .value = node->payload.float_literal.value }); + } + case True_TAG: return true_lit(a); + case False_TAG: return false_lit(a); + case StringLiteral_TAG: return string_lit(a, (StringLiteral) { .string = shd_string(a, node->payload.string_lit.string )}); + case RefDecl_TAG: break; + case FnAddr_TAG: break; + case Value_Undef_TAG: break; + case Value_Composite_TAG: { + const Node* elem_type = infer(ctx, node->payload.composite.type, NULL); + bool uniform = false; + if (elem_type && expected_type) { + assert(shd_is_subtype(shd_get_unqualified_type(expected_type), elem_type)); + } else if (expected_type) { + uniform = shd_deconstruct_qualified_type(&elem_type); + elem_type = expected_type; + } + + Nodes omembers = node->payload.composite.contents; + LARRAY(const Node*, inferred, omembers.count); + if (elem_type) { + Nodes expected_members = shd_get_composite_type_element_types(elem_type); + for (size_t i = 0; i < omembers.count; i++) + inferred[i] = infer(ctx, omembers.nodes[i], qualified_type(a, (QualifiedType) { .is_uniform = uniform, .type = expected_members.nodes[i] })); + } else { + for (size_t i = 0; i < omembers.count; i++) + inferred[i] = infer(ctx, omembers.nodes[i], NULL); + } + Nodes nmembers = shd_nodes(a, omembers.count, inferred); + + // Composites are tuples by default + if (!elem_type) + elem_type = record_type(a, (RecordType) { .members = shd_strip_qualifiers(a, shd_get_values_types(a, nmembers)) }); + + return composite_helper(a, elem_type, nmembers); + } + case Value_Fill_TAG: { + const Node* composite_t = infer(ctx, node->payload.fill.type, NULL); + assert(composite_t); + bool uniform = false; + if (composite_t && expected_type) { + assert(shd_is_subtype(shd_get_unqualified_type(expected_type), composite_t)); + } else if (expected_type) { + uniform = shd_deconstruct_qualified_type(&composite_t); + composite_t = expected_type; + } + assert(composite_t); + const Node* element_t = shd_get_fill_type_element_type(composite_t); + const Node* value = infer(ctx, node->payload.fill.value, qualified_type(a, (QualifiedType) { .is_uniform = uniform, .type = element_t })); + return fill(a, (Fill) { .type = composite_t, .value = value }); + } + default: break; + } + return shd_recreate_node(&ctx->rewriter, node); +} + +static const Node* infer_case(Context* ctx, const Node* node, Nodes inferred_arg_type) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + assert(inferred_arg_type.count == node->payload.basic_block.params.count || node->payload.basic_block.params.count == 0); + + Context body_context = *ctx; + LARRAY(const Node*, nparams, inferred_arg_type.count); + for (size_t i = 0; i < inferred_arg_type.count; i++) { + if (node->payload.basic_block.params.count == 0) { + // syntax sugar: make up a parameter if there was none + nparams[i] = param(a, inferred_arg_type.nodes[i], shd_make_unique_name(a, "_")); + } else { + const Param* old_param = &node->payload.basic_block.params.nodes[i]->payload.param; + // for the param type: use the inferred one if none is already provided + // if one is provided, check the inferred argument type is a subtype of the param type + const Type* param_type = old_param->type ? infer_type(ctx, old_param->type) : NULL; + // and do not use the provided param type if it is an untyped ptr + if (!param_type || param_type->tag != PtrType_TAG || param_type->payload.ptr_type.pointed_type) + param_type = inferred_arg_type.nodes[i]; + assert(shd_is_subtype(param_type, inferred_arg_type.nodes[i])); + nparams[i] = param(a, param_type, old_param->name); + shd_register_processed(&body_context.rewriter, node->payload.basic_block.params.nodes[i], nparams[i]); + } + } + + Node* new_case = basic_block(a, shd_nodes(a, inferred_arg_type.count, nparams), shd_get_abstraction_name_unsafe(node)); + shd_register_processed(r, node, new_case); + shd_set_abstraction_body(new_case, infer(&body_context, node->payload.basic_block.body, NULL)); + return new_case; +} + +static const Node* _infer_basic_block(Context* ctx, const Node* node) { + assert(is_basic_block(node)); + IrArena* a = ctx->rewriter.dst_arena; + + Context body_context = *ctx; + LARRAY(const Node*, nparams, node->payload.basic_block.params.count); + for (size_t i = 0; i < node->payload.basic_block.params.count; i++) { + const Param* old_param = &node->payload.basic_block.params.nodes[i]->payload.param; + // for the param type: use the inferred one if none is already provided + // if one is provided, check the inferred argument type is a subtype of the param type + const Type* param_type = infer(ctx, old_param->type, NULL); + assert(param_type); + nparams[i] = param(a, param_type, old_param->name); + shd_register_processed(&body_context.rewriter, node->payload.basic_block.params.nodes[i], nparams[i]); + } + + Node* bb = basic_block(a, shd_nodes(a, node->payload.basic_block.params.count, nparams), node->payload.basic_block.name); + assert(bb); + shd_register_processed(&ctx->rewriter, node, bb); + + shd_set_abstraction_body(bb, infer(&body_context, node->payload.basic_block.body, NULL)); + return bb; +} + +static const Node* infer_primop(Context* ctx, const Node* node, const Node* expected_type) { + assert(node->tag == PrimOp_TAG); + IrArena* a = ctx->rewriter.dst_arena; + + for (size_t i = 0; i < node->payload.prim_op.type_arguments.count; i++) + assert(node->payload.prim_op.type_arguments.nodes[i] && is_type(node->payload.prim_op.type_arguments.nodes[i])); + for (size_t i = 0; i < node->payload.prim_op.operands.count; i++) + assert(node->payload.prim_op.operands.nodes[i] && is_value(node->payload.prim_op.operands.nodes[i])); + + Nodes old_type_args = node->payload.prim_op.type_arguments; + Nodes type_args = infer_nodes(ctx, old_type_args); + Nodes old_operands = node->payload.prim_op.operands; + + BodyBuilder* bb = shd_bld_begin_pure(a); + Op op = node->payload.prim_op.op; + LARRAY(const Node*, new_operands, old_operands.count); + Nodes input_types = shd_empty(a); + switch (node->payload.prim_op.op) { + case reinterpret_op: + case convert_op: { + new_operands[0] = infer(ctx, old_operands.nodes[0], NULL); + const Type* src_pointer_type = shd_get_unqualified_type(new_operands[0]->type); + const Type* old_dst_pointer_type = shd_first(old_type_args); + const Type* dst_pointer_type = shd_first(type_args); + + if (shd_is_generic_ptr_type(src_pointer_type) != shd_is_generic_ptr_type(dst_pointer_type)) + op = convert_op; + + goto rebuild; + } + case empty_mask_op: + case mask_is_thread_active_op: { + input_types = mk_nodes(a, shd_as_qualified_type(mask_type(a), false), + shd_as_qualified_type(shd_uint32_type(a), false)); + break; + } + default: { + for (size_t i = 0; i < old_operands.count; i++) { + new_operands[i] = old_operands.nodes[i] ? infer(ctx, old_operands.nodes[i], NULL) : NULL; + } + goto rebuild; + } + } + + assert(input_types.count == old_operands.count); + for (size_t i = 0; i < input_types.count; i++) + new_operands[i] = infer(ctx, old_operands.nodes[i], input_types.nodes[i]); + + rebuild: { + const Node* new_instruction = prim_op(a, (PrimOp) { + .op = op, + .type_arguments = type_args, + .operands = shd_nodes(a, old_operands.count, new_operands) + }); + return shd_bld_to_instr_with_last_instr(bb, new_instruction); + } +} + +static const Node* infer_indirect_call(Context* ctx, const Node* node, const Node* expected_type) { + assert(node->tag == Call_TAG); + IrArena* a = ctx->rewriter.dst_arena; + + const Node* new_callee = infer(ctx, node->payload.call.callee, NULL); + assert(is_value(new_callee)); + LARRAY(const Node*, new_args, node->payload.call.args.count); + + const Type* callee_type = shd_get_unqualified_type(new_callee->type); + if (callee_type->tag != PtrType_TAG) + shd_error("functions are called through function pointers"); + callee_type = callee_type->payload.ptr_type.pointed_type; + + if (callee_type->tag != FnType_TAG) + shd_error("Callees must have a function type"); + if (callee_type->payload.fn_type.param_types.count != node->payload.call.args.count) + shd_error("Mismatched argument counts"); + for (size_t i = 0; i < node->payload.call.args.count; i++) { + const Node* arg = node->payload.call.args.nodes[i]; + assert(arg); + new_args[i] = infer(ctx, node->payload.call.args.nodes[i], callee_type->payload.fn_type.param_types.nodes[i]); + assert(new_args[i]->type); + } + + return call(a, (Call) { + .callee = new_callee, + .args = shd_nodes(a, node->payload.call.args.count, new_args), + .mem = infer(ctx, node->payload.if_instr.mem, NULL), + }); +} + +static const Node* infer_if(Context* ctx, const Node* node) { + assert(node->tag == If_TAG); + IrArena* a = ctx->rewriter.dst_arena; + const Node* condition = infer(ctx, node->payload.if_instr.condition, shd_as_qualified_type(bool_type(a), false)); + + Nodes join_types = infer_nodes(ctx, node->payload.if_instr.yield_types); + Context infer_if_body_ctx = *ctx; + // When we infer the types of the arguments to a call to merge(), they are expected to be varying + Nodes expected_join_types = shd_add_qualifiers(a, join_types, false); + + const Node* true_body = infer_case(&infer_if_body_ctx, node->payload.if_instr.if_true, shd_nodes(a, 0, NULL)); + // don't allow seeing the variables made available in the true branch + infer_if_body_ctx.rewriter = ctx->rewriter; + const Node* false_body = node->payload.if_instr.if_false ? infer_case(&infer_if_body_ctx, node->payload.if_instr.if_false, shd_nodes(a, 0, NULL)) : NULL; + + return if_instr(a, (If) { + .yield_types = join_types, + .condition = condition, + .if_true = true_body, + .if_false = false_body, + //.tail = infer_case(ctx, node->payload.if_instr.tail, expected_join_types) + .tail = infer(ctx, node->payload.if_instr.tail, NULL), + .mem = infer(ctx, node->payload.if_instr.mem, NULL), + }); +} + +static const Node* infer_loop(Context* ctx, const Node* node) { + assert(node->tag == Loop_TAG); + IrArena* a = ctx->rewriter.dst_arena; + Context loop_body_ctx = *ctx; + const Node* old_body = node->payload.loop_instr.body; + + Nodes old_params = get_abstraction_params(old_body); + Nodes old_params_types = shd_get_param_types(a, old_params); + Nodes new_params_types = infer_nodes(ctx, old_params_types); + new_params_types = annotate_all_types(a, new_params_types, false); + + Nodes old_initial_args = node->payload.loop_instr.initial_args; + LARRAY(const Node*, new_initial_args, old_params.count); + for (size_t i = 0; i < old_params.count; i++) + new_initial_args[i] = infer(ctx, old_initial_args.nodes[i], new_params_types.nodes[i]); + + Nodes loop_yield_types = infer_nodes(ctx, node->payload.loop_instr.yield_types); + Nodes qual_yield_types = shd_add_qualifiers(a, loop_yield_types, false); + + const Node* nbody = infer_case(&loop_body_ctx, old_body, new_params_types); + // TODO check new body params match continue types + + return loop_instr(a, (Loop) { + .yield_types = loop_yield_types, + .initial_args = shd_nodes(a, old_params.count, new_initial_args), + .body = nbody, + //.tail = infer_case(ctx, node->payload.loop_instr.tail, qual_yield_types) + .tail = infer(ctx, node->payload.loop_instr.tail, NULL), + .mem = infer(ctx, node->payload.if_instr.mem, NULL), + }); +} + +static const Node* infer_control(Context* ctx, const Node* node) { + assert(node->tag == Control_TAG); + IrArena* a = ctx->rewriter.dst_arena; + + Nodes yield_types = infer_nodes(ctx, node->payload.control.yield_types); + + const Node* olam = node->payload.control.inside; + const Node* ojp = shd_first(get_abstraction_params(olam)); + + Context joinable_ctx = *ctx; + const Type* jpt = join_point_type(a, (JoinPointType) { + .yield_types = yield_types + }); + jpt = qualified_type(a, (QualifiedType) { .is_uniform = true, .type = jpt }); + const Node* jp = param(a, jpt, ojp->payload.param.name); + shd_register_processed(&joinable_ctx.rewriter, ojp, jp); + + Node* new_case = basic_block(a, shd_singleton(jp), NULL); + shd_register_processed(&joinable_ctx.rewriter, olam, new_case); + shd_set_abstraction_body(new_case, infer(&joinable_ctx, get_abstraction_body(olam), NULL)); + + return control(a, (Control) { + .yield_types = yield_types, + .inside = new_case, + .tail = infer(ctx, get_structured_construct_tail(node), NULL /*add_qualifiers(a, yield_types, false)*/), + .mem = infer(ctx, node->payload.if_instr.mem, NULL), + }); +} + +static const Node* infer_instruction(Context* ctx, const Node* node, const Type* expected_type) { + IrArena* a = ctx->rewriter.dst_arena; + switch (is_instruction(node)) { + case PrimOp_TAG: return infer_primop(ctx, node, expected_type); + case Call_TAG: return infer_indirect_call(ctx, node, expected_type); + case Instruction_Comment_TAG: return shd_recreate_node(&ctx->rewriter, node); + case Instruction_Load_TAG: { + return load(a, (Load) { .ptr = infer(ctx, node->payload.load.ptr, NULL), .mem = infer(ctx, node->payload.load.mem, NULL) }); + } + case Instruction_Store_TAG: { + Store payload = node->payload.store; + const Node* ptr = infer(ctx, payload.ptr, NULL); + const Type* ptr_type = shd_get_unqualified_type(ptr->type); + assert(ptr_type->tag == PtrType_TAG); + const Type* element_t = ptr_type->payload.ptr_type.pointed_type; + assert(element_t); + const Node* value = infer(ctx, payload.value, shd_as_qualified_type(element_t, false)); + return store(a, (Store) { .ptr = ptr, .value = value, .mem = infer(ctx, node->payload.store.mem, NULL) }); + } + case Instruction_StackAlloc_TAG: { + const Type* element_type = node->payload.stack_alloc.type; + assert(is_type(element_type)); + assert(shd_is_data_type(element_type)); + return stack_alloc(a, (StackAlloc) { .type = infer_type(ctx, element_type), .mem = infer(ctx, node->payload.stack_alloc.mem, NULL) }); + } + default: break; + case NotAnInstruction: shd_error("not an instruction"); + } + return shd_recreate_node(&ctx->rewriter, node); +} + +static const Node* infer_terminator(Context* ctx, const Node* node) { + IrArena* a = ctx->rewriter.dst_arena; + switch (is_terminator(node)) { + case NotATerminator: assert(false); + case If_TAG: return infer_if (ctx, node); + case Match_TAG: shd_error("TODO") + case Loop_TAG: return infer_loop (ctx, node); + case Control_TAG: return infer_control(ctx, node); + case Return_TAG: { + const Node* imported_fn = ctx->current_fn; + Nodes return_types = imported_fn->payload.fun.return_types; + + Return payload = node->payload.fn_ret; + LARRAY(const Node*, nvalues, payload.args.count); + for (size_t i = 0; i < payload.args.count; i++) + nvalues[i] = infer(ctx, payload.args.nodes[i], return_types.nodes[i]); + return fn_ret(a, (Return) { + .args = shd_nodes(a, payload.args.count, nvalues), + .mem = infer(ctx, payload.mem, NULL), + }); + } + default: break; + } + return shd_recreate_node(&ctx->rewriter, node); +} + +static const Node* process(Context* src_ctx, const Node* node) { + const Node* expected_type = src_ctx->expected_type; + Context ctx = *src_ctx; + ctx.expected_type = NULL; + + IrArena* a = ctx.rewriter.dst_arena; + + if (is_type(node)) { + assert(expected_type == NULL); + return infer_type(&ctx, node); + } else if (is_instruction(node)) { + if (expected_type) { + return infer_instruction(&ctx, node, expected_type); + } + return infer_instruction(&ctx, node, NULL); + } else if (is_value(node)) { + const Node* value = infer_value(&ctx, node, expected_type); + assert(shd_is_value_type(value->type)); + return value; + } else if (is_terminator(node)) { + assert(expected_type == NULL); + return infer_terminator(&ctx, node); + } else if (is_declaration(node)) { + return infer_decl(&ctx, node); + } else if (is_annotation(node)) { + assert(expected_type == NULL); + return infer_annotation(&ctx, node); + } else if (is_basic_block(node)) { + return _infer_basic_block(&ctx, node); + }else if (is_mem(node)) { + return shd_recreate_node(&ctx.rewriter, node); + } + assert(false); +} + +Module* slim_pass_infer(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + assert(!aconfig.check_types); + aconfig.check_types = true; + aconfig.allow_fold = true; // TODO was moved here because a refactor, does this cause issues ? + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + }; + //ctx.rewriter.config.search_map = false; + //ctx.rewriter.config.write_map = false; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/frontend/slim/normalize.c b/src/frontend/slim/normalize.c new file mode 100644 index 000000000..95205462f --- /dev/null +++ b/src/frontend/slim/normalize.c @@ -0,0 +1,126 @@ +#include "shady/pass.h" + +#include "log.h" +#include "portability.h" +#include "dict.h" + +#include + +typedef struct Context_ { + Rewriter rewriter; +} Context; + +static const Node* process_node(Context* ctx, const Node* node); + +static const Node* force_to_be_value(Context* ctx, const Node* node) { + if (node == NULL) return NULL; + IrArena* a = ctx->rewriter.dst_arena; + + switch (node->tag) { + // All decls map to refdecl/fnaddr + case Constant_TAG: + case GlobalVariable_TAG: { + return ref_decl_helper(a, process_node(ctx, node)); + } + case Function_TAG: { + return fn_addr_helper(a, process_node(ctx, node)); + } + case Param_TAG: return shd_find_processed(&ctx->rewriter, node); + default: + break; + } + + assert(is_value(node)); + const Node* value = process_node(ctx, node); + assert(is_value(value)); + return value; +} + +static const Node* process_op(Context* ctx, NodeClass op_class, SHADY_UNUSED String op_name, const Node* node) { + if (node == NULL) return NULL; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (op_class) { + case NcType: { + switch (node->tag) { + case NominalType_TAG: { + return type_decl_ref(ctx->rewriter.dst_arena, (TypeDeclRef) { + .decl = process_node(ctx, node), + }); + } + default: break; + } + assert(is_type(node)); + const Node* type = process_node(ctx, node); + assert(is_type(type)); + return type; + } + case NcValue: + return force_to_be_value(ctx, node); + case NcParam: + break; + case NcInstruction: { + if (is_instruction(node)) { + const Node* new = process_node(ctx, node); + //register_processed(r, node, new); + return new; + } + const Node* val = force_to_be_value(ctx, node); + return val; + } + case NcTerminator: + break; + case NcDeclaration: + break; + case NcBasic_block: + break; + case NcAnnotation: + break; + case NcJump: + break; + case NcStructured_construct: + break; + } + return process_node(ctx, node); +} + +static const Node* process_node(Context* ctx, const Node* node) { + const Node** already_done = shd_search_processed(&ctx->rewriter, node); + if (already_done) + return *already_done; + + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + + // add a builder to each abstraction... + switch (node->tag) { + // case Let_TAG: { + // const Node* ninstr = rewrite_op(r, NcInstruction, "instruction", get_let_instruction(node)); + // register_processed(r, get_let_instruction(node), ninstr); + // return let(a, ninstr, rewrite_op(r, NcTerminator, "in", node->payload.let.in)); + // } + default: break; + } + + const Node* new = shd_recreate_node(&ctx->rewriter, node); + if (is_instruction(new)) + shd_register_processed(r, node, new); + return new; +} + +Module* slim_pass_normalize(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + aconfig.check_op_classes = true; + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_op_rewriter(src, dst, (RewriteOpFn) process_op), + }; + + ctx.rewriter.config.search_map = false; + ctx.rewriter.config.write_map = false; + + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/frontend/slim/parser.c b/src/frontend/slim/parser.c new file mode 100644 index 000000000..ae5af81a0 --- /dev/null +++ b/src/frontend/slim/parser.c @@ -0,0 +1,1293 @@ +#include "parser.h" +#include "token.h" +#include "SlimFrontendOps.h" + +#include "shady/ir/ext.h" + +#include "list.h" +#include "portability.h" +#include "log.h" +#include "util.h" + +#include "ir_private.h" + +#include +#include +#include +#include + +typedef enum DivergenceQualifier_ { + Unknown, + Uniform, + Varying +} DivergenceQualifier; + +static int max_precedence() { + return 10; +} + +static int get_precedence(InfixOperators op) { + switch (op) { +#define INFIX_OPERATOR(name, token, primop_op, precedence) case Infix##name: return precedence; +INFIX_OPERATORS() +#undef INFIX_OPERATOR + default: shd_error("unknown operator"); + } +} +static bool is_primop_op(InfixOperators op, Op* out) { + switch (op) { +#define INFIX_OPERATOR(name, token, primop_op, precedence) case Infix##name: if (primop_op != -1) { *out = primop_op; return true; } else return false; +INFIX_OPERATORS() +#undef INFIX_OPERATOR + default: shd_error("unknown operator"); + } +} + +static bool is_infix_operator(TokenTag token_tag, InfixOperators* out) { + switch (token_tag) { +#define INFIX_OPERATOR(name, token, primop_op, precedence) case token: { *out = Infix##name; return true; } +INFIX_OPERATORS() +#undef INFIX_OPERATOR + default: return false; + } +} + +// to avoid some repetition +#define ctxparams SHADY_UNUSED const SlimParserConfig* config, SHADY_UNUSED const char* contents, SHADY_UNUSED Module* mod, SHADY_UNUSED IrArena* arena, SHADY_UNUSED Tokenizer* tokenizer +#define ctx config, contents, mod, arena, tokenizer + +static void error_with_loc(ctxparams) { + Loc loc = shd_current_loc(tokenizer); + size_t startline = loc.line - 2; + if (startline < 1) startline = 1; + size_t endline = startline + 5; + + int numdigits = 1; + int e = endline; + while (e >= 10) { + numdigits++; + e /= 10; + } + LARRAY(char, digits, numdigits); + // char* digits = malloc(sizeof(char) * numdigits); + + size_t line = 1; + size_t len = strlen(contents); + for (size_t i = 0; i < len; i++) { + if (line >= startline && line <= endline) { + shd_log_fmt(ERROR, "%c", contents[i]); + } + if (contents[i] == '\n') { + if (line == loc.line) { + for (size_t digit = 0; digit < numdigits; digit++) { + shd_log_fmt(ERROR, " "); + } + shd_log_fmt(ERROR, " "); + for (size_t j = 1; j < loc.column; j++) { + shd_log_fmt(ERROR, " "); + } + shd_log_fmt(ERROR, "^"); + shd_log_fmt(ERROR, "\n"); + } + line++; + + if (line >= startline && line <= endline) { + size_t l = line, digit; + for (digit = 0; digit < numdigits; digit++) { + if (l == 0) + break; + digits[numdigits - 1 - digit] = (char) ('0' + (l % 10)); + l /= 10; + } + for (; digit < numdigits; digit++) { + digits[numdigits - 1 - digit] = (char) ' '; + } + for (digit = 0; digit < numdigits; digit++) { + shd_log_fmt(ERROR, "%c", digits[numdigits - 1 - digit]); + } + shd_log_fmt(ERROR, ": "); + } + } + } + shd_log_fmt(ERROR, "At %d:%d, ", loc.line, loc.column); +} + +#define syntax_error(condition) syntax_error_impl(ctx, condition) +#define syntax_error_fmt(condition, ...) syntax_error_impl(ctx, condition, __VA_ARGS__) +static void syntax_error_impl(ctxparams, const char* format, ...) { + va_list args; + va_start(args, format); + error_with_loc(ctx); + shd_log_fmt_va_list(ERROR, format, args); + shd_log_fmt(ERROR, "\n"); + exit(-4); + va_end(args); +} + +#define expect(condition, format, ...) expect_impl(ctx, condition, format) +#define expect_fmt(condition, format, ...) expect_impl(ctx, condition, format, __VA_ARGS__) +static void expect_impl(ctxparams, bool condition, const char* format, ...) { + if (!condition) { + va_list args; + va_start(args, format); + error_with_loc(ctx); + shd_log_fmt(ERROR, "expected "); + shd_log_fmt_va_list(ERROR, format, args); + shd_log_fmt(ERROR, "\n"); + exit(-4); + va_end(args); + } +} + +static bool accept_token(ctxparams, TokenTag tag) { + if (shd_curr_token(tokenizer).tag == tag) { + shd_next_token(tokenizer); + return true; + } + return false; +} + +static const char* accept_identifier(ctxparams) { + Token tok = shd_curr_token(tokenizer); + if (tok.tag == identifier_tok) { + shd_next_token(tokenizer); + size_t size = tok.end - tok.start; + return shd_string_sized(arena, (int) size, &contents[tok.start]); + } + return NULL; +} + +static const Node* expect_body(ctxparams, const Node* mem, const Node* default_terminator(const Node*)); +static const Node* accept_value(ctxparams, BodyBuilder*); +static const Type* accept_unqualified_type(ctxparams); +static const Node* accept_expr(ctxparams, BodyBuilder*, int); +static Nodes expect_operands(ctxparams, BodyBuilder*); +static const Node* expect_operand(ctxparams, BodyBuilder*); +static const Type* accept_qualified_type(ctxparams); + +static const Type* accept_numerical_type(ctxparams) { + if (accept_token(ctx, i8_tok)) { + return shd_int8_type(arena); + } else if (accept_token(ctx, i16_tok)) { + return shd_int16_type(arena); + } else if (accept_token(ctx, i32_tok)) { + return shd_int32_type(arena); + } else if (accept_token(ctx, i64_tok)) { + return shd_int64_type(arena); + } else if (accept_token(ctx, u8_tok)) { + return shd_uint8_type(arena); + } else if (accept_token(ctx, u16_tok)) { + return shd_uint16_type(arena); + } else if (accept_token(ctx, u32_tok)) { + return shd_uint32_type(arena); + } else if (accept_token(ctx, u64_tok)) { + return shd_uint64_type(arena); + } else if (accept_token(ctx, f16_tok)) { + return shd_fp16_type(arena); + } else if (accept_token(ctx, f32_tok)) { + return shd_fp32_type(arena); + } else if (accept_token(ctx, f64_tok)) { + return shd_fp64_type(arena); + } + return NULL; +} + +static const Node* accept_numerical_literal(ctxparams) { + const Type* num_type = accept_numerical_type(ctx); + + bool negate = accept_token(ctx, minus_tok); + + Token tok = shd_curr_token(tokenizer); + size_t size = tok.end - tok.start; + String str = shd_string_sized(arena, (int) size, &contents[tok.start]); + + switch (tok.tag) { + case hex_lit_tok: + if (negate) + syntax_error("hexadecimal literals can't start with '-'"); + case dec_lit_tok: { + shd_next_token(tokenizer); + break; + } + default: { + if (negate || num_type) + syntax_error("expected numerical literal"); + return NULL; + } + } + + if (negate) // add back the - in front + str = shd_format_string_arena(arena->arena, "-%s", str); + + const Node* n = untyped_number(arena, (UntypedNumber) { + .plaintext = str + }); + + if (num_type) + n = constrained(arena, (ConstrainedValue) { + .type = num_type, + .value = n + }); + + return n; +} + +static Nodes accept_type_arguments(ctxparams) { + Nodes ty_args = shd_empty(arena); + if (accept_token(ctx, lsbracket_tok)) { + while (true) { + const Type* t = accept_unqualified_type(ctx); + expect(t, "unqualified type"); + ty_args = shd_nodes_append(arena, ty_args, t); + if (accept_token(ctx, comma_tok)) + continue; + if (accept_token(ctx, rsbracket_tok)) + break; + } + } + return ty_args; +} + +static const Node* make_unbound(IrArena* a, const Node* mem, String identifier) { + return ext_instr(a, (ExtInstr) { + .mem = mem, + .set = "shady.frontend", + .opcode = SlimFrontendOpsSlimUnboundSHADY, + .result_t = unit_type(a), + .operands = shd_singleton(string_lit_helper(a, identifier)), + }); +} + +static const Node* accept_value(ctxparams, BodyBuilder* bb) { + Token tok = shd_curr_token(tokenizer); + size_t size = tok.end - tok.start; + + const Node* number = accept_numerical_literal(ctx); + if (number) + return number; + + switch (tok.tag) { + case identifier_tok: { + const char* id = shd_string_sized(arena, (int) size, &contents[tok.start]); + shd_next_token(tokenizer); + + Op op = PRIMOPS_COUNT; + for (size_t i = 0; i < PRIMOPS_COUNT; i++) { + if (strcmp(id, shd_get_primop_name(i)) == 0) { + op = i; + break; + } + } + + if (op != PRIMOPS_COUNT) { + if (!bb) + syntax_error("primops cannot be used outside of a function"); + return shd_bld_add_instruction(bb, prim_op(arena, (PrimOp) { + .op = op, + .type_arguments = accept_type_arguments(ctx), + .operands = expect_operands(ctx, bb) + })); + } else if (strcmp(id, "ext_instr") == 0) { + expect(accept_token(ctx, lsbracket_tok), "'['"); + const Node* set = accept_value(ctx, NULL); + expect(set->tag == StringLiteral_TAG, "string literal"); + expect(accept_token(ctx, comma_tok), "','"); + const Node* opcode = accept_value(ctx, NULL); + expect(opcode->tag == UntypedNumber_TAG, "number"); + expect(accept_token(ctx, comma_tok), "','"); + const Type* type = accept_qualified_type(ctx); + expect(type, "type"); + expect(accept_token(ctx, rsbracket_tok), "]"); + Nodes ops = expect_operands(ctx, bb); + return shd_bld_add_instruction(bb, ext_instr(arena, (ExtInstr) { + .result_t = type, + .set = set->payload.string_lit.string, + .opcode = strtoll(opcode->payload.untyped_number.plaintext, NULL, 10), + .mem = shd_bb_mem(bb), + .operands = ops, + })); + } else if (strcmp(id, "alloca") == 0) { + const Node* type = shd_first(accept_type_arguments(ctx)); + Nodes ops = expect_operands(ctx, bb); + expect(ops.count == 0, "no operands"); + return shd_bld_add_instruction(bb, stack_alloc(arena, (StackAlloc) { + .type = type, + .mem = shd_bb_mem(bb), + })); + } else if (strcmp(id, "debug_printf") == 0) { + Nodes ops = expect_operands(ctx, bb); + return shd_bld_add_instruction(bb, debug_printf(arena, (DebugPrintf) { + .string = shd_get_string_literal(arena, shd_first(ops)), + .args = shd_nodes(arena, ops.count - 1, &ops.nodes[1]), + .mem = shd_bb_mem(bb), + })); + } + + if (bb) + return shd_bld_add_instruction(bb, make_unbound(arena, shd_bb_mem(bb), id)); + return make_unbound(arena, NULL, id); + } + case hex_lit_tok: + case dec_lit_tok: { + shd_next_token(tokenizer); + return untyped_number(arena, (UntypedNumber) { + .plaintext = shd_string_sized(arena, (int) size, &contents[tok.start]) + }); + } + case string_lit_tok: { + shd_next_token(tokenizer); + char* unescaped = calloc(size + 1, 1); + size_t j = shd_apply_escape_codes(&contents[tok.start], size, unescaped); + const Node* lit = string_lit(arena, (StringLiteral) {.string = shd_string_sized(arena, (int) j, unescaped) }); + free(unescaped); + return lit; + } + case true_tok: + shd_next_token(tokenizer); return true_lit(arena); + case false_tok: + shd_next_token(tokenizer); return false_lit(arena); + case lpar_tok: { + shd_next_token(tokenizer); + if (accept_token(ctx, rpar_tok)) { + return shd_tuple_helper(arena, shd_empty(arena)); + } + const Node* atom = expect_operand(ctx, bb); + if (shd_curr_token(tokenizer).tag == rpar_tok) { + shd_next_token(tokenizer); + } else { + struct List* elements = shd_new_list(const Node*); + shd_list_append(const Node*, elements, atom); + + while (!accept_token(ctx, rpar_tok)) { + expect(accept_token(ctx, comma_tok), "','"); + const Node* element = expect_operand(ctx, bb); + shd_list_append(const Node*, elements, element); + } + + Nodes tcontents = shd_nodes(arena, shd_list_count(elements), shd_read_list(const Node*, elements)); + shd_destroy_list(elements); + atom = shd_tuple_helper(arena, tcontents); + } + return atom; + } + case composite_tok: { + shd_next_token(tokenizer); + const Type* elem_type = accept_unqualified_type(ctx); + expect(elem_type, "composite data type"); + Nodes elems = expect_operands(ctx, bb); + return composite_helper(arena, elem_type, elems); + } + default: return NULL; + } +} + +static AddressSpace accept_address_space(ctxparams) { + switch (shd_curr_token(tokenizer).tag) { + case global_tok: + shd_next_token(tokenizer); return AsGlobal; + case private_tok: + shd_next_token(tokenizer); return AsPrivate; + case shared_tok: + shd_next_token(tokenizer); return AsShared; + case subgroup_tok: + shd_next_token(tokenizer); return AsSubgroup; + case generic_tok: + shd_next_token(tokenizer); return AsGeneric; + case input_tok: + shd_next_token(tokenizer); return AsInput; + case output_tok: + shd_next_token(tokenizer); return AsOutput; + case extern_tok: + shd_next_token(tokenizer); return AsExternal; + default: + break; + } + return NumAddressSpaces; +} + +static const Type* accept_unqualified_type(ctxparams) { + const Type* prim_type = accept_numerical_type(ctx); + if (prim_type) return prim_type; + else if (accept_token(ctx, bool_tok)) { + return bool_type(arena); + } else if (accept_token(ctx, mask_t_tok)) { + return mask_type(arena); + } else if (accept_token(ctx, ptr_tok)) { + AddressSpace as = accept_address_space(ctx); + expect(as != NumAddressSpaces, "address space"); + const Type* elem_type = accept_unqualified_type(ctx); + expect(elem_type, "data type"); + return ptr_type(arena, (PtrType) { + .address_space = as, + .pointed_type = elem_type, + }); + } else if (accept_token(ctx, ref_tok)) { + AddressSpace as = accept_address_space(ctx); + expect(as != NumAddressSpaces, "address space"); + const Type* elem_type = accept_unqualified_type(ctx); + expect(elem_type, "data type"); + return ptr_type(arena, (PtrType) { + .address_space = as, + .pointed_type = elem_type, + .is_reference = true, + }); + } else if (config->front_end && accept_token(ctx, lsbracket_tok)) { + const Type* elem_type = accept_unqualified_type(ctx); + expect(elem_type, "type"); + const Node* size = NULL; + if (accept_token(ctx, semi_tok)) { + size = accept_value(ctx, NULL); + expect(size, "value"); + } + expect(accept_token(ctx, rsbracket_tok), "']'"); + return arr_type(arena, (ArrType) { + .element_type = elem_type, + .size = size + }); + } else if (accept_token(ctx, pack_tok)) { + expect(accept_token(ctx, lsbracket_tok), "'['"); + const Type* elem_type = accept_unqualified_type(ctx); + expect(elem_type, "packed element type"); + const Node* size = NULL; + expect(accept_token(ctx, semi_tok), "';'"); + size = accept_numerical_literal(ctx); + expect(size && size->tag == UntypedNumber_TAG, "number"); + expect(accept_token(ctx, rsbracket_tok), "']'"); + return pack_type(arena, (PackType) { + .element_type = elem_type, + .width = strtoll(size->payload.untyped_number.plaintext, NULL, 10) + }); + } else if (accept_token(ctx, struct_tok)) { + expect(accept_token(ctx, lbracket_tok), "'{'"); + struct List* names = shd_new_list(String); + struct List* types = shd_new_list(const Type*); + while (true) { + if (accept_token(ctx, rbracket_tok)) + break; + const Type* elem = accept_unqualified_type(ctx); + expect(elem, "struct member type"); + String id = accept_identifier(ctx); + expect(id, "struct member name"); + shd_list_append(String, names, id); + shd_list_append(const Type*, types, elem); + expect(accept_token(ctx, semi_tok), "';'"); + } + Nodes elem_types = shd_nodes(arena, shd_list_count(types), shd_read_list(const Type*, types)); + Strings names2 = shd_strings(arena, shd_list_count(names), shd_read_list(String, names)); + shd_destroy_list(names); + shd_destroy_list(types); + return record_type(arena, (RecordType) { + .names = names2, + .members = elem_types, + .special = NotSpecial, + }); + } else { + String id = accept_identifier(ctx); + if (id) + return make_unbound(arena, NULL, id); + + return NULL; + } +} + +static DivergenceQualifier accept_uniformity_qualifier(ctxparams) { + DivergenceQualifier divergence = Unknown; + if (accept_token(ctx, uniform_tok)) + divergence = Uniform; + else if (accept_token(ctx, varying_tok)) + divergence = Varying; + return divergence; +} + +static const Type* accept_maybe_qualified_type(ctxparams) { + DivergenceQualifier qualifier = accept_uniformity_qualifier(ctx); + const Type* unqualified = accept_unqualified_type(ctx); + if (qualifier != Unknown) + expect(unqualified, "unqualified type"); + if (qualifier == Unknown) + return unqualified; + else + return qualified_type(arena, (QualifiedType) { .is_uniform = qualifier == Uniform, .type = unqualified }); +} + +static const Type* accept_qualified_type(ctxparams) { + DivergenceQualifier qualifier = accept_uniformity_qualifier(ctx); + if (qualifier == Unknown) + return NULL; + const Type* unqualified = accept_unqualified_type(ctx); + expect(unqualified, "unqualified type"); + return qualified_type(arena, (QualifiedType) { .is_uniform = qualifier == Uniform, .type = unqualified }); +} + +static const Node* accept_operand(ctxparams, BodyBuilder* bb) { + return config->front_end ? accept_expr(ctx, bb, max_precedence()) : accept_value(ctx, bb); +} + +static const Node* expect_operand(ctxparams, BodyBuilder* bb) { + const Node* operand = accept_operand(ctx, bb); + expect(operand, "value operand"); + return operand; +} + +static void expect_parameters(ctxparams, Nodes* parameters, Nodes* default_values, BodyBuilder* bb) { + expect(accept_token(ctx, lpar_tok), "'('"); + struct List* params = shd_new_list(Node*); + struct List* default_vals = default_values ? shd_new_list(Node*) : NULL; + + while (true) { + if (accept_token(ctx, rpar_tok)) + break; + + next: { + const Type* qtype = accept_qualified_type(ctx); + expect(qtype, "qualified type"); + const char* id = accept_identifier(ctx); + expect(id, "parameter name"); + + const Node* node = param(arena, qtype, id); + shd_list_append(Node*, params, node); + + if (default_values) { + expect(accept_token(ctx, equal_tok), "'='"); + const Node* default_val = accept_operand(ctx, bb); + shd_list_append(const Node*, default_vals, default_val); + } + + if (accept_token(ctx, comma_tok)) + goto next; + } + } + + size_t count = shd_list_count(params); + *parameters = shd_nodes(arena, count, shd_read_list(const Node*, params)); + shd_destroy_list(params); + if (default_values) { + *default_values = shd_nodes(arena, count, shd_read_list(const Node*, default_vals)); + shd_destroy_list(default_vals); + } +} + +typedef enum { MustQualified, MaybeQualified, NeverQualified } Qualified; + +static Nodes accept_types(ctxparams, TokenTag separator, Qualified qualified) { + struct List* tmp = shd_new_list(Type*); + while (true) { + const Type* type; + switch (qualified) { + case MustQualified: type = accept_qualified_type(ctx); break; + case MaybeQualified: type = accept_maybe_qualified_type(ctx); break; + case NeverQualified: type = accept_unqualified_type(ctx); break; + } + if (!type) + break; + + shd_list_append(Type*, tmp, type); + + if (separator != 0) + accept_token(ctx, separator); + } + + Nodes types2 = shd_nodes(arena, tmp->elements_count, (const Type**) tmp->alloc); + shd_destroy_list(tmp); + return types2; +} + +static const Node* accept_primary_expr(ctxparams, BodyBuilder* bb) { + assert(bb); + if (accept_token(ctx, minus_tok)) { + const Node* expr = accept_primary_expr(ctx, bb); + expect(expr, "expression"); + if (expr->tag == IntLiteral_TAG) { + return int_literal(arena, (IntLiteral) { + // We always treat that value like an signed integer, because it makes no sense to negate an unsigned number ! + .value = -shd_get_int_literal_value(*shd_resolve_to_int_literal(expr), true) + }); + } else { + return shd_bld_add_instruction(bb, prim_op(arena, (PrimOp) { + .op = neg_op, + .operands = shd_nodes(arena, 1, (const Node* []) { expr }) + })); + } + } else if (accept_token(ctx, unary_excl_tok)) { + const Node* expr = accept_primary_expr(ctx, bb); + expect(expr, "expression"); + return shd_bld_add_instruction(bb, prim_op(arena, (PrimOp) { + .op = not_op, + .operands = shd_singleton(expr), + })); + } else if (accept_token(ctx, star_tok)) { + const Node* expr = accept_primary_expr(ctx, bb); + expect(expr, "expression"); + return shd_bld_add_instruction(bb, ext_instr(arena, (ExtInstr) { .set = "shady.frontend", .result_t = unit_type(arena), .opcode = SlimFrontendOpsSlimDereferenceSHADY, .operands = shd_singleton(expr), .mem = shd_bb_mem(bb) })); + } else if (accept_token(ctx, infix_and_tok)) { + const Node* expr = accept_primary_expr(ctx, bb); + expect(expr, "expression"); + return shd_bld_add_instruction(bb, ext_instr(arena, (ExtInstr) { + .set = "shady.frontend", + .result_t = unit_type(arena), + .opcode = SlimFrontendOpsSlimAddrOfSHADY, + .operands = shd_singleton(expr), + .mem = shd_bb_mem(bb), + })); + } + + return accept_value(ctx, bb); +} + +static const Node* accept_expr(ctxparams, BodyBuilder* bb, int outer_precedence) { + assert(bb); + const Node* expr = accept_primary_expr(ctx, bb); + while (expr) { + InfixOperators infix; + if (is_infix_operator(shd_curr_token(tokenizer).tag, &infix)) { + int precedence = get_precedence(infix); + if (precedence > outer_precedence) break; + shd_next_token(tokenizer); + + const Node* rhs = accept_expr(ctx, bb, precedence - 1); + expect(rhs, "expression"); + Op primop_op; + if (is_primop_op(infix, &primop_op)) { + expr = shd_bld_add_instruction(bb, prim_op(arena, (PrimOp) { + .op = primop_op, + .operands = shd_nodes(arena, 2, (const Node* []) { expr, rhs }) + })); + } else switch (infix) { + case InfixAss: { + expr = shd_bld_add_instruction(bb, ext_instr(arena, (ExtInstr) { + .set = "shady.frontend", + .opcode = SlimFrontendOpsSlimAssignSHADY, + .result_t = unit_type(arena), + .operands = shd_nodes(arena, 2, (const Node* []) { expr, rhs }), + .mem = shd_bb_mem(bb), + })); + break; + } + case InfixSbs: { + expr = shd_bld_add_instruction(bb, ext_instr(arena, (ExtInstr) { + .set = "shady.frontend", + .opcode = SlimFrontendOpsSlimSubscriptSHADY, + .result_t = unit_type(arena), + .operands = shd_nodes(arena, 2, (const Node* []) { expr, rhs }), + .mem = shd_bb_mem(bb), + })); + break; + } + default: syntax_error("unknown infix operator"); + } + continue; + } + + switch (shd_curr_token(tokenizer).tag) { + case lpar_tok: { + Nodes ops = expect_operands(ctx, bb); + expr = shd_bld_add_instruction(bb, call(arena, (Call) { + .callee = expr, + .args = ops, + .mem = shd_bb_mem(bb), + })); + continue; + } + default: + break; + } + + break; + } + return expr; +} + +static Nodes expect_operands(ctxparams, BodyBuilder* bb) { + expect(accept_token(ctx, lpar_tok), "'('"); + + struct List* list = shd_new_list(Node*); + + bool expect = false; + while (true) { + const Node* val = accept_operand(ctx, bb); + if (!val) { + if (expect) + syntax_error("expected value but got none"); + else if (accept_token(ctx, rpar_tok)) + break; + else + syntax_error("Expected value or ')'"); + } + + shd_list_append(Node*, list, val); + + if (accept_token(ctx, comma_tok)) + expect = true; + else if (accept_token(ctx, rpar_tok)) + break; + else + syntax_error("Expected ',' or ')'"); + } + + Nodes final = shd_nodes(arena, list->elements_count, (const Node**) list->alloc); + shd_destroy_list(list); + return final; +} + +static const Node* make_selection_merge(const Node* mem) { + IrArena* a = mem->arena; + return merge_selection(a, (MergeSelection) { .args = shd_nodes(a, 0, NULL), .mem = mem }); +} + +static const Node* make_loop_continue(const Node* mem) { + IrArena* a = mem->arena; + return merge_continue(a, (MergeContinue) { .args = shd_nodes(a, 0, NULL), .mem = mem }); +} + +static const Node* accept_control_flow_instruction(ctxparams, BodyBuilder* bb) { + Token current_token = shd_curr_token(tokenizer); + switch (current_token.tag) { + case if_tok: { + shd_next_token(tokenizer); + Nodes yield_types = accept_types(ctx, 0, NeverQualified); + expect(accept_token(ctx, lpar_tok), "'('"); + const Node* condition = accept_operand(ctx, bb); + expect(condition, "condition value"); + expect(accept_token(ctx, rpar_tok), "')'"); + const Node* (*merge)(const Node*) = config->front_end ? make_selection_merge : NULL; + + Node* true_case = case_(arena, shd_nodes(arena, 0, NULL)); + shd_set_abstraction_body(true_case, expect_body(ctx, shd_get_abstraction_mem(true_case), merge)); + + // else defaults to an empty body + bool has_else = accept_token(ctx, else_tok); + Node* false_case = NULL; + if (has_else) { + false_case = case_(arena, shd_nodes(arena, 0, NULL)); + shd_set_abstraction_body(false_case, expect_body(ctx, shd_get_abstraction_mem(false_case), merge)); + } + return shd_maybe_tuple_helper(arena, shd_bld_if(bb, yield_types, condition, true_case, false_case)); + } + case loop_tok: { + shd_next_token(tokenizer); + Nodes yield_types = accept_types(ctx, 0, NeverQualified); + Nodes parameters; + Nodes initial_arguments; + expect_parameters(ctx, ¶meters, &initial_arguments, bb); + // by default loops continue forever + const Node* (*default_loop_end_behaviour)(const Node*) = config->front_end ? make_loop_continue : NULL; + Node* loop_case = case_(arena, parameters); + shd_set_abstraction_body(loop_case, expect_body(ctx, shd_get_abstraction_mem(loop_case), default_loop_end_behaviour)); + return shd_maybe_tuple_helper(arena, shd_bld_loop(bb, yield_types, initial_arguments, loop_case)); + } + case control_tok: { + shd_next_token(tokenizer); + Nodes yield_types = accept_types(ctx, 0, NeverQualified); + expect(accept_token(ctx, lpar_tok), "'('"); + String str = accept_identifier(ctx); + expect(str, "control parameter name"); + const Node* jp = param(arena, join_point_type(arena, (JoinPointType) { + .yield_types = yield_types, + }), str); + expect(accept_token(ctx, rpar_tok), "')'"); + Node* control_case = case_(arena, shd_singleton(jp)); + shd_set_abstraction_body(control_case, expect_body(ctx, shd_get_abstraction_mem(control_case), NULL)); + return shd_maybe_tuple_helper(arena, shd_bld_control(bb, yield_types, control_case)); + } + default: break; + } + return NULL; +} + +static const Node* accept_instruction(ctxparams, BodyBuilder* bb) { + const Node* instr = accept_expr(ctx, bb, max_precedence()); + + switch(shd_curr_token(tokenizer).tag) { + case call_tok: { + shd_next_token(tokenizer); + expect(accept_token(ctx, lpar_tok), "'('"); + const Node* callee = accept_operand(ctx, bb); + expect(accept_token(ctx, rpar_tok), "')'"); + Nodes args = expect_operands(ctx, bb); + return call(arena, (Call) { + .callee = callee, + .args = args, + .mem = shd_bb_mem(bb) + }); + } + default: break; + } + + if (instr) + expect(accept_token(ctx, semi_tok), "';'"); + + if (!instr) instr = accept_control_flow_instruction(ctx, bb); + return instr; +} + +static void expect_identifiers(ctxparams, Strings* out_strings) { + struct List* list = shd_new_list(const char*); + while (true) { + const char* id = accept_identifier(ctx); + expect(id, "identifier"); + + shd_list_append(const char*, list, id); + + if (accept_token(ctx, comma_tok)) + continue; + else + break; + } + + *out_strings = shd_strings(arena, list->elements_count, (const char**) list->alloc); + shd_destroy_list(list); +} + +static void expect_types_and_identifiers(ctxparams, Strings* out_strings, Nodes* out_types) { + struct List* slist = shd_new_list(const char*); + struct List* tlist = shd_new_list(const char*); + + while (true) { + const Type* type = accept_unqualified_type(ctx); + expect(type, "type"); + const char* id = accept_identifier(ctx); + expect(id, "identifier"); + + shd_list_append(const char*, tlist, type); + shd_list_append(const char*, slist, id); + + if (accept_token(ctx, comma_tok)) + continue; + else + break; + } + + *out_strings = shd_strings(arena, slist->elements_count, (const char**) slist->alloc); + *out_types = shd_nodes(arena, tlist->elements_count, (const Node**) tlist->alloc); + shd_destroy_list(slist); + shd_destroy_list(tlist); +} + +static Nodes strings2nodes(IrArena* a, Strings strings) { + LARRAY(const Node*, arr, strings.count); + for (size_t i = 0; i < strings.count; i++) + arr[i] = string_lit_helper(a, strings.strings[i]); + return shd_nodes(a, strings.count, arr); +} + +static bool accept_statement(ctxparams, BodyBuilder* bb) { + Strings ids; + if (accept_token(ctx, val_tok)) { + expect_identifiers(ctx, &ids); + expect(accept_token(ctx, equal_tok), "'='"); + const Node* instruction = accept_instruction(ctx, bb); + shd_bld_ext_instruction(bb, "shady.frontend", SlimFrontendOpsSlimBindValSHADY, unit_type(arena), shd_nodes_prepend(arena, strings2nodes(arena, ids), instruction)); + } else if (accept_token(ctx, var_tok)) { + Nodes types; + expect_types_and_identifiers(ctx, &ids, &types); + expect(accept_token(ctx, equal_tok), "'='"); + const Node* instruction = accept_instruction(ctx, bb); + shd_bld_ext_instruction(bb, "shady.frontend", SlimFrontendOpsSlimBindVarSHADY, unit_type(arena), shd_nodes_prepend(arena, shd_concat_nodes(arena, strings2nodes(arena, ids), types), instruction)); + } else { + const Node* instr = accept_instruction(ctx, bb); + if (!instr) return false; + //bind_instruction_outputs_count(bb, instr, 0); + } + return true; +} + +static const Node* expect_jump(ctxparams, BodyBuilder* bb) { + String target = accept_identifier(ctx); + expect(target, "jump target name"); + Nodes args = expect_operands(ctx, bb); + const Node* tgt = make_unbound(arena, shd_bb_mem(bb), target); + shd_bld_add_instruction(bb, tgt); + return jump(arena, (Jump) { + .target = tgt, + .args = args, + .mem = shd_bb_mem(bb) + }); +} + +static const Node* accept_terminator(ctxparams, BodyBuilder* bb) { + TokenTag tag = shd_curr_token(tokenizer).tag; + switch (tag) { + case jump_tok: { + shd_next_token(tokenizer); + return expect_jump(ctx, bb); + } + case branch_tok: { + shd_next_token(tokenizer); + + expect(accept_token(ctx, lpar_tok), "'('"); + const Node* condition = accept_value(ctx, bb); + expect(condition, "branch condition value"); + expect(accept_token(ctx, comma_tok), "','"); + const Node* true_target = expect_jump(ctx, bb); + expect(accept_token(ctx, comma_tok), "','"); + const Node* false_target = expect_jump(ctx, bb); + expect(accept_token(ctx, rpar_tok), "')'"); + + return branch(arena, (Branch) { + .condition = condition, + .true_jump = true_target, + .false_jump = false_target, + .mem = shd_bb_mem(bb) + }); + } + case switch_tok: { + shd_next_token(tokenizer); + + expect(accept_token(ctx, lpar_tok), "'('"); + const Node* inspectee = accept_value(ctx, bb); + expect(inspectee, "value"); + expect(accept_token(ctx, comma_tok), "','"); + Nodes values = shd_empty(arena); + Nodes cases = shd_empty(arena); + const Node* default_jump; + while (true) { + if (accept_token(ctx, default_tok)) { + default_jump = expect_jump(ctx, bb); + break; + } + expect(accept_token(ctx, case_tok), "'case'"); + const Node* value = accept_value(ctx, bb); + expect(value, "case value"); + expect(accept_token(ctx, comma_tok), "','"); + const Node* j = expect_jump(ctx, bb); + expect(accept_token(ctx, comma_tok), "','"); + values = shd_nodes_append(arena, values, value); + cases = shd_nodes_append(arena, cases, j); + } + expect(accept_token(ctx, rpar_tok), "')'"); + + return br_switch(arena, (Switch) { + .switch_value = shd_first(values), + .case_values = values, + .case_jumps = cases, + .default_jump = default_jump, + .mem = shd_bb_mem(bb) + }); + } + case return_tok: { + shd_next_token(tokenizer); + Nodes args = expect_operands(ctx, bb); + return fn_ret(arena, (Return) { + .args = args, + .mem = shd_bb_mem(bb) + }); + } + case merge_selection_tok: { + shd_next_token(tokenizer); + Nodes args = shd_curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx, bb) : shd_nodes(arena, 0, NULL); + return merge_selection(arena, (MergeSelection) { + .args = args, + .mem = shd_bb_mem(bb) + }); + } + case continue_tok: { + shd_next_token(tokenizer); + Nodes args = shd_curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx, bb) : shd_nodes(arena, 0, NULL); + return merge_continue(arena, (MergeContinue) { + .args = args, + .mem = shd_bb_mem(bb) + }); + } + case break_tok: { + shd_next_token(tokenizer); + Nodes args = shd_curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx, bb) : shd_nodes(arena, 0, NULL); + return merge_break(arena, (MergeBreak) { + .args = args, + .mem = shd_bb_mem(bb) + }); + } + case join_tok: { + shd_next_token(tokenizer); + expect(accept_token(ctx, lpar_tok), "'('"); + const Node* jp = accept_operand(ctx, bb); + expect(accept_token(ctx, rpar_tok), "')'"); + Nodes args = expect_operands(ctx, bb); + return join(arena, (Join) { + .join_point = jp, + .args = args, + .mem = shd_bb_mem(bb) + }); + } + case tailcall_tok: { + shd_next_token(tokenizer); + expect(accept_token(ctx, lpar_tok), "'('"); + const Node* callee = accept_operand(ctx, bb); + expect(accept_token(ctx, rpar_tok), "')'"); + Nodes args = expect_operands(ctx, bb); + return tail_call(arena, (TailCall) { + .callee = callee, + .args = args, + .mem = shd_bb_mem(bb) + }); + } + case unreachable_tok: { + shd_next_token(tokenizer); + expect(accept_token(ctx, lpar_tok), "'('"); + expect(accept_token(ctx, rpar_tok), "')'"); + return unreachable(arena, (Unreachable) { .mem = shd_bb_mem(bb) }); + } + default: break; + } + return NULL; +} + +static const Node* expect_body(ctxparams, const Node* mem, const Node* default_terminator(const Node*)) { + expect(accept_token(ctx, lbracket_tok), "'['"); + BodyBuilder* bb = shd_bld_begin(arena, mem); + + while (true) { + if (!accept_statement(ctx, bb)) + break; + } + + Node* terminator_case = case_(arena, shd_empty(arena)); + BodyBuilder* terminator_bb = shd_bld_begin(arena, shd_get_abstraction_mem(terminator_case)); + const Node* terminator = accept_terminator(ctx, terminator_bb); + + if (terminator) + expect(accept_token(ctx, semi_tok), "';'"); + + if (!terminator) { + if (default_terminator) + terminator = default_terminator(shd_bb_mem(terminator_bb)); + else + syntax_error("expected terminator: return, jump, branch ..."); + } + + shd_set_abstraction_body(terminator_case, shd_bld_finish(terminator_bb, terminator)); + + Node* cont_wrapper_case = case_(arena, shd_empty(arena)); + BodyBuilder* cont_wrapper_bb = shd_bld_begin(arena, shd_get_abstraction_mem(cont_wrapper_case)); + + Nodes ids = shd_empty(arena); + Nodes conts = shd_empty(arena); + if (shd_curr_token(tokenizer).tag == cont_tok) { + while (true) { + if (!accept_token(ctx, cont_tok)) + break; + const char* name = accept_identifier(ctx); + + Nodes parameters; + expect_parameters(ctx, ¶meters, NULL, bb); + Node* continuation = basic_block(arena, parameters, name); + shd_set_abstraction_body(continuation, expect_body(ctx, shd_get_abstraction_mem(continuation), NULL)); + ids = shd_nodes_append(arena, ids, string_lit_helper(arena, name)); + conts = shd_nodes_append(arena, conts, continuation); + } + } + + shd_bld_ext_instruction(cont_wrapper_bb, "shady.frontend", SlimFrontendOpsSlimBindContinuationsSHADY, unit_type(arena), shd_concat_nodes(arena, ids, conts)); + expect(accept_token(ctx, rbracket_tok), "']'"); + + shd_set_abstraction_body(cont_wrapper_case, shd_bld_jump(cont_wrapper_bb, terminator_case, shd_empty(arena))); + return shd_bld_jump(bb, cont_wrapper_case, shd_empty(arena)); +} + +static Nodes accept_annotations(ctxparams) { + struct List* list = shd_new_list(const Node*); + + while (true) { + if (accept_token(ctx, at_tok)) { + const char* id = accept_identifier(ctx); + const Node* annot = NULL; + if (accept_token(ctx, lpar_tok)) { + const Node* first_value = accept_value(ctx, NULL); + if (!first_value) { + expect(accept_token(ctx, rpar_tok), "value"); + goto no_params; + } + + // TODO: AnnotationCompound ? + if (shd_curr_token(tokenizer).tag == comma_tok) { + shd_next_token(tokenizer); + struct List* values = shd_new_list(const Node*); + shd_list_append(const Node*, values, first_value); + while (true) { + const Node* next_value = accept_value(ctx, NULL); + expect(next_value, "value"); + shd_list_append(const Node*, values, next_value); + if (accept_token(ctx, comma_tok)) + continue; + else break; + } + annot = annotation_values(arena, (AnnotationValues) { + .name = id, + .values = shd_nodes(arena, shd_list_count(values), shd_read_list(const Node*, values)) + }); + shd_destroy_list(values); + } else { + annot = annotation_value(arena, (AnnotationValue) { + .name = id, + .value = first_value + }); + } + + expect(accept_token(ctx, rpar_tok), "')'"); + } else { + no_params: + annot = annotation(arena, (Annotation) { + .name = id, + }); + } + expect(annot, "annotation"); + shd_list_append(const Node*, list, annot); + continue; + } + break; + } + + Nodes annotations = shd_nodes(arena, shd_list_count(list), shd_read_list(const Node*, list)); + shd_destroy_list(list); + return annotations; +} + +static const Node* accept_const(ctxparams, Nodes annotations) { + if (!accept_token(ctx, const_tok)) + return NULL; + + const Type* type = accept_unqualified_type(ctx); + const char* id = accept_identifier(ctx); + expect(id, "constant name"); + expect(accept_token(ctx, equal_tok), "'='"); + BodyBuilder* bb = shd_bld_begin_pure(arena); + const Node* definition = accept_expr(ctx, bb, max_precedence()); + expect(definition, "expression"); + + expect(accept_token(ctx, semi_tok), "';'"); + + Node* cnst = constant(mod, annotations, type, id); + cnst->payload.constant.value = shd_bld_to_instr_pure_with_values(bb, shd_singleton(definition)); + return cnst; +} + +static const Node* make_return_void(const Node* mem) { + IrArena* a = mem->arena; + return fn_ret(a, (Return) { .args = shd_empty(a), .mem = mem }); +} + +static const Node* accept_fn_decl(ctxparams, Nodes annotations) { + if (!accept_token(ctx, fn_tok)) + return NULL; + + const char* name = accept_identifier(ctx); + expect(name, "function name"); + Nodes types = accept_types(ctx, comma_tok, MaybeQualified); + expect(shd_curr_token(tokenizer).tag == lpar_tok, "')'"); + Nodes parameters; + expect_parameters(ctx, ¶meters, NULL, NULL); + + Node* fn = function(mod, parameters, name, annotations, types); + if (!accept_token(ctx, semi_tok)) + shd_set_abstraction_body(fn, expect_body(ctx, shd_get_abstraction_mem(fn), types.count == 0 ? make_return_void : NULL)); + + return fn; +} + +static const Node* accept_global_var_decl(ctxparams, Nodes annotations) { + if (!accept_token(ctx, var_tok)) + return NULL; + + AddressSpace as = NumAddressSpaces; + bool uniform = false, logical = false; + while (true) { + if (accept_token(ctx, logical_tok)) { + logical = true; + continue; + } + if (accept_token(ctx, uniform_tok)) { + uniform = true; + continue; + } + AddressSpace nas = accept_address_space(ctx); + if (nas != NumAddressSpaces) { + if (as != NumAddressSpaces && as != nas) { + syntax_error_fmt("Conflicting address spaces for definition: %s and %s", shd_get_address_space_name(as), shd_get_address_space_name(nas)); + } + as = nas; + continue; + } + break; + } + + if (as == NumAddressSpaces) { + syntax_error("Address space required for global variable declaration."); + } + + if (uniform) { + if (as == AsInput) + as = AsUInput; + else { + syntax_error("'uniform' can only be used with 'input'"); + } + } + + if (logical) { + annotations = shd_nodes_append(arena, annotations, annotation(arena, (Annotation) { + .name = "Logical" + })); + } + + const Type* type = accept_unqualified_type(ctx); + expect(type, "global variable type"); + const char* id = accept_identifier(ctx); + expect(id, "global variable name"); + + const Node* initial_value = NULL; + if (accept_token(ctx, equal_tok)) { + initial_value = accept_value(ctx, NULL); + expect_fmt(initial_value, "value for global variable '%s'", id); + } + + expect(accept_token(ctx, semi_tok), "';'"); + + Node* gv = global_var(mod, annotations, type, id, as); + gv->payload.global_variable.init = initial_value; + return gv; +} + +static const Node* accept_nominal_type_decl(ctxparams, Nodes annotations) { + if (!accept_token(ctx, type_tok)) + return NULL; + + const char* id = accept_identifier(ctx); + expect(id, "nominal type name"); + + expect(accept_token(ctx, equal_tok), "'='"); + + Node* nom = nominal_type(mod, annotations, id); + nom->payload.nom_type.body = accept_unqualified_type(ctx); + expect(nom->payload.nom_type.body, "nominal type body"); + + expect(accept_token(ctx, semi_tok), "';'"); + return nom; +} + +void slim_parse_string(const SlimParserConfig* config, const char* contents, Module* mod) { + IrArena* arena = shd_module_get_arena(mod); + Tokenizer* tokenizer = shd_new_tokenizer(contents); + + while (true) { + Token token = shd_curr_token(tokenizer); + if (token.tag == EOF_tok) + break; + + Nodes annotations = accept_annotations(ctx); + + const Node* decl = accept_const(ctx, annotations); + if (!decl) decl = accept_fn_decl(ctx, annotations); + if (!decl) decl = accept_global_var_decl(ctx, annotations); + if (!decl) decl = accept_nominal_type_decl(ctx, annotations); + + if (decl) { + shd_log_fmt(DEBUGVV, "decl parsed : "); + shd_log_node(DEBUGVV, decl); + shd_log_fmt(DEBUGVV, "\n"); + continue; + } + + syntax_error("expected a declaration"); + } + + shd_destroy_tokenizer(tokenizer); +} + diff --git a/src/frontends/slim/parser.h b/src/frontend/slim/parser.h similarity index 84% rename from src/frontends/slim/parser.h rename to src/frontend/slim/parser.h index ac2f21150..7e2e37884 100644 --- a/src/frontends/slim/parser.h +++ b/src/frontend/slim/parser.h @@ -1,13 +1,9 @@ #ifndef SHADY_PARSER_H - #define SHADY_PARSER_H +#include "shady/fe/slim.h" #include "shady/ir.h" -typedef struct { - bool front_end; -} ParserConfig; - #define INFIX_OPERATORS() \ INFIX_OPERATOR(Mul, star_tok, mul_op, 1) \ INFIX_OPERATOR(Sub, minus_tok, sub_op, 2) \ @@ -26,8 +22,8 @@ INFIX_OPERATOR(Gt, infix_gt_tok, gt_op, 7) \ INFIX_OPERATOR(Ge, infix_geq_tok, gte_op, 7) \ INFIX_OPERATOR(Lt, infix_ls_tok, lt_op, 7) \ INFIX_OPERATOR(Le, infix_leq_tok, lte_op, 7) \ -INFIX_OPERATOR(Sbs, pound_tok, subscript_op, 0) \ -INFIX_OPERATOR(Ass, equal_tok, assign_op, 10) \ +INFIX_OPERATOR(Sbs, pound_tok, -1, 0) \ +INFIX_OPERATOR(Ass, equal_tok, -1, 10) \ typedef enum { #define INFIX_OPERATOR(name, token, primop, precedence) Infix##name, @@ -36,6 +32,4 @@ INFIX_OPERATORS() InfixOperatorsCount } InfixOperators; -void parse_shady_ir(ParserConfig config, const char* contents, Module* mod); - #endif diff --git a/src/frontend/slim/slim_driver.c b/src/frontend/slim/slim_driver.c new file mode 100644 index 000000000..95450e4be --- /dev/null +++ b/src/frontend/slim/slim_driver.c @@ -0,0 +1,46 @@ +#include "parser.h" + +#include "shady/pass.h" + +#include "../shady/transform/internal_constants.h" +#include "../shady/passes/passes.h" + +#include "log.h" + +/// Removes all Unresolved nodes and replaces them with the appropriate decl/value +RewritePass slim_pass_bind; +/// Enforces the grammar, notably by let-binding any intermediary result +RewritePass slim_pass_normalize; +/// Makes sure every node is well-typed +RewritePass slim_pass_infer; + +void slim_parse_string(const SlimParserConfig* config, const char* contents, Module* mod); + +Module* shd_parse_slim_module(const CompilerConfig* config, const SlimParserConfig* pconfig, const char* contents, String name) { + ArenaConfig aconfig = shd_default_arena_config(&config->target); + aconfig.name_bound = false; + aconfig.check_op_classes = false; + aconfig.check_types = false; + aconfig.validate_builtin_types = false; + aconfig.allow_fold = false; + IrArena* initial_arena = shd_new_ir_arena(&aconfig); + Module* m = shd_new_module(initial_arena, name); + slim_parse_string(pconfig, contents, m); + Module** pmod = &m; + Module* old_mod = NULL; + + shd_debugv_print("Parsed slim module:\n"); + shd_log_module(DEBUGV, config, *pmod); + + shd_generate_dummy_constants(config, *pmod); + + RUN_PASS(slim_pass_bind) + RUN_PASS(slim_pass_normalize) + + RUN_PASS(shd_pass_normalize_builtins) + RUN_PASS(slim_pass_infer) + RUN_PASS(shd_pass_lower_cf_instrs) + + shd_destroy_ir_arena(initial_arena); + return *pmod; +} diff --git a/src/frontends/slim/token.c b/src/frontend/slim/token.c similarity index 82% rename from src/frontends/slim/token.c rename to src/frontend/slim/token.c index a20c78e47..aa57e58dc 100644 --- a/src/frontends/slim/token.c +++ b/src/frontend/slim/token.c @@ -15,7 +15,7 @@ static const char* token_strings[] = { #undef TOKEN }; -const char* token_tags[] = { +static const char* token_tags[] = { #define TOKEN(name, str) #name, TOKENS() #undef TOKEN @@ -37,10 +37,12 @@ typedef struct Tokenizer_ { const size_t source_size; size_t pos; + size_t line; + size_t last_line_pos; Token current; } Tokenizer; -Tokenizer* new_tokenizer(const char* source) { +Tokenizer* shd_new_tokenizer(const char* source) { if (!constants_initialized) { init_tokenizer_constants(); constants_initialized = true; @@ -50,14 +52,15 @@ Tokenizer* new_tokenizer(const char* source) { Tokenizer tokenizer = (Tokenizer) { .source = source, .source_size = strlen(source), - .pos = 0 + .pos = 0, + .line = 1, }; memcpy(alloc, &tokenizer, sizeof(Tokenizer)); - next_token(alloc); + shd_next_token(alloc); return alloc; } -void destroy_tokenizer(Tokenizer* tokenizer) { +void shd_destroy_tokenizer(Tokenizer* tokenizer) { free(tokenizer); } @@ -65,7 +68,7 @@ static bool in_bounds(Tokenizer* tokenizer, size_t offset_to_slice) { return (tokenizer->pos + offset_to_slice) <= tokenizer->source_size; } -const char whitespace[] = { ' ', '\t', '\n', '\r' }; +static const char whitespace[] = { ' ', '\t', '\r' }; static inline bool is_alpha(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); } static inline bool is_digit(char c) { return c >= '0' && c <= '9'; } @@ -75,7 +78,11 @@ static inline bool can_make_up_identifier(char c) { return can_start_identifier( static void eat_whitespace_and_comments(Tokenizer* tokenizer) { while (tokenizer->pos < tokenizer->source_size) { - if (is_whitespace(tokenizer->source[tokenizer->pos])) { + if (tokenizer->source[tokenizer->pos] == '\n') { + tokenizer->line++; + tokenizer->pos++; + tokenizer->last_line_pos = tokenizer->pos; + } else if (is_whitespace(tokenizer->source[tokenizer->pos])) { tokenizer->pos++; } else if (tokenizer->pos + 2 <= tokenizer->source_size && tokenizer->source[tokenizer->pos] == '/' && tokenizer->source[tokenizer->pos + 1] == '/') { while (tokenizer->pos < tokenizer->source_size) { @@ -97,10 +104,10 @@ static void eat_whitespace_and_comments(Tokenizer* tokenizer) { } } -Token next_token(Tokenizer* tokenizer) { +Token shd_next_token(Tokenizer* tokenizer) { eat_whitespace_and_comments(tokenizer); if (tokenizer->pos == tokenizer->source_size) { - debugvv_print("EOF\n"); + shd_debugvv_print("EOF\n"); Token token = { .tag = EOF_tok }; @@ -185,7 +192,7 @@ Token next_token(Tokenizer* tokenizer) { goto parsed_successfully; } - error_print("We don't know how to tokenize %.16s...\n", slice); + shd_error_print("We don't know how to tokenize %.16s...\n", slice); exit(-2); parsed_successfully: @@ -195,16 +202,23 @@ Token next_token(Tokenizer* tokenizer) { token.end = token.start + token_size; tokenizer->current = token; - debugvv_print("Token parsed: (tag = %s, pos = %zu", token_tags[token.tag], token.start); - if (token.tag == identifier_tok || token.tag == string_lit_tok) { - debugvv_print(", str="); - for (size_t i = token.start; i < token.end; i++) - debugvv_print("%c", tokenizer->source[i]); - } - debugvv_print(")\n"); + // debugvv_print("Token parsed: (tag = %s, pos = %zu", token_tags[token.tag], token.start); + // if (token.tag == identifier_tok || token.tag == string_lit_tok) { + // debugvv_print(", str="); + // for (size_t i = token.start; i < token.end; i++) + // debugvv_print("%c", tokenizer->source[i]); + // } + // debugvv_print(")\n"); return token; } -Token curr_token(Tokenizer* tokenizer) { +Loc shd_current_loc(Tokenizer* tokenizer) { + return (Loc) { + .line = tokenizer->line, + .column = tokenizer->pos - tokenizer->last_line_pos + }; +} + +Token shd_curr_token(Tokenizer* tokenizer) { return tokenizer->current; } diff --git a/src/frontends/slim/token.h b/src/frontend/slim/token.h similarity index 86% rename from src/frontends/slim/token.h rename to src/frontend/slim/token.h index 860ce03b7..b53d9e348 100644 --- a/src/frontends/slim/token.h +++ b/src/frontend/slim/token.h @@ -26,11 +26,11 @@ TEXT_TOKEN(input) \ TEXT_TOKEN(output) \ TOKEN(extern, "extern") \ TEXT_TOKEN(generic) \ +TEXT_TOKEN(logical) \ TEXT_TOKEN(var) \ TEXT_TOKEN(val) \ -TEXT_TOKEN(let) \ -TEXT_TOKEN(in) \ TEXT_TOKEN(ptr) \ +TEXT_TOKEN(ref) \ TEXT_TOKEN(type) \ TEXT_TOKEN(fn) \ TEXT_TOKEN(cont) \ @@ -56,7 +56,7 @@ TOKEN(false, "false") \ TOKEN(if, "if") \ TOKEN(else, "else") \ TEXT_TOKEN(control) \ -TEXT_TOKEN(yield) \ +TEXT_TOKEN(merge_selection) \ TEXT_TOKEN(loop) \ TOKEN(continue, "continue") \ TOKEN(break, "break") \ @@ -67,6 +67,7 @@ TOKEN(case, "case") \ TOKEN(default, "default") \ TEXT_TOKEN(join) \ TEXT_TOKEN(call) \ +TEXT_TOKEN(tailcall) \ TOKEN(return, "return") \ TEXT_TOKEN(unreachable) \ TOKEN(infix_rshift_logical, ">>>") \ @@ -103,8 +104,8 @@ TOKEN(equal, "=") \ TOKEN(LIST_END, NULL) typedef struct Tokenizer_ Tokenizer; -Tokenizer* new_tokenizer(const char* source); -void destroy_tokenizer(Tokenizer*); +Tokenizer* shd_new_tokenizer(const char* source); +void shd_destroy_tokenizer(Tokenizer* tokenizer); typedef enum { #define TOKEN(name, str) name##_tok, @@ -112,16 +113,20 @@ typedef enum { #undef TOKEN } TokenTag; -extern const char* token_tags[]; - typedef struct { TokenTag tag; size_t start; size_t end; } Token; -Token curr_token(Tokenizer* tokenizer); -Token next_token(Tokenizer* tokenizer); +typedef struct { + size_t line, column; +} Loc; + +Loc shd_current_loc(Tokenizer* tokenizer); + +Token shd_curr_token(Tokenizer* tokenizer); +Token shd_next_token(Tokenizer* tokenizer); #define SHADY_TOKEN_H diff --git a/src/frontend/spirv/CMakeLists.txt b/src/frontend/spirv/CMakeLists.txt new file mode 100644 index 000000000..7f905245d --- /dev/null +++ b/src/frontend/spirv/CMakeLists.txt @@ -0,0 +1,6 @@ +add_library(shady_s2s STATIC s2s.c) +target_link_libraries(shady_s2s PRIVATE api common) +target_link_libraries(shady_s2s PRIVATE "$") + +target_compile_definitions(driver PUBLIC SPV_PARSER_PRESENT) +target_link_libraries(driver PUBLIC "$") diff --git a/src/frontends/spirv/s2s.c b/src/frontend/spirv/s2s.c similarity index 73% rename from src/frontends/spirv/s2s.c rename to src/frontend/spirv/s2s.c index 978e25f5d..f0618169e 100644 --- a/src/frontends/spirv/s2s.c +++ b/src/frontend/spirv/s2s.c @@ -1,5 +1,15 @@ #include "s2s.h" -#include "shady/builtins.h" + +#include "shady/ir/builtin.h" +#include "shady/ir/memory_layout.h" + +#include "../shady/ir_private.h" + +#include "log.h" +#include "arena.h" +#include "portability.h" +#include "dict.h" +#include "util.h" // this avoids polluting the namespace #define SpvHasResultAndType ShadySpvHasResultAndType @@ -13,7 +23,9 @@ static void SpvHasResultAndType(SpvOp opcode, bool *hasResult, bool *hasResultTy #include "spirv/unified1/OpenCL.std.h" #include "spirv/unified1/GLSL.std.450.h" -extern SpvBuiltIn spv_builtins[]; +#include +#include +#include // TODO: reserve real decoration IDs typedef enum { @@ -23,20 +35,6 @@ typedef enum { ShdDecorationEntryPointName = 999996, } ShdDecoration; -#include "log.h" -#include "arena.h" -#include "portability.h" -#include "dict.h" -#include "util.h" - -#include "../shady/type.h" -#include "../shady/ir_private.h" -#include "../shady/transform/ir_gen_helpers.h" - -#include -#include -#include - typedef struct { struct { uint8_t major, minor; @@ -97,26 +95,26 @@ typedef struct { struct Dict* phi_arguments; } SpvParser; -SpvDef* get_definition_by_id(SpvParser* parser, size_t id); +static SpvDef* get_definition_by_id(SpvParser* parser, size_t id); -SpvDef* new_def(SpvParser* parser) { - SpvDef* interned = arena_alloc(parser->decorations_arena, sizeof(SpvDef)); +static SpvDef* new_def(SpvParser* parser) { + SpvDef* interned = shd_arena_alloc(parser->decorations_arena, sizeof(SpvDef)); SpvDef empty = {0}; memcpy(interned, &empty, sizeof(SpvDef)); return interned; } -void add_decoration(SpvParser* parser, SpvId id, SpvDeco decoration) { +static void add_decoration(SpvParser* parser, SpvId id, SpvDeco decoration) { SpvDef* tgt_def = &parser->defs[id]; while (tgt_def->next_decoration) { tgt_def = &tgt_def->next_decoration->payload; } - SpvDeco* interned = arena_alloc(parser->decorations_arena, sizeof(SpvDeco)); + SpvDeco* interned = shd_arena_alloc(parser->decorations_arena, sizeof(SpvDeco)); memcpy(interned, &decoration, sizeof(SpvDeco)); tgt_def->next_decoration = interned; } -SpvDeco* find_decoration(SpvParser* parser, SpvId id, int member, SpvDecoration tag) { +static SpvDeco* find_decoration(SpvParser* parser, SpvId id, int member, SpvDecoration tag) { SpvDef* tgt_def = &parser->defs[id]; while (tgt_def->next_decoration) { if (tgt_def->next_decoration->decoration == tag && (member < 0 || tgt_def->next_decoration->member == member)) @@ -126,21 +124,21 @@ SpvDeco* find_decoration(SpvParser* parser, SpvId id, int member, SpvDecoration return NULL; } -String get_name(SpvParser* parser, SpvId id) { +static String get_name(SpvParser* parser, SpvId id) { SpvDeco* deco = find_decoration(parser, id, -1, ShdDecorationName); if (!deco) return NULL; return deco->payload.str; } -String get_member_name(SpvParser* parser, SpvId id, int member_id) { +static String get_member_name(SpvParser* parser, SpvId id, int member_id) { SpvDeco* deco = find_decoration(parser, id, member_id, ShdDecorationName); if (!deco) return NULL; return deco->payload.str; } -const Type* get_def_type(SpvParser* parser, SpvId id) { +static const Type* get_def_type(SpvParser* parser, SpvId id) { SpvDef* def = get_definition_by_id(parser, id); assert(def->type == Typ); const Node* t = def->node; @@ -148,7 +146,7 @@ const Type* get_def_type(SpvParser* parser, SpvId id) { return t; } -const Type* get_def_decl(SpvParser* parser, SpvId id) { +static const Type* get_def_decl(SpvParser* parser, SpvId id) { SpvDef* def = get_definition_by_id(parser, id); assert(def->type == Decl); const Node* n = def->node; @@ -156,13 +154,13 @@ const Type* get_def_decl(SpvParser* parser, SpvId id) { return n; } -String get_def_string(SpvParser* parser, SpvId id) { +static String get_def_string(SpvParser* parser, SpvId id) { SpvDef* def = get_definition_by_id(parser, id); assert(def->type == Str); return def->str; } -const Type* get_def_ssa_value(SpvParser* parser, SpvId id) { +static const Type* get_def_ssa_value(SpvParser* parser, SpvId id) { SpvDef* def = get_definition_by_id(parser, id); const Node* n = def->node; if (is_declaration(n)) @@ -171,7 +169,7 @@ const Type* get_def_ssa_value(SpvParser* parser, SpvId id) { return n; } -const Type* get_def_block(SpvParser* parser, SpvId id) { +static const Type* get_def_block(SpvParser* parser, SpvId id) { SpvDef* def = get_definition_by_id(parser, id); assert(def->type == BB); const Node* n = def->node; @@ -179,7 +177,7 @@ const Type* get_def_block(SpvParser* parser, SpvId id) { return n; } -bool parse_spv_header(SpvParser* parser) { +static bool parse_spv_header(SpvParser* parser) { assert(parser->cursor == 0); assert(parser->len >= 4); assert(parser->words[0] == SpvMagicNumber); @@ -194,42 +192,45 @@ bool parse_spv_header(SpvParser* parser) { return true; } -String decode_spv_string_literal(SpvParser* parser, uint32_t* at) { +static String decode_spv_string_literal(SpvParser* parser, uint32_t* at) { // TODO: assumes little endian - return string(get_module_arena(parser->mod), (const char*) at); + return shd_string(shd_module_get_arena(parser->mod), (const char*) at); } -AddressSpace convert_storage_class(SpvStorageClass class) { +static AddressSpace convert_storage_class(SpvStorageClass class) { switch (class) { case SpvStorageClassInput: return AsInput; case SpvStorageClassOutput: return AsOutput; - case SpvStorageClassWorkgroup: return AsSharedPhysical; - case SpvStorageClassCrossWorkgroup: return AsGlobalPhysical; - case SpvStorageClassPhysicalStorageBuffer: return AsGlobalPhysical; - case SpvStorageClassPrivate: return AsPrivatePhysical; - case SpvStorageClassFunction: return AsPrivatePhysical; + case SpvStorageClassWorkgroup: return AsShared; + case SpvStorageClassCrossWorkgroup: return AsGlobal; + case SpvStorageClassPhysicalStorageBuffer: return AsGlobal; + case SpvStorageClassPrivate: return AsPrivate; + case SpvStorageClassFunction: return AsPrivate; case SpvStorageClassGeneric: return AsGeneric; case SpvStorageClassPushConstant: return AsPushConstant; case SpvStorageClassAtomicCounter: - error("TODO"); + break; case SpvStorageClassImage: return AsImage; - error("TODO"); - case SpvStorageClassStorageBuffer: return AsGlobalLogical; + break; + case SpvStorageClassStorageBuffer: return AsShaderStorageBufferObject; case SpvStorageClassUniformConstant: - case SpvStorageClassUniform: return AsGlobalPhysical; // TODO: should probably depend on CL/VK flavours! + case SpvStorageClassUniform: return AsGlobal; // TODO: should probably depend on CL/VK flavours! case SpvStorageClassCallableDataKHR: case SpvStorageClassIncomingCallableDataKHR: case SpvStorageClassRayPayloadKHR: case SpvStorageClassHitAttributeKHR: case SpvStorageClassIncomingRayPayloadKHR: case SpvStorageClassShaderRecordBufferKHR: - error("Unsupported"); + break; case SpvStorageClassCodeSectionINTEL: case SpvStorageClassDeviceOnlyINTEL: case SpvStorageClassHostOnlyINTEL: case SpvStorageClassMax: - error("Unsupported"); + break; + default: + break; } + shd_error("s2s: Unsupported storage class: %d\n", class); } typedef struct { @@ -298,7 +299,7 @@ static SpvShdOpMapping spv_shd_op_mapping[] = { [SpvOpOrdered] = { 0 }, [SpvOpUnordered] = { 0 }, [SpvOpLogicalEqual] = { 1, eq_op, 3 }, - [SpvOpLogicalNotEqual] = { 1, eq_op, 3 }, + [SpvOpLogicalNotEqual] = { 1, neq_op, 3 }, [SpvOpLogicalOr] = { 1, or_op, 3 }, [SpvOpLogicalAnd] = { 1, and_op, 3 }, [SpvOpLogicalNot] = { 1, not_op, 3 }, @@ -329,7 +330,7 @@ static SpvShdOpMapping spv_shd_op_mapping[] = { // honestly none of those are implemented ... }; -const SpvShdOpMapping* convert_spv_op(SpvOp src) { +static const SpvShdOpMapping* convert_spv_op(SpvOp src) { const int nentries = sizeof(spv_shd_op_mapping) / sizeof(*spv_shd_op_mapping); if (src >= nentries) return NULL; @@ -338,7 +339,7 @@ const SpvShdOpMapping* convert_spv_op(SpvOp src) { return NULL; } -SpvId get_result_defined_at(SpvParser* parser, size_t instruction_offset) { +static SpvId get_result_defined_at(SpvParser* parser, size_t instruction_offset) { uint32_t* instruction = parser->words + instruction_offset; SpvOp op = instruction[0] & 0xFFFF; @@ -352,10 +353,10 @@ SpvId get_result_defined_at(SpvParser* parser, size_t instruction_offset) { result = instruction[1]; return result; } - error("no result defined at offset %zu", instruction_offset); + shd_error("no result defined at offset %zu", instruction_offset); } -void scan_definitions(SpvParser* parser) { +static void scan_definitions(SpvParser* parser) { size_t old_cursor = parser->cursor; while (true) { size_t available = parser->len - parser->cursor; @@ -383,7 +384,7 @@ void scan_definitions(SpvParser* parser) { parser->cursor = old_cursor; } -Nodes get_args_from_phi(SpvParser* parser, SpvId block, SpvId predecessor) { +static Nodes get_args_from_phi(SpvParser* parser, SpvId block, SpvId predecessor) { SpvDef* block_def = get_definition_by_id(parser, block); assert(block_def->type == BB && block_def->node); int params_count = block_def->node->payload.basic_block.params.count; @@ -393,9 +394,9 @@ Nodes get_args_from_phi(SpvParser* parser, SpvId block, SpvId predecessor) { params[i] = NULL; if (params_count == 0) - return empty(parser->arena); + return shd_empty(parser->arena); - SpvPhiArgs** found = find_value_dict(SpvId, SpvPhiArgs*, parser->phi_arguments, block); + SpvPhiArgs** found = shd_dict_find_value(SpvId, SpvPhiArgs*, parser->phi_arguments, block); assert(found); SpvPhiArgs* arg = *found; while (true) { @@ -412,10 +413,11 @@ Nodes get_args_from_phi(SpvParser* parser, SpvId block, SpvId predecessor) { for (size_t i = 0; i < params_count; i++) assert(params[i]); - return nodes(parser->arena, params_count, params); + return shd_nodes(parser->arena, params_count, params); } -size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { +static size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { + IrArena* a = parser->arena; uint32_t* instruction = parser->words + instruction_offset; SpvOp op = instruction[0] & 0xFFFF; int size = (int) ((instruction[0] >> 16u) & 0xFFFFu); @@ -454,14 +456,14 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { for (size_t i = 0; i < num_ops; i++) ops[i] = get_def_ssa_value(parser, instruction[shd_op.ops_offset + i]); int results_count = has_result ? 1 : 0; - Nodes results = bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + Nodes results = shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { .op = shd_op.op, - .type_arguments = empty(parser->arena), - .operands = nodes(parser->arena, num_ops, ops) - }), results_count, NULL, false); + .type_arguments = shd_empty(parser->arena), + .operands = shd_nodes(parser->arena, num_ops, ops) + }), results_count); if (has_result) { parser->defs[result].type = Value; - parser->defs[result].node = first(results); + parser->defs[result].node = shd_first(results); } return size; } @@ -514,7 +516,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { type = "Vertex"; break; default: - error("Unsupported execution model %d", instruction[1]) + shd_error("Unsupported execution model %d", instruction[1]) } add_decoration(parser, instruction[2], (SpvDeco) { .decoration = ShdDecorationEntryPointType, @@ -569,7 +571,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case 16: w = IntTy16; break; case 32: w = IntTy32; break; case 64: w = IntTy64; break; - default: error("unhandled int width"); + default: shd_error("unhandled int width"); } parser->defs[result].type = Typ; parser->defs[result].node = int_type(parser->arena, (Int) { @@ -585,7 +587,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case 16: w = FloatTy16; break; case 32: w = FloatTy32; break; case 64: w = FloatTy64; break; - default: error("unhandled float width"); + default: shd_error("unhandled float width"); } parser->defs[result].type = Typ; parser->defs[result].node = float_type(parser->arena, (Float) { @@ -615,16 +617,16 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { for (size_t i = 0; i < size - 3; i++) param_ts[i] = get_def_type(parser, instruction[3 + i]); parser->defs[result].node = fn_type(parser->arena, (FnType) { - .return_types = (return_t == unit_type(parser->arena)) ? empty(parser->arena) : singleton(return_t), - .param_types = nodes(parser->arena, size - 3, param_ts) + .return_types = (return_t == unit_type(parser->arena)) ? shd_empty(parser->arena) : shd_singleton(return_t), + .param_types = shd_nodes(parser->arena, size - 3, param_ts) }); break; } case SpvOpTypeStruct: { parser->defs[result].type = Typ; String name = get_name(parser, result); - name = name ? name : unique_name(parser->arena, "struct_type"); - Node* nominal_type_decl = nominal_type(parser->mod, empty(parser->arena), name); + name = name ? name : shd_make_unique_name(parser->arena, "struct_type"); + Node* nominal_type_decl = nominal_type(parser->mod, shd_empty(parser->arena), name); const Node* nom_t_ref = type_decl_ref(parser->arena, (TypeDeclRef) { .decl = nominal_type_decl }); @@ -635,12 +637,12 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { for (size_t i = 0; i < members_count; i++) { member_names[i] = get_member_name(parser, result, i); if (!member_names[i]) - member_names[i] = format_string_arena(parser->arena->arena, "member%d", i); + member_names[i] = shd_format_string_arena(parser->arena->arena, "member%d", i); member_tys[i] = get_def_type(parser, instruction[2 + i]); } nominal_type_decl->payload.nom_type.body = record_type(parser->arena, (RecordType) { - .members = nodes(parser->arena, members_count, member_tys), - .names = strings(parser->arena, members_count, member_names), + .members = shd_nodes(parser->arena, members_count, member_tys), + .names = shd_strings(parser->arena, members_count, member_names), }); break; } @@ -669,7 +671,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case SpvOpConstant: { parser->defs[result].type = Value; const Type* t = get_def_type(parser, result_t); - int width = get_type_bitwidth(t); + int width = shd_get_type_bitwidth(t); switch (is_type(t)) { case Int_TAG: { uint64_t v; @@ -696,7 +698,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { }); break; } - default: error("OpConstant must produce an int or a float"); + default: shd_error("OpConstant must produce an int or a float"); } break; } @@ -704,7 +706,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case SpvOpConstantNull: { const Type* element_t = get_def_type(parser, result_t); parser->defs[result].type = Value; - parser->defs[result].node = get_default_zero_value(parser->arena, element_t); + parser->defs[result].node = shd_get_default_value(parser->arena, element_t); break; } case SpvOpConstantFalse: { @@ -724,42 +726,36 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { LARRAY(const Node*, contents, size - 3); for (size_t i = 0; i < size - 3; i++) contents[i] = get_def_ssa_value(parser, instruction[3 + i]); - parser->defs[result].node = composite_helper(parser->arena, t, nodes(parser->arena, size - 3, contents)); + parser->defs[result].node = composite_helper(parser->arena, t, shd_nodes(parser->arena, size - 3, contents)); break; } case SpvOpVariable: { String name = get_name(parser, result); - name = name ? name : unique_name(parser->arena, "global_variable"); + name = name ? name : shd_make_unique_name(parser->arena, "global_variable"); AddressSpace as = convert_storage_class(instruction[3]); const Type* contents_t = get_def_type(parser, result_t); - AddressSpace as2 = deconstruct_pointer_type(&contents_t); + AddressSpace as2 = shd_deconstruct_pointer_type(&contents_t); assert(as == as2); - assert(is_data_type(contents_t)); + assert(shd_is_data_type(contents_t)); if (parser->fun) { - const Node* ptr = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = alloca_op, - .type_arguments = singleton(contents_t), - .operands = empty(parser->arena) - }), 1, NULL, false)); + const Node* ptr = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, stack_alloc(parser->arena, (StackAlloc) { .type = contents_t, .mem = shd_bb_mem(parser->current_block.builder) }), 1)); parser->defs[result].type = Value; parser->defs[result].node = ptr; if (size == 5) - bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = store_op, - .type_arguments = empty(parser->arena), - .operands = mk_nodes(parser->arena, ptr, get_def_ssa_value(parser, instruction[4])) - }), 1, NULL, false); + shd_bld_add_instruction_extract_count(parser->current_block.builder, store(parser->arena, (Store) { .ptr = ptr, .value = get_def_ssa_value(parser, instruction[4]), .mem = shd_bb_mem(parser->current_block.builder) }), 1); } else { - Nodes annotations = empty(parser->arena); + Nodes annotations = shd_empty(parser->arena); SpvDeco* builtin = find_decoration(parser, result, -1, SpvDecorationBuiltIn); if (builtin) { - Builtin b = get_builtin_by_spv_id(*builtin->payload.literals.data); + Builtin b = shd_get_builtin_by_spv_id(*builtin->payload.literals.data); assert(b != BuiltinsCount && "Unsupported builtin"); - annotations = append_nodes(parser->arena, annotations, annotation_value_helper(parser->arena, "Builtin", string_lit_helper(parser->arena, get_builtin_name(b)))); + annotations = shd_nodes_append(parser->arena, annotations, annotation_value_helper(parser->arena, "Builtin", string_lit_helper(parser->arena, + shd_get_builtin_name( + b)))); } parser->defs[result].type = Decl; @@ -779,18 +775,18 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { String name = get_name(parser, result); if (!name) - name = unique_name(parser->arena, "function"); + name = shd_make_unique_name(parser->arena, "function"); else - name = unique_name(parser->arena, name); + name = shd_make_unique_name(parser->arena, name); - Nodes annotations = empty(parser->arena); - annotations = append_nodes(parser->arena, annotations, annotation(parser->arena, (Annotation) { .name = "Restructure" })); + Nodes annotations = shd_empty(parser->arena); + annotations = shd_nodes_append(parser->arena, annotations, annotation(parser->arena, (Annotation) { .name = "Restructure" })); SpvDeco* entry_point_type = find_decoration(parser, result, -1, ShdDecorationEntryPointType); parser->is_entry_pt = entry_point_type; if (entry_point_type) { - annotations = append_nodes(parser->arena, annotations, annotation_value(parser->arena, (AnnotationValue) { + annotations = shd_nodes_append(parser->arena, annotations, annotation_value(parser->arena, (AnnotationValue) { .name = "EntryPoint", - .value = string_lit(parser->arena, (StringLiteral) { entry_point_type->payload.str }) + .value = string_lit(parser->arena, (StringLiteral) { .string = entry_point_type->payload.str }) })); SpvDeco* entry_point_name = find_decoration(parser, result, -1, ShdDecorationEntryPointName); @@ -800,19 +796,19 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { if (strcmp(entry_point_type->payload.str, "Compute") == 0) { SpvDeco* wg_size_dec = find_decoration(parser, result, -2, SpvExecutionModeLocalSize); assert(wg_size_dec && wg_size_dec->payload.literals.count == 3 && "we require kernels decorated with a workgroup size"); - annotations = append_nodes(parser->arena, annotations, annotation_values(parser->arena, (AnnotationValues) { - .name = "WorkgroupSize", - .values = mk_nodes(parser->arena, - int32_literal(parser->arena, wg_size_dec->payload.literals.data[0]), - int32_literal(parser->arena, wg_size_dec->payload.literals.data[1]), - int32_literal(parser->arena, wg_size_dec->payload.literals.data[2])) + annotations = shd_nodes_append(parser->arena, annotations, annotation_values(parser->arena, (AnnotationValues) { + .name = "WorkgroupSize", + .values = mk_nodes(parser->arena, + shd_int32_literal(parser->arena, wg_size_dec->payload.literals.data[0]), + shd_int32_literal(parser->arena, wg_size_dec->payload.literals.data[1]), + shd_int32_literal(parser->arena, wg_size_dec->payload.literals.data[2])) })); } else if (strcmp(entry_point_type->payload.str, "Fragment") == 0) { } else if (strcmp(entry_point_type->payload.str, "Vertex") == 0) { } else { - warn_print("Unknown entry point type '%s' for '%s'\n", entry_point_type->payload.str, name); + shd_warn_print("Unknown entry point type '%s' for '%s'\n", entry_point_type->payload.str, name); } } @@ -827,7 +823,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { instruction_offset += s; } - Node* fun = function(parser->mod, nodes(parser->arena, params_count, params), name, annotations, t->payload.fn_type.return_types); + Node* fun = function(parser->mod, shd_nodes(parser->arena, params_count, params), name, annotations, t->payload.fn_type.return_types); parser->defs[result].node = fun; Node* old_fun = parser->fun; parser->fun = fun; @@ -850,7 +846,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { // steal the body of the first block, it can't be jumped to anyways! if (first_block) - fun->payload.fun.body = first_block->payload.basic_block.body; + shd_set_abstraction_body(fun, first_block->payload.basic_block.body); parser->fun = old_fun; break; } @@ -860,15 +856,16 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case SpvOpFunctionParameter: { parser->defs[result].type = Value; String param_name = get_name(parser, result); - param_name = param_name ? param_name : format_string_arena(parser->arena->arena, "param%d", parser->fun_arg_i); - parser->defs[result].node = var(parser->arena, qualified_type_helper(get_def_type(parser, result_t), parser->is_entry_pt), param_name); + param_name = param_name ? param_name : shd_format_string_arena(parser->arena->arena, "param%d", parser->fun_arg_i); + parser->defs[result].node = param(parser->arena, + shd_as_qualified_type(get_def_type(parser, result_t), parser->is_entry_pt), param_name); break; } case SpvOpLabel: { struct CurrBlock old = parser->current_block; parser->current_block.id = result; - Nodes params = empty(parser->arena); + Nodes params = shd_empty(parser->arena); parser->fun_arg_i = 0; while (true) { SpvOp param_op = (parser->words + instruction_offset)[0] & 0xFFFF; @@ -887,8 +884,8 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { assert(s > 0); if (is_param) { const Node* param = get_definition_by_id(parser, get_result_defined_at(parser, instruction_offset))->node; - assert(param && param->tag == Variable_TAG); - params = concat_nodes(parser->arena, params, singleton(param)); + assert(param && param->tag == Param_TAG); + params = shd_concat_nodes(parser->arena, params, shd_singleton(param)); } size += s; instruction_offset += s; @@ -898,11 +895,11 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { parser->defs[result].type = BB; String bb_name = get_name(parser, result); - bb_name = bb_name ? bb_name : unique_name(parser->arena, "basic_block"); - Node* block = basic_block(parser->arena, parser->fun, params, bb_name); + bb_name = bb_name ? bb_name : shd_make_unique_name(parser->arena, "basic_block"); + Node* block = basic_block(parser->arena, params, bb_name); parser->defs[result].node = block; - BodyBuilder* bb = begin_body(parser->arena); + BodyBuilder* bb = shd_bld_begin(parser->arena, shd_get_abstraction_mem(block)); parser->current_block.builder = bb; parser->current_block.finished = NULL; while (parser->current_block.builder) { @@ -912,22 +909,23 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { instruction_offset += s; } assert(parser->current_block.finished); - block->payload.basic_block.body = parser->current_block.finished; + shd_set_abstraction_body(block, parser->current_block.finished); parser->current_block = old; break; } case SpvOpPhi: { parser->defs[result].type = Value; String phi_name = get_name(parser, result); - phi_name = phi_name ? phi_name : unique_name(parser->arena, "phi"); - parser->defs[result].node = var(parser->arena, qualified_type_helper(get_def_type(parser, result_t), false), phi_name); + phi_name = phi_name ? phi_name : shd_make_unique_name(parser->arena, "phi"); + parser->defs[result].node = param(parser->arena, + shd_as_qualified_type(get_def_type(parser, result_t), false), phi_name); assert(size % 2 == 1); int num_callsites = (size - 3) / 2; for (size_t i = 0; i < num_callsites; i++) { SpvId argument_value = instruction[3 + i * 2 + 0]; SpvId predecessor_block = instruction[3 + i * 2 + 1]; - SpvPhiArgs* new = arena_alloc(parser->decorations_arena, sizeof(SpvPhiArgs)); + SpvPhiArgs* new = shd_arena_alloc(parser->decorations_arena, sizeof(SpvPhiArgs)); *new = (SpvPhiArgs) { .predecessor = predecessor_block, .arg_i = parser->fun_arg_i, @@ -935,7 +933,7 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { .next_ = NULL, }; - SpvPhiArgs** found = find_value_dict(SpvId, SpvPhiArgs*, parser->phi_arguments, parser->current_block.id); + SpvPhiArgs** found = shd_dict_find_value(SpvId, SpvPhiArgs*, parser->phi_arguments, parser->current_block.id); if (found) { SpvPhiArgs* arg = *found; while (arg->next_) { @@ -943,10 +941,10 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { } arg->next_ = new; } else { - insert_dict(SpvId, SpvPhiArgs*, parser->phi_arguments, parser->current_block.id, new); + shd_dict_insert(SpvId, SpvPhiArgs*, parser->phi_arguments, parser->current_block.id, new); } - debugv_print("s2s: recorded argument %d (value id=%d) for block %d with predecessor %d\n", parser->fun_arg_i, argument_value, parser->current_block.id, predecessor_block); + shd_debugv_print("s2s: recorded argument %d (value id=%d) for block %d with predecessor %d\n", parser->fun_arg_i, argument_value, parser->current_block.id, predecessor_block); } parser->fun_arg_i++; break; @@ -967,11 +965,11 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { const Type* src = get_def_ssa_value(parser, instruction[3]); const Type* dst_t = get_def_type(parser, result_t); parser->defs[result].type = Value; - parser->defs[result].node = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = convert_op, - .type_arguments = singleton(dst_t), - .operands = singleton(src) - }), 1, NULL, false)); + parser->defs[result].node = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + .op = convert_op, + .type_arguments = shd_singleton(dst_t), + .operands = shd_singleton(src) + }), 1)); break; } case SpvOpConvertPtrToU: @@ -980,11 +978,11 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { const Type* src = get_def_ssa_value(parser, instruction[3]); const Type* dst_t = get_def_type(parser, result_t); parser->defs[result].type = Value; - parser->defs[result].node = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = reinterpret_op, - .type_arguments = singleton(dst_t), - .operands = singleton(src) - }), 1, NULL, false)); + parser->defs[result].node = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + .op = reinterpret_op, + .type_arguments = shd_singleton(dst_t), + .operands = shd_singleton(src) + }), 1)); break; } case SpvOpInBoundsPtrAccessChain: @@ -994,20 +992,17 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { bool has_element = op == SpvOpInBoundsPtrAccessChain || op == SpvOpPtrAccessChain; int indices_start = has_element ? 5 : 4; int num_indices = size - indices_start; - LARRAY(const Node*, ops, 2 + num_indices); - ops[0] = get_def_ssa_value(parser, instruction[3]); + LARRAY(const Node*, indices, num_indices); + const Node* ptr = get_def_ssa_value(parser, instruction[3]); + const Node* offset = NULL; if (has_element) - ops[1] = get_def_ssa_value(parser, instruction[4]); + offset = get_def_ssa_value(parser, instruction[4]); else - ops[1] = int32_literal(parser->arena, 0); + offset = shd_int32_literal(parser->arena, 0); for (size_t i = 0; i < num_indices; i++) - ops[2 + i] = get_def_ssa_value(parser, instruction[indices_start + i]); + indices[i] = get_def_ssa_value(parser, instruction[indices_start + i]); parser->defs[result].type = Value; - parser->defs[result].node = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = lea_op, - .type_arguments = empty(parser->arena), - .operands = nodes(parser->arena, 2 + num_indices, ops) - }), 1, NULL, false)); + parser->defs[result].node = lea_helper(a, ptr, offset, shd_nodes(a, num_indices, indices)); break; } case SpvOpCompositeExtract: { @@ -1015,13 +1010,13 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { LARRAY(const Node*, ops, 1 + num_indices); ops[0] = get_def_ssa_value(parser, instruction[3]); for (size_t i = 0; i < num_indices; i++) - ops[1 + i] = int32_literal(parser->arena, instruction[4 + i]); + ops[1 + i] = shd_int32_literal(parser->arena, instruction[4 + i]); parser->defs[result].type = Value; - parser->defs[result].node = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = extract_op, - .type_arguments = empty(parser->arena), - .operands = nodes(parser->arena, 1 + num_indices, ops) - }), 1, NULL, false)); + parser->defs[result].node = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + .op = extract_op, + .type_arguments = shd_empty(parser->arena), + .operands = shd_nodes(parser->arena, 1 + num_indices, ops) + }), 1)); break; } case SpvOpCompositeInsert: { @@ -1030,13 +1025,13 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { ops[0] = get_def_ssa_value(parser, instruction[4]); ops[1] = get_def_ssa_value(parser, instruction[3]); for (size_t i = 0; i < num_indices; i++) - ops[2 + i] = int32_literal(parser->arena, instruction[5 + i]); + ops[2 + i] = shd_int32_literal(parser->arena, instruction[5 + i]); parser->defs[result].type = Value; - parser->defs[result].node = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = insert_op, - .type_arguments = empty(parser->arena), - .operands = nodes(parser->arena, 2 + num_indices, ops) - }), 1, NULL, false)); + parser->defs[result].node = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + .op = insert_op, + .type_arguments = shd_empty(parser->arena), + .operands = shd_nodes(parser->arena, 2 + num_indices, ops) + }), 1)); break; } case SpvOpVectorShuffle: { @@ -1057,38 +1052,30 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { index -= num_components_a; src = src_b; } - components[i] = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + components[i] = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { .op = extract_op, - .type_arguments = empty(parser->arena), - .operands = mk_nodes(parser->arena, src, int32_literal(parser->arena, index)) - }), 1, NULL, false)); + .type_arguments = shd_empty(parser->arena), + .operands = mk_nodes(parser->arena, src, shd_int32_literal(parser->arena, index)) + }), 1)); } parser->defs[result].type = Value; parser->defs[result].node = composite_helper(parser->arena, pack_type(parser->arena, (PackType) { .element_type = src_a_t->payload.pack_type.element_type, .width = num_components, - }), nodes(parser->arena, num_components, components)); + }), shd_nodes(parser->arena, num_components, components)); break; } case SpvOpLoad: { const Type* src = get_def_ssa_value(parser, instruction[3]); parser->defs[result].type = Value; - parser->defs[result].node = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = load_op, - .type_arguments = empty(parser->arena), - .operands = singleton(src) - }), 1, NULL, false)); + parser->defs[result].node = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, load(a, (Load) { .ptr = src, .mem = shd_bb_mem(parser->current_block.builder) }), 1)); break; } case SpvOpStore: { const Type* ptr = get_def_ssa_value(parser, instruction[1]); const Type* value = get_def_ssa_value(parser, instruction[2]); - bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = store_op, - .type_arguments = empty(parser->arena), - .operands = mk_nodes(parser->arena, ptr, value) - }), 0, NULL, false); + shd_bld_add_instruction_extract_count(parser->current_block.builder, store(a, (Store) { .ptr = ptr, .value = value, .mem = shd_bb_mem(parser->current_block.builder) }), 0); break; } case SpvOpCopyMemory: @@ -1098,21 +1085,21 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { const Node* cnt; if (op == SpvOpCopyMemory) { const Type* elem_t = src->type; - deconstruct_qualified_type(&elem_t); - deconstruct_pointer_type(&elem_t); - cnt = first(bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + shd_deconstruct_qualified_type(&elem_t); + shd_deconstruct_pointer_type(&elem_t); + cnt = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { .op = size_of_op, - .type_arguments = singleton(elem_t), - .operands = empty(parser->arena) - }), 1, NULL, false)); + .type_arguments = shd_singleton(elem_t), + .operands = shd_empty(parser->arena) + }), 1)); } else { cnt = get_def_ssa_value(parser, instruction[3]); } - bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { - .op = memcpy_op, - .type_arguments = empty(parser->arena), - .operands = mk_nodes(parser->arena, dst, src, cnt) - }), 0, NULL, false); + shd_bld_add_instruction_extract_count(parser->current_block.builder, copy_bytes(parser->arena, (CopyBytes) { + .src = src, + .dst = dst, + .count = cnt, + }), 0); break; } case SpvOpSelectionMerge: @@ -1133,8 +1120,8 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { if (callee->tag == Function_TAG) { const Node* fn = callee; //callee->payload.fn_addr.fn; - String fn_name = get_abstraction_name(fn); - if (string_starts_with(fn_name, "__shady")) { + String fn_name = shd_get_abstraction_name(fn); + if (shd_string_starts_with(fn_name, "__shady")) { char* copy = malloc(strlen(fn_name) + 1); memcpy(copy, fn_name, strlen(fn_name) + 1); strtok(copy, ":"); @@ -1143,32 +1130,32 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { char* primop = strtok(NULL, ":"); Op op = PRIMOPS_COUNT; for (size_t i = 0; i < PRIMOPS_COUNT; i++) { - if (strcmp(get_primop_name(i), primop) == 0) { + if (strcmp(shd_get_primop_name(i), primop) == 0) { op = i; break; } } assert(op != PRIMOPS_COUNT); //assert(false && intrinsic); - Nodes rslts = bind_instruction_outputs_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { + Nodes rslts = shd_bld_add_instruction_extract_count(parser->current_block.builder, prim_op(parser->arena, (PrimOp) { .op = op, - .type_arguments = empty(parser->arena), - .operands = nodes(parser->arena, num_args, args) - }), rslts_count, NULL, false); + .type_arguments = shd_empty(parser->arena), + .operands = shd_nodes(parser->arena, num_args, args) + }), rslts_count); if (rslts_count == 1) - parser->defs[result].node = first(rslts); + parser->defs[result].node = shd_first(rslts); break; } } - Nodes rslts = bind_instruction_outputs_count(parser->current_block.builder, call(parser->arena, (Call) { + Nodes rslts = shd_bld_add_instruction_extract_count(parser->current_block.builder, call(parser->arena, (Call) { .callee = fn_addr_helper(parser->arena, callee), - .args = nodes(parser->arena, num_args, args) - }), rslts_count, NULL, false); + .args = shd_nodes(parser->arena, num_args, args) + }), rslts_count); if (rslts_count == 1) - parser->defs[result].node = first(rslts); + parser->defs[result].node = shd_first(rslts); break; } @@ -1198,34 +1185,34 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case OpenCLstd_Floor: instr = prim_op(parser->arena, (PrimOp) { .op = floor_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case OpenCLstd_Sqrt: instr = prim_op(parser->arena, (PrimOp) { .op = sqrt_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case OpenCLstd_Fabs: instr = prim_op(parser->arena, (PrimOp) { .op = abs_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case OpenCLstd_Sin: instr = prim_op(parser->arena, (PrimOp) { .op = sin_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case OpenCLstd_Cos: instr = prim_op(parser->arena, (PrimOp) { .op = cos_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; - default: error("unhandled extended instruction %d in set '%s'", ext_instr, set); + default: shd_error("unhandled extended instruction %d in set '%s'", ext_instr, set); } } else if (strcmp(set, "GLSL.std.450") == 0) { switch (ext_instr) { @@ -1243,31 +1230,31 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case GLSLstd450Floor: instr = prim_op(parser->arena, (PrimOp) { .op = floor_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case GLSLstd450Sqrt: instr = prim_op(parser->arena, (PrimOp) { .op = sqrt_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case GLSLstd450FAbs: instr = prim_op(parser->arena, (PrimOp) { .op = abs_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case GLSLstd450Sin: instr = prim_op(parser->arena, (PrimOp) { .op = sin_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case GLSLstd450Cos: instr = prim_op(parser->arena, (PrimOp) { .op = cos_op, - .operands = singleton(args[0]) + .operands = shd_singleton(args[0]) }); break; case GLSLstd450FMin: instr = prim_op(parser->arena, (PrimOp) { .op = min_op, .operands = mk_nodes(parser->arena, args[0], args[1]) }); break; @@ -1276,23 +1263,24 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case GLSLstd450FMax: instr = prim_op(parser->arena, (PrimOp) { .op = max_op, .operands = mk_nodes(parser->arena, args[0], args[1]) }); break; case GLSLstd450SMax: instr = prim_op(parser->arena, (PrimOp) { .op = max_op, .operands = mk_nodes(parser->arena, args[0], args[1]) }); break; case GLSLstd450UMax: instr = prim_op(parser->arena, (PrimOp) { .op = max_op, .operands = mk_nodes(parser->arena, args[0], args[1]) }); break; - case GLSLstd450Exp: instr = prim_op(parser->arena, (PrimOp) { .op = exp_op, .operands = singleton(args[0]) }); break; + case GLSLstd450Exp: instr = prim_op(parser->arena, (PrimOp) { .op = exp_op, .operands = shd_singleton(args[0]) }); break; case GLSLstd450Pow: instr = prim_op(parser->arena, (PrimOp) { .op = pow_op, .operands = mk_nodes(parser->arena, args[0], args[1]) }); break; - default: error("unhandled extended instruction %d in set '%s'", ext_instr, set); + default: shd_error("unhandled extended instruction %d in set '%s'", ext_instr, set); } } else { - error("Unknown extended instruction set '%s'", set); + shd_error("Unknown extended instruction set '%s'", set); } parser->defs[result].type = Value; - parser->defs[result].node = first(bind_instruction_outputs_count(parser->current_block.builder, instr, 1, NULL, false)); + parser->defs[result].node = shd_first(shd_bld_add_instruction_extract_count(parser->current_block.builder, instr, 1)); break; } case SpvOpBranch: { BodyBuilder* bb = parser->current_block.builder; - parser->current_block.finished = finish_body(bb, jump(parser->arena, (Jump) { - .target = get_def_block(parser, instruction[1]), - .args = get_args_from_phi(parser, instruction[1], parser->current_block.id), + parser->current_block.finished = shd_bld_finish(bb, jump(parser->arena, (Jump) { + .target = get_def_block(parser, instruction[1]), + .args = get_args_from_phi(parser, instruction[1], parser->current_block.id), + .mem = shd_bb_mem(bb) })); parser->current_block.builder = NULL; break; @@ -1300,10 +1288,12 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case SpvOpBranchConditional: { SpvId destinations[2] = { instruction[2], instruction[3] }; BodyBuilder* bb = parser->current_block.builder; - parser->current_block.finished = finish_body(bb, branch(parser->arena, (Branch) { - .true_jump = jump_helper(parser->arena, get_def_block(parser, destinations[0]), get_args_from_phi(parser, destinations[0], parser->current_block.id)), - .false_jump = jump_helper(parser->arena, get_def_block(parser, destinations[1]), get_args_from_phi(parser, destinations[1], parser->current_block.id)), - .branch_condition = get_def_ssa_value(parser, instruction[1]), + parser->current_block.finished = shd_bld_finish(bb, branch(parser->arena, (Branch) { + .true_jump = jump_helper(parser->arena, shd_bb_mem(bb), get_def_block(parser, destinations[0]), + get_args_from_phi(parser, destinations[0], parser->current_block.id)), + .false_jump = jump_helper(parser->arena, shd_bb_mem(bb), get_def_block(parser, destinations[1]), + get_args_from_phi(parser, destinations[1], parser->current_block.id)), + .condition = get_def_ssa_value(parser, instruction[1]), })); parser->current_block.builder = NULL; break; @@ -1312,18 +1302,17 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { case SpvOpReturnValue: { Nodes args; if (op == SpvOpReturn) - args = empty(parser->arena); + args = shd_empty(parser->arena); else - args = singleton(get_def_ssa_value(parser, instruction[1])); + args = shd_singleton(get_def_ssa_value(parser, instruction[1])); BodyBuilder* bb = parser->current_block.builder; - parser->current_block.finished = finish_body(bb, fn_ret(parser->arena, (Return) { - .fn = parser->fun, + parser->current_block.finished = shd_bld_finish(bb, fn_ret(parser->arena, (Return) { .args = args, })); parser->current_block.builder = NULL; break; } - default: error("Unsupported op: %d, size: %d", op, size); + default: shd_error("Unsupported op: %d, size: %d", op, size); } if (has_result) { @@ -1333,36 +1322,40 @@ size_t parse_spv_instruction_at(SpvParser* parser, size_t instruction_offset) { return size; } -SpvDef* get_definition_by_id(SpvParser* parser, size_t id) { +static SpvDef* get_definition_by_id(SpvParser* parser, size_t id) { assert(id > 0 && id < parser->header.bound); if (parser->defs[id].type == Nothing) - error("there is no Op that defines result %zu", id); + shd_error("there is no Op that defines result %zu", id); if (parser->defs[id].type == Forward) parse_spv_instruction_at(parser, parser->defs[id].instruction_offset); assert(parser->defs[id].type != Forward); return &parser->defs[id]; } -KeyHash hash_spvid(SpvId* p) { - return hash_murmur(p, sizeof(SpvId)); +static KeyHash hash_spvid(SpvId* p) { + return shd_hash(p, sizeof(SpvId)); } -bool compare_spvid(SpvId* pa, SpvId* pb) { +static bool compare_spvid(SpvId* pa, SpvId* pb) { if (pa == pb) return true; if (!pa || !pb) return false; return *pa == *pb; } -S2SError parse_spirv_into_shady(Module* dst, size_t len, const char* data) { +S2SError shd_parse_spirv(const CompilerConfig* config, size_t len, const char* data, String name, Module** dst) { + ArenaConfig aconfig = shd_default_arena_config(&config->target); + IrArena* a = shd_new_ir_arena(&aconfig); + *dst = shd_new_module(a, name); + SpvParser parser = { .cursor = 0, .len = len / sizeof(uint32_t), .words = (uint32_t*) data, - .mod = dst, - .arena = get_module_arena(dst), + .mod = *dst, + .arena = shd_module_get_arena(*dst), - .decorations_arena = new_arena(), - .phi_arguments = new_dict(SpvId, SpvPhiArgs*, (HashFn) hash_spvid, (CmpFn) compare_spvid), + .decorations_arena = shd_new_arena(), + .phi_arguments = shd_new_dict(SpvId, SpvPhiArgs*, (HashFn) hash_spvid, (CmpFn) compare_spvid), }; if (!parse_spv_header(&parser)) @@ -1376,8 +1369,8 @@ S2SError parse_spirv_into_shady(Module* dst, size_t len, const char* data) { parser.cursor += parse_spv_instruction_at(&parser, parser.cursor); } - destroy_dict(parser.phi_arguments); - destroy_arena(parser.decorations_arena); + shd_destroy_dict(parser.phi_arguments); + shd_destroy_arena(parser.decorations_arena); free(parser.defs); return S2S_Success; diff --git a/src/frontend/spirv/s2s.h b/src/frontend/spirv/s2s.h new file mode 100644 index 000000000..2f095b48a --- /dev/null +++ b/src/frontend/spirv/s2s.h @@ -0,0 +1,14 @@ +#ifndef SHADY_S2S +#define SHADY_S2S + +#include "shady/ir.h" + +typedef enum { + S2S_Success, + S2S_FailedParsingGeneric, +} S2SError; + +typedef struct CompilerConfig_ CompilerConfig; +S2SError shd_parse_spirv(const CompilerConfig* config, size_t len, const char* data, String name, Module** dst); + +#endif diff --git a/src/frontends/CMakeLists.txt b/src/frontends/CMakeLists.txt deleted file mode 100644 index 509f90bcd..000000000 --- a/src/frontends/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -add_subdirectory(slim) -add_subdirectory(spirv) - -find_package(LLVM) -if(LLVM_FOUND) - message("LLVM ${LLVM_VERSION} found") - add_subdirectory(llvm) -endif () diff --git a/src/frontends/llvm/CMakeLists.txt b/src/frontends/llvm/CMakeLists.txt deleted file mode 100644 index 0e46cabd2..000000000 --- a/src/frontends/llvm/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_generated_file(FILE_NAME l2s_generated.c SOURCES generator_l2s.c) - -add_library(shady_fe_llvm STATIC l2s.c l2s_type.c l2s_value.c l2s_instr.c l2s_meta.c l2s_postprocess.c l2s_annotations.c ${CMAKE_CURRENT_BINARY_DIR}/l2s_generated.c) - -target_include_directories(shady_fe_llvm PRIVATE ${LLVM_INCLUDE_DIRS}) -target_include_directories(shady_fe_llvm PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) # for l2s_generated.c -separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) -add_definitions(${LLVM_DEFINITIONS_LIST}) -target_compile_definitions(shady_fe_llvm PRIVATE "LLVM_VERSION_MAJOR=${LLVM_VERSION_MAJOR}") - -if (TARGET LLVM) - message("LLVM shared library target exists, major version = ${LLVM_VERSION_MAJOR}") - target_link_libraries(shady_fe_llvm PRIVATE LLVM) -else () - message(FATAL_ERROR "Failed to find LLVM target, but found LLVM module earlier") -endif() - -target_link_libraries(shady_fe_llvm PRIVATE api common shady) -set_property(TARGET shady_fe_llvm PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/frontends/llvm/generator_l2s.c b/src/frontends/llvm/generator_l2s.c deleted file mode 100644 index 7fc35b42f..000000000 --- a/src/frontends/llvm/generator_l2s.c +++ /dev/null @@ -1,31 +0,0 @@ -#include "generator.h" - -void generate_llvm_shady_address_space_conversion(Growy* g, json_object* address_spaces) { - growy_append_formatted(g, "AddressSpace convert_llvm_address_space(unsigned as) {\n"); - growy_append_formatted(g, "\tstatic bool warned = false;\n"); - growy_append_formatted(g, "\tswitch (as) {\n"); - for (size_t i = 0; i < json_object_array_length(address_spaces); i++) { - json_object* as = json_object_array_get_idx(address_spaces, i); - String name = json_object_get_string(json_object_object_get(as, "name")); - json_object* llvm_id = json_object_object_get(as, "llvm-id"); - if (!llvm_id || json_object_get_type(llvm_id) != json_type_int) - continue; - growy_append_formatted(g, "\t\t case %d: return As%s;\n", json_object_get_int(llvm_id), name); - } - growy_append_formatted(g, "\t\tdefault:\n"); - growy_append_formatted(g, "\t\t\tif (!warned)\n"); - growy_append_string(g, "\t\t\t\twarn_print(\"Warning: unrecognised address space %d\", as);\n"); - growy_append_formatted(g, "\t\t\twarned = true;\n"); - growy_append_formatted(g, "\t\t\treturn AsGeneric;\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "}\n"); -} - -void generate(Growy* g, Data data) { - generate_header(g, data); - growy_append_formatted(g, "#include \"l2s_private.h\"\n"); - growy_append_formatted(g, "#include \"log.h\"\n"); - growy_append_formatted(g, "#include \n"); - - generate_llvm_shady_address_space_conversion(g, json_object_object_get(data.shd, "address-spaces")); -} diff --git a/src/frontends/llvm/l2s.c b/src/frontends/llvm/l2s.c deleted file mode 100644 index 0597654da..000000000 --- a/src/frontends/llvm/l2s.c +++ /dev/null @@ -1,214 +0,0 @@ -#include "l2s_private.h" - -#include "log.h" -#include "dict.h" -#include "util.h" - -#include "llvm-c/IRReader.h" -#include "portability.h" - -#include -#include -#include - -typedef struct OpaqueRef* OpaqueRef; - -static KeyHash hash_opaque_ptr(OpaqueRef* pvalue) { - if (!pvalue) - return 0; - size_t ptr = *(size_t*) pvalue; - return hash_murmur(&ptr, sizeof(size_t)); -} - -static bool cmp_opaque_ptr(OpaqueRef* a, OpaqueRef* b) { - if (a == b) - return true; - if (!a ^ !b) - return false; - return *a == *b; -} - -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); - -static const Node* write_bb_tail(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVMBasicBlockRef bb, LLVMValueRef first_instr) { - LLVMValueRef instr; - for (instr = first_instr; instr; instr = LLVMGetNextInstruction(instr)) { - bool last = instr == LLVMGetLastInstruction(bb); - if (last) - assert(LLVMGetBasicBlockTerminator(bb) == instr); - // LLVMDumpValue(instr); - // printf("\n"); - EmittedInstr emitted = convert_instruction(p, fn_or_bb, b, instr); - if (emitted.terminator) - return finish_body(b, emitted.terminator); - if (!emitted.instruction) - continue; - String names[] = { LLVMGetValueName(instr) }; - Nodes results = bind_instruction_explicit_result_types(b, emitted.instruction, emitted.result_types, names, false); - if (emitted.result_types.count == 1) { - const Node* result = first(results); - insert_dict(LLVMValueRef, const Node*, p->map, instr, result); - } - } - assert(false); -} - -const Node* convert_basic_block(Parser* p, Node* fn, LLVMBasicBlockRef bb) { - const Node** found = find_value_dict(LLVMValueRef, const Node*, p->map, bb); - if (found) return *found; - IrArena* a = get_module_arena(p->dst); - - Nodes params = empty(a); - LLVMValueRef instr = LLVMGetFirstInstruction(bb); - while (instr) { - switch (LLVMGetInstructionOpcode(instr)) { - case LLVMPHI: { - assert(false); - break; - } - default: goto after_phis; - } - instr = LLVMGetNextInstruction(instr); - } - after_phis: - { - String name = LLVMGetBasicBlockName(bb); - if (!name || strlen(name) == 0) - name = unique_name(a, "bb"); - Node* nbb = basic_block(a, fn, params, name); - insert_dict(LLVMValueRef, const Node*, p->map, bb, nbb); - BodyBuilder* b = begin_body(a); - nbb->payload.basic_block.body = write_bb_tail(p, nbb, b, bb, instr); - return nbb; - } -} - -const Node* convert_function(Parser* p, LLVMValueRef fn) { - if (is_llvm_intrinsic(fn)) { - warn_print("Skipping unknown LLVM intrinsic function: %s\n", LLVMGetValueName(fn)); - return NULL; - } - if (is_shady_intrinsic(fn)) { - warn_print("Skipping shady intrinsic function: %s\n", LLVMGetValueName(fn)); - return NULL; - } - - const Node** found = find_value_dict(LLVMValueRef, const Node*, p->map, fn); - if (found) return *found; - IrArena* a = get_module_arena(p->dst); - debug_print("Converting function: %s\n", LLVMGetValueName(fn)); - - Nodes params = empty(a); - for (LLVMValueRef oparam = LLVMGetFirstParam(fn); oparam && oparam <= LLVMGetLastParam(fn); oparam = LLVMGetNextParam(oparam)) { - LLVMTypeRef ot = LLVMTypeOf(oparam); - const Type* t = convert_type(p, ot); - const Node* param = var(a, t, LLVMGetValueName(oparam)); - insert_dict(LLVMValueRef, const Node*, p->map, oparam, param); - params = append_nodes(a, params, param); - } - const Type* fn_type = convert_type(p, LLVMGlobalGetValueType(fn)); - assert(fn_type->tag == FnType_TAG); - assert(fn_type->payload.fn_type.param_types.count == params.count); - Node* f = function(p->dst, params, LLVMGetValueName(fn), empty(a), fn_type->payload.fn_type.return_types); - const Node* r = fn_addr_helper(a, f); - insert_dict(LLVMValueRef, const Node*, p->map, fn, r); - - if (LLVMCountBasicBlocks(fn) > 0) { - LLVMBasicBlockRef first_bb = LLVMGetEntryBasicBlock(fn); - BodyBuilder* b = begin_body(a); - insert_dict(LLVMValueRef, const Node*, p->map, first_bb, f); - f->payload.fun.body = write_bb_tail(p, f, b, first_bb, LLVMGetFirstInstruction(first_bb)); - } - - return r; -} - -const Node* convert_global(Parser* p, LLVMValueRef global) { - const Node** found = find_value_dict(LLVMValueRef, const Node*, p->map, global); - if (found) return *found; - IrArena* a = get_module_arena(p->dst); - - String name = LLVMGetValueName(global); - String intrinsic = is_llvm_intrinsic(global); - if (intrinsic) { - if (strcmp(intrinsic, "llvm.global.annotations") == 0) { - process_llvm_annotations(p, global); - return NULL; - } - warn_print("Skipping unknown LLVM intrinsic function: %s\n", name); - return NULL; - } - debug_print("Converting global: %s\n", name); - - Node* decl = NULL; - - if (LLVMIsAGlobalVariable(global)) { - LLVMValueRef value = LLVMGetInitializer(global); - const Type* type = convert_type(p, LLVMGlobalGetValueType(global)); - // nb: even if we have untyped pointers, they still carry useful address space info - const Type* ptr_t = convert_type(p, LLVMTypeOf(global)); - assert(ptr_t->tag == PtrType_TAG); - AddressSpace as = ptr_t->payload.ptr_type.address_space; - decl = global_var(p->dst, empty(a), type, name, as); - if (value && as != AsUniformConstant) - decl->payload.global_variable.init = convert_value(p, value); - } else { - const Type* type = convert_type(p, LLVMTypeOf(global)); - decl = constant(p->dst, empty(a), type, name); - decl->payload.constant.instruction = convert_value(p, global); - } - - assert(decl && is_declaration(decl)); - const Node* r = ref_decl_helper(a, decl); - - insert_dict(LLVMValueRef, const Node*, p->map, global, r); - return r; -} - -bool parse_llvm_into_shady(Module* dst, size_t len, const char* data) { - LLVMContextRef context = LLVMContextCreate(); - LLVMModuleRef src; - LLVMMemoryBufferRef mem = LLVMCreateMemoryBufferWithMemoryRange(data, len, "my_great_buffer", false); - char* parsing_diagnostic = ""; - if (LLVMParseIRInContext(context, mem, &src, &parsing_diagnostic)) { - error_print("Failed to parse LLVM IR\n"); - error_print(parsing_diagnostic); - error_die(); - } - info_print("LLVM IR parsed successfully\n"); - - Module* dirty = new_module(get_module_arena(dst), "dirty"); - Parser p = { - .ctx = context, - .map = new_dict(LLVMValueRef, const Node*, (HashFn) hash_opaque_ptr, (CmpFn) cmp_opaque_ptr), - .annotations = new_dict(LLVMValueRef, ParsedAnnotation, (HashFn) hash_opaque_ptr, (CmpFn) cmp_opaque_ptr), - .scopes = new_dict(const Node*, Nodes, (HashFn) hash_node, (CmpFn) compare_node), - .annotations_arena = new_arena(), - .src = src, - .dst = dirty, - }; - - for (LLVMValueRef fn = LLVMGetFirstFunction(src); fn && fn <= LLVMGetNextFunction(fn); fn = LLVMGetLastFunction(src)) { - convert_function(&p, fn); - } - - LLVMValueRef global = LLVMGetFirstGlobal(src); - while (global) { - convert_global(&p, global); - if (global == LLVMGetLastGlobal(src)) - break; - global = LLVMGetNextGlobal(global); - } - - postprocess(&p, dirty, dst); - - destroy_dict(p.map); - destroy_dict(p.annotations); - destroy_dict(p.scopes); - destroy_arena(p.annotations_arena); - - LLVMContextDispose(context); - - return true; -} diff --git a/src/frontends/llvm/l2s.h b/src/frontends/llvm/l2s.h deleted file mode 100644 index 8ec275b98..000000000 --- a/src/frontends/llvm/l2s.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef SHADY_FE_LLVM_H -#define SHADY_FE_LLVM_H - -#include "shady/ir.h" -#include - -bool parse_llvm_into_shady(Module* dst, size_t len, const char* data); - -#endif diff --git a/src/frontends/llvm/l2s_instr.c b/src/frontends/llvm/l2s_instr.c deleted file mode 100644 index b806776e6..000000000 --- a/src/frontends/llvm/l2s_instr.c +++ /dev/null @@ -1,578 +0,0 @@ -#include "l2s_private.h" - -#include "portability.h" -#include "log.h" -#include "dict.h" - -#include "../shady/type.h" - -#include "llvm-c/DebugInfo.h" - -static Nodes convert_operands(Parser* p, size_t num_ops, LLVMValueRef v) { - IrArena* a = get_module_arena(p->dst); - LARRAY(const Node*, ops, num_ops); - for (size_t i = 0; i < num_ops; i++) { - LLVMValueRef op = LLVMGetOperand(v, i); - if (LLVMIsAFunction(op) && (is_llvm_intrinsic(op) || is_shady_intrinsic(op))) - ops[i] = NULL; - else - ops[i] = convert_value(p, op); - } - Nodes operands = nodes(a, num_ops, ops); - return operands; -} - -static const Type* change_int_t_sign(const Type* t, bool as_signed) { - assert(t); - assert(t->tag == Int_TAG); - return int_type(t->arena, (Int) { - .width = t->payload.int_type.width, - .is_signed = as_signed - }); -} - -static Nodes reinterpret_operands(BodyBuilder* b, Nodes ops, const Type* dst_t) { - assert(ops.count > 0); - IrArena* a = dst_t->arena; - LARRAY(const Node*, nops, ops.count); - for (size_t i = 0; i < ops.count; i++) - nops[i] = first(bind_instruction_explicit_result_types(b, prim_op_helper(a, reinterpret_op, singleton(dst_t), singleton(ops.nodes[i])), singleton(dst_t), NULL, false)); - return nodes(a, ops.count, nops); -} - -LLVMValueRef remove_ptr_bitcasts(Parser* p, LLVMValueRef v) { - while (true) { - if (LLVMIsAInstruction(v) || LLVMIsAConstantExpr(v)) { - if (LLVMGetInstructionOpcode(v) == LLVMBitCast) { - LLVMTypeRef t = LLVMTypeOf(v); - if (LLVMGetTypeKind(t) == LLVMPointerTypeKind) - v = LLVMGetOperand(v, 0); - } - } - break; - } - return v; -} - -/// instr may be an instruction or a constantexpr -EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVMValueRef instr) { - Node* fn = fn_or_bb; - if (fn) { - if (fn_or_bb->tag == BasicBlock_TAG) - fn = (Node*) fn_or_bb->payload.basic_block.fn; - assert(fn->tag == Function_TAG); - } - - IrArena* a = get_module_arena(p->dst); - int num_ops = LLVMGetNumOperands(instr); - size_t num_results = 1; - Nodes result_types = empty(a); - const Node* r = NULL; - - LLVMOpcode opcode; - if (LLVMIsAInstruction(instr)) - opcode = LLVMGetInstructionOpcode(instr); - else if (LLVMIsAConstantExpr(instr)) - opcode = LLVMGetConstOpcode(instr); - else - assert(false); - - const Type* t = convert_type(p, LLVMTypeOf(instr)); - -#define BIND_PREV_R(t) bind_instruction_explicit_result_types(b, r, singleton(t), NULL, false) - - //if (LLVMIsATerminatorInst(instr)) { - if (LLVMIsAInstruction(instr)) { - assert(fn && fn_or_bb); - LLVMMetadataRef dbgloc = LLVMInstructionGetDebugLoc(instr); - if (dbgloc) { - Nodes* found = find_value_dict(const Node*, Nodes, p->scopes, fn_or_bb); - if (!found) { - Nodes str = scope_to_string(p, dbgloc); - insert_dict(const Node*, Nodes, p->scopes, fn_or_bb, str); - debug_print("Found a debug location for "); - log_node(DEBUG, fn_or_bb); - for (size_t i = 0; i < str.count; i++) { - log_node(DEBUG, str.nodes[i]); - debug_print(" -> "); - } - debug_print(" (depth= %zu)\n", str.count); - } - } - } - - switch (opcode) { - case LLVMRet: return (EmittedInstr) { - .terminator = fn_ret(a, (Return) { - .fn = NULL, - .args = num_ops == 0 ? empty(a) : convert_operands(p, num_ops, instr) - }) - }; - case LLVMBr: { - unsigned n_successors = LLVMGetNumSuccessors(instr); - LARRAY(const Node*, targets, n_successors); - for (size_t i = 0; i < n_successors; i++) - targets[i] = convert_basic_block(p, fn, LLVMGetSuccessor(instr, i)); - if (LLVMIsConditional(instr)) { - assert(n_successors == 2); - const Node* condition = convert_value(p, LLVMGetCondition(instr)); - return (EmittedInstr) { - .terminator = branch(a, (Branch) { - .branch_condition = condition, - .true_jump = jump_helper(a, targets[0], empty(a)), - .false_jump = jump_helper(a, targets[1], empty(a)), - }) - }; - } else { - assert(n_successors == 1); - return (EmittedInstr) { - .terminator = jump_helper(a, targets[0], empty(a)) - }; - } - } - case LLVMSwitch: - goto unimplemented; - case LLVMIndirectBr: - goto unimplemented; - case LLVMInvoke: - goto unimplemented; - case LLVMUnreachable: return (EmittedInstr) { - .terminator = unreachable(a) - }; - case LLVMCallBr: - goto unimplemented; - case LLVMFNeg: - r = prim_op_helper(a, neg_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMFAdd: - case LLVMAdd: - r = prim_op_helper(a, add_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMSub: - case LLVMFSub: - r = prim_op_helper(a, sub_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMMul: - case LLVMFMul: - r = prim_op_helper(a, mul_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMUDiv: - case LLVMFDiv: - r = prim_op_helper(a, div_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMSDiv: { - const Type* int_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); - const Type* signed_t = change_int_t_sign(int_t, true); - r = prim_op_helper(a, div_op, empty(a), reinterpret_operands(b, convert_operands(p, num_ops, instr), signed_t)); - r = prim_op_helper(a, reinterpret_op, singleton(int_t), BIND_PREV_R(signed_t)); - break; - } case LLVMURem: - case LLVMFRem: - r = prim_op_helper(a, mod_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMSRem: { - const Type* int_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); - const Type* signed_t = change_int_t_sign(int_t, true); - r = prim_op_helper(a, mod_op, empty(a), reinterpret_operands(b, convert_operands(p, num_ops, instr), signed_t)); - r = prim_op_helper(a, reinterpret_op, singleton(int_t), BIND_PREV_R(signed_t)); - break; - } case LLVMShl: - r = prim_op_helper(a, lshift_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMLShr: - r = prim_op_helper(a, rshift_logical_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMAShr: - r = prim_op_helper(a, rshift_arithm_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMAnd: - r = prim_op_helper(a, and_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMOr: - r = prim_op_helper(a, or_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMXor: - r = prim_op_helper(a, xor_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMAlloca: { - assert(t->tag == PtrType_TAG); - const Type* allocated_t = convert_type(p, LLVMGetAllocatedType(instr)); - const Type* allocated_ptr_t = ptr_type(a, (PtrType) { .pointed_type = allocated_t, .address_space = AsPrivatePhysical }); - r = first(bind_instruction_explicit_result_types(b, prim_op_helper(a, alloca_op, singleton(allocated_t), empty(a)), singleton(allocated_ptr_t), NULL, false)); - if (UNTYPED_POINTERS) { - const Type* untyped_ptr_t = ptr_type(a, (PtrType) { .pointed_type = unit_type(a), .address_space = AsPrivatePhysical }); - r = first(bind_instruction_outputs_count(b, prim_op_helper(a, reinterpret_op, singleton(untyped_ptr_t), singleton(r)), 1, NULL, false)); - } - r = prim_op_helper(a, convert_op, singleton(t), singleton(r)); - break; - } - case LLVMLoad: { - Nodes ops = convert_operands(p, num_ops, instr); - assert(ops.count == 1); - const Node* ptr = first(ops); - r = prim_op_helper(a, load_op, singleton(t), singleton(ptr)); - break; - } - case LLVMStore: { - num_results = 0; - Nodes ops = convert_operands(p, num_ops, instr); - assert(ops.count == 2); - r = prim_op_helper(a, store_op, UNTYPED_POINTERS ? singleton(convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0)))) : empty(a), mk_nodes(a, ops.nodes[1], ops.nodes[0])); - break; - } - case LLVMGetElementPtr: { - Nodes ops = convert_operands(p, num_ops, instr); - r = prim_op_helper(a, lea_op, UNTYPED_POINTERS ? singleton(convert_type(p, LLVMGetGEPSourceElementType(instr))) : empty(a), ops); - break; - } - case LLVMTrunc: - case LLVMZExt: { - const Type* src_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); - Nodes ops = convert_operands(p, num_ops, instr); - if (src_t->tag == Bool_TAG) { - assert(t->tag == Int_TAG); - const Node* zero = int_literal(a, (IntLiteral) { .value = 0, .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed }); - const Node* one = int_literal(a, (IntLiteral) { .value = 1, .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed }); - r = prim_op_helper(a, select_op, empty(a), mk_nodes(a, first(ops), one, zero)); - } else { - // reinterpret as unsigned, convert to change size, reinterpret back to target T - const Type* unsigned_src_t = change_int_t_sign(src_t, false); - const Type* unsigned_dst_t = change_int_t_sign(t, false); - r = prim_op_helper(a, convert_op, singleton(unsigned_dst_t), reinterpret_operands(b, ops, unsigned_src_t)); - r = prim_op_helper(a, reinterpret_op, singleton(t), BIND_PREV_R(unsigned_dst_t)); - } - break; - } case LLVMSExt: { - // reinterpret as signed, convert to change size, reinterpret back to target T - const Type* src_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); - const Type* signed_src_t = change_int_t_sign(src_t, true); - const Type* signed_dst_t = change_int_t_sign(t, true); - r = prim_op_helper(a, convert_op, singleton(signed_dst_t), reinterpret_operands(b, convert_operands(p, num_ops, instr), signed_src_t)); - r = prim_op_helper(a, reinterpret_op, singleton(t), BIND_PREV_R(signed_dst_t)); - break; - } case LLVMFPToUI: - case LLVMFPToSI: - case LLVMUIToFP: - case LLVMSIToFP: - r = prim_op_helper(a, convert_op, singleton(t), convert_operands(p, num_ops, instr)); - break; - case LLVMFPTrunc: - goto unimplemented; - case LLVMFPExt: - goto unimplemented; - case LLVMPtrToInt: - case LLVMIntToPtr: - case LLVMBitCast: - case LLVMAddrSpaceCast: { - // when constructing or deconstructing generic pointers, we need to emit a convert_op instead - assert(num_ops == 1); - const Node* src = first(convert_operands(p, num_ops, instr)); - Op op = reinterpret_op; - const Type* src_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); - if (src_t->tag == PtrType_TAG && t->tag == PtrType_TAG) { - if ((t->payload.ptr_type.address_space == AsGeneric)) { - switch (src_t->payload.ptr_type.address_space) { - case AsPrivatePhysical: - case AsSubgroupPhysical: - case AsSharedPhysical: - case AsGlobalPhysical: - op = convert_op; - break; - default: { - warn_print("Cannot cast address space %s to Generic! Ignoring.\n", get_address_space_name(src_t->payload.ptr_type.address_space)); - r = quote_helper(a, singleton(src)); - goto shortcut; - } - } - } else if (!is_physical_as(t->payload.ptr_type.address_space)) { - warn_print("Cannot cast address space %s since it's non-physical. Ignoring.\n", get_address_space_name(src_t->payload.ptr_type.address_space)); - r = quote_helper(a, singleton(src)); - goto shortcut; - } - } else { - assert(opcode != LLVMAddrSpaceCast); - } - r = prim_op_helper(a, op, singleton(t), singleton(src)); - break; - } - case LLVMICmp: { - Op op; - bool cast_to_signed = false; - switch(LLVMGetICmpPredicate(instr)) { - case LLVMIntEQ: - op = eq_op; - break; - case LLVMIntNE: - op = neq_op; - break; - case LLVMIntUGT: - op = gt_op; - break; - case LLVMIntUGE: - op = gte_op; - break; - case LLVMIntULT: - op = lt_op; - break; - case LLVMIntULE: - op = lte_op; - break; - case LLVMIntSGT: - op = gt_op; - cast_to_signed = true; - break; - case LLVMIntSGE: - op = gte_op; - cast_to_signed = true; - break; - case LLVMIntSLT: - op = lt_op; - cast_to_signed = true; - break; - case LLVMIntSLE: - op = lte_op; - cast_to_signed = true; - break; - } - Nodes ops = convert_operands(p, num_ops, instr); - if (cast_to_signed) { - const Type* unsigned_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))); - assert(unsigned_t->tag == Int_TAG); - const Type* signed_t = change_int_t_sign(unsigned_t, true); - ops = reinterpret_operands(b, ops, signed_t); - } - r = prim_op_helper(a, op, empty(a), ops); - break; - } - case LLVMFCmp: { - Op op; - bool cast_to_signed = false; - switch(LLVMGetFCmpPredicate(instr)) { - case LLVMRealOEQ: - op = eq_op; - break; - case LLVMRealONE: - op = neq_op; - break; - case LLVMRealOGT: - op = gt_op; - break; - case LLVMRealOGE: - op = gte_op; - break; - case LLVMRealOLT: - op = lt_op; - break; - case LLVMRealOLE: - op = lte_op; - break; - default: goto unimplemented; - } - Nodes ops = convert_operands(p, num_ops, instr); - r = prim_op_helper(a, op, empty(a), ops); - break; - } - case LLVMPHI: - assert(false && "We deal with phi nodes before, there shouldn't be one here"); - break; - case LLVMCall: { - unsigned num_args = LLVMGetNumArgOperands(instr); - LLVMValueRef callee = LLVMGetCalledValue(instr); - callee = remove_ptr_bitcasts(p, callee); - assert(num_args + 1 == num_ops); - String intrinsic = is_llvm_intrinsic(callee); - if (!intrinsic) - intrinsic = is_shady_intrinsic(callee); - if (intrinsic) { - assert(LLVMIsAFunction(callee)); - if (strcmp(intrinsic, "llvm.dbg.declare") == 0) { - const Node* target = convert_value(p, LLVMGetOperand(instr, 0)); - assert(target->tag == Variable_TAG); - const Node* meta = convert_value(p, LLVMGetOperand(instr, 1)); - assert(meta->tag == RefDecl_TAG); - meta = meta->payload.ref_decl.decl; - assert(meta->tag == GlobalVariable_TAG); - meta = meta->payload.global_variable.init; - assert(meta && meta->tag == Composite_TAG); - const Node* name_node = meta->payload.composite.contents.nodes[2]; - String name = get_string_literal(target->arena, name_node); - assert(name); - set_variable_name((Node*) target, name); - return (EmittedInstr) { 0 }; - } - if (strcmp(intrinsic, "llvm.dbg.label") == 0) { - // TODO - return (EmittedInstr) { 0 }; - } - if (string_starts_with(intrinsic, "llvm.memcpy")) { - Nodes ops = convert_operands(p, num_ops, instr); - num_results = 0; - r = prim_op_helper(a, memcpy_op, empty(a), nodes(a, 3, ops.nodes)); - break; - } else if (string_starts_with(intrinsic, "llvm.memset")) { - Nodes ops = convert_operands(p, num_ops, instr); - num_results = 0; - r = prim_op_helper(a, memset_op, empty(a), nodes(a, 3, ops.nodes)); - break; - } else if (string_starts_with(intrinsic, "llvm.fmuladd")) { - Nodes ops = convert_operands(p, num_ops, instr); - num_results = 1; - r = prim_op_helper(a, mul_op, empty(a), nodes(a, 2, ops.nodes)); - r = prim_op_helper(a, add_op, empty(a), mk_nodes(a, first(BIND_PREV_R(convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))))), ops.nodes[2])); - break; - } else if (string_starts_with(intrinsic, "llvm.fabs")) { - Nodes ops = convert_operands(p, num_ops, instr); - num_results = 1; - r = prim_op_helper(a, abs_op, empty(a), nodes(a, 1, ops.nodes)); - break; - } else if (string_starts_with(intrinsic, "llvm.floor")) { - Nodes ops = convert_operands(p, num_ops, instr); - num_results = 1; - r = prim_op_helper(a, floor_op, empty(a), nodes(a, 1, ops.nodes)); - break; - } - - typedef struct { - bool is_byval; - } DecodedParamAttr; - - size_t params_count = 0; - for (LLVMValueRef oparam = LLVMGetFirstParam(callee); oparam && oparam <= LLVMGetLastParam(callee); oparam = LLVMGetNextParam(oparam)) { - params_count++; - } - LARRAY(DecodedParamAttr, decoded, params_count); - memset(decoded, 0, sizeof(DecodedParamAttr) * params_count); - size_t param_index = 0; - for (LLVMValueRef oparam = LLVMGetFirstParam(callee); oparam && oparam <= LLVMGetLastParam(callee); oparam = LLVMGetNextParam(oparam)) { - size_t num_attrs = LLVMGetAttributeCountAtIndex(callee, param_index + 1); - LARRAY(LLVMAttributeRef, attrs, num_attrs); - LLVMGetAttributesAtIndex(callee, param_index + 1, attrs); - bool is_byval = false; - for (size_t i = 0; i < num_attrs; i++) { - LLVMAttributeRef attr = attrs[i]; - size_t k = LLVMGetEnumAttributeKind(attr); - size_t e = LLVMGetEnumAttributeKindForName("byval", 5); - // printf("p = %zu, i = %zu, k = %zu, e = %zu\n", param_index, i, k, e); - if (k == e) - decoded[param_index].is_byval = true; - } - param_index++; - } - - String ostr = intrinsic; - char* str = calloc(strlen(ostr) + 1, 1); - memcpy(str, ostr, strlen(ostr) + 1); - - if (strcmp(strtok(str, "::"), "shady") == 0) { - char* keyword = strtok(NULL, "::"); - if (strcmp(keyword, "prim_op") == 0) { - char* opname = strtok(NULL, "::");Op op; - size_t i; - for (i = 0; i < PRIMOPS_COUNT; i++) { - if (strcmp(get_primop_name(i), opname) == 0) { - op = (Op) i; - break; - } - } - assert(i != PRIMOPS_COUNT); - Nodes ops = convert_operands(p, num_args, instr); - LARRAY(const Node*, processed_ops, ops.count); - for (i = 0; i < num_args; i++) { - if (decoded[i].is_byval) - processed_ops[i] = first(bind_instruction_outputs_count(b, prim_op_helper(a, load_op, empty(a), singleton(ops.nodes[i])), 1, NULL, false)); - else - processed_ops[i] = ops.nodes[i]; - } - r = prim_op_helper(a, op, empty(a), nodes(a, num_args, processed_ops)); - free(str); - break; - } else { - error_print("Unrecognised shady intrinsic '%s'\n", keyword); - error_die(); - } - } - - error_print("Unhandled intrinsic '%s'\n", intrinsic); - error_die(); - } - if (r) - break; - - Nodes ops = convert_operands(p, num_ops, instr); - r = call(a, (Call) { - .callee = ops.nodes[num_args], - .args = nodes(a, num_args, ops.nodes), - }); - if (t == unit_type(a)) - num_results = 0; - break; - } - case LLVMSelect: - r = prim_op_helper(a, select_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMUserOp1: - goto unimplemented; - case LLVMUserOp2: - goto unimplemented; - case LLVMVAArg: - goto unimplemented; - case LLVMExtractElement: - r = prim_op_helper(a, extract_dynamic_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMInsertElement: - r = prim_op_helper(a, insert_op, empty(a), convert_operands(p, num_ops, instr)); - break; - case LLVMShuffleVector: { - Nodes ops = convert_operands(p, num_ops, instr); - unsigned num_indices = LLVMGetNumMaskElements(instr); - LARRAY(const Node*, cindices, num_indices); - for (size_t i = 0; i < num_indices; i++) - cindices[i] = uint32_literal(a, LLVMGetMaskValue(instr, i)); - ops = concat_nodes(a, ops, nodes(a, num_indices, cindices)); - r = prim_op_helper(a, shuffle_op, empty(a), ops); - break; - } - case LLVMExtractValue: - goto unimplemented; - case LLVMInsertValue: - goto unimplemented; - case LLVMFreeze: - goto unimplemented; - case LLVMFence: - goto unimplemented; - case LLVMAtomicCmpXchg: - goto unimplemented; - case LLVMAtomicRMW: - goto unimplemented; - case LLVMResume: - goto unimplemented; - case LLVMLandingPad: - goto unimplemented; - case LLVMCleanupRet: - goto unimplemented; - case LLVMCatchRet: - goto unimplemented; - case LLVMCatchPad: - goto unimplemented; - case LLVMCleanupPad: - goto unimplemented; - case LLVMCatchSwitch: - goto unimplemented; - } - shortcut: - if (r) { - if (num_results == 1) - result_types = singleton(convert_type(p, LLVMTypeOf(instr))); - assert(result_types.count == num_results); - return (EmittedInstr) { - .instruction = r, - .result_types = result_types, - }; - } - - unimplemented: - error_print("Shady: unimplemented LLVM instruction "); - LLVMDumpValue(instr); - error_print(" (opcode=%d)\n", opcode); - error_die(); -} \ No newline at end of file diff --git a/src/frontends/llvm/l2s_postprocess.c b/src/frontends/llvm/l2s_postprocess.c deleted file mode 100644 index 241448b8e..000000000 --- a/src/frontends/llvm/l2s_postprocess.c +++ /dev/null @@ -1,246 +0,0 @@ -#include "l2s_private.h" - -#include "portability.h" -#include "dict.h" -#include "log.h" - -#include "../shady/rewrite.h" -#include "../shady/type.h" -#include "../shady/ir_private.h" -#include "../shady/analysis/scope.h" - -typedef struct { - Rewriter rewriter; - Parser* p; - Scope* curr_scope; - const Node* old_fn_or_bb; - struct Dict* controls; -} Context; - -typedef struct { - Nodes tokens, destinations; -} Controls; - -static void initialize_controls(Context* ctx, Controls* controls, const Node* fn_or_bb) { - IrArena* a = ctx->rewriter.dst_arena; - *controls = (Controls) { - .destinations = empty(a), - .tokens = empty(a), - }; - insert_dict(const Node*, Controls*, ctx->controls, fn_or_bb, controls); -} - -static const Node* wrap_in_controls(Context* ctx, Controls* controls, const Node* body) { - IrArena* a = ctx->rewriter.dst_arena; - if (!body) - return NULL; - for (size_t i = 0; i < controls->destinations.count; i++) { - const Node* token = controls->tokens.nodes[i]; - const Node* dst = controls->destinations.nodes[i]; - Nodes o_dst_params = get_abstraction_params(dst); - LARRAY(const Node*, new_control_params, o_dst_params.count); - for (size_t j = 0; j < o_dst_params.count; j++) - new_control_params[j] = var(a, o_dst_params.nodes[j]->payload.var.type, unique_name(a, "v")); - Nodes nparams = nodes(a, o_dst_params.count, new_control_params); - body = let(a, control(a, (Control) { - .yield_types = get_variables_types(a, o_dst_params), - .inside = case_(a, singleton(token), body) - }), case_(a, nparams, jump_helper(a, rewrite_node(&ctx->rewriter, dst), nparams))); - } - return body; -} - -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); - -bool lexical_scope_is_nested(Nodes scope, Nodes parentMaybe) { - if (scope.count <= parentMaybe.count) - return false; - for (size_t i = 0; i < parentMaybe.count; i++) { - if (scope.nodes[i] != parentMaybe.nodes[i]) - return false; - } - return true; -} - -bool compare_nodes(Nodes* a, Nodes* b); - -static const Node* process_op(Context* ctx, NodeClass op_class, String op_name, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - switch (node->tag) { - case Variable_TAG: return var(a, node->payload.var.type ? qualified_type_helper(rewrite_node(&ctx->rewriter, node->payload.var.type), false) : NULL, node->payload.var.name); - case Function_TAG: { - Context fn_ctx = *ctx; - fn_ctx.curr_scope = new_scope(node); - fn_ctx.old_fn_or_bb = node; - Controls controls; - initialize_controls(ctx, &controls, node); - Node* decl = (Node*) recreate_node_identity(&fn_ctx.rewriter, node); - Nodes annotations = decl->payload.fun.annotations; - ParsedAnnotation* an = find_annotation(ctx->p, node); - while (an) { - if (strcmp(get_annotation_name(an->payload), "PrimOpIntrinsic") == 0) { - assert(!decl->payload.fun.body); - Op op; - size_t i; - for (i = 0; i < PRIMOPS_COUNT; i++) { - if (strcmp(get_primop_name(i), get_annotation_string_payload(an->payload)) == 0) { - op = (Op) i; - break; - } - } - assert(i != PRIMOPS_COUNT); - decl->payload.fun.body = fn_ret(a, (Return) { - .args = singleton(prim_op_helper(a, op, empty(a), get_abstraction_params(decl))) - }); - } - annotations = append_nodes(a, annotations, an->payload); - an = an->next; - } - decl->payload.fun.annotations = annotations; - // decl->payload.fun.body = wrap_in_controls(ctx, &controls, decl->payload.fun.body); - destroy_scope(fn_ctx.curr_scope); - return decl; - } - case BasicBlock_TAG: { - Context bb_ctx = *ctx; - bb_ctx.old_fn_or_bb = node; - Controls controls; - initialize_controls(ctx, &controls, node); - Node* new_bb = (Node*) recreate_node_identity(&bb_ctx.rewriter, node); - // new_bb->payload.basic_block.body = wrap_in_controls(ctx, &controls, new_bb->payload.basic_block.body); - return new_bb; - } - case Jump_TAG: { - const Node* src = ctx->old_fn_or_bb; - const Node* dst = node->payload.jump.target; - assert(src && dst); - rewrite_node(&ctx->rewriter, dst); - - Nodes* src_lexical_scope = find_value_dict(const Node*, Nodes, ctx->p->scopes, src); - Nodes* dst_lexical_scope = find_value_dict(const Node*, Nodes, ctx->p->scopes, dst); - if (!src_lexical_scope) { - warn_print("Failed to find jump source node "); - log_node(WARN, src); - warn_print(" in lexical_scopes map. Is debug information enabled ?\n"); - } else if (!dst_lexical_scope) { - warn_print("Failed to find jump target node "); - log_node(WARN, dst); - warn_print(" in lexical_scopes map. Is debug information enabled ?\n"); - } else if (lexical_scope_is_nested(*src_lexical_scope, *dst_lexical_scope)) { - debug_print("Jump from %s to %s exits one or more nested lexical scopes, it might reconverge.\n", get_abstraction_name(src), get_abstraction_name(dst)); - - CFNode* src_cfnode = scope_lookup(ctx->curr_scope, src); - assert(src_cfnode->node); - CFNode* target_cfnode = scope_lookup(ctx->curr_scope, dst); - assert(src_cfnode && target_cfnode); - CFNode* dom = src_cfnode->idom; - while (dom) { - if (dom->node->tag == BasicBlock_TAG || dom->node->tag == Function_TAG) { - debug_print("Considering %s as a location for control\n", get_abstraction_name(dom->node)); - Nodes* dom_lexical_scope = find_value_dict(const Node*, Nodes, ctx->p->scopes, dom->node); - if (!dom_lexical_scope) { - warn_print("Basic block %s did not have an entry in the lexical_scopes map. Is debug information enabled ?\n", get_abstraction_name(dom->node)); - } else if (lexical_scope_is_nested(*dst_lexical_scope, *dom_lexical_scope)) { - error_print("We went up too far: %s is a parent of the jump destination scope.\n", get_abstraction_name(dom->node)); - } else if (compare_nodes(dom_lexical_scope, dst_lexical_scope)) { - debug_print("We need to introduce a control() block at %s, pointing at %s\n.", get_abstraction_name(dom->node), get_abstraction_name(dst)); - Controls** found = find_value_dict(const Node, Controls*, ctx->controls, dom->node); - assert(found); - if (found) { - Controls* controls = *found; - const Node* join_token = NULL; - for (size_t i = 0; i < controls->destinations.count; i++) { - if (controls->destinations.nodes[i] == dst) { - join_token = controls->tokens.nodes[i]; - break; - } - } - if (!join_token) { - const Type* jp_type = join_point_type(a, (JoinPointType) { - .yield_types = get_variables_types(a, get_abstraction_params(dst)) - }); - join_token = var(a, jp_type, get_abstraction_name(dst)); - controls->tokens = append_nodes(a, controls->tokens, join_token); - controls->destinations = append_nodes(a, controls->destinations, dst); - } - Nodes nargs = recreate_variables(&ctx->rewriter, get_abstraction_params(dst)); - - Node* fn = src; - if (fn->tag == BasicBlock_TAG) - fn = (Node*) fn->payload.basic_block.fn; - assert(fn->tag == Function_TAG); - Node* wrapper = basic_block(a, fn, nargs, format_string_arena(a->arena, "wrapper_to_%s", get_abstraction_name(dst))); - wrapper->payload.basic_block.body = join(a, (Join) { - .args = nargs, - .join_point = join_token - }); - return jump_helper(a, wrapper, rewrite_nodes(&ctx->rewriter, node->payload.jump.args)); - } else { - assert(false); - } - } else { - dom = dom->idom; - continue; - } - break; - } - dom = dom->idom; - } - } - break; - } - case GlobalVariable_TAG: { - AddressSpace as = node->payload.global_variable.address_space; - const Node* old_init = node->payload.global_variable.init; - Nodes annotations = rewrite_nodes(&ctx->rewriter, node->payload.global_variable.annotations); - const Type* type = rewrite_node(&ctx->rewriter, node->payload.global_variable.type); - ParsedAnnotation* an = find_annotation(ctx->p, node); - while (an) { - annotations = append_nodes(a, annotations, an->payload); - if (strcmp(get_annotation_name(an->payload), "Builtin") == 0) - old_init = NULL; - if (strcmp(get_annotation_name(an->payload), "UniformConstant") == 0) - as = AsUniformConstant; - an = an->next; - } - Node* decl = global_var(ctx->rewriter.dst_module, annotations, type, get_decl_name(node), as); - register_processed(&ctx->rewriter, node, decl); - if (old_init) - decl->payload.global_variable.init = rewrite_node(&ctx->rewriter, old_init); - return decl; - } - default: break; - } - - if (op_class == NcTerminator && node->tag != Let_TAG) { - Controls** found = find_value_dict(const Node, Controls*, ctx->controls, ctx->old_fn_or_bb); - assert(found); - Controls* controls = *found; - return wrap_in_controls(ctx, controls, recreate_node_identity(&ctx->rewriter, node)); - } - - return recreate_node_identity(&ctx->rewriter, node); -} - -static const Node* process_node(Context* ctx, const Node* old) { - return process_op(ctx, 0, NULL, old); -} - -void postprocess(Parser* p, Module* src, Module* dst) { - assert(src != dst); - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node), - .p = p, - .controls = new_dict(const Node*, Controls*, (HashFn) hash_node, (CmpFn) compare_node), - }; - - ctx.rewriter.rewrite_op_fn = (RewriteOpFn) process_op; - ctx.rewriter.config.process_variables = true; - // ctx.rewriter.config.search_map = false; - // ctx.rewriter.config.write_map = false; - - rewrite_module(&ctx.rewriter); - destroy_dict(ctx.controls); - destroy_rewriter(&ctx.rewriter); -} diff --git a/src/frontends/llvm/l2s_private.h b/src/frontends/llvm/l2s_private.h deleted file mode 100644 index b2e7b2283..000000000 --- a/src/frontends/llvm/l2s_private.h +++ /dev/null @@ -1,77 +0,0 @@ -#ifndef SHADY_L2S_PRIVATE_H -#define SHADY_L2S_PRIVATE_H - -#include "l2s.h" -#include "arena.h" -#include "util.h" - -#include "llvm-c/Core.h" - -#include -#include - -typedef struct { - LLVMContextRef ctx; - struct Dict* map; - struct Dict* annotations; - struct Dict* scopes; - Arena* annotations_arena; - LLVMModuleRef src; - Module* dst; -} Parser; - -#ifndef LLVM_VERSION_MAJOR -#error "Missing LLVM_VERSION_MAJOR" -#else -#define UNTYPED_POINTERS (LLVM_VERSION_MAJOR >= 15) -#endif - -typedef struct ParsedAnnotationContents_ { - const Node* payload; - struct ParsedAnnotationContents_* next; -} ParsedAnnotation; - -ParsedAnnotation* find_annotation(Parser*, const Node*); -ParsedAnnotation* next_annotation(ParsedAnnotation*); -void add_annotation(Parser*, const Node*, ParsedAnnotation); - -void process_llvm_annotations(Parser* p, LLVMValueRef global); - -AddressSpace convert_llvm_address_space(unsigned); -const Node* convert_value(Parser* p, LLVMValueRef v); -const Node* convert_function(Parser* p, LLVMValueRef fn); -const Type* convert_type(Parser* p, LLVMTypeRef t); -const Node* convert_metadata(Parser* p, LLVMMetadataRef meta); -const Node* convert_global(Parser* p, LLVMValueRef global); -const Node* convert_function(Parser* p, LLVMValueRef fn); -const Node* convert_basic_block(Parser* p, Node* fn, LLVMBasicBlockRef bb); - -typedef struct { - const Node* terminator; - const Node* instruction; - Nodes result_types; -} EmittedInstr; - -EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVMValueRef instr); - -Nodes scope_to_string(Parser* p, LLVMMetadataRef dbgloc); - -void postprocess(Parser*, Module* src, Module* dst); - -static String is_llvm_intrinsic(LLVMValueRef fn) { - assert(LLVMIsAFunction(fn) || LLVMIsConstant(fn)); - String name = LLVMGetValueName(fn); - if (string_starts_with(name, "llvm.")) - return name; - return NULL; -} - -static String is_shady_intrinsic(LLVMValueRef fn) { - assert(LLVMIsAFunction(fn) || LLVMIsConstant(fn)); - String name = LLVMGetValueName(fn); - if (string_starts_with(name, "shady::")) - return name; - return NULL; -} - -#endif diff --git a/src/frontends/slim/CMakeLists.txt b/src/frontends/slim/CMakeLists.txt deleted file mode 100644 index a9865e843..000000000 --- a/src/frontends/slim/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_library(slim_parser parser.c token.c) -target_link_libraries(slim_parser common api) -target_include_directories(slim_parser INTERFACE "$") -set_property(TARGET slim_parser PROPERTY POSITION_INDEPENDENT_CODE ON) -target_link_libraries(shady PUBLIC "$") diff --git a/src/frontends/slim/parser.c b/src/frontends/slim/parser.c deleted file mode 100644 index 87be4c718..000000000 --- a/src/frontends/slim/parser.c +++ /dev/null @@ -1,1068 +0,0 @@ -#include "token.h" -#include "parser.h" - -#include "list.h" -#include "portability.h" -#include "log.h" -#include "util.h" - -#include "type.h" -#include "ir_private.h" - -#include -#include -#include -#include - -static int max_precedence() { - return 10; -} - -static int get_precedence(InfixOperators op) { - switch (op) { -#define INFIX_OPERATOR(name, token, primop_op, precedence) case Infix##name: return precedence; -INFIX_OPERATORS() -#undef INFIX_OPERATOR - default: error("unknown operator"); - } -} -static bool is_primop_op(InfixOperators op, Op* out) { - switch (op) { -#define INFIX_OPERATOR(name, token, primop_op, precedence) case Infix##name: if (primop_op != -1) { *out = primop_op; return true; } else return false; -INFIX_OPERATORS() -#undef INFIX_OPERATOR - default: error("unknown operator"); - } -} - -static bool is_infix_operator(TokenTag token_tag, InfixOperators* out) { - switch (token_tag) { -#define INFIX_OPERATOR(name, token, primop_op, precedence) case token: { *out = Infix##name; return true; } -INFIX_OPERATORS() -#undef INFIX_OPERATOR - default: return false; - } -} - -// to avoid some repetition -#define ctxparams SHADY_UNUSED ParserConfig config, SHADY_UNUSED const char* contents, SHADY_UNUSED Module* mod, SHADY_UNUSED IrArena* arena, SHADY_UNUSED Tokenizer* tokenizer -#define ctx config, contents, mod, arena, tokenizer - -#define expect(condition) expect_impl(condition, #condition) -static void expect_impl(bool condition, const char* err) { - if (!condition) { - error_print("expected to parse: %s\n", err); - exit(-4); - } -} - -static bool accept_token(ctxparams, TokenTag tag) { - if (curr_token(tokenizer).tag == tag) { - next_token(tokenizer); - return true; - } - return false; -} - -static const char* accept_identifier(ctxparams) { - Token tok = curr_token(tokenizer); - if (tok.tag == identifier_tok) { - next_token(tokenizer); - size_t size = tok.end - tok.start; - return string_sized(arena, (int) size, &contents[tok.start]); - } - return NULL; -} - -static const Node* expect_body(ctxparams, Node* fn, const Node* default_terminator); -static const Node* accept_value(ctxparams); -static const Type* accept_unqualified_type(ctxparams); -static const Node* accept_expr(ctxparams, int); -static Nodes expect_operands(ctxparams); - -static const Type* accept_numerical_type(ctxparams) { - if (accept_token(ctx, i8_tok)) { - return int8_type(arena); - } else if (accept_token(ctx, i16_tok)) { - return int16_type(arena); - } else if (accept_token(ctx, i32_tok)) { - return int32_type(arena); - } else if (accept_token(ctx, i64_tok)) { - return int64_type(arena); - } else if (accept_token(ctx, u8_tok)) { - return uint8_type(arena); - } else if (accept_token(ctx, u16_tok)) { - return uint16_type(arena); - } else if (accept_token(ctx, u32_tok)) { - return uint32_type(arena); - } else if (accept_token(ctx, u64_tok)) { - return uint64_type(arena); - } else if (accept_token(ctx, f16_tok)) { - return fp16_type(arena); - } else if (accept_token(ctx, f32_tok)) { - return fp32_type(arena); - } else if (accept_token(ctx, f64_tok)) { - return fp64_type(arena); - } - return NULL; -} - -static const Node* accept_numerical_literal(ctxparams) { - const Type* num_type = accept_numerical_type(ctx); - - bool negate = accept_token(ctx, minus_tok); - - Token tok = curr_token(tokenizer); - size_t size = tok.end - tok.start; - String str = string_sized(arena, (int) size, &contents[tok.start]); - - switch (tok.tag) { - case hex_lit_tok: - if (negate) - error("hexadecimal literals can't start with '-'"); - case dec_lit_tok: { - next_token(tokenizer); - break; - } - default: { - if (negate || num_type) - error("expected numerical literal"); - return NULL; - } - } - - if (negate) // add back the - in front - str = format_string_arena(arena->arena, "-%s", str); - - const Node* n = untyped_number(arena, (UntypedNumber) { - .plaintext = str - }); - - if (num_type) - n = constrained(arena, (ConstrainedValue) { - .type = num_type, - .value = n - }); - - return n; -} - -static const Node* accept_value(ctxparams) { - Token tok = curr_token(tokenizer); - size_t size = tok.end - tok.start; - - const Node* number = accept_numerical_literal(ctx); - if (number) - return number; - - switch (tok.tag) { - case identifier_tok: { - const char* id = string_sized(arena, (int) size, &contents[tok.start]); - next_token(tokenizer); - return unbound(arena, (Unbound) { .name = id }); - } - case hex_lit_tok: - case dec_lit_tok: { - next_token(tokenizer); - return untyped_number(arena, (UntypedNumber) { - .plaintext = string_sized(arena, (int) size, &contents[tok.start]) - }); - } - case string_lit_tok: { - next_token(tokenizer); - char* unescaped = calloc(size + 1, 1); - size_t j = apply_escape_codes(&contents[tok.start], size, unescaped); - const Node* lit = string_lit(arena, (StringLiteral) {.string = string_sized(arena, (int) j, unescaped) }); - free(unescaped); - return lit; - } - case true_tok: next_token(tokenizer); return true_lit(arena); - case false_tok: next_token(tokenizer); return false_lit(arena); - case lpar_tok: { - next_token(tokenizer); - if (accept_token(ctx, rpar_tok)) { - return quote_helper(arena, empty(arena)); - } - const Node* atom = config.front_end ? accept_expr(ctx, max_precedence()) : accept_value(ctx); - expect(atom); - if (curr_token(tokenizer).tag == rpar_tok) { - next_token(tokenizer); - } else { - struct List* elements = new_list(const Node*); - append_list(const Node*, elements, atom); - - while (!accept_token(ctx, rpar_tok)) { - expect(accept_token(ctx, comma_tok)); - const Node* element = config.front_end ? accept_expr(ctx, max_precedence()) : accept_value(ctx); - expect(elements); - append_list(const Node*, elements, element); - } - - Nodes tcontents = nodes(arena, entries_count_list(elements), read_list(const Node*, elements)); - destroy_list(elements); - atom = tuple_helper(arena, tcontents); - } - return atom; - } - case composite_tok: { - next_token(tokenizer); - const Type* elem_type = accept_unqualified_type(ctx); - expect(elem_type); - Nodes elems = expect_operands(ctx); - return composite_helper(arena, elem_type, elems); - } - default: return NULL; - } -} - -static AddressSpace expect_ptr_address_space(ctxparams) { - switch (curr_token(tokenizer).tag) { - case global_tok: next_token(tokenizer); return AsGlobalPhysical; - case private_tok: next_token(tokenizer); return AsPrivatePhysical; - case shared_tok: next_token(tokenizer); return AsSharedPhysical; - case subgroup_tok: next_token(tokenizer); return AsSubgroupPhysical; - case generic_tok: next_token(tokenizer); return AsGeneric; - default: error("expected address space qualifier"); - } - SHADY_UNREACHABLE; -} - -static const Type* accept_unqualified_type(ctxparams) { - const Type* prim_type = accept_numerical_type(ctx); - if (prim_type) return prim_type; - else if (accept_token(ctx, bool_tok)) { - return bool_type(arena); - } else if (accept_token(ctx, mask_t_tok)) { - return mask_type(arena); - } else if (accept_token(ctx, ptr_tok)) { - AddressSpace as = expect_ptr_address_space(ctx); - const Type* elem_type = accept_unqualified_type(ctx); - expect(elem_type); - return ptr_type(arena, (PtrType) { - .address_space = as, - .pointed_type = elem_type, - }); - } else if (config.front_end && accept_token(ctx, lsbracket_tok)) { - const Type* elem_type = accept_unqualified_type(ctx); - expect(elem_type); - const Node* size = NULL; - if(accept_token(ctx, semi_tok)) { - size = accept_value(ctx); - expect(size); - } - expect(accept_token(ctx, rsbracket_tok)); - return arr_type(arena, (ArrType) { - .element_type = elem_type, - .size = size - }); - } else if (accept_token(ctx, pack_tok)) { - expect(accept_token(ctx, lsbracket_tok)); - const Type* elem_type = accept_unqualified_type(ctx); - expect(elem_type); - const Node* size = NULL; - expect(accept_token(ctx, semi_tok)); - size = accept_numerical_literal(ctx); - expect(size && size->tag == UntypedNumber_TAG); - expect(accept_token(ctx, rsbracket_tok)); - return pack_type(arena, (PackType) { - .element_type = elem_type, - .width = strtoll(size->payload.untyped_number.plaintext, NULL, 10) - }); - } else if (accept_token(ctx, struct_tok)) { - expect(accept_token(ctx, lbracket_tok)); - struct List* names = new_list(String); - struct List* types = new_list(const Type*); - while (true) { - if (accept_token(ctx, rbracket_tok)) - break; - const Type* elem = accept_unqualified_type(ctx); - expect(elem); - String id = accept_identifier(ctx); - expect(id); - append_list(String, names, id); - append_list(const Type*, types, elem); - expect(accept_token(ctx, semi_tok)); - } - Nodes elem_types = nodes(arena, entries_count_list(types), read_list(const Type*, types)); - Strings names2 = strings(arena, entries_count_list(names), read_list(String, names)); - destroy_list(names); - destroy_list(types); - return record_type(arena, (RecordType) { - .names = names2, - .members = elem_types, - .special = NotSpecial, - }); - } else { - String id = accept_identifier(ctx); - if (id) - return unbound(arena, (Unbound) { .name = id }); - - return NULL; - } -} - -static DivergenceQualifier accept_uniformity_qualifier(ctxparams) { - DivergenceQualifier divergence = Unknown; - if (accept_token(ctx, uniform_tok)) - divergence = Uniform; - else if (accept_token(ctx, varying_tok)) - divergence = Varying; - return divergence; -} - -static const Type* accept_maybe_qualified_type(ctxparams) { - DivergenceQualifier qualifier = accept_uniformity_qualifier(ctx); - const Type* unqualified = accept_unqualified_type(ctx); - if (qualifier != Unknown) - expect(unqualified && "we read a uniformity qualifier and expected a type to follow"); - if (qualifier == Unknown) - return unqualified; - else - return qualified_type(arena, (QualifiedType) { .is_uniform = qualifier == Uniform, .type = unqualified }); -} - -static const Type* accept_qualified_type(ctxparams) { - DivergenceQualifier qualifier = accept_uniformity_qualifier(ctx); - if (qualifier == Unknown) - return NULL; - const Type* unqualified = accept_unqualified_type(ctx); - expect(unqualified); - return qualified_type(arena, (QualifiedType) { .is_uniform = qualifier == Uniform, .type = unqualified }); -} - -static const Node* accept_operand(ctxparams) { - return config.front_end ? accept_expr(ctx, max_precedence()) : accept_value(ctx); -} - -static void expect_parameters(ctxparams, Nodes* parameters, Nodes* default_values) { - expect(accept_token(ctx, lpar_tok)); - struct List* params = new_list(Node*); - struct List* default_vals = default_values ? new_list(Node*) : NULL; - - while (true) { - if (accept_token(ctx, rpar_tok)) - break; - - next: { - const Type* qtype = accept_qualified_type(ctx); - expect(qtype); - const char* id = accept_identifier(ctx); - expect(id); - - const Node* node = var(arena, qtype, id); - append_list(Node*, params, node); - - if (default_values) { - expect(accept_token(ctx, equal_tok)); - const Node* default_val = accept_operand(ctx); - append_list(const Node*, default_vals, default_val); - } - - if (accept_token(ctx, comma_tok)) - goto next; - } - } - - size_t count = entries_count_list(params); - *parameters = nodes(arena, count, read_list(const Node*, params)); - destroy_list(params); - if (default_values) { - *default_values = nodes(arena, count, read_list(const Node*, default_vals)); - destroy_list(default_vals); - } -} - -typedef enum { MustQualified, MaybeQualified, NeverQualified } Qualified; - -static Nodes accept_types(ctxparams, TokenTag separator, Qualified qualified) { - struct List* tmp = new_list(Type*); - while (true) { - const Type* type; - switch (qualified) { - case MustQualified: type = accept_qualified_type(ctx); break; - case MaybeQualified: type = accept_maybe_qualified_type(ctx); break; - case NeverQualified: type = accept_unqualified_type(ctx); break; - } - if (!type) - break; - - append_list(Type*, tmp, type); - - if (separator != 0) - accept_token(ctx, separator); - } - - Nodes types2 = nodes(arena, tmp->elements_count, (const Type**) tmp->alloc); - destroy_list(tmp); - return types2; -} - -static const Node* accept_primary_expr(ctxparams) { - if (accept_token(ctx, minus_tok)) { - const Node* expr = accept_primary_expr(ctx); - expect(expr); - if (expr->tag == IntLiteral_TAG) { - return int_literal(arena, (IntLiteral) { - // We always treat that value like an signed integer, because it makes no sense to negate an unsigned number ! - .value = -get_int_literal_value(*resolve_to_int_literal(expr), true) - }); - } else { - return prim_op(arena, (PrimOp) { - .op = neg_op, - .operands = nodes(arena, 1, (const Node* []) {expr}) - }); - } - } else if (accept_token(ctx, unary_excl_tok)) { - const Node* expr = accept_primary_expr(ctx); - expect(expr); - return prim_op(arena, (PrimOp) { - .op = not_op, - .operands = singleton(expr), - }); - } else if (accept_token(ctx, star_tok)) { - const Node* expr = accept_primary_expr(ctx); - expect(expr); - return prim_op(arena, (PrimOp) { - .op = deref_op, - .operands = singleton(expr), - }); - } else if (accept_token(ctx, infix_and_tok)) { - const Node* expr = accept_primary_expr(ctx); - expect(expr); - return prim_op(arena, (PrimOp) { - .op = addrof_op, - .operands = singleton(expr), - }); - } - - const Node* expr = accept_value(ctx); - while (expr) { - Nodes ty_args = nodes(arena, 0, NULL); - bool parse_ty_args = false; - if (accept_token(ctx, lsbracket_tok)) { - parse_ty_args = true; - while (true) { - const Type* t = accept_unqualified_type(ctx); - expect(t); - ty_args = append_nodes(arena, ty_args, t); - if (accept_token(ctx, comma_tok)) - continue; - if (accept_token(ctx, rsbracket_tok)) - break; - } - } - switch (curr_token(tokenizer).tag) { - case lpar_tok: { - Op op = PRIMOPS_COUNT; - if (expr->tag == Unbound_TAG) { - String s = expr->payload.unbound.name; - for (size_t i = 0; i < PRIMOPS_COUNT; i++) { - if (strcmp(s, get_primop_name(i)) == 0) { - op = i; - break; - } - } - } - if (op != PRIMOPS_COUNT) { - return prim_op(arena, (PrimOp) { - .op = op, - .type_arguments = ty_args, - .operands = expect_operands(ctx) - }); - } - - assert(ty_args.count == 0 && "Function calls do not support type arguments"); - Nodes args = expect_operands(ctx); - expr = call(arena, (Call) { - .callee = expr, - .args = args - }); - continue; - } - default: - if (parse_ty_args) - expect(false && "expected function call arguments"); - break; - } - break; - } - return expr; -} - -static const Node* accept_expr(ctxparams, int outer_precedence) { - const Node* expr = accept_primary_expr(ctx); - while (expr) { - InfixOperators infix; - if (is_infix_operator(curr_token(tokenizer).tag, &infix)) { - int precedence = get_precedence(infix); - if (precedence > outer_precedence) break; - next_token(tokenizer); - - const Node* rhs = accept_expr(ctx, precedence - 1); - expect(rhs); - Op primop_op; - if (is_primop_op(infix, &primop_op)) { - expr = prim_op(arena, (PrimOp) { - .op = primop_op, - .operands = nodes(arena, 2, (const Node* []) {expr, rhs}) - }); - } else switch (infix) { - default: error("unknown infix operator") - } - continue; - } - break; - } - return expr; -} - -static Nodes expect_operands(ctxparams) { - if (!accept_token(ctx, lpar_tok)) - error("Expected left parenthesis") - - struct List* list = new_list(Node*); - - bool expect = false; - while (true) { - const Node* val = accept_operand(ctx); - if (!val) { - if (expect) - error("expected value but got none") - else if (accept_token(ctx, rpar_tok)) - break; - else - error("Expected value or closing parenthesis") - } - - append_list(Node*, list, val); - - if (accept_token(ctx, comma_tok)) - expect = true; - else if (accept_token(ctx, rpar_tok)) - break; - else - error("Expected comma or closing parenthesis") - } - - Nodes final = nodes(arena, list->elements_count, (const Node**) list->alloc); - destroy_list(list); - return final; -} - -static const Node* accept_control_flow_instruction(ctxparams, Node* fn) { - Token current_token = curr_token(tokenizer); - switch (current_token.tag) { - case if_tok: { - next_token(tokenizer); - Nodes yield_types = accept_types(ctx, 0, NeverQualified); - expect(accept_token(ctx, lpar_tok)); - const Node* condition = accept_operand(ctx); - expect(condition); - expect(accept_token(ctx, rpar_tok)); - const Node* merge = config.front_end ? yield(arena, (Yield) { .args = nodes(arena, 0, NULL) }) : NULL; - - const Node* if_true = case_(arena, nodes(arena, 0, NULL), expect_body(ctx, fn, merge)); - - // else defaults to an empty body - bool has_else = accept_token(ctx, else_tok); - const Node* if_false = NULL; - if (has_else) { - if_false = case_(arena, nodes(arena, 0, NULL), expect_body(ctx, fn, merge)); - } - return if_instr(arena, (If) { - .yield_types = yield_types, - .condition = condition, - .if_true = if_true, - .if_false = if_false - }); - } - case loop_tok: { - next_token(tokenizer); - Nodes yield_types = accept_types(ctx, 0, NeverQualified); - Nodes parameters; - Nodes default_values; - expect_parameters(ctx, ¶meters, &default_values); - // by default loops continue forever - const Node* default_loop_end_behaviour = config.front_end ? merge_continue(arena, (MergeContinue) { .args = nodes(arena, 0, NULL) }) : NULL; - const Node* body = case_(arena, parameters, expect_body(ctx, fn, default_loop_end_behaviour)); - - return loop_instr(arena, (Loop) { - .initial_args = default_values, - .yield_types = yield_types, - .body = body - }); - } - case control_tok: { - next_token(tokenizer); - Nodes yield_types = accept_types(ctx, 0, NeverQualified); - expect(accept_token(ctx, lpar_tok)); - String str = accept_identifier(ctx); - expect(str); - const Node* param = var(arena, join_point_type(arena, (JoinPointType) { - .yield_types = yield_types, - }), str); - expect(accept_token(ctx, rpar_tok)); - const Node* body = case_(arena, singleton(param), expect_body(ctx, fn, NULL)); - return control(arena, (Control) { - .inside = body, - .yield_types = yield_types - }); - } - default: break; - } - return NULL; -} - -static const Node* accept_instruction(ctxparams, Node* fn, bool in_list) { - const Node* instr = accept_expr(ctx, max_precedence()); - - if (in_list && instr) - expect(accept_token(ctx, semi_tok) && "Non-control flow instructions must be followed by a semicolon"); - - if (!instr) instr = accept_control_flow_instruction(ctx, fn); - return instr; -} - -static void expect_identifiers(ctxparams, Strings* out_strings) { - struct List* list = new_list(const char*); - while (true) { - const char* id = accept_identifier(ctx); - expect(id); - - append_list(const char*, list, id); - - if (accept_token(ctx, comma_tok)) - continue; - else - break; - } - - *out_strings = strings(arena, list->elements_count, (const char**) list->alloc); - destroy_list(list); -} - -static void expect_types_and_identifiers(ctxparams, Strings* out_strings, Nodes* out_types) { - struct List* slist = new_list(const char*); - struct List* tlist = new_list(const char*); - - while (true) { - const Type* type = accept_unqualified_type(ctx); - expect(type); - const char* id = accept_identifier(ctx); - expect(id); - - append_list(const char*, tlist, type); - append_list(const char*, slist, id); - - if (accept_token(ctx, comma_tok)) - continue; - else - break; - } - - *out_strings = strings(arena, slist->elements_count, (const char**) slist->alloc); - *out_types = nodes(arena, tlist->elements_count, (const Node**) tlist->alloc); - destroy_list(slist); - destroy_list(tlist); -} - -static bool accept_non_terminator_instr(ctxparams, BodyBuilder* bb, Node* fn) { - Strings ids; - if (accept_token(ctx, val_tok)) { - expect_identifiers(ctx, &ids); - expect(accept_token(ctx, equal_tok)); - const Node* instruction = accept_instruction(ctx, fn, true); - bind_instruction_outputs_count(bb, instruction, ids.count, ids.strings, false); - } else if (accept_token(ctx, var_tok)) { - Nodes types; - expect_types_and_identifiers(ctx, &ids, &types); - expect(accept_token(ctx, equal_tok)); - const Node* instruction = accept_instruction(ctx, fn, true); - bind_instruction_explicit_result_types(bb, instruction, types, ids.strings, true); - } else { - const Node* instr = accept_instruction(ctx, fn, true); - if (!instr) return false; - bind_instruction_outputs_count(bb, instr, 0, NULL, false); - } - return true; -} - -static const Node* accept_case(ctxparams, Node* fn) { - if (!accept_token(ctx, lambda_tok)) - return NULL; - - Nodes params; - expect_parameters(ctx, ¶ms, NULL); - const Node* body = expect_body(ctx, fn, NULL); - return case_(arena, params, body); -} - -static const Node* expect_jump(ctxparams) { - String target = accept_identifier(ctx); - expect(target); - Nodes args = curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx) : nodes(arena, 0, NULL); - return jump(arena, (Jump) { - .target = unbound(arena, (Unbound) { .name = target }), - .args = args - }); -} - -static const Node* accept_terminator(ctxparams, Node* fn) { - TokenTag tag = curr_token(tokenizer).tag; - switch (tag) { - case let_tok: { - next_token(tokenizer); - const Node* instruction = accept_instruction(ctx, fn, false); - expect(instruction); - expect(accept_token(ctx, in_tok)); - switch (tag) { - case let_tok: { - const Node* lam = accept_case(ctx, fn); - expect(lam); - return let(arena, instruction, lam); - } - default: SHADY_UNREACHABLE; - } - } - case jump_tok: { - next_token(tokenizer); - return expect_jump(ctx); - } - case branch_tok: { - next_token(tokenizer); - - expect(accept_token(ctx, lpar_tok)); - const Node* condition = accept_value(ctx); - expect(condition); - expect(accept_token(ctx, comma_tok)); - const Node* true_target = expect_jump(ctx); - expect(accept_token(ctx, comma_tok)); - const Node* false_target = expect_jump(ctx); - expect(accept_token(ctx, rpar_tok)); - - Nodes args = curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx) : nodes(arena, 0, NULL); - return branch(arena, (Branch) { - .branch_condition = condition, - .true_jump = true_target, - .false_jump = false_target, - }); - } - case switch_tok: { - next_token(tokenizer); - - expect(accept_token(ctx, lpar_tok)); - const Node* inspectee = accept_value(ctx); - expect(inspectee); - expect(accept_token(ctx, comma_tok)); - Nodes values = empty(arena); - Nodes cases = empty(arena); - const Node* default_jump; - while (true) { - if (accept_token(ctx, default_tok)) { - default_jump = expect_jump(ctx); - break; - } - expect(accept_token(ctx, case_tok)); - const Node* value = accept_value(ctx); - expect(value); - expect(accept_token(ctx, comma_tok) && 1); - const Node* j = expect_jump(ctx); - expect(accept_token(ctx, comma_tok) && true); - values = append_nodes(arena, values, value); - cases = append_nodes(arena, cases, j); - } - expect(accept_token(ctx, rpar_tok)); - - return br_switch(arena, (Switch) { - .switch_value = first(values), - .case_values = values, - .case_jumps = cases, - .default_jump = default_jump, - }); - } - case return_tok: { - next_token(tokenizer); - Nodes args = expect_operands(ctx); - return fn_ret(arena, (Return) { - .fn = NULL, - .args = args - }); - } - case yield_tok: { - next_token(tokenizer); - Nodes args = curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx) : nodes(arena, 0, NULL); - return yield(arena, (Yield) { - .args = args - }); - } - case continue_tok: { - next_token(tokenizer); - Nodes args = curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx) : nodes(arena, 0, NULL); - return merge_continue(arena, (MergeContinue) { - .args = args - }); - } - case break_tok: { - next_token(tokenizer); - Nodes args = curr_token(tokenizer).tag == lpar_tok ? expect_operands(ctx) : nodes(arena, 0, NULL); - return merge_break(arena, (MergeBreak) { - .args = args - }); - } - case join_tok: { - next_token(tokenizer); - expect(accept_token(ctx, lpar_tok)); - const Node* jp = accept_operand(ctx); - expect(accept_token(ctx, rpar_tok)); - Nodes args = expect_operands(ctx); - return join(arena, (Join) { - .join_point = jp, - .args = args - }); - } - case unreachable_tok: { - next_token(tokenizer); - expect(accept_token(ctx, lpar_tok)); - expect(accept_token(ctx, rpar_tok)); - return unreachable(arena); - } - default: break; - } - return NULL; -} - -static const Node* expect_body(ctxparams, Node* fn, const Node* default_terminator) { - assert(fn->tag == Function_TAG); - expect(accept_token(ctx, lbracket_tok)); - BodyBuilder* bb = begin_body(arena); - - while (true) { - if (!accept_non_terminator_instr(ctx, bb, fn)) - break; - } - - const Node* terminator = accept_terminator(ctx, fn); - - if (terminator) - expect(accept_token(ctx, semi_tok)); - - if (!terminator) { - if (default_terminator) - terminator = default_terminator; - else - error("expected terminator: return, jump, branch ..."); - } - - if (curr_token(tokenizer).tag == cont_tok) { - struct List* conts = new_list(Node*); - while (true) { - if (!accept_token(ctx, cont_tok)) - break; - assert(fn); - const char* name = accept_identifier(ctx); - - Nodes parameters; - expect_parameters(ctx, ¶meters, NULL); - Node* continuation = basic_block(arena, fn, parameters, name); - continuation->payload.basic_block.body = expect_body(ctx, fn, NULL); - append_list(Node*, conts, continuation); - } - - terminator = unbound_bbs(arena, (UnboundBBs) { .body = terminator, .children_blocks = nodes(arena, entries_count_list(conts), read_list(const Node*, conts)) }); - destroy_list(conts); - } - - expect(accept_token(ctx, rbracket_tok)); - - return finish_body(bb, terminator); -} - -static Nodes accept_annotations(ctxparams) { - struct List* list = new_list(const Node*); - - while (true) { - if (accept_token(ctx, at_tok)) { - const char* id = accept_identifier(ctx); - const Node* annot = NULL; - if (accept_token(ctx, lpar_tok)) { - const Node* first_value = accept_value(ctx); - if (!first_value) { - expect(accept_token(ctx, rpar_tok)); - goto no_params; - } - - // this is a map - if (first_value->tag == Unbound_TAG && accept_token(ctx, equal_tok)) { - error("TODO: parse map") - } else if (curr_token(tokenizer).tag == comma_tok) { - next_token(tokenizer); - struct List* values = new_list(const Node*); - append_list(const Node*, values, first_value); - while (true) { - const Node* next_value = accept_value(ctx); - expect(next_value); - append_list(const Node*, values, next_value); - if (accept_token(ctx, comma_tok)) - continue; - else break; - } - annot = annotation_values(arena, (AnnotationValues) { - .name = id, - .values = nodes(arena, entries_count_list(values), read_list(const Node*, values)) - }); - destroy_list(values); - } else { - annot = annotation_value(arena, (AnnotationValue) { - .name = id, - .value = first_value - }); - } - - expect(accept_token(ctx, rpar_tok)); - } else { - no_params: - annot = annotation(arena, (Annotation) { - .name = id, - }); - } - expect(annot); - append_list(const Node*, list, annot); - continue; - } - break; - } - - Nodes annotations = nodes(arena, entries_count_list(list), read_list(const Node*, list)); - destroy_list(list); - return annotations; -} - -static const Node* accept_const(ctxparams, Nodes annotations) { - if (!accept_token(ctx, const_tok)) - return NULL; - - const Type* type = accept_unqualified_type(ctx); - const char* id = accept_identifier(ctx); - expect(id); - expect(accept_token(ctx, equal_tok)); - const Node* definition = accept_instruction(ctx, NULL, false); - expect(definition); - - expect(accept_token(ctx, semi_tok)); - - Node* cnst = constant(mod, annotations, type, id); - cnst->payload.constant.instruction = definition; - return cnst; -} - -static const Node* accept_fn_decl(ctxparams, Nodes annotations) { - if (!accept_token(ctx, fn_tok)) - return NULL; - - const char* name = accept_identifier(ctx); - expect(name); - Nodes types = accept_types(ctx, comma_tok, MaybeQualified); - expect(curr_token(tokenizer).tag == lpar_tok); - Nodes parameters; - expect_parameters(ctx, ¶meters, NULL); - - Node* fn = function(mod, parameters, name, annotations, types); - if (!accept_token(ctx, semi_tok)) - fn->payload.fun.body = expect_body(ctx, fn, types.count == 0 ? fn_ret(arena, (Return) { .args = types }) : NULL); - - const Node* declaration = fn; - expect(declaration); - - return declaration; -} - -static const Node* accept_global_var_decl(ctxparams, Nodes annotations) { - AddressSpace as; - if (accept_token(ctx, private_tok)) - as = AsPrivateLogical; - else if (accept_token(ctx, shared_tok)) - as = AsSharedLogical; - else if (accept_token(ctx, subgroup_tok)) - as = AsSubgroupLogical; - else if (accept_token(ctx, global_tok)) - as = AsGlobalLogical; - else if (accept_token(ctx, extern_tok)) - as = AsExternal; - else if (accept_token(ctx, input_tok)) - as = AsInput; - else if (accept_token(ctx, output_tok)) - as = AsOutput; - else if (accept_token(ctx, uniform_tok)) { - if (accept_token(ctx, input_tok)) { - as = AsUInput; - } else { - expect(false && "expected 'input'"); - return NULL; - } - } else - return NULL; - - const Type* type = accept_unqualified_type(ctx); - expect(type); - const char* id = accept_identifier(ctx); - expect(id); - - const Node* initial_value = NULL; - if (accept_token(ctx, equal_tok)) { - initial_value = accept_value(ctx); - expect(initial_value); - } - - expect(accept_token(ctx, semi_tok)); - - Node* gv = global_var(mod, annotations, type, id, as); - gv->payload.global_variable.init = initial_value; - return gv; -} - -static const Node* accept_nominal_type_decl(ctxparams, Nodes annotations) { - if (!accept_token(ctx, type_tok)) - return NULL; - - const char* id = accept_identifier(ctx); - expect(id); - - expect(accept_token(ctx, equal_tok)); - - Node* nom = nominal_type(mod, annotations, id); - nom->payload.nom_type.body = accept_unqualified_type(ctx); - expect(nom->payload.nom_type.body); - - expect(accept_token(ctx, semi_tok)); - return nom; -} - -void parse_shady_ir(ParserConfig config, const char* contents, Module* mod) { - IrArena* arena = get_module_arena(mod); - Tokenizer* tokenizer = new_tokenizer(contents); - - while (true) { - Token token = curr_token(tokenizer); - if (token.tag == EOF_tok) - break; - - Nodes annotations = accept_annotations(ctx); - - const Node* decl = accept_const(ctx, annotations); - if (!decl) decl = accept_fn_decl(ctx, annotations); - if (!decl) decl = accept_global_var_decl(ctx, annotations); - if (!decl) decl = accept_nominal_type_decl(ctx, annotations); - - if (decl) { - debugv_print("decl parsed : "); - log_node(DEBUGV, decl); - debugv_print("\n"); - continue; - } - - error_print("No idea what to parse here... (tok=(tag = %s, pos = %zu))\n", token_tags[token.tag], token.start); - exit(-3); - } - - destroy_tokenizer(tokenizer); -} diff --git a/src/frontends/spirv/CMakeLists.txt b/src/frontends/spirv/CMakeLists.txt deleted file mode 100644 index ae1786be6..000000000 --- a/src/frontends/spirv/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_library(shady_s2s STATIC s2s.c) -target_link_libraries(shady_s2s PRIVATE api common shady) -target_link_libraries(shady_s2s PRIVATE "$") -set_property(TARGET shady_s2s PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/frontends/spirv/s2s.h b/src/frontends/spirv/s2s.h deleted file mode 100644 index 5f86e26f7..000000000 --- a/src/frontends/spirv/s2s.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef SHADY_S2S -#define SHADY_S2S - -#include "shady/ir.h" - -typedef enum { - S2S_Success, - S2S_FailedParsingGeneric, -} S2SError; - -S2SError parse_spirv_into_shady(Module* dst, size_t len, const char* data); - -#endif diff --git a/src/runtime/CMakeLists.txt b/src/runtime/CMakeLists.txt index c03ac06a9..4eb71e028 100644 --- a/src/runtime/CMakeLists.txt +++ b/src/runtime/CMakeLists.txt @@ -1,10 +1,15 @@ -add_library(runtime SHARED runtime.c runtime_program.c) -target_link_libraries(runtime PUBLIC shady) -target_link_libraries(runtime PUBLIC "$") -set_property(TARGET runtime PROPERTY POSITION_INDEPENDENT_CODE ON) -set_target_properties(runtime PROPERTIES OUTPUT_NAME "shady_runtime") +option(SHADY_ENABLE_RUNTIME "Offers helpful utilities for building applications with shady. Some samples and tests depend on it." ON) -add_subdirectory(vulkan) +if (SHADY_ENABLE_RUNTIME) + add_library(runtime runtime.c runtime_program.c runtime_cli.c) + target_link_libraries(runtime PUBLIC driver) + set_target_properties(runtime PROPERTIES OUTPUT_NAME "shady_runtime") -add_executable(runtime_test runtime_test.c) -target_link_libraries(runtime_test runtime) + add_subdirectory(vulkan) + add_subdirectory(cuda) + + add_executable(runtime_test runtime_test.c) + target_link_libraries(runtime_test runtime) + + install(TARGETS runtime EXPORT shady_export_set ARCHIVE DESTINATION ${CMAKE_ARCHIVE_OUTPUT_DIRECTORY}) +endif() diff --git a/src/runtime/cuda/CMakeLists.txt b/src/runtime/cuda/CMakeLists.txt new file mode 100644 index 000000000..20378e635 --- /dev/null +++ b/src/runtime/cuda/CMakeLists.txt @@ -0,0 +1,18 @@ +find_package(CUDAToolkit) + +if (CUDAToolkit_FOUND) + message("CUDA toolkit found.") + option(SHADY_ENABLE_RUNTIME_CUDA "CUDA support for the 'runtime' component" ON) +else() + message("CUDA toolkit not found, CUDA runtime component cannot be built.") +endif () + +if (SHADY_ENABLE_RUNTIME_CUDA) + add_library(cuda_runtime STATIC cuda_runtime.c cuda_runtime_buffer.c cuda_runtime_program.c) + target_link_libraries(cuda_runtime PRIVATE api) + target_link_libraries(cuda_runtime PRIVATE "$") + target_link_libraries(cuda_runtime PRIVATE CUDA::cudart CUDA::cuda_driver CUDA::nvrtc) + + target_link_libraries(runtime PRIVATE "$") + target_compile_definitions(runtime PUBLIC CUDA_BACKEND_PRESENT=1) +endif() \ No newline at end of file diff --git a/src/runtime/cuda/cuda_runtime.c b/src/runtime/cuda/cuda_runtime.c new file mode 100644 index 000000000..6d079528d --- /dev/null +++ b/src/runtime/cuda/cuda_runtime.c @@ -0,0 +1,128 @@ +#include "cuda_runtime_private.h" +#include "shady/config.h" + +#include "log.h" +#include "portability.h" +#include "list.h" +#include "dict.h" + +#include + +static void shutdown_cuda_runtime(CudaBackend* b) { + +} + +static const char* cuda_device_get_name(CudaDevice* device) { return device->name; } + +static void cuda_device_cleanup(CudaDevice* device) { + size_t i = 0; + CudaKernel* kernel; + while (dict_iter(device->specialized_programs, &i, NULL, &kernel)) { + shd_cuda_destroy_specialized_kernel(kernel); + } + destroy_dict(device->specialized_programs); +} + +bool cuda_command_wait(CudaCommand* command) { + CHECK_CUDA(cuCtxSynchronize(), return false); + if (command->profiled_gpu_time) { + cudaEventSynchronize(command->stop); + float ms; + cudaEventElapsedTime(&ms, command->start, command->stop); + *command->profiled_gpu_time = (uint64_t) ((double) ms * 1000000); + } + return true; +} + +static CudaCommand* shd_cuda_launch_kernel(CudaDevice* device, Program* p, String entry_point, int dimx, int dimy, int dimz, int args_count, void** args, ExtraKernelOptions* options) { + CudaKernel* kernel = shd_cuda_get_specialized_program(device, p, entry_point); + + CudaCommand* cmd = calloc(sizeof(CudaCommand), 1); + *cmd = (CudaCommand) { + .base = { + .wait_for_completion = (bool(*)(Command*)) cuda_command_wait + } + }; + + if (options && options->profiled_gpu_time) { + cmd->profiled_gpu_time = options->profiled_gpu_time; + cudaEventCreate(&cmd->start); + cudaEventCreate(&cmd->stop); + cudaEventRecord(cmd->start, 0); + } + + ArenaConfig final_config = *get_arena_config(get_module_arena(kernel->final_module)); + unsigned int gx = final_config.specializations.workgroup_size[0]; + unsigned int gy = final_config.specializations.workgroup_size[1]; + unsigned int gz = final_config.specializations.workgroup_size[2]; + CHECK_CUDA(cuLaunchKernel(kernel->entry_point_function, dimx, dimy, dimz, gx, gy, gz, 0, 0, args, NULL), return NULL); + cudaEventRecord(cmd->stop, 0); + return cmd; +} + +static KeyHash hash_spec_program_key(SpecProgramKey* ptr) { + return hash_murmur(ptr, sizeof(SpecProgramKey)); +} + +static bool cmp_spec_program_keys(SpecProgramKey* a, SpecProgramKey* b) { + return memcmp(a, b, sizeof(SpecProgramKey)) == 0; +} + +static CudaDevice* create_cuda_device(CudaBackend* b, int ordinal) { + CUdevice handle; + CHECK_CUDA(cuDeviceGet(&handle, ordinal), return NULL); + CudaDevice* device = calloc(sizeof(CudaDevice), 1); + *device = (CudaDevice) { + .base = { + .get_name = (const char*(*)(Device*)) cuda_device_get_name, + .cleanup = (void(*)(Device*)) cuda_device_cleanup, + .allocate_buffer = (Buffer* (*)(Device*, size_t)) shd_rt_cuda_allocate_buffer, + .can_import_host_memory = (bool (*)(Device*)) shd_rt_cuda_can_import_host_memory, + .import_host_memory_as_buffer = (Buffer* (*)(Device*, void*, size_t)) shd_rt_cuda_import_host_memory, + .launch_kernel = (Command*(*)(Device*, Program*, String, int, int, int, int, void**, ExtraKernelOptions*)) shd_cuda_launch_kernel, + }, + .handle = handle, + .specialized_programs = new_dict(SpecProgramKey, CudaKernel*, (HashFn) hash_spec_program_key, (CmpFn) cmp_spec_program_keys), + }; + CHECK_CUDA(cuDeviceGetName(device->name, 255, handle), goto dealloc_and_return_null); + CHECK_CUDA(cuDeviceGetAttribute(&device->cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device->handle), goto dealloc_and_return_null); + CHECK_CUDA(cuDeviceGetAttribute(&device->cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device->handle), goto dealloc_and_return_null); + CHECK_CUDA(cuCtxCreate(&device->context, 0, handle), goto dealloc_and_return_null); + return device; + + dealloc_and_return_null: + free(device); + return NULL; +} + +static bool probe_cuda_devices(CudaBackend* b) { + int count; + CHECK_CUDA(cuDeviceGetCount(&count), return false); + for (size_t i = 0; i < count; i++) { + CudaDevice* device = create_cuda_device(b, i); + if (!device) + continue; + b->num_devices++; + append_list(CudaDevice*, b->base.runtime->devices, device); + } + return true; +} + +Backend* shd_rt_initialize_cuda_backend(Runtime* base) { + CudaBackend* backend = malloc(sizeof(CudaBackend)); + memset(backend, 0, sizeof(CudaBackend)); + backend->base = (Backend) { + .runtime = base, + .cleanup = (void(*)()) shutdown_cuda_runtime, + }; + + CHECK_CUDA(cuInit(0), goto init_fail_free); + CHECK(probe_cuda_devices(backend), goto init_fail_free); + info_print("Shady CUDA backend successfully initialized, found %d devices\n", backend->num_devices); + return &backend->base; + + init_fail_free: + shd_error_print("Failed to initialise the CUDA back-end.\n"); + free(backend); + return NULL; +} \ No newline at end of file diff --git a/src/runtime/cuda/cuda_runtime_buffer.c b/src/runtime/cuda/cuda_runtime_buffer.c new file mode 100644 index 000000000..c1af084ca --- /dev/null +++ b/src/runtime/cuda/cuda_runtime_buffer.c @@ -0,0 +1,73 @@ +#include "cuda_runtime_private.h" + +#include "log.h" +#include "portability.h" + +static void cuda_destroy_buffer(CudaBuffer* buffer) { + if (buffer->is_allocated) + CHECK_CUDA(cuMemFree(buffer->device_ptr), {}); + if (buffer->is_imported) + CHECK_CUDA(cuMemHostUnregister(buffer->host_ptr), {}); + free(buffer); +} + +static uint64_t cuda_get_deviceptr(CudaBuffer* buffer) { + return (uint64_t) buffer->device_ptr; +} + +static void* cuda_get_host_ptr(CudaBuffer* buffer) { + return (void*) buffer->host_ptr; +} + +static bool cuda_copy_to_buffer_fallback(CudaBuffer* dst, size_t dst_offset, void* src, size_t size) { + CHECK_CUDA(cuMemcpyHtoD(dst->device_ptr + dst_offset, src, size), return false); + return true; +} + +static bool cuda_copy_from_buffer_fallback(CudaBuffer* src, size_t src_offset, void* dst, size_t size) { + CHECK_CUDA(cuMemcpyDtoH(dst, src->device_ptr + src_offset, size), return false); + return true; +} + +static CudaBuffer* new_buffer_common(size_t size) { + CudaBuffer* buffer = calloc(sizeof(CudaBuffer), 1); + *buffer = (CudaBuffer) { + .base = { + .backend_tag = CUDARuntimeBackend, + .get_host_ptr = (void*(*)(Buffer*)) cuda_get_host_ptr, + .get_device_ptr = (uint64_t(*)(Buffer*)) cuda_get_deviceptr, + .destroy = (void(*)(Buffer*)) cuda_destroy_buffer, + .copy_into = (bool(*)(Buffer*, size_t, void*, size_t)) cuda_copy_to_buffer_fallback, + .copy_from = (bool(*)(Buffer*, size_t, void*, size_t)) cuda_copy_from_buffer_fallback, + }, + .size = size, + }; + return buffer; +} + +CudaBuffer* shd_rt_cuda_allocate_buffer(CudaDevice* device, size_t size) { + CUdeviceptr device_ptr; + CHECK_CUDA(cuMemAlloc(&device_ptr, size), return NULL); + CudaBuffer* buffer = new_buffer_common(size); + buffer->is_allocated = true; + buffer->device_ptr = device_ptr; + // TODO: check the assumptions of unified virtual addressing + buffer->host_ptr = (void*) device_ptr; + return buffer; +} + +CudaBuffer* shd_rt_cuda_import_host_memory(CudaDevice* device, void* host_ptr, size_t size) { + CUdeviceptr device_ptr; + CHECK_CUDA(cuMemHostRegister(host_ptr, size, CU_MEMHOSTREGISTER_DEVICEMAP), return NULL); + CHECK_CUDA(cuMemHostGetDevicePointer(&device_ptr, host_ptr, 0), return NULL); + CudaBuffer* buffer = new_buffer_common(size); + buffer->is_imported = true; + buffer->device_ptr = device_ptr; + // TODO: check the assumptions of unified virtual addressing + buffer->host_ptr = (void*) host_ptr; + return buffer; +} + +bool shd_rt_cuda_can_import_host_memory(CudaDevice* d) { + return true; +} diff --git a/src/runtime/cuda/cuda_runtime_private.h b/src/runtime/cuda/cuda_runtime_private.h new file mode 100644 index 000000000..6e17f677d --- /dev/null +++ b/src/runtime/cuda/cuda_runtime_private.h @@ -0,0 +1,72 @@ +#ifndef SHADY_CUDA_RUNTIME_PRIVATE_H +#define SHADY_CUDA_RUNTIME_PRIVATE_H + +#include "../runtime_private.h" + +#include +#include +#include + +#define CHECK_NVRTC(x, failure_handler) { nvrtcResult the_result_ = x; if (the_result_ != NVRTC_SUCCESS) { const char* msg = nvrtcGetErrorString(the_result_); shd_error_print(#x " failed (%s)\n", msg); failure_handler; } } +#define CHECK_CUDA(x, failure_handler) { CUresult the_result_ = x; if (the_result_ != CUDA_SUCCESS) { const char* msg; cuGetErrorName(the_result_, &msg); shd_error_print(#x " failed (%s)\n", msg); failure_handler; } } + +typedef struct { + Program* base; + String entry_point; +} SpecProgramKey; + +typedef struct CudaBackend_ { + Backend base; + size_t num_devices; +} CudaBackend; + +typedef struct { + Device base; + CUdevice handle; + CUcontext context; + char name[256]; + int cc_major; + int cc_minor; + struct Dict* specialized_programs; +} CudaDevice; + +typedef struct { + Buffer base; + size_t size; + CUdeviceptr device_ptr; + void* host_ptr; + bool is_allocated; + bool is_imported; +} CudaBuffer; + +typedef struct { + Command base; + + uint64_t* profiled_gpu_time; + cudaEvent_t start, stop; +} CudaCommand; + +typedef struct { + SpecProgramKey key; + CudaDevice* device; + Module* final_module; + struct { + char* cuda_code; + size_t cuda_code_size; + }; + struct { + char* ptx; + size_t ptx_size; + }; + CUmodule cuda_module; + CUfunction entry_point_function; +} CudaKernel; + +CudaBuffer* shd_rt_cuda_allocate_buffer(CudaDevice*, size_t size); +CudaBuffer* shd_rt_cuda_import_host_memory(CudaDevice*, void* host_ptr, size_t size); +bool shd_rt_cuda_can_import_host_memory(CudaDevice*); + +CudaKernel* shd_rt_cuda_get_specialized_program(CudaDevice*, Program*, String ep); +bool shd_rt_cuda_destroy_specialized_kernel(CudaKernel*); + +#endif diff --git a/src/runtime/cuda/cuda_runtime_program.c b/src/runtime/cuda/cuda_runtime_program.c new file mode 100644 index 000000000..4c3a460e4 --- /dev/null +++ b/src/runtime/cuda/cuda_runtime_program.c @@ -0,0 +1,167 @@ +#include "cuda_runtime_private.h" + +#include "shady/driver.h" + +#include "log.h" +#include "portability.h" +#include "dict.h" +#include "util.h" + +static CompilerConfig get_compiler_config_for_device(CudaDevice* device, const CompilerConfig* base_config) { + CompilerConfig config = *base_config; + config.specialization.subgroup_size = 32; + + return config; +} + +static bool emit_cuda_c_code(CudaKernel* spec) { + CompilerConfig config = get_compiler_config_for_device(spec->device, spec->key.base->base_config); + config.specialization.entry_point = spec->key.entry_point; + + Module* dst_mod = spec->key.base->module; + CHECK(run_compiler_passes(&config, &dst_mod) == CompilationNoError, return false); + + CEmitterConfig emitter_config = { + .dialect = CDialect_CUDA, + .explicitly_sized_types = false, + .allow_compound_literals = false, + .decay_unsized_arrays = true, + }; + Module* final_mod; + emit_c(&config, emitter_config, dst_mod, &spec->cuda_code_size, &spec->cuda_code, &final_mod); + spec->final_module = final_mod; + + if (get_log_level() <= DEBUG) + write_file("cuda_dump.cu", spec->cuda_code_size - 1, spec->cuda_code); + + return true; +} + +static bool cuda_c_to_ptx(CudaKernel* kernel) { + String override_file = getenv("SHADY_OVERRIDE_PTX"); + if (override_file) { + read_file(override_file, &kernel->ptx_size, &kernel->ptx); + return true; + } + + nvrtcProgram program; + CHECK_NVRTC(nvrtcCreateProgram(&program, kernel->cuda_code, kernel->key.entry_point, 0, NULL, NULL), return false); + + assert(kernel->device->cc_major < 10 && kernel->device->cc_minor < 10); + + char arch_flag[] = "-arch=compute_00"; + arch_flag[14] = '0' + kernel->device->cc_major; + arch_flag[15] = '0' + kernel->device->cc_minor; + + const char* options[] = { + arch_flag, + "--use_fast_math" + }; + + nvrtcResult compile_result = nvrtcCompileProgram(program, sizeof(options)/sizeof(*options), options); + if (compile_result != NVRTC_SUCCESS) { + shd_error_print("NVRTC compilation failed: %s\n", nvrtcGetErrorString(compile_result)); + debug_print("Dumping source:\n%s", kernel->cuda_code); + } + + size_t log_size; + CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size), return false); + char* log_buffer = calloc(log_size, 1); + CHECK_NVRTC(nvrtcGetProgramLog(program, log_buffer), return false); + shd_log_fmt(compile_result == NVRTC_SUCCESS ? DEBUG : ERROR, "NVRTC compilation log: %s\n", log_buffer); + free(log_buffer); + + CHECK_NVRTC(nvrtcGetPTXSize(program, &kernel->ptx_size), return false); + kernel->ptx = calloc(kernel->ptx_size, 1); + CHECK_NVRTC(nvrtcGetPTX(program, kernel->ptx), return false); + CHECK_NVRTC(nvrtcDestroyProgram(&program), return false); + + if (get_log_level() <= DEBUG) + write_file("cuda_dump.ptx", kernel->ptx_size - 1, kernel->ptx); + + return true; +} + +static bool load_ptx_into_cuda_program(CudaKernel* kernel) { + char info_log[10240] = {}; + char error_log[10240] = {}; + + CUjit_option options[] = { + CU_JIT_INFO_LOG_BUFFER, CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_TARGET + }; + + void* option_values[] = { + info_log, (void*)(uintptr_t)sizeof(info_log), + error_log, (void*)(uintptr_t)sizeof(error_log), + (void*)(uintptr_t)(kernel->device->cc_major * 10 + kernel->device->cc_minor) + }; + + CUlinkState linker; + CHECK_CUDA(cuLinkCreate(sizeof(options)/sizeof(options[0]), options, option_values, &linker), goto err_linker_create); + CHECK_CUDA(cuLinkAddData(linker, CU_JIT_INPUT_PTX, kernel->ptx, kernel->ptx_size, NULL, 0U, NULL, NULL), goto err_post_linker_create); + + void* binary; + size_t binary_size; + CHECK_CUDA(cuLinkComplete(linker, &binary, &binary_size), goto err_post_linker_create); + + if (*info_log) + info_print("CUDA JIT info: %s\n", info_log); + + if (get_log_level() <= DEBUG) + write_file("cuda_dump.cubin", binary_size, binary); + + CHECK_CUDA(cuModuleLoadData(&kernel->cuda_module, binary), goto err_post_linker_create); + CHECK_CUDA(cuModuleGetFunction(&kernel->entry_point_function, kernel->cuda_module, kernel->key.entry_point), goto err_post_module_load); + + cuLinkDestroy(linker); + return true; + +err_post_module_load: + cuModuleUnload(kernel->cuda_module); +err_post_linker_create: + cuLinkDestroy(linker); + if (*info_log) + info_print("CUDA JIT info: %s\n", info_log); + if (*error_log) + shd_error_print("CUDA JIT failed: %s\n", error_log); +err_linker_create: + return false; +} + +static CudaKernel* create_specialized_program(CudaDevice* device, SpecProgramKey key) { + CudaKernel* kernel = calloc(1, sizeof(CudaKernel)); + if (!kernel) + return NULL; + *kernel = (CudaKernel) { + .key = key, + .device = device, + }; + + CHECK(emit_cuda_c_code(kernel), return NULL); + CHECK(cuda_c_to_ptx(kernel), return NULL); + CHECK(load_ptx_into_cuda_program(kernel), return NULL); + + return kernel; +} + +CudaKernel* shd_rt_cuda_get_specialized_program(CudaDevice* device, Program* program, String entry_point) { + SpecProgramKey key = { .base = program, .entry_point = entry_point }; + CudaKernel** found = find_value_dict(SpecProgramKey, CudaKernel*, device->specialized_programs, key); + if (found) + return *found; + CudaKernel* spec = create_specialized_program(device, key); + assert(spec); + insert_dict(SpecProgramKey, CudaKernel*, device->specialized_programs, key, spec); + return spec; +} + +bool shd_rt_cuda_destroy_specialized_kernel(CudaKernel* kernel) { + free(kernel->cuda_code); + free(kernel->ptx); + CHECK_CUDA(cuModuleUnload(kernel->cuda_module), return false); + + free(kernel); + return true; +} diff --git a/src/runtime/runtime.c b/src/runtime/runtime.c index b2875c37f..a3e8d08db 100644 --- a/src/runtime/runtime.c +++ b/src/runtime/runtime.c @@ -6,88 +6,89 @@ #include #include -Runtime* initialize_runtime(RuntimeConfig config) { +Runtime* shd_rt_initialize(RuntimeConfig config) { Runtime* runtime = malloc(sizeof(Runtime)); memset(runtime, 0, sizeof(Runtime)); runtime->config = config; - runtime->backends = new_list(Backend*); - runtime->devices = new_list(Device*); - runtime->programs = new_list(Program*); + runtime->backends = shd_new_list(Backend*); + runtime->devices = shd_new_list(Device*); + runtime->programs = shd_new_list(Program*); #if VK_BACKEND_PRESENT - Backend* vk_backend = initialize_vk_backend(runtime); + Backend* vk_backend = shd_rt_initialize_vk_backend(runtime); CHECK(vk_backend, goto init_fail_free); - append_list(Backend*, runtime->backends, vk_backend); + shd_list_append(Backend*, runtime->backends, vk_backend); #endif - info_print("Shady runtime successfully initialized !\n"); +#if CUDA_BACKEND_PRESENT + Backend* cuda_backend = shd_rt_initialize_cuda_backend(runtime); + CHECK(cuda_backend, goto init_fail_free); + append_list(Backend*, runtime->backends, cuda_backend); +#endif + + shd_info_print("Shady runtime successfully initialized !\n"); return runtime; init_fail_free: - error_print("Failed to initialise the runtime.\n"); + shd_error_print("Failed to initialise the runtime.\n"); free(runtime); return NULL; } -void shutdown_runtime(Runtime* runtime) { +void shd_rt_shutdown(Runtime* runtime) { if (!runtime) return; // TODO force wait outstanding dispatches ? - for (size_t i = 0; i < entries_count_list(runtime->devices); i++) { - Device* dev = read_list(Device*, runtime->devices)[i]; + for (size_t i = 0; i < shd_list_count(runtime->devices); i++) { + Device* dev = shd_read_list(Device*, runtime->devices)[i]; dev->cleanup(dev); } - destroy_list(runtime->devices); + shd_destroy_list(runtime->devices); - for (size_t i = 0; i < entries_count_list(runtime->programs); i++) { - unload_program(read_list(Program*, runtime->programs)[i]); + for (size_t i = 0; i < shd_list_count(runtime->programs); i++) { + shd_rt_unload_program(shd_read_list(Program*, runtime->programs)[i]); } - destroy_list(runtime->programs); + shd_destroy_list(runtime->programs); - for (size_t i = 0; i < entries_count_list(runtime->backends); i++) { - Backend* bk = read_list(Backend*, runtime->backends)[i]; + for (size_t i = 0; i < shd_list_count(runtime->backends); i++) { + Backend* bk = shd_read_list(Backend*, runtime->backends)[i]; bk->cleanup(bk); } free(runtime); } -size_t device_count(Runtime* r) { - return entries_count_list(r->devices); +size_t shd_rt_device_count(Runtime* r) { + return shd_list_count(r->devices); } -Device* get_device(Runtime* r, size_t i) { - assert(i < device_count(r)); - return read_list(Device*, r->devices)[i]; +Device* shd_rt_get_device(Runtime* r, size_t i) { + assert(i < shd_rt_device_count(r)); + return shd_read_list(Device*, r->devices)[i]; } -Device* get_an_device(Runtime* r) { - assert(device_count(r) > 0); - return get_device(r, 0); +Device* shd_rt_get_an_device(Runtime* r) { + assert(shd_rt_device_count(r) > 0); + return shd_rt_get_device(r, 0); } // Virtual functions ... -const char* get_device_name(Device* d) { return d->get_name(d); } +const char* shd_rt_get_device_name(Device* d) { return d->get_name(d); } -Command* launch_kernel(Program* p, Device* d, const char* entry_point, int dimx, int dimy, int dimz, int args_count, void** args) { - return d->launch_kernel(d, p, entry_point, dimx, dimy, dimz, args_count, args); +Command* shd_rt_launch_kernel(Program* p, Device* d, const char* entry_point, int dimx, int dimy, int dimz, int args_count, void** args, ExtraKernelOptions* extra_options) { + return d->launch_kernel(d, p, entry_point, dimx, dimy, dimz, args_count, args, extra_options); } -bool wait_completion(Command* cmd) { return cmd->wait_for_completion(cmd); } +bool shd_rt_wait_completion(Command* cmd) { return cmd->wait_for_completion(cmd); } -bool can_import_host_memory(Device* device) { return device->can_import_host_memory(device); } +bool shd_rt_can_import_host_memory(Device* device) { return device->can_import_host_memory(device); } -Buffer* allocate_buffer_device(Device* device, size_t bytes) { return device->allocate_buffer(device, bytes); } -Buffer* import_buffer_host(Device* device, void* ptr, size_t bytes) { return device->import_host_memory_as_buffer(device, ptr, bytes); } +Buffer* shd_rt_allocate_buffer_device(Device* device, size_t bytes) { return device->allocate_buffer(device, bytes); } +Buffer* shd_rt_import_buffer_host(Device* device, void* ptr, size_t bytes) { return device->import_host_memory_as_buffer(device, ptr, bytes); } -void destroy_buffer(Buffer* buf) { buf->destroy(buf); }; +void shd_rt_destroy_buffer(Buffer* buf) { buf->destroy(buf); } -void* get_buffer_host_pointer(Buffer* buf) { return buf->get_host_ptr(buf); } -uint64_t get_buffer_device_pointer(Buffer* buf) { return buf->get_device_ptr(buf); }; +void* shd_rt_get_buffer_host_pointer(Buffer* buf) { return buf->get_host_ptr(buf); } +uint64_t shd_rt_get_buffer_device_pointer(Buffer* buf) { return buf->get_device_ptr(buf); } -bool copy_to_buffer(Buffer* dst, size_t buffer_offset, void* src, size_t size) { - return dst->copy_into(dst, buffer_offset, src, size); -} - -bool copy_from_buffer(Buffer* src, size_t buffer_offset, void* dst, size_t size) { - return src->copy_from(src, buffer_offset, dst, size); -} +bool shd_rt_copy_to_buffer(Buffer* dst, size_t buffer_offset, void* src, size_t size) { return dst->copy_into(dst, buffer_offset, src, size); } +bool shd_rt_copy_from_buffer(Buffer* src, size_t buffer_offset, void* dst, size_t size) { return src->copy_from(src, buffer_offset, dst, size); } diff --git a/src/runtime/runtime_app_common.h b/src/runtime/runtime_app_common.h new file mode 100644 index 000000000..eca75e227 --- /dev/null +++ b/src/runtime/runtime_app_common.h @@ -0,0 +1,53 @@ +#ifndef SHADY_RUNTIME_CLI +#define SHADY_RUNTIME_CLI + +#include "shady/driver.h" +#include "runtime_private.h" + +#include "log.h" + +#include +#include + +typedef struct { + size_t device; +} CommonAppArgs; + +static void cli_parse_common_app_arguments(CommonAppArgs* args, int* pargc, char** argv) { + int argc = *pargc; + + bool help = false; + for (int i = 1; i < argc; i++) { + if (argv[i] == NULL) + continue; + if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { + help = true; + continue; + } else if (strcmp(argv[i], "--device") == 0 || strcmp(argv[i], "-d") == 0) { + argv[i] = NULL; + i++; + if (i >= argc) { + shd_error_print("Missing device number for --device\n"); + exit(1); + } + args->device = strtol(argv[i], NULL, 10); + } else { + continue; + } + argv[i] = NULL; + } + + if (help) { + shd_error_print("Usage: runtime_test [source.slim]\n"); + shd_error_print("Available arguments: \n"); + shd_error_print(" --log-level debug[v[v]], info, warn, error]\n"); + shd_error_print(" --shd_print-builtin\n"); + shd_error_print(" --shd_print-generated\n"); + shd_error_print(" --device n\n"); + exit(0); + } + + shd_pack_remaining_args(pargc, argv); +} + +#endif diff --git a/src/runtime/runtime_cli.c b/src/runtime/runtime_cli.c new file mode 100644 index 000000000..8c9f53e09 --- /dev/null +++ b/src/runtime/runtime_cli.c @@ -0,0 +1,49 @@ +#include "runtime_private.h" +#include "../driver/cli.h" + +#include "log.h" + +RuntimeConfig shd_rt_default_config() { + return (RuntimeConfig) { +#ifndef NDEBUG + .dump_spv = true, + .use_validation = true, +#else + 0 +#endif + }; +} + +#define DRIVER_CONFIG_OPTIONS(F) \ +F(config->use_validation, api-validation) \ +F(config->dump_spv, dump-spv) \ + +void shd_rt_cli_parse_runtime_config(RuntimeConfig* config, int* pargc, char** argv) { + int argc = *pargc; + + bool help = false; + for (int i = 1; i < argc; i++) { + + DRIVER_CONFIG_OPTIONS(PARSE_TOGGLE_OPTION) + + if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { + help = true; + continue; + } else { + continue; + } + argv[i] = NULL; + } + + if (help) { + // shd_error_print("Usage: slim source.slim\n"); + // shd_error_print("Available arguments: \n"); + shd_error_print(" --target \n"); + shd_error_print(" --output , -o \n"); + shd_error_print(" --dump-cfg Dumps the control flow graph of the final IR\n"); + shd_error_print(" --dump-loop-tree \n"); + shd_error_print(" --dump-ir Dumps the final IR\n"); + } + + shd_pack_remaining_args(pargc, argv); +} \ No newline at end of file diff --git a/src/runtime/runtime_private.h b/src/runtime/runtime_private.h index aef273ffe..8b8a504a4 100644 --- a/src/runtime/runtime_private.h +++ b/src/runtime/runtime_private.h @@ -3,7 +3,7 @@ #include "shady/runtime.h" #include "shady/ir.h" -#define CHECK(x, failure_handler) { if (!(x)) { error_print(#x " failed\n"); failure_handler; } } +#define CHECK(x, failure_handler) { if (!(x)) { shd_error_print(#x " failed\n"); failure_handler; } } // typedef struct SpecProgram_ SpecProgram; @@ -15,6 +15,11 @@ struct Runtime_ { struct List* programs; }; +typedef enum { + VulkanRuntimeBackend, + CUDARuntimeBackend, +} ShdRuntimeBackend; + typedef struct Backend_ Backend; struct Backend_ { Runtime* runtime; @@ -25,12 +30,19 @@ struct Device_ { void (*cleanup)(Device*); String (*get_name)(Device*); - Command* (*launch_kernel)(Device*, Program*, const char* entry_point, int dimx, int dimy, int dimz, int args_count, void** args); + Command* (*launch_kernel)(Device*, Program*, const char* entry_point, int dimx, int dimy, int dimz, int args_count, void** args, ExtraKernelOptions*); Buffer* (*allocate_buffer)(Device*, size_t bytes); Buffer* (*import_host_memory_as_buffer)(Device*, void* base, size_t bytes); bool (*can_import_host_memory)(Device*); }; +typedef struct { + size_t num_args; + const size_t* arg_offset; + const size_t* arg_size; + size_t args_size; +} ProgramParamsInfo; + struct Program_ { Runtime* runtime; const CompilerConfig* base_config; @@ -44,15 +56,17 @@ struct Command_ { }; struct Buffer_ { - void (*destroy)(Buffer*); + ShdRuntimeBackend backend_tag; + void (*destroy)(Buffer*); void* (*get_host_ptr)(Buffer*); uint64_t (*get_device_ptr)(Buffer*); - - bool (*copy_into)(Buffer* dst, size_t buffer_offset, void* src, size_t bytes); - bool (*copy_from)(Buffer* src, size_t buffer_offset, void* dst, size_t bytes); + bool (*copy_into)(Buffer* dst, size_t buffer_offset, void* src, size_t bytes); + bool (*copy_from)(Buffer* src, size_t buffer_offset, void* dst, size_t bytes); }; -void unload_program(Program*); +void shd_rt_unload_program(Program* program); + +Backend* shd_rt_initialize_vk_backend(Runtime*); +Backend* shd_rt_shd_rt_initialize_cuda_backend(Runtime*); -Backend* initialize_vk_backend(Runtime*); #endif diff --git a/src/runtime/runtime_program.c b/src/runtime/runtime_program.c index caf344e8d..8f8792689 100644 --- a/src/runtime/runtime_program.c +++ b/src/runtime/runtime_program.c @@ -4,12 +4,13 @@ #include "log.h" #include "list.h" #include "util.h" +#include "portability.h" #include #include #include -Program* new_program_from_module(Runtime* runtime, const CompilerConfig* base_config, Module* mod) { +Program* shd_rt_new_program_from_module(Runtime* runtime, const CompilerConfig* base_config, Module* mod) { Program* program = calloc(1, sizeof(Program)); program->runtime = runtime; program->base_config = base_config; @@ -17,41 +18,40 @@ Program* new_program_from_module(Runtime* runtime, const CompilerConfig* base_co program->module = mod; // TODO split the compilation pipeline into generic and non-generic parts - append_list(Program*, runtime->programs, program); + shd_list_append(Program*, runtime->programs, program); return program; } -Program* load_program(Runtime* runtime, const CompilerConfig* base_config, const char* program_src) { - IrArena* arena = new_ir_arena(default_arena_config()); - Module* module = new_module(arena, "my_module"); +Program* shd_rt_load_program(Runtime* runtime, const CompilerConfig* base_config, const char* program_src) { + Module* module; - int err = driver_load_source_file(SrcShadyIR, strlen(program_src), program_src, module); + int err = shd_driver_load_source_file(base_config, SrcShadyIR, strlen(program_src), program_src, "my_module", + &module); if (err != NoError) { return NULL; } - Program* program = new_program_from_module(runtime, base_config, module); - program->arena = arena; + Program* program = shd_rt_new_program_from_module(runtime, base_config, module); + program->arena = shd_module_get_arena(module); return program; } -Program* load_program_from_disk(Runtime* runtime, const CompilerConfig* base_config, const char* path) { - IrArena* arena = new_ir_arena(default_arena_config()); - Module* module = new_module(arena, "my_module"); +Program* shd_rt_load_program_from_disk(Runtime* runtime, const CompilerConfig* base_config, const char* path) { + Module* module; - int err = driver_load_source_file_from_filename(path, module); + int err = shd_driver_load_source_file_from_filename(base_config, path, "my_module", &module); if (err != NoError) { return NULL; } - Program* program = new_program_from_module(runtime, base_config, module); - program->arena = arena; + Program* program = shd_rt_new_program_from_module(runtime, base_config, module); + program->arena = shd_module_get_arena(module); return program; } -void unload_program(Program* program) { +void shd_rt_unload_program(Program* program) { // TODO iterate over the specialized stuff if (program->arena) // if the program owns an arena - destroy_ir_arena(program->arena); + shd_destroy_ir_arena(program->arena); free(program); } diff --git a/src/runtime/runtime_test.c b/src/runtime/runtime_test.c index 301f93e9d..a5ce05432 100644 --- a/src/runtime/runtime_test.c +++ b/src/runtime/runtime_test.c @@ -2,6 +2,8 @@ #include "shady/ir.h" #include "shady/driver.h" +#include "runtime_app_common.h" + #include "log.h" #include "portability.h" #include "util.h" @@ -11,99 +13,70 @@ #include #include #include - -static const char* default_shader = -"@EntryPoint(\"Compute\") @WorkgroupSize(SUBGROUP_SIZE, 1, 1) fn main(uniform i32 a, uniform ptr global i32 b) {\n" -" val rb = reinterpret[u64](b);\n" -" debug_printf(\"hi %d 0x%lx\\n\", a, rb);\n" -" return ();\n" -"}"; +#include typedef struct { DriverConfig driver_config; RuntimeConfig runtime_config; - size_t device; + CommonAppArgs common_app_args; } Args; -static void parse_runtime_arguments(int* pargc, char** argv, Args* args) { - int argc = *pargc; - - bool help = false; - for (int i = 1; i < argc; i++) { - if (argv[i] == NULL) - continue; - if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { - help = true; - continue; - } else if (strcmp(argv[i], "--device") == 0 || strcmp(argv[i], "-d") == 0) { - argv[i] = NULL; - i++; - args->device = strtol(argv[i], NULL, 10); - } else { - continue; - } - argv[i] = NULL; - } - - if (help) { - error_print("Usage: runtime_test [source.slim]\n"); - error_print("Available arguments: \n"); - error_print(" --log-level debug[v[v]], info, warn, error]\n"); - error_print(" --print-builtin\n"); - error_print(" --print-generated\n"); - error_print(" --device n\n"); - exit(0); - } -} +static const char* default_shader = +"@EntryPoint(\"Compute\") @Exported @WorkgroupSize(SUBGROUP_SIZE, 1, 1) fn my_kernel(uniform i32 a, uniform ptr global i32 b) {\n" +" val rb = reinterpret[u64](b);\n" +" debug_printf(\"hi %d 0x%lx\\n\", a, rb);\n" +" return ();\n" +"}"; int main(int argc, char* argv[]) { - set_log_level(INFO); + shd_log_set_level(INFO); Args args = { - .driver_config = default_driver_config(), + .driver_config = shd_default_driver_config(), + .runtime_config = shd_rt_default_config(), }; - args.runtime_config = (RuntimeConfig) { - .use_validation = true, - .dump_spv = true, - }; - parse_runtime_arguments(&argc, argv, &args); - cli_parse_common_args(&argc, argv); - cli_parse_compiler_config_args(&args.driver_config.config, &argc, argv); - cli_parse_input_files(args.driver_config.input_filenames, &argc, argv); + cli_parse_common_app_arguments(&args.common_app_args, &argc, argv); + shd_parse_common_args(&argc, argv); + shd_rt_cli_parse_runtime_config(&args.runtime_config, &argc, argv); + shd_parse_compiler_config_args(&args.driver_config.config, &argc, argv); + shd_driver_parse_input_files(args.driver_config.input_filenames, &argc, argv); - info_print("Shady runtime test starting...\n"); + shd_info_print("Shady runtime test starting...\n"); - Runtime* runtime = initialize_runtime(args.runtime_config); - Device* device = get_device(runtime, args.device); + Runtime* runtime = shd_rt_initialize(args.runtime_config); + Device* device = shd_rt_get_device(runtime, args.common_app_args.device); assert(device); Program* program; IrArena* arena = NULL; - if (entries_count_list(args.driver_config.input_filenames) == 0) { - program = load_program(runtime, &args.driver_config.config, default_shader); + ArenaConfig aconfig = shd_default_arena_config(&args.driver_config.config.target); + arena = shd_new_ir_arena(&aconfig); + if (shd_list_count(args.driver_config.input_filenames) == 0) { + Module* module; + shd_driver_load_source_file(&args.driver_config.config, SrcSlim, strlen(default_shader), default_shader, + "runtime_test", &module); + program = shd_rt_new_program_from_module(runtime, &args.driver_config.config, module); } else { - arena = new_ir_arena(default_arena_config()); - Module* module = new_module(arena, "my_module"); - - int err = driver_load_source_files(&args.driver_config, module); + Module* module = shd_new_module(arena, "my_module"); + int err = shd_driver_load_source_files(&args.driver_config, module); if (err) return err; - program = new_program_from_module(runtime, &args.driver_config.config, module); + program = shd_rt_new_program_from_module(runtime, &args.driver_config.config, module); } int32_t stuff[] = { 42, 42, 42, 42 }; - Buffer* buffer = allocate_buffer_device(device, sizeof(stuff)); - copy_to_buffer(buffer, 0, stuff, sizeof(stuff)); - copy_from_buffer(buffer, 0, stuff, sizeof(stuff)); + Buffer* buffer = shd_rt_allocate_buffer_device(device, sizeof(stuff)); + shd_rt_copy_to_buffer(buffer, 0, stuff, sizeof(stuff)); + shd_rt_copy_from_buffer(buffer, 0, stuff, sizeof(stuff)); int32_t a0 = 42; - uint64_t a1 = get_buffer_device_pointer(buffer); - wait_completion(launch_kernel(program, device, "main", 1, 1, 1, 2, (void*[]) { &a0, &a1 })); + uint64_t a1 = shd_rt_get_buffer_device_pointer(buffer); + shd_rt_wait_completion(shd_rt_launch_kernel(program, device, args.driver_config.config.specialization.entry_point ? args.driver_config.config.specialization.entry_point : "my_kernel", 1, 1, 1, 2, (void* []) { &a0, &a1 }, NULL)); - destroy_buffer(buffer); + shd_rt_destroy_buffer(buffer); - shutdown_runtime(runtime); + shd_rt_shutdown(runtime); if (arena) - destroy_ir_arena(arena); - destroy_driver_config(&args.driver_config); + shd_destroy_ir_arena(arena); + shd_destroy_driver_config(&args.driver_config); return 0; } diff --git a/src/runtime/vulkan/CMakeLists.txt b/src/runtime/vulkan/CMakeLists.txt index e7b2b6430..f3f0966d7 100644 --- a/src/runtime/vulkan/CMakeLists.txt +++ b/src/runtime/vulkan/CMakeLists.txt @@ -1,16 +1,17 @@ find_package(Vulkan) if (Vulkan_FOUND) - message("Vulkan found") + option(SHADY_ENABLE_RUNTIME_VULKAN "Vulkan support for the 'runtime' component" ON) +else() + message("Vulkan not found, runtime component cannot be built.") +endif() + +if (SHADY_ENABLE_RUNTIME_VULKAN) add_library(vk_runtime STATIC vk_runtime.c vk_runtime_device.c vk_runtime_program.c vk_runtime_dispatch.c vk_runtime_buffer.c) - target_link_libraries(vk_runtime PUBLIC api) - target_link_libraries(vk_runtime PUBLIC shady) + target_link_libraries(vk_runtime PRIVATE api) target_link_libraries(vk_runtime PRIVATE "$") - target_link_libraries(vk_runtime PRIVATE "$") target_link_libraries(vk_runtime PRIVATE Vulkan::Headers Vulkan::Vulkan) - set_property(TARGET vk_runtime PROPERTY POSITION_INDEPENDENT_CODE ON) - target_link_libraries(runtime PRIVATE vk_runtime) + target_compile_definitions(runtime PUBLIC VK_BACKEND_PRESENT=1) -else() - message("Vulkan not found, runtime component will not be built.") + target_link_libraries(runtime PRIVATE "$") endif() diff --git a/src/runtime/vulkan/vk_runtime.c b/src/runtime/vulkan/vk_runtime.c index e5d58a600..dcb324dde 100644 --- a/src/runtime/vulkan/vk_runtime.c +++ b/src/runtime/vulkan/vk_runtime.c @@ -9,7 +9,7 @@ #include static VKAPI_ATTR VkBool32 VKAPI_CALL the_callback(SHADY_UNUSED VkDebugUtilsMessageSeverityFlagBitsEXT messageSeverity, SHADY_UNUSED VkDebugUtilsMessageTypeFlagsEXT messageType, const VkDebugUtilsMessengerCallbackDataEXT* pCallbackData, SHADY_UNUSED void* pUserData) { - warn_print("Validation says: %s\n", pCallbackData->pMessage); + shd_warn_print("Validation says: %s\n", pCallbackData->pMessage); return VK_FALSE; } @@ -52,7 +52,7 @@ static bool initialize_vk_instance(VkrBackend* runtime) { // Enable validation if the config says so if (runtime->base.runtime->config.use_validation && strcmp(layer->layerName, "VK_LAYER_KHRONOS_validation") == 0) { - info_print("Enabling validation... \n"); + shd_info_print("Enabling validation... \n"); runtime->enabled_layers.validation.enabled = true; enabled_layers[enabled_layers_count++] = layer->layerName; } @@ -71,7 +71,7 @@ static bool initialize_vk_instance(VkrBackend* runtime) { #define X(is_required, name, _) \ if (strcmp(extension->extensionName, "VK_"#name) == 0) { \ - info_print("Enabling instance extension VK_"#name"\n"); \ + shd_info_print("Enabling instance extension VK_"#name"\n"); \ runtime->instance_exts.name.enabled = true; \ enabled_extensions[enabled_extensions_count++] = extension->extensionName; \ } @@ -106,12 +106,12 @@ static bool initialize_vk_instance(VkrBackend* runtime) { case VK_ERROR_INCOMPATIBLE_DRIVER: { // Vulkan 1.0 is not worth supporting. It has many API warts and 1.1 fixes many of them. // the hardware support is basically identical, so you're not cutting off any devices, just stinky old drivers. - error_print("vkCreateInstance reported VK_ERROR_INCOMPATIBLE_DRIVER. This most certainly means you're trying to run on a Vulkan 1.0 implementation.\n"); - error_print("This application is written with Vulkan 1.1 as the baseline, you will need to update your Vulkan loader and/or driver."); + shd_error_print("vkCreateInstance reported VK_ERROR_INCOMPATIBLE_DRIVER. This most certainly means you're trying to run on a Vulkan 1.0 implementation.\n"); + shd_error_print("This application is written with Vulkan 1.1 as the baseline, you will need to update your Vulkan loader and/or driver."); return false; } default: { - error_print("vkCreateInstanced failed (%u)\n", err_create_instance); + shd_error_print("vkCreateInstanced failed (%u)\n", err_create_instance); return false; } } @@ -134,7 +134,7 @@ static void shutdown_vulkan_runtime(VkrBackend* backend) { free(backend); } -Backend* initialize_vk_backend(Runtime* base) { +Backend* shd_rt_initialize_vk_backend(Runtime* base) { VkrBackend* backend = malloc(sizeof(VkrBackend)); memset(backend, 0, sizeof(VkrBackend)); backend->base = (Backend) { @@ -143,12 +143,12 @@ Backend* initialize_vk_backend(Runtime* base) { }; CHECK(initialize_vk_instance(backend), goto init_fail_free) - probe_vkr_devices(backend); - info_print("Shady Vulkan backend successfully initialized !\n"); + shd_rt_vk_probe_devices(backend); + shd_info_print("Shady Vulkan backend successfully initialized !\n"); return &backend->base; init_fail_free: - error_print("Failed to initialise the Vulkan back-end.\n"); + shd_error_print("Failed to initialise the Vulkan back-end.\n"); free(backend); return NULL; } diff --git a/src/runtime/vulkan/vk_runtime_buffer.c b/src/runtime/vulkan/vk_runtime_buffer.c index fd78b1f50..ccbdea544 100644 --- a/src/runtime/vulkan/vk_runtime_buffer.c +++ b/src/runtime/vulkan/vk_runtime_buffer.c @@ -32,14 +32,14 @@ static uint32_t find_suitable_memory_type(VkrDevice* device, uint32_t memory_typ } } } - assert(false && "Unable to find a suitable memory type"); + shd_error("Unable to find a suitable memory type") } static Buffer make_base_buffer(VkrDevice*); -VkrBuffer* vkr_allocate_buffer_device_(VkrDevice* device, size_t size, AllocHeap heap) { +static VkrBuffer* vkr_allocate_buffer_device_(VkrDevice* device, size_t size, AllocHeap heap) { if (!device->caps.features.buffer_device_address.bufferDeviceAddress) { - error_print("device buffers require VK_KHR_buffer_device_address\n"); + shd_error_print("device buffers require VK_KHR_buffer_device_address\n"); return NULL; } @@ -102,31 +102,31 @@ VkrBuffer* vkr_allocate_buffer_device_(VkrDevice* device, size_t size, AllocHeap return NULL; } -VkrBuffer* vkr_allocate_buffer_device(VkrDevice* device, size_t size) { +VkrBuffer* shd_rt_vk_allocate_buffer_device(VkrDevice* device, size_t size) { return vkr_allocate_buffer_device_(device, size, AllocDeviceLocal); } static bool vkr_can_import_host_memory_(VkrDevice* device, bool log) { if (!device->caps.supported_extensions[ShadySupportsEXT_external_memory_host]) { if (log) - error_print("host imported buffers require VK_EXT_external_memory_host\n"); + shd_error_print("host imported buffers require VK_EXT_external_memory_host\n"); return false; } if (!device->caps.features.buffer_device_address.bufferDeviceAddress) { if (log) - error_print("host imported buffers require VK_KHR_buffer_device_address\n"); + shd_error_print("host imported buffers require VK_KHR_buffer_device_address\n"); return false; } return true; } -bool vkr_can_import_host_memory(VkrDevice* device) { +bool shd_rt_vk_can_import_host_memory(VkrDevice* device) { return vkr_can_import_host_memory_(device, false); } -VkrBuffer* vkr_import_buffer_host(VkrDevice* device, void* ptr, size_t size) { +VkrBuffer* shd_rt_vk_import_buffer_host(VkrDevice* device, void* ptr, size_t size) { if (!vkr_can_import_host_memory_(device, true)) { - error_die(); + shd_error_die(); } VkrBuffer* buffer = calloc(sizeof(VkrBuffer), 1); @@ -140,7 +140,7 @@ VkrBuffer* vkr_import_buffer_host(VkrDevice* device, void* ptr, size_t size) { size_t aligned_addr = (unaligned_addr / desired_alignment) * desired_alignment; assert(unaligned_addr >= aligned_addr); buffer->offset = unaligned_addr - aligned_addr; - debug_print("desired alignment = %zu, offset = %zu\n", desired_alignment, buffer->offset); + shd_debug_print("desired alignment = %zu, offset = %zu\n", desired_alignment, buffer->offset); size_t unaligned_end = unaligned_addr + size; assert(unaligned_end >= aligned_addr); @@ -149,8 +149,8 @@ VkrBuffer* vkr_import_buffer_host(VkrDevice* device, void* ptr, size_t size) { size_t aligned_size = aligned_end - aligned_addr; assert(aligned_size >= size); assert(aligned_size % desired_alignment == 0); - debug_print("unaligned start %zu end %zu\n", unaligned_addr, unaligned_end); - debug_print("aligned start %zu end %zu\n", aligned_addr, aligned_end); + shd_debug_print("unaligned start %zu end %zu\n", unaligned_addr, unaligned_end); + shd_debug_print("aligned start %zu end %zu\n", aligned_addr, aligned_end); buffer->host_ptr = (void*) aligned_addr; buffer->size = aligned_size; @@ -176,11 +176,11 @@ VkrBuffer* vkr_import_buffer_host(VkrDevice* device, void* ptr, size_t size) { .sType = VK_STRUCTURE_TYPE_MEMORY_HOST_POINTER_PROPERTIES_EXT, .pNext = NULL }; - CHECK_VK(device->extensions.EXT_external_memory_host.vkGetMemoryHostPointerPropertiesEXT(device->device, VK_EXTERNAL_MEMORY_HANDLE_TYPE_HOST_ALLOCATION_BIT_EXT, ptr, &host_ptr_properties), goto err_post_buffer_create); + CHECK_VK(device->extensions.EXT_external_memory_host.vkGetMemoryHostPointerPropertiesEXT(device->device, VK_EXTERNAL_MEMORY_HANDLE_TYPE_HOST_ALLOCATION_BIT_EXT, (void*) aligned_addr, &host_ptr_properties), goto err_post_buffer_create); uint32_t memory_type_index = find_suitable_memory_type(device, host_ptr_properties.memoryTypeBits, AllocHostVisible); VkPhysicalDeviceMemoryProperties device_memory_properties; vkGetPhysicalDeviceMemoryProperties(device->caps.physical_device, &device_memory_properties); - debug_print("memory type index: %d heap: %d\n", memory_type_index, device_memory_properties.memoryTypes[memory_type_index].heapIndex); + shd_debug_print("memory type index: %d heap: %d\n", memory_type_index, device_memory_properties.memoryTypes[memory_type_index].heapIndex); VkMemoryAllocateInfo allocation_info = { .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO, @@ -218,7 +218,7 @@ VkrBuffer* vkr_import_buffer_host(VkrDevice* device, void* ptr, size_t size) { return NULL; } -static void vkr_destroy_buffer(VkrBuffer* buffer) { +void shd_rt_vk_destroy_buffer(VkrBuffer* buffer) { vkDestroyBuffer(buffer->device->device, buffer->buffer, NULL); vkFreeMemory(buffer->device->device, buffer->memory, NULL); } @@ -236,23 +236,24 @@ static void* vkr_get_buffer_host_pointer(VkrBuffer* buf) { } static VkrCommand* submit_buffer_copy(VkrDevice* device, VkBuffer src, size_t src_offset, VkBuffer dst, size_t dst_offset, size_t size) { - VkrCommand* commands = vkr_begin_command(device); + VkrCommand* commands = shd_rt_vk_begin_command(device); if (!commands) return NULL; vkCmdCopyBuffer(commands->cmd_buf, src, dst, 1, (VkBufferCopy[]) { { .srcOffset = src_offset, .dstOffset = dst_offset, .size = size } }); - if (!vkr_submit_command(commands)) + if (!shd_rt_vk_submit_command(commands)) goto err_post_commands_create; return commands; err_post_commands_create: - vkr_destroy_command(commands); + shd_rt_vk_destroy_command(commands); return NULL; } static bool vkr_copy_to_buffer_fallback(VkrBuffer* dst, size_t buffer_offset, void* src, size_t size) { + CHECK(dst->base.backend_tag == VulkanRuntimeBackend, return false); VkrDevice* device = dst->device; VkrBuffer* src_buf = vkr_allocate_buffer_device_(device, size, AllocHostVisible); @@ -263,19 +264,20 @@ static bool vkr_copy_to_buffer_fallback(VkrBuffer* dst, size_t buffer_offset, vo CHECK_VK(vkMapMemory(device->device, src_buf->memory, src_buf->offset, src_buf->size, 0, &mapped), goto err_post_buffer_create); memcpy(mapped, src, size); - if (!wait_completion(submit_buffer_copy(device, src_buf->buffer, src_buf->offset, dst->buffer, dst->offset + buffer_offset, size))) + if (!shd_rt_wait_completion((Command*) submit_buffer_copy(device, src_buf->buffer, src_buf->offset, dst->buffer, dst->offset + buffer_offset, size))) goto err_post_buffer_create; vkUnmapMemory(device->device, src_buf->memory); - vkr_destroy_buffer(src_buf); + shd_rt_vk_destroy_buffer(src_buf); return true; err_post_buffer_create: - vkr_destroy_buffer(src_buf); + shd_rt_vk_destroy_buffer(src_buf); return false; } static bool vkr_copy_from_buffer_fallback(VkrBuffer* src, size_t buffer_offset, void* dst, size_t size) { + CHECK(src->base.backend_tag == VulkanRuntimeBackend, return false); VkrDevice* device = src->device; VkrBuffer* dst_buf = vkr_allocate_buffer_device_(device, size, AllocHostVisible); @@ -285,64 +287,67 @@ static bool vkr_copy_from_buffer_fallback(VkrBuffer* src, size_t buffer_offset, void* mapped; CHECK_VK(vkMapMemory(device->device, dst_buf->memory, dst_buf->offset, dst_buf->size, 0, &mapped), goto err_post_buffer_create); - if (!wait_completion(submit_buffer_copy(device, src->buffer, src->offset + buffer_offset, dst_buf->buffer, dst_buf->offset, size))) + if (!shd_rt_wait_completion((Command*) submit_buffer_copy(device, src->buffer, src->offset + buffer_offset, dst_buf->buffer, dst_buf->offset, size))) goto err_post_buffer_create; memcpy(dst, mapped, size); vkUnmapMemory(device->device, dst_buf->memory); - vkr_destroy_buffer(dst_buf); + shd_rt_vk_destroy_buffer(dst_buf); return true; err_post_buffer_create: - vkr_destroy_buffer(dst_buf); + shd_rt_vk_destroy_buffer(dst_buf); return false; } static bool vkr_copy_to_buffer_importing(VkrBuffer* dst, size_t buffer_offset, void* src, size_t size) { + CHECK(dst->base.backend_tag == VulkanRuntimeBackend, return false); VkrDevice* device = dst->device; - VkrBuffer* src_buf = vkr_import_buffer_host(device, src, size); + VkrBuffer* src_buf = shd_rt_vk_import_buffer_host(device, src, size); if (!src_buf) return false; - if (!wait_completion(submit_buffer_copy(device, src_buf->buffer, src_buf->offset, dst->buffer, dst->offset + buffer_offset, size))) + if (!shd_rt_wait_completion((Command*) submit_buffer_copy(device, src_buf->buffer, src_buf->offset, dst->buffer, dst->offset + buffer_offset, size))) goto err_post_buffer_import; - vkr_destroy_buffer(src_buf); + shd_rt_vk_destroy_buffer(src_buf); return true; err_post_buffer_import: - vkr_destroy_buffer(src_buf); + shd_rt_vk_destroy_buffer(src_buf); return false; } static bool vkr_copy_from_buffer_importing(VkrBuffer* src, size_t buffer_offset, void* dst, size_t size) { + CHECK(src->base.backend_tag == VulkanRuntimeBackend, return false); VkrDevice* device = src->device; - VkrBuffer* dst_buf = vkr_import_buffer_host(device, dst, size); + VkrBuffer* dst_buf = shd_rt_vk_import_buffer_host(device, dst, size); if (!dst_buf) return false; - if (!wait_completion(submit_buffer_copy(device, src->buffer, src->offset + buffer_offset, dst_buf->buffer, dst_buf->offset, size))) + if (!shd_rt_wait_completion((Command*) submit_buffer_copy(device, src->buffer, src->offset + buffer_offset, dst_buf->buffer, dst_buf->offset, size))) goto err_post_buffer_import; - vkr_destroy_buffer(dst_buf); + shd_rt_vk_destroy_buffer(dst_buf); return true; err_post_buffer_import: - vkr_destroy_buffer(dst_buf); + shd_rt_vk_destroy_buffer(dst_buf); return false; } static Buffer make_base_buffer(VkrDevice* device) { Buffer buffer = { - .destroy = (void(*)(Buffer*)) vkr_destroy_buffer, + .backend_tag = VulkanRuntimeBackend, + .destroy = (void (*)(Buffer*)) shd_rt_vk_destroy_buffer, .get_device_ptr = (uint64_t(*)(Buffer*)) vkr_get_buffer_device_pointer, .get_host_ptr = (void*(*)(Buffer*)) vkr_get_buffer_host_pointer, .copy_into = (bool(*)(Buffer*, size_t, void*, size_t)) vkr_copy_to_buffer_fallback, .copy_from = (bool(*)(Buffer*, size_t, void*, size_t)) vkr_copy_from_buffer_fallback, }; - if (vkr_can_import_host_memory(device)) { + if (shd_rt_vk_can_import_host_memory(device)) { buffer.copy_from = (bool(*)(Buffer*, size_t, void*, size_t)) vkr_copy_from_buffer_importing; buffer.copy_into = (bool(*)(Buffer*, size_t, void*, size_t)) vkr_copy_to_buffer_importing; } diff --git a/src/runtime/vulkan/vk_runtime_device.c b/src/runtime/vulkan/vk_runtime_device.c index 6f8bca4c5..b03e9c569 100644 --- a/src/runtime/vulkan/vk_runtime_device.c +++ b/src/runtime/vulkan/vk_runtime_device.c @@ -62,7 +62,7 @@ static void figure_out_spirv_version(VkrDeviceCaps* caps) { caps->spirv_version.minor = 6; } - debug_print("Using SPIR-V version %d.%d, on Vulkan %d.%d\n", caps->spirv_version.major, caps->spirv_version.minor, major, minor); + shd_debug_print("Using SPIR-V version %d.%d, on Vulkan %d.%d\n", caps->spirv_version.major, caps->spirv_version.minor, major, minor); } static bool fill_device_properties(VkrDeviceCaps* caps) { @@ -70,13 +70,13 @@ static bool fill_device_properties(VkrDeviceCaps* caps) { vkGetPhysicalDeviceProperties2(caps->physical_device, &caps->properties.base); if (caps->properties.base.properties.apiVersion < VK_MAKE_API_VERSION(0, 1, 1, 0)) { - info_print("Rejecting device '%s' because it does not support Vulkan 1.1 or later\n", caps->properties.base.properties.deviceName); + shd_info_print("Rejecting device '%s' because it does not support Vulkan 1.1 or later\n", caps->properties.base.properties.deviceName); return false; } String missing_ext; if (!fill_available_extensions(caps->physical_device, NULL, &missing_ext, caps->supported_extensions)) { - info_print("Rejecting device %s because it lacks support for '%s'\n", caps->properties.base.properties.deviceName, missing_ext); + shd_info_print("Rejecting device %s because it lacks support for '%s'\n", caps->properties.base.properties.deviceName, missing_ext); return false; } @@ -100,6 +100,11 @@ static bool fill_device_properties(VkrDeviceCaps* caps) { append_pnext((VkBaseOutStructure*) &caps->properties.base, &caps->properties.external_memory_host); } + if (caps->supported_extensions[ShadySupportsKHR_driver_properties]) { + caps->properties.driver_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES; + append_pnext((VkBaseOutStructure*) &caps->properties.base, &caps->properties.driver_properties); + } + vkGetPhysicalDeviceProperties2(caps->physical_device, &caps->properties.base); if (caps->supported_extensions[ShadySupportsEXT_subgroup_size_control] || caps->properties.base.properties.apiVersion >= VK_MAKE_VERSION(1, 3, 0)) { @@ -109,12 +114,7 @@ static bool fill_device_properties(VkrDeviceCaps* caps) { caps->subgroup_size.max = caps->properties.subgroup.subgroupSize; caps->subgroup_size.min = caps->properties.subgroup.subgroupSize; } - debug_print("Subgroup size range for device '%s' is [%d; %d]\n", caps->properties.base.properties.deviceName, caps->subgroup_size.min, caps->subgroup_size.max); - - if (caps->supported_extensions[ShadySupportsKHR_driver_properties]) { - caps->properties.driver_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES; - append_pnext((VkBaseOutStructure*) &caps->properties.base, &caps->properties.driver_properties); - } + shd_debug_print("Subgroup size range for device '%s' is [%d; %d]\n", caps->properties.base.properties.deviceName, caps->subgroup_size.min, caps->subgroup_size.max); return true; } @@ -157,7 +157,7 @@ static bool fill_device_features(VkrDeviceCaps* caps) { vkGetPhysicalDeviceFeatures2(caps->physical_device, &caps->features.base); if (!caps->features.subgroup_size_control.computeFullSubgroups) { - warn_print("Potentially broken behaviour on device %s because it does not support computeFullSubgroups", caps->properties.base.properties.deviceName); + shd_warn_print("Potentially broken behaviour on device %s because it does not support computeFullSubgroups", caps->properties.base.properties.deviceName); // TODO just outright reject such devices ? } @@ -183,7 +183,7 @@ static bool fill_queue_properties(VkrDeviceCaps* caps) { } } if (compute_queue_family >= queue_families_count) { - info_print("Rejecting device %s because it lacks a compute queue family\n", caps->properties.base.properties.deviceName); + shd_info_print("Rejecting device %s because it lacks a compute queue family\n", caps->properties.base.properties.deviceName); return false; } caps->compute_queue_family = compute_queue_family; @@ -208,12 +208,16 @@ static bool get_physical_device_caps(SHADY_UNUSED VkrBackend* runtime, VkPhysica return false; } -KeyHash hash_spec_program_key(SpecProgramKey* ptr) { - return hash_murmur(ptr, sizeof(SpecProgramKey)); +KeyHash shd_hash_string(const char** string); +bool shd_compare_string(const char** a, const char** b); + +static KeyHash hash_spec_program_key(SpecProgramKey* ptr) { + return shd_hash(ptr->base, sizeof(Program*)) ^ shd_hash_string(&ptr->entry_point); } -bool cmp_spec_program_keys(SpecProgramKey* a, SpecProgramKey* b) { - return memcmp(a, b, sizeof(SpecProgramKey)) == 0; +static bool cmp_spec_program_keys(SpecProgramKey* a, SpecProgramKey* b) { + assert(!!a & !!b); + return a->base == b->base && strcmp(a->entry_point, b->entry_point) == 0; } static void obtain_device_pointers(VkrDevice* device) { @@ -233,7 +237,7 @@ static VkrDevice* create_vkr_device(SHADY_UNUSED VkrBackend* runtime, VkPhysical VkrDevice* device = calloc(1, sizeof(VkrDevice)); device->runtime = runtime; CHECK(get_physical_device_caps(runtime, physical_device, &device->caps), assert(false)); - info_print("Initialising device %s\n", device->caps.properties.base.properties.deviceName); + shd_info_print("Initialising device %s\n", device->caps.properties.base.properties.deviceName); LARRAY(const char*, enabled_device_exts, ShadySupportedDeviceExtensionsCount); size_t enabled_device_exts_count; @@ -267,7 +271,7 @@ static VkrDevice* create_vkr_device(SHADY_UNUSED VkrBackend* runtime, VkPhysical .flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT }, NULL, &device->cmd_pool), goto delete_device); - device->specialized_programs = new_dict(SpecProgramKey, VkrSpecProgram*, (HashFn) hash_spec_program_key, (CmpFn) cmp_spec_program_keys); + device->specialized_programs = shd_new_dict(SpecProgramKey, VkrSpecProgram*, (HashFn) hash_spec_program_key, (CmpFn) cmp_spec_program_keys); vkGetDeviceQueue(device->device, device->caps.compute_queue_family, 0, &device->compute_queue); @@ -286,10 +290,10 @@ static void shutdown_vkr_device(VkrDevice* device) { size_t i = 0; SpecProgramKey k; VkrSpecProgram* sp; - while (dict_iter(device->specialized_programs, &i, &k, &sp)) { - destroy_specialized_program(sp); + while (shd_dict_iter(device->specialized_programs, &i, &k, &sp)) { + shd_rt_vk_destroy_specialized_program(sp); } - destroy_dict(device->specialized_programs); + shd_destroy_dict(device->specialized_programs); vkDestroyCommandPool(device->device, device->cmd_pool, NULL); vkDestroyDevice(device->device, NULL); free(device); @@ -297,15 +301,15 @@ static void shutdown_vkr_device(VkrDevice* device) { static const char* get_vkr_device_name(VkrDevice* device) { return device->caps.properties.base.properties.deviceName; } -bool probe_vkr_devices(VkrBackend* runtime) { +bool shd_rt_vk_probe_devices(VkrBackend* runtime) { uint32_t devices_count; CHECK_VK(vkEnumeratePhysicalDevices(runtime->instance, &devices_count, NULL), return false) LARRAY(VkPhysicalDevice, available_devices, devices_count); CHECK_VK(vkEnumeratePhysicalDevices(runtime->instance, &devices_count, available_devices), return false) if (devices_count == 0 && !runtime->base.runtime->config.allow_no_devices) { - error_print("No vulkan devices found!\n"); - error_print("You may be able to diagnose this further using `VK_LOADER_DEBUG=all vulkaninfo`.\n"); + shd_error_print("No vulkan devices found!\n"); + shd_error_print("You may be able to diagnose this further using `VK_LOADER_DEBUG=all vulkaninfo`.\n"); return false; } @@ -317,22 +321,22 @@ bool probe_vkr_devices(VkrBackend* runtime) { device->base = (Device) { .cleanup = (void(*)(Device*)) shutdown_vkr_device, .get_name = (String(*)(Device*)) get_vkr_device_name, - .allocate_buffer = (Buffer*(*)(Device*, size_t)) vkr_allocate_buffer_device, - .import_host_memory_as_buffer = (Buffer*(*)(Device*, void*, size_t)) vkr_import_buffer_host, - .launch_kernel = (Command*(*)(Device*, Program*, String, int, int, int, int, void**)) vkr_launch_kernel, - .can_import_host_memory = (bool(*)(Device*)) vkr_can_import_host_memory, + .allocate_buffer = (Buffer* (*)(Device*, size_t)) shd_rt_vk_allocate_buffer_device, + .import_host_memory_as_buffer = (Buffer* (*)(Device*, void*, size_t)) shd_rt_vk_import_buffer_host, + .launch_kernel = (Command* (*)(Device*, Program*, String, int, int, int, int, void**, ExtraKernelOptions*)) shd_rt_vk_launch_kernel, + .can_import_host_memory = (bool (*)(Device*)) shd_rt_vk_can_import_host_memory, }; - append_list(Device*, runtime->base.runtime->devices, device); + shd_list_append(Device*, runtime->base.runtime->devices, device); } } - if (entries_count_list(runtime->base.runtime->devices) == 0 && !runtime->base.runtime->config.allow_no_devices) { - error_print("No __suitable__ vulkan devices found!\n"); - error_print("This is caused by running on weird hardware configurations. Hardware support might get better in the future.\n"); + if (shd_list_count(runtime->base.runtime->devices) == 0 && !runtime->base.runtime->config.allow_no_devices) { + shd_error_print("No __suitable__ vulkan devices found!\n"); + shd_error_print("This is caused by running on weird hardware configurations. Hardware support might get better in the future.\n"); return false; } - info_print("Found %d usable devices\n", entries_count_list(runtime->base.runtime->devices)); + shd_info_print("Found %d usable devices\n", shd_list_count(runtime->base.runtime->devices)); return true; } diff --git a/src/runtime/vulkan/vk_runtime_dispatch.c b/src/runtime/vulkan/vk_runtime_dispatch.c index 5685e22c5..3207410f5 100644 --- a/src/runtime/vulkan/vk_runtime_dispatch.c +++ b/src/runtime/vulkan/vk_runtime_dispatch.c @@ -30,7 +30,7 @@ static void bind_program_resources(VkrCommand* cmd, VkrSpecProgram* prog) { write_descriptor_sets[write_descriptor_sets_count] = (VkWriteDescriptorSet) { .sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET, .pNext = NULL, - .descriptorType = as_to_descriptor_type(resource->as), + .descriptorType = shd_rt_vk_as_to_descriptor_type(resource->as), .descriptorCount = 1, .dstSet = prog->sets[resource->set], .dstBinding = resource->binding, @@ -53,18 +53,18 @@ static void bind_program_resources(VkrCommand* cmd, VkrSpecProgram* prog) { static Command make_command_base() { return (Command) { - .wait_for_completion = (bool(*)(Command*)) vkr_wait_completion, + .wait_for_completion = (bool (*)(Command*)) shd_rt_vk_wait_completion, }; } -VkrCommand* vkr_launch_kernel(VkrDevice* device, Program* program, String entry_point, int dimx, int dimy, int dimz, int args_count, void** args) { +VkrCommand* shd_rt_vk_launch_kernel(VkrDevice* device, Program* program, String entry_point, int dimx, int dimy, int dimz, int args_count, void** args, ExtraKernelOptions* options) { assert(program && device); - VkrSpecProgram* prog = get_specialized_program(program, entry_point, device); + VkrSpecProgram* prog = shd_rt_vk_get_specialized_program(program, entry_point, device); - debug_print("Dispatching kernel on %s\n", device->caps.properties.base.properties.deviceName); + shd_debug_print("Dispatching kernel on %s\n", device->caps.properties.base.properties.deviceName); - VkrCommand* cmd = vkr_begin_command(device); + VkrCommand* cmd = shd_rt_vk_begin_command(device); if (!cmd) return NULL; @@ -82,19 +82,37 @@ VkrCommand* vkr_launch_kernel(VkrDevice* device, Program* program, String entry_ vkCmdBindPipeline(cmd->cmd_buf, VK_PIPELINE_BIND_POINT_COMPUTE, prog->pipeline); bind_program_resources(cmd, prog); + + if (options && options->profiled_gpu_time) { + VkQueryPoolCreateInfo qpci = { + .sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO, + .pNext = NULL, + .queryType = VK_QUERY_TYPE_TIMESTAMP, + .queryCount = 2, + }; + CHECK_VK(vkCreateQueryPool(device->device, &qpci, NULL, &cmd->query_pool), {}); + cmd->profiled_gpu_time = options->profiled_gpu_time; + vkCmdResetQueryPool(cmd->cmd_buf, cmd->query_pool, 0, 2); + vkCmdWriteTimestamp(cmd->cmd_buf, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, cmd->query_pool, 0); + } + vkCmdDispatch(cmd->cmd_buf, dimx, dimy, dimz); - if (!vkr_submit_command(cmd)) + if (options && options->profiled_gpu_time) { + vkCmdWriteTimestamp(cmd->cmd_buf, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, cmd->query_pool, 1); + } + + if (!shd_rt_vk_submit_command(cmd)) goto err_post_commands_create; return cmd; err_post_commands_create: - vkr_destroy_command(cmd); + shd_rt_vk_destroy_command(cmd); return NULL; } -VkrCommand* vkr_begin_command(VkrDevice* device) { +VkrCommand* shd_rt_vk_begin_command(VkrDevice* device) { VkrCommand* cmd = calloc(1, sizeof(VkrCommand)); cmd->base = make_command_base(); cmd->device = device; @@ -123,7 +141,7 @@ VkrCommand* vkr_begin_command(VkrDevice* device) { return NULL; } -bool vkr_submit_command(VkrCommand* cmd) { +bool shd_rt_vk_submit_command(VkrCommand* cmd) { CHECK_VK(vkEndCommandBuffer(cmd->cmd_buf), return false); CHECK_VK(vkCreateFence(cmd->device->device, &(VkFenceCreateInfo) { @@ -150,16 +168,23 @@ bool vkr_submit_command(VkrCommand* cmd) { return false; } -bool vkr_wait_completion(VkrCommand* cmd) { +bool shd_rt_vk_wait_completion(VkrCommand* cmd) { assert(cmd->submitted && "Command must be submitted before they can be waited on"); CHECK_VK(vkWaitForFences(cmd->device->device, 1, (VkFence[]) { cmd->done_fence }, true, UINT32_MAX), return false); - vkr_destroy_command(cmd); + if (cmd->profiled_gpu_time) { + uint64_t ts[2]; + CHECK_VK(vkGetQueryPoolResults(cmd->device->device, cmd->query_pool, 0, 2, sizeof(uint64_t) * 2, ts, sizeof(uint64_t), VK_QUERY_RESULT_64_BIT), {}); + *cmd->profiled_gpu_time = (ts[1] - ts[0]) * cmd->device->caps.properties.base.properties.limits.timestampPeriod; + } + shd_rt_vk_destroy_command(cmd); return true; } -void vkr_destroy_command(VkrCommand* cmd) { +void shd_rt_vk_destroy_command(VkrCommand* cmd) { if (cmd->submitted) vkDestroyFence(cmd->device->device, cmd->done_fence, NULL); + if (cmd->query_pool) + vkDestroyQueryPool(cmd->device->device, cmd->query_pool, NULL); vkFreeCommandBuffers(cmd->device->device, cmd->device->cmd_pool, 1, &cmd->cmd_buf); free(cmd); } diff --git a/src/runtime/vulkan/vk_runtime_private.h b/src/runtime/vulkan/vk_runtime_private.h index 5df5a1511..4625cf1e6 100644 --- a/src/runtime/vulkan/vk_runtime_private.h +++ b/src/runtime/vulkan/vk_runtime_private.h @@ -62,7 +62,7 @@ SHADY_UNUSED static const bool is_instance_ext_required[] = { INSTANCE_EXTENSION SHADY_UNUSED static const bool is_device_ext_required[] = { DEVICE_EXTENSIONS(R) }; #undef R -#define CHECK_VK(x, failure_handler) { VkResult the_result_ = x; if (the_result_ != VK_SUCCESS) { error_print(#x " failed (code %d)\n", the_result_); failure_handler; } } +#define CHECK_VK(x, failure_handler) { VkResult the_result_ = x; if (the_result_ != VK_SUCCESS) { shd_error_print(#x " failed (code %d)\n", the_result_); failure_handler; } } typedef struct VkrSpecProgram_ VkrSpecProgram; @@ -156,7 +156,7 @@ struct VkrDevice_ { struct Dict* specialized_programs; }; -bool probe_vkr_devices(VkrBackend*); +bool shd_rt_vk_probe_devices(VkrBackend* runtime); typedef struct VkrBuffer_ { Buffer base; @@ -169,9 +169,10 @@ typedef struct VkrBuffer_ { void* host_ptr; } VkrBuffer; -VkrBuffer* vkr_allocate_buffer_device(VkrDevice* device, size_t size); -VkrBuffer* vkr_import_buffer_host(VkrDevice* device, void* ptr, size_t size); -bool vkr_can_import_host_memory(VkrDevice* device); +VkrBuffer* shd_rt_vk_allocate_buffer_device(VkrDevice* device, size_t size); +VkrBuffer* shd_rt_vk_import_buffer_host(VkrDevice* device, void* ptr, size_t size); +bool shd_rt_vk_can_import_host_memory(VkrDevice* device); +void shd_rt_vk_destroy_buffer(VkrBuffer* buffer); typedef struct VkrCommand_ VkrCommand; @@ -181,21 +182,17 @@ struct VkrCommand_ { VkCommandBuffer cmd_buf; VkFence done_fence; bool submitted; -}; -VkrCommand* vkr_begin_command(VkrDevice* device); -bool vkr_submit_command(VkrCommand* commands); -void vkr_destroy_command(VkrCommand* commands); -bool vkr_wait_completion(VkrCommand* cmd); + uint64_t* profiled_gpu_time; + VkQueryPool query_pool; +}; -VkrCommand* vkr_launch_kernel(VkrDevice* device, Program* program, String entry_point, int dimx, int dimy, int dimz, int args_count, void** args); +VkrCommand* shd_rt_vk_begin_command(VkrDevice* device); +bool shd_rt_vk_submit_command(VkrCommand* cmd); +void shd_rt_vk_destroy_command(VkrCommand* cmd); +bool shd_rt_vk_wait_completion(VkrCommand* cmd); -typedef struct { - size_t num_args; - const size_t* arg_offset; - const size_t* arg_size; - size_t args_size; -} ProgramParamsInfo; +VkrCommand* shd_rt_vk_launch_kernel(VkrDevice* device, Program* program, String entry_point, int dimx, int dimy, int dimz, int args_count, void** args, ExtraKernelOptions* options); typedef struct ProgramResourceInfo_ ProgramResourceInfo; struct ProgramResourceInfo_ { @@ -214,7 +211,7 @@ struct ProgramResourceInfo_ { size_t size; VkrBuffer* buffer; - char* default_data; + unsigned char* default_data; }; typedef struct { @@ -224,7 +221,7 @@ typedef struct { #define MAX_DESCRIPTOR_SETS 4 -VkDescriptorType as_to_descriptor_type(AddressSpace as); +VkDescriptorType shd_rt_vk_as_to_descriptor_type(AddressSpace as); struct VkrSpecProgram_ { SpecProgramKey key; @@ -251,8 +248,8 @@ struct VkrSpecProgram_ { VkDescriptorSet sets[MAX_DESCRIPTOR_SETS]; }; -VkrSpecProgram* get_specialized_program(Program*, String ep, VkrDevice*); -void destroy_specialized_program(VkrSpecProgram*); +VkrSpecProgram* shd_rt_vk_get_specialized_program(Program* program, String entry_point, VkrDevice* device); +void shd_rt_vk_destroy_specialized_program(VkrSpecProgram* spec); static inline void append_pnext(VkBaseOutStructure* s, void* n) { while (s->pNext != NULL) diff --git a/src/runtime/vulkan/vk_runtime_program.c b/src/runtime/vulkan/vk_runtime_program.c index a529b8d4e..a60354969 100644 --- a/src/runtime/vulkan/vk_runtime_program.c +++ b/src/runtime/vulkan/vk_runtime_program.c @@ -1,122 +1,18 @@ #include "vk_runtime_private.h" +#include "shady/driver.h" +#include "shady/ir/memory_layout.h" + #include "log.h" #include "portability.h" #include "dict.h" -#include "list.h" #include "growy.h" - #include "arena.h" #include "util.h" -#include "../../shady/transform/memory_layout.h" - #include #include -static bool extract_parameters_info(VkrSpecProgram* program) { - Nodes decls = get_module_declarations(program->specialized_module); - - const Node* args_struct_annotation; - const Node* args_struct_type = NULL; - const Node* entry_point_function = NULL; - - for (int i = 0; i < decls.count; ++i) { - const Node* node = decls.nodes[i]; - - switch (node->tag) { - case GlobalVariable_TAG: { - const Node* entry_point_args_annotation = lookup_annotation(node, "EntryPointArgs"); - if (entry_point_args_annotation) { - if (node->payload.global_variable.type->tag != RecordType_TAG) { - error_print("EntryPointArgs must be a struct\n"); - return false; - } - - if (args_struct_type) { - error_print("there cannot be more than one EntryPointArgs\n"); - return false; - } - - args_struct_annotation = entry_point_args_annotation; - args_struct_type = node->payload.global_variable.type; - } - break; - } - case Function_TAG: { - if (lookup_annotation(node, "EntryPoint")) { - if (node->payload.fun.params.count != 0) { - error_print("EntryPoint cannot have parameters\n"); - return false; - } - - if (entry_point_function) { - error_print("there cannot be more than one EntryPoint\n"); - return false; - } - - entry_point_function = node; - } - break; - } - default: break; - } - } - - if (!entry_point_function) { - error_print("could not find EntryPoint\n"); - return false; - } - - if (!args_struct_type) { - program->parameters = (ProgramParamsInfo) { .num_args = 0 }; - return true; - } - - if (args_struct_annotation->tag != AnnotationValue_TAG) { - error_print("EntryPointArgs annotation must contain exactly one value\n"); - return false; - } - - const Node* annotation_fn = args_struct_annotation->payload.annotation_value.value; - assert(annotation_fn->tag == FnAddr_TAG); - if (annotation_fn->payload.fn_addr.fn != entry_point_function) { - error_print("EntryPointArgs annotation refers to different EntryPoint\n"); - return false; - } - - size_t num_args = args_struct_type->payload.record_type.members.count; - - if (num_args == 0) { - error_print("EntryPointArgs cannot be empty\n"); - return false; - } - - IrArena* a = get_module_arena(program->specialized_module); - - LARRAY(FieldLayout, fields, num_args); - get_record_layout(a, args_struct_type, fields); - - size_t* offset_size_buffer = calloc(1, 2 * num_args * sizeof(size_t)); - if (!offset_size_buffer) { - error_print("failed to allocate EntryPointArgs offsets and sizes array\n"); - return false; - } - size_t* offsets = offset_size_buffer; - size_t* sizes = offset_size_buffer + num_args; - - for (int i = 0; i < num_args; ++i) { - offsets[i] = fields[i].offset_in_bytes; - sizes[i] = fields[i].mem_layout.size_in_bytes; - } - - program->parameters.num_args = num_args; - program->parameters.arg_offset = offsets; - program->parameters.arg_size = sizes; - program->parameters.args_size = offsets[num_args - 1] + sizes[num_args - 1]; - return true; -} - static void register_required_descriptors(VkrSpecProgram* program, VkDescriptorSetLayoutBinding* binding) { assert(binding->descriptorCount > 0); size_t i = 0; @@ -130,32 +26,72 @@ static void register_required_descriptors(VkrSpecProgram* program, VkDescriptorS static void add_binding(VkDescriptorSetLayoutCreateInfo* layout_create_info, Growy** bindings_lists, int set, VkDescriptorSetLayoutBinding binding) { if (bindings_lists[set] == NULL) { - bindings_lists[set] = new_growy(); - layout_create_info[set].pBindings = (const VkDescriptorSetLayoutBinding*) growy_data(bindings_lists[set]); + bindings_lists[set] = shd_new_growy(); + layout_create_info[set].pBindings = (const VkDescriptorSetLayoutBinding*) shd_growy_data(bindings_lists[set]); } layout_create_info[set].bindingCount += 1; - growy_append_object(bindings_lists[set], binding); + shd_growy_append_object(bindings_lists[set], binding); } -VkDescriptorType as_to_descriptor_type(AddressSpace as) { +VkDescriptorType shd_rt_vk_as_to_descriptor_type(AddressSpace as) { switch (as) { case AsUniform: return VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; case AsShaderStorageBufferObject: return VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - default: error("No mapping to a descriptor type"); + default: shd_error("No mapping to a descriptor type"); + } +} + +static void write_value(unsigned char* tgt, const Node* value) { + IrArena* a = value->arena; + switch (value->tag) { + case IntLiteral_TAG: { + switch (value->payload.int_literal.width) { + case IntTy8: *((uint8_t*) tgt) = (uint8_t) (value->payload.int_literal.value & 0xFF); break; + case IntTy16: *((uint16_t*) tgt) = (uint16_t) (value->payload.int_literal.value & 0xFFFF); break; + case IntTy32: *((uint32_t*) tgt) = (uint32_t) (value->payload.int_literal.value & 0xFFFFFFFF); break; + case IntTy64: *((uint64_t*) tgt) = (uint64_t) (value->payload.int_literal.value); break; + } + break; + } + case Composite_TAG: { + Nodes values = value->payload.composite.contents; + const Type* struct_t = value->payload.composite.type; + struct_t = shd_get_maybe_nominal_type_body(struct_t); + + if (struct_t->tag == RecordType_TAG) { + LARRAY(FieldLayout, fields, values.count); + shd_get_record_layout(a, struct_t, fields); + for (size_t i = 0; i < values.count; i++) { + // TypeMemLayout layout = get_mem_layout(value->arena, get_unqualified_type(element->type)); + write_value(tgt + fields->offset_in_bytes, values.nodes[i]); + } + } else if (struct_t->tag == ArrType_TAG) { + for (size_t i = 0; i < values.count; i++) { + TypeMemLayout layout = shd_get_mem_layout(value->arena, shd_get_unqualified_type(values.nodes[i]->type)); + write_value(tgt, values.nodes[i]); + tgt += layout.size_in_bytes; + } + } else { + assert(false); + } + break; + } + default: + assert(false); } } static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLayout layouts[]) { VkDescriptorSetLayoutCreateInfo layout_create_infos[MAX_DESCRIPTOR_SETS] = { 0 }; Growy* bindings_lists[MAX_DESCRIPTOR_SETS] = { 0 }; - Growy* resources = new_growy(); + Growy* resources = shd_new_growy(); - Nodes decls = get_module_declarations(program->specialized_module); + Nodes decls = shd_module_get_declarations(program->specialized_module); for (size_t i = 0; i < decls.count; i++) { const Node* decl = decls.nodes[i]; if (decl->tag != GlobalVariable_TAG) continue; - if (lookup_annotation(decl, "Constants")) { + if (shd_lookup_annotation(decl, "Constants")) { AddressSpace as = decl->payload.global_variable.address_space; switch (as) { case AsShaderStorageBufferObject: @@ -163,17 +99,17 @@ static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLay default: continue; } - int set = get_int_literal_value(*resolve_to_int_literal(get_annotation_value(lookup_annotation(decl, "DescriptorSet"))), false); - int binding = get_int_literal_value(*resolve_to_int_literal(get_annotation_value(lookup_annotation(decl, "DescriptorBinding"))), false); + int set = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_annotation_value(shd_lookup_annotation(decl, "DescriptorSet"))), false); + int binding = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_annotation_value(shd_lookup_annotation(decl, "DescriptorBinding"))), false); - ProgramResourceInfo* res_info = arena_alloc(program->arena, sizeof(ProgramResourceInfo)); + ProgramResourceInfo* res_info = shd_arena_alloc(program->arena, sizeof(ProgramResourceInfo)); *res_info = (ProgramResourceInfo) { .is_bound = true, .as = as, .set = set, .binding = binding, }; - growy_append_object(resources, res_info); + shd_growy_append_object(resources, res_info); program->resources.num_resources++; const Type* struct_t = decl->payload.global_variable.type; @@ -181,14 +117,16 @@ static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLay for (size_t j = 0; j < struct_t->payload.record_type.members.count; j++) { const Type* member_t = struct_t->payload.record_type.members.nodes[j]; - TypeMemLayout layout = get_mem_layout(program->specialized_module->arena, member_t); + assert(member_t->tag == PtrType_TAG); + member_t = shd_get_pointee_type(member_t->arena, member_t); + TypeMemLayout layout = shd_get_mem_layout(shd_module_get_arena(program->specialized_module), member_t); - ProgramResourceInfo* constant_res_info = arena_alloc(program->arena, sizeof(ProgramResourceInfo)); + ProgramResourceInfo* constant_res_info = shd_arena_alloc(program->arena, sizeof(ProgramResourceInfo)); *constant_res_info = (ProgramResourceInfo) { .parent = res_info, .as = as, }; - growy_append_object(resources, constant_res_info); + shd_growy_append_object(resources, constant_res_info); program->resources.num_resources++; constant_res_info->size = layout.size_in_bytes; @@ -196,16 +134,25 @@ static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLay res_info->size += sizeof(void*); // TODO initial value + Nodes annotations = get_declaration_annotations(decl); + for (size_t k = 0; k < annotations.count; k++) { + const Node* a = annotations.nodes[k]; + if ((strcmp(get_annotation_name(a), "InitialValue") == 0) && shd_resolve_to_int_literal(shd_first(shd_get_annotation_values(a)))->value == j) { + constant_res_info->default_data = calloc(1, layout.size_in_bytes); + write_value(constant_res_info->default_data, shd_get_annotation_values(a).nodes[1]); + //printf("wowie"); + } + } } - if (vkr_can_import_host_memory(program->device)) + if (shd_rt_vk_can_import_host_memory(program->device)) res_info->host_backed_allocation = true; else res_info->staging = calloc(1, res_info->size); VkDescriptorSetLayoutBinding vk_binding = { .binding = binding, - .descriptorType = as_to_descriptor_type(as), + .descriptorType = shd_rt_vk_as_to_descriptor_type(as), .descriptorCount = 1, .stageFlags = VK_SHADER_STAGE_ALL, .pImmutableSamplers = NULL, @@ -216,25 +163,24 @@ static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLay } for (size_t set = 0; set < MAX_DESCRIPTOR_SETS; set++) { - layouts[set] = NULL; + layouts[set] = 0; layout_create_infos[set].sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; layout_create_infos[set].flags = 0; layout_create_infos[set].pNext = NULL; vkCreateDescriptorSetLayout(program->device->device, &layout_create_infos[set], NULL, &layouts[set]); if (bindings_lists[set] != NULL) { - destroy_growy(bindings_lists[set]); + shd_destroy_growy(bindings_lists[set]); } } - program->resources.resources = (ProgramResourceInfo**) growy_deconstruct(resources); + program->resources.resources = (ProgramResourceInfo**) shd_growy_deconstruct(resources); return true; } static bool extract_layout(VkrSpecProgram* program) { - CHECK(extract_parameters_info(program), return false); if (program->parameters.args_size > program->device->caps.properties.base.properties.limits.maxPushConstantsSize) { - error_print("EntryPointArgs exceed available push constant space\n"); + shd_error_print("EntryPointArgs exceed available push constant space\n"); return false; } VkPushConstantRange push_constant_ranges[1] = { @@ -312,36 +258,148 @@ static CompilerConfig get_compiler_config_for_device(VkrDevice* device, const Co config.lower.int64 = !device->caps.features.base.features.shaderInt64; if (device->caps.implementation.is_moltenvk) { - warn_print("Hack: MoltenVK says they supported subgroup extended types, but it's a lie. 64-bit types are unaccounted for !\n"); + shd_warn_print("Hack: MoltenVK says they supported subgroup extended types, but it's a lie. 64-bit types are unaccounted for !\n"); config.lower.emulate_subgroup_ops_extended_types = true; - warn_print("Hack: MoltenVK does not support pointers to unsized arrays properly.\n"); + shd_warn_print("Hack: MoltenVK does not support pointers to unsized arrays properly.\n"); config.lower.decay_ptrs = true; } if (device->caps.properties.driver_properties.driverID == VK_DRIVER_ID_NVIDIA_PROPRIETARY) { - warn_print("Hack: NVidia somehow has unreliable broadcast_first. Emulating it with shuffles seemingly fixes the issue.\n"); + shd_warn_print("Hack: NVidia somehow has unreliable broadcast_first. Emulating it with shuffles seemingly fixes the issue.\n"); config.hacks.spv_shuffle_instead_of_broadcast_first = true; } return config; } +static bool extract_parameters_info(ProgramParamsInfo* parameters, Module* mod) { + Nodes decls = shd_module_get_declarations(mod); + + const Node* args_struct_annotation; + const Node* args_struct_type = NULL; + const Node* entry_point_function = NULL; + + for (int i = 0; i < decls.count; ++i) { + const Node* node = decls.nodes[i]; + + switch (node->tag) { + case GlobalVariable_TAG: { + const Node* entry_point_args_annotation = shd_lookup_annotation(node, "EntryPointArgs"); + if (entry_point_args_annotation) { + if (node->payload.global_variable.type->tag != RecordType_TAG) { + shd_error_print("EntryPointArgs must be a struct\n"); + return false; + } + + if (args_struct_type) { + shd_error_print("there cannot be more than one EntryPointArgs\n"); + return false; + } + + args_struct_annotation = entry_point_args_annotation; + args_struct_type = node->payload.global_variable.type; + } + break; + } + case Function_TAG: { + if (shd_lookup_annotation(node, "EntryPoint")) { + if (node->payload.fun.params.count != 0) { + shd_error_print("EntryPoint cannot have parameters\n"); + return false; + } + + if (entry_point_function) { + shd_error_print("there cannot be more than one EntryPoint\n"); + return false; + } + + entry_point_function = node; + } + break; + } + default: break; + } + } + + if (!entry_point_function) { + shd_error_print("could not find EntryPoint\n"); + return false; + } + + if (!args_struct_type) { + *parameters = (ProgramParamsInfo) { .num_args = 0 }; + return true; + } + + if (args_struct_annotation->tag != AnnotationValue_TAG) { + shd_error_print("EntryPointArgs annotation must contain exactly one value\n"); + return false; + } + + const Node* annotation_fn = args_struct_annotation->payload.annotation_value.value; + assert(annotation_fn->tag == FnAddr_TAG); + if (annotation_fn->payload.fn_addr.fn != entry_point_function) { + shd_error_print("EntryPointArgs annotation refers to different EntryPoint\n"); + return false; + } + + size_t num_args = args_struct_type->payload.record_type.members.count; + + if (num_args == 0) { + shd_error_print("EntryPointArgs cannot be shd_empty\n"); + return false; + } + + IrArena* a = shd_module_get_arena(mod); + + LARRAY(FieldLayout, fields, num_args); + shd_get_record_layout(a, args_struct_type, fields); + + size_t* offset_size_buffer = calloc(1, 2 * num_args * sizeof(size_t)); + if (!offset_size_buffer) { + shd_error_print("failed to allocate EntryPointArgs offsets and sizes array\n"); + return false; + } + size_t* offsets = offset_size_buffer; + size_t* sizes = offset_size_buffer + num_args; + + for (int i = 0; i < num_args; ++i) { + offsets[i] = fields[i].offset_in_bytes; + sizes[i] = fields[i].mem_layout.size_in_bytes; + } + + parameters->num_args = num_args; + parameters->arg_offset = offsets; + parameters->arg_size = sizes; + parameters->args_size = offsets[num_args - 1] + sizes[num_args - 1]; + return true; +} + static bool compile_specialized_program(VkrSpecProgram* spec) { CompilerConfig config = get_compiler_config_for_device(spec->device, spec->key.base->base_config); config.specialization.entry_point = spec->key.entry_point; - CHECK(run_compiler_passes(&config, &spec->specialized_module) == CompilationNoError, return false); + CHECK(shd_run_compiler_passes(&config, &spec->specialized_module) == CompilationNoError, return false); + + Module* final_mod; + shd_emit_spirv(&config, spec->specialized_module, &spec->spirv_size, &spec->spirv_bytes, &final_mod); + + CHECK(extract_parameters_info(&spec->parameters, final_mod), return false); - Module* new_mod; - emit_spirv(&config, spec->specialized_module, &spec->spirv_size, &spec->spirv_bytes, &new_mod); - spec->specialized_module = new_mod; + spec->specialized_module = final_mod; if (spec->key.base->runtime->config.dump_spv) { - String module_name = get_module_name(spec->specialized_module); - String file_name = format_string_new("%s.spv", module_name); - write_file(file_name, spec->spirv_size, (const char*) spec->spirv_bytes); + String module_name = shd_module_get_name(spec->specialized_module); + String file_name = shd_format_string_new("%s.spv", module_name); + shd_write_file(file_name, spec->spirv_size, (const char*) spec->spirv_bytes); free((void*) file_name); } + String override_file = getenv("SHADY_OVERRIDE_SPV"); + if (override_file) { + shd_read_file(override_file, &spec->spirv_size, &spec->spirv_bytes); + return true; + } + return true; } @@ -374,7 +432,7 @@ static void flush_staged_data(VkrSpecProgram* program) { for (size_t i = 0; i < program->resources.num_resources; i++) { ProgramResourceInfo* resource = program->resources.resources[i]; if (resource->staging) { - copy_to_buffer(resource->buffer, 0, resource->buffer, resource->size); + shd_rt_copy_to_buffer((Buffer*) resource->buffer, 0, resource->buffer, resource->size); free(resource->staging); } } @@ -385,18 +443,20 @@ static bool prepare_resources(VkrSpecProgram* program) { ProgramResourceInfo* resource = program->resources.resources[i]; if (resource->host_backed_allocation) { - assert(vkr_can_import_host_memory(program->device)); - resource->host_ptr = alloc_aligned(resource->size, program->device->caps.properties.external_memory_host.minImportedHostPointerAlignment); - resource->buffer = import_buffer_host(program->device, resource->host_ptr, resource->size); + assert(shd_rt_vk_can_import_host_memory(program->device)); + resource->host_ptr = shd_alloc_aligned(resource->size, program->device->caps.properties.external_memory_host.minImportedHostPointerAlignment); + resource->buffer = shd_rt_vk_import_buffer_host(program->device, resource->host_ptr, resource->size); } else { - resource->buffer = allocate_buffer_device(program->device, resource->size); + resource->buffer = shd_rt_vk_allocate_buffer_device(program->device, resource->size); } - // TODO: initial data! - // if (!resource->host_owned) - char* zeroes = calloc(1, resource->size); - copy_to_buffer(resource->buffer, 0, zeroes, resource->size); - free(zeroes); + if (resource->default_data) { + shd_rt_copy_to_buffer((Buffer*) resource->buffer, 0, resource->default_data, resource->size); + } else { + char* zeroes = calloc(1, resource->size); + shd_rt_copy_to_buffer((Buffer*) resource->buffer, 0, zeroes, resource->size); + free(zeroes); + } if (resource->parent) { char* dst = resource->parent->host_ptr; @@ -404,7 +464,7 @@ static bool prepare_resources(VkrSpecProgram* program) { dst = resource->parent->staging; } assert(dst); - *((uint64_t*) (dst + resource->offset)) = get_buffer_device_pointer(resource->buffer); + *((uint64_t*) (dst + resource->offset)) = shd_rt_get_buffer_device_pointer((Buffer*) resource->buffer); } } @@ -421,7 +481,7 @@ static VkrSpecProgram* create_specialized_program(SpecProgramKey key, VkrDevice* spec_program->key = key; spec_program->device = device; spec_program->specialized_module = key.base->module; - spec_program->arena = new_arena(); + spec_program->arena = shd_new_arena(); CHECK(compile_specialized_program(spec_program), return NULL); CHECK(extract_layout(spec_program), return NULL); @@ -431,36 +491,36 @@ static VkrSpecProgram* create_specialized_program(SpecProgramKey key, VkrDevice* return spec_program; } -VkrSpecProgram* get_specialized_program(Program* program, String entry_point, VkrDevice* device) { +VkrSpecProgram* shd_rt_vk_get_specialized_program(Program* program, String entry_point, VkrDevice* device) { SpecProgramKey key = { .base = program, .entry_point = entry_point }; - VkrSpecProgram** found = find_value_dict(SpecProgramKey, VkrSpecProgram*, device->specialized_programs, key); + VkrSpecProgram** found = shd_dict_find_value(SpecProgramKey, VkrSpecProgram*, device->specialized_programs, key); if (found) return *found; VkrSpecProgram* spec = create_specialized_program(key, device); assert(spec); - insert_dict(SpecProgramKey, VkrSpecProgram*, device->specialized_programs, key, spec); + shd_dict_insert(SpecProgramKey, VkrSpecProgram*, device->specialized_programs, key, spec); return spec; } -void destroy_specialized_program(VkrSpecProgram* spec) { +void shd_rt_vk_destroy_specialized_program(VkrSpecProgram* spec) { vkDestroyPipeline(spec->device->device, spec->pipeline, NULL); for (size_t set = 0; set < MAX_DESCRIPTOR_SETS; set++) vkDestroyDescriptorSetLayout(spec->device->device, spec->set_layouts[set], NULL); vkDestroyPipelineLayout(spec->device->device, spec->layout, NULL); vkDestroyShaderModule(spec->device->device, spec->shader_module, NULL); - free(spec->parameters.arg_offset); + free( (void*) spec->parameters.arg_offset); free(spec->spirv_bytes); - if (get_module_arena(spec->specialized_module) != get_module_arena(spec->key.base->module)) - destroy_ir_arena(get_module_arena(spec->specialized_module)); + if (shd_module_get_arena(spec->specialized_module) != shd_module_get_arena(spec->key.base->module)) + shd_destroy_ir_arena(shd_module_get_arena(spec->specialized_module)); for (size_t i = 0; i < spec->resources.num_resources; i++) { ProgramResourceInfo* resource = spec->resources.resources[i]; if (resource->buffer) - destroy_buffer(resource->buffer); + shd_rt_vk_destroy_buffer(resource->buffer); if (resource->host_ptr && resource->host_backed_allocation) - free_aligned(resource->host_ptr); + shd_free_aligned(resource->host_ptr); } free(spec->resources.resources); vkDestroyDescriptorPool(spec->device->device, spec->descriptor_pool, NULL); - destroy_arena(spec->arena); + shd_destroy_arena(spec->arena); free(spec); } diff --git a/src/shady/CMakeLists.txt b/src/shady/CMakeLists.txt index 40fbfb908..26a959700 100644 --- a/src/shady/CMakeLists.txt +++ b/src/shady/CMakeLists.txt @@ -8,93 +8,41 @@ add_subdirectory(api) get_target_property(SPIRV_HEADERS_INCLUDE_DIRS SPIRV-Headers::SPIRV-Headers INTERFACE_INCLUDE_DIRECTORIES) -add_generated_file(FILE_NAME type_generated.h TARGET_NAME type_generated SOURCES generator_type.c) +add_generated_file(FILE_NAME type_generated.c TARGET_NAME type_generated SOURCES generator_type.c) add_generated_file(FILE_NAME node_generated.c TARGET_NAME node_generated SOURCES generator_node.c) add_generated_file(FILE_NAME primops_generated.c TARGET_NAME primops_generated SOURCES generator_primops.c) add_generated_file(FILE_NAME constructors_generated.c TARGET_NAME constructors_generated SOURCES generator_constructors.c) add_generated_file(FILE_NAME visit_generated.c TARGET_NAME visit_generated SOURCES generator_visit.c) add_generated_file(FILE_NAME rewrite_generated.c TARGET_NAME rewrite_generated SOURCES generator_rewrite.c) +add_generated_file(FILE_NAME print_generated.c TARGET_NAME print_generated SOURCES generator_print.c) add_library(shady_generated INTERFACE) -add_dependencies(shady_generated node_generated primops_generated type_generated constructors_generated visit_generated rewrite_generated) +add_dependencies(shady_generated node_generated primops_generated type_generated constructors_generated visit_generated rewrite_generated print_generated) target_include_directories(shady_generated INTERFACE "$") target_link_libraries(api INTERFACE "$") -set(SHADY_SOURCES +add_library(shady STATIC) + +target_sources(shady PRIVATE ir.c node.c - constructors.c - type.c - type_helpers.c + check.c primops.c - builtins.c rewrite.c visit.c print.c fold.c body_builder.c compile.c - annotation.c - module.c - - analysis/scope.c - analysis/free_variables.c - analysis/verify.c - analysis/callgraph.c - analysis/uses.c - analysis/looptree.c - analysis/leak.c - - transform/memory_layout.c - transform/ir_gen_helpers.c - transform/internal_constants.c - - passes/import.c - passes/cleanup.c - passes/bind.c - passes/normalize.c - passes/infer.c - passes/lower_cf_instrs.c - passes/lift_indirect_targets.c - passes/lcssa.c - passes/lower_callf.c - passes/lower_stack.c - passes/lower_lea.c - passes/lower_physical_ptrs.c - passes/lower_generic_ptrs.c - passes/lower_memory_layout.c - passes/lower_memcpy.c - passes/lower_decay_ptrs.c - passes/lower_tailcalls.c - passes/lower_mask.c - passes/lower_fill.c - passes/lower_switch_btree.c - passes/setup_stack_frames.c - passes/eliminate_constants.c - passes/normalize_builtins.c - passes/lower_subgroup_ops.c - passes/lower_subgroup_vars.c - passes/lower_int64.c - passes/lower_vec_arr.c - passes/lower_workgroups.c - passes/lower_generic_globals.c - passes/mark_leaf_functions.c - passes/opt_inline.c - passes/opt_stack.c - passes/opt_restructure.c - passes/opt_mem2reg.c - passes/reconvergence_heuristics.c - passes/simt2d.c - passes/specialize_entry_point.c - passes/specialize_execution_model.c - - passes/lower_entrypoint_args.c - passes/spirv_map_entrypoint_args.c - passes/spirv_lift_globals_ssbo.c + config.c ) -add_library(shady STATIC ${SHADY_SOURCES}) -set_property(TARGET shady PROPERTY POSITION_INDEPENDENT_CODE ON) +add_subdirectory(analysis) +add_subdirectory(transform) +add_subdirectory(passes) +add_subdirectory(ir) + +target_include_directories(shady PUBLIC $) if (WIN32) if (MSVC) @@ -105,9 +53,10 @@ if (WIN32) endif() add_subdirectory(internal) -add_subdirectory(emit) -target_link_libraries(shady PUBLIC "api") -target_link_libraries(shady PUBLIC "$") -target_link_libraries(shady PRIVATE "$") +target_link_libraries(shady PRIVATE "api") +target_link_libraries(shady PRIVATE "common") target_link_libraries(shady PRIVATE "$") + +install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/shady DESTINATION include) +#install(TARGETS shady EXPORT shady_export_set ARCHIVE DESTINATION ${CMAKE_ARCHIVE_OUTPUT_DIRECTORY}) diff --git a/src/shady/analysis/CMakeLists.txt b/src/shady/analysis/CMakeLists.txt new file mode 100644 index 000000000..778f0f643 --- /dev/null +++ b/src/shady/analysis/CMakeLists.txt @@ -0,0 +1,12 @@ +target_sources(shady PRIVATE + cfg.c + cfg_dump.c + free_frontier.c + verify.c + callgraph.c + uses.c + looptree.c + leak.c + scheduler.c + literal.c +) diff --git a/src/shady/analysis/callgraph.c b/src/shady/analysis/callgraph.c index 69f3f7e6a..482a8a090 100644 --- a/src/shady/analysis/callgraph.c +++ b/src/shady/analysis/callgraph.c @@ -1,24 +1,25 @@ #include "callgraph.h" +#include "uses.h" #include "list.h" #include "dict.h" - #include "portability.h" #include "log.h" -#include "../visit.h" +#include "shady/visit.h" #include +#include #include -KeyHash hash_node(const Node**); -bool compare_node(const Node**, const Node**); +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node**, const Node**); -KeyHash hash_cgedge(CGEdge* n) { - return hash_murmur(n, sizeof(CGEdge)); +KeyHash shd_hash_cgedge(CGEdge* n) { + return shd_hash(n, sizeof(CGEdge)); } -bool compare_cgedge(CGEdge* a, CGEdge* b) { +bool shd_compare_cgedge(CGEdge* a, CGEdge* b) { return (a)->src_fn == (b)->src_fn && (a)->instr == (b)->instr; } @@ -28,7 +29,6 @@ typedef struct { Visitor visitor; CallGraph* graph; CGNode* root; - const Node* abs; } CGVisitor; static const Node* ignore_immediate_fn_addr(const Node* node) { @@ -38,7 +38,10 @@ static const Node* ignore_immediate_fn_addr(const Node* node) { return node; } +static CGNode* analyze_fn(CallGraph* graph, const Node* fn); + static void visit_callsite(CGVisitor* visitor, const Node* callee, const Node* instr) { + assert(visitor->root); assert(callee->tag == Function_TAG); CGNode* target = analyze_fn(visitor->graph, callee); // Immediate recursion @@ -48,66 +51,51 @@ static void visit_callsite(CGVisitor* visitor, const Node* callee, const Node* i .src_fn = visitor->root, .dst_fn = target, .instr = instr, - .abs = visitor->abs, }; - insert_set_get_result(CGEdge, visitor->root->callees, edge); - insert_set_get_result(CGEdge, target->callers, edge); + shd_set_insert_get_result(CGEdge, visitor->root->callees, edge); + shd_set_insert_get_result(CGEdge, target->callers, edge); } static void search_for_callsites(CGVisitor* visitor, const Node* node) { - assert(is_abstraction(visitor->abs)); + if (is_abstraction(node)) + search_for_callsites(visitor, get_abstraction_body(node)); switch (node->tag) { - case Function_TAG: { - assert(false); - break; - } - case BasicBlock_TAG: - case Case_TAG: { - const Node* old_abs = visitor->abs; - visit_node_operands(&visitor->visitor, IGNORE_ABSTRACTIONS_MASK, node); - visitor->abs = old_abs; - break; - } - case FnAddr_TAG: { - CGNode* callee_node = analyze_fn(visitor->graph, node->payload.fn_addr.fn); - callee_node->is_address_captured = true; - break; - } case Call_TAG: { + assert(visitor->root && "calls can only occur in functions"); const Node* callee = node->payload.call.callee; callee = ignore_immediate_fn_addr(callee); if (callee->tag == Function_TAG) visit_callsite(visitor, callee, node); else - visit_op(&visitor->visitor, NcValue, "callee", callee); - visit_ops(&visitor->visitor, NcValue, "args", node->payload.call.args); + visitor->root->calls_indirect = true; break; } case TailCall_TAG: { - const Node* callee = node->payload.tail_call.target; + assert(visitor->root && "tail calls can only occur in functions"); + const Node* callee = node->payload.tail_call.callee; callee = ignore_immediate_fn_addr(callee); if (callee->tag == Function_TAG) visit_callsite(visitor, callee, node); else - visit_node(&visitor->visitor, callee); - visit_nodes(&visitor->visitor, node->payload.tail_call.args); + visitor->root->calls_indirect = true; break; } - default: visit_node_operands(&visitor->visitor, IGNORE_ABSTRACTIONS_MASK, node); + default: break; } + shd_visit_node_operands(&visitor->visitor, ~NcMem, node); } static CGNode* analyze_fn(CallGraph* graph, const Node* fn) { assert(fn && fn->tag == Function_TAG); - CGNode** found = find_value_dict(const Node*, CGNode*, graph->fn2cgn, fn); + CGNode** found = shd_dict_find_value(const Node*, CGNode*, graph->fn2cgn, fn); if (found) return *found; CGNode* new = calloc(1, sizeof(CGNode)); new->fn = fn; - new->callees = new_set(CGEdge, (HashFn) hash_cgedge, (CmpFn) compare_cgedge); - new->callers = new_set(CGEdge, (HashFn) hash_cgedge, (CmpFn) compare_cgedge); + new->callees = shd_new_set(CGEdge, (HashFn) shd_hash_cgedge, (CmpFn) shd_compare_cgedge); + new->callers = shd_new_set(CGEdge, (HashFn) shd_hash_cgedge, (CmpFn) shd_compare_cgedge); new->tarjan.index = -1; - insert_dict_and_get_key(const Node*, CGNode*, graph->fn2cgn, fn, new); + shd_dict_insert_get_key(const Node*, CGNode*, graph->fn2cgn, fn, new); CGVisitor v = { .visitor = { @@ -115,12 +103,10 @@ static CGNode* analyze_fn(CallGraph* graph, const Node* fn) { }, .graph = graph, .root = new, - .abs = fn, }; - if (fn->payload.fun.body) { - search_for_callsites(&v, fn->payload.fun.body); - visit_function_rpo(&v.visitor, fn); + if (get_abstraction_body(fn)) { + shd_visit_function_rpo(&v.visitor, fn); } return new; @@ -141,21 +127,21 @@ static int min(int a, int b) { return a < b ? a : b; } // https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm static void strongconnect(CGNode* v, int* index, struct List* stack) { - debugv_print("strongconnect(%s) \n", v->fn->payload.fun.name); + shd_debugv_print("strongconnect(%s) \n", v->fn->payload.fun.name); v->tarjan.index = *index; v->tarjan.lowlink = *index; (*index)++; - append_list(const Node*, stack, v); + shd_list_append(const Node*, stack, v); v->tarjan.on_stack = true; // Consider successors of v { size_t iter = 0; CGEdge e; - debugv_print(" has %d successors\n", entries_count_dict(v->callees)); - while (dict_iter(v->callees, &iter, &e, NULL)) { - debugv_print(" %s\n", e.dst_fn->fn->payload.fun.name); + shd_debugv_print(" has %d successors\n", shd_dict_count(v->callees)); + while (shd_dict_iter(v->callees, &iter, &e, NULL)) { + shd_debugv_print(" %s\n", e.dst_fn->fn->payload.fun.name); if (e.dst_fn->tarjan.index == -1) { // Successor w has not yet been visited; recurse on it strongconnect(e.dst_fn, index, stack); @@ -172,13 +158,13 @@ static void strongconnect(CGNode* v, int* index, struct List* stack) { // If v is a root node, pop the stack and generate an SCC if (v->tarjan.lowlink == v->tarjan.index) { - LARRAY(CGNode*, scc, entries_count_list(stack)); + LARRAY(CGNode*, scc, shd_list_count(stack)); size_t scc_size = 0; { CGNode* w; - assert(entries_count_list(stack) > 0); + assert(shd_list_count(stack) > 0); do { - w = pop_last_list(CGNode*, stack); + w = shd_list_pop(CGNode*, stack); w->tarjan.on_stack = false; scc[scc_size++] = w; } while (v != w); @@ -187,7 +173,7 @@ static void strongconnect(CGNode* v, int* index, struct List* stack) { if (scc_size > 1) { for (size_t i = 0; i < scc_size; i++) { CGNode* w = scc[i]; - debugv_print("Function %s is part of a recursive call chain \n", w->fn->payload.fun.name); + shd_debugv_print("Function %s is part of a recursive call chain \n", w->fn->payload.fun.name); w->is_recursive = true; } } @@ -196,47 +182,79 @@ static void strongconnect(CGNode* v, int* index, struct List* stack) { static void tarjan(struct Dict* verts) { int index = 0; - struct List* stack = new_list(CGNode*); + struct List* stack = shd_new_list(CGNode*); size_t iter = 0; CGNode* n; - while (dict_iter(verts, &iter, NULL, &n)) { + while (shd_dict_iter(verts, &iter, NULL, &n)) { if (n->tarjan.index == -1) strongconnect(n, &index, stack); } - destroy_list(stack); + shd_destroy_list(stack); } -CallGraph* new_callgraph(Module* mod) { +CallGraph* shd_new_callgraph(Module* mod) { CallGraph* graph = calloc(sizeof(CallGraph), 1); *graph = (CallGraph) { - .fn2cgn = new_dict(const Node*, CGNode*, (HashFn) hash_node, (CmpFn) compare_node) + .fn2cgn = shd_new_dict(const Node*, CGNode*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node) }; - Nodes decls = get_module_declarations(mod); + const UsesMap* uses = shd_new_uses_map_module(mod, NcType); + + Nodes decls = shd_module_get_declarations(mod); for (size_t i = 0; i < decls.count; i++) { - if (decls.nodes[i]->tag == Function_TAG) { - analyze_fn(graph, decls.nodes[i]); + const Node* decl = decls.nodes[i]; + if (decl->tag == Function_TAG) { + CGNode* node = analyze_fn(graph, decl); + + const Use* use = shd_get_first_use(uses, fn_addr_helper(shd_module_get_arena(mod), decl)); + for (;use;use = use->next_use) { + if (use->user->tag == Call_TAG && strcmp(use->operand_name, "callee") == 0) + continue; + if (use->user->tag == TailCall_TAG && strcmp(use->operand_name, "callee") == 0) + continue; + node->is_address_captured = true; + } + } else if (decl->tag == GlobalVariable_TAG && decl->payload.global_variable.init) { + CGVisitor v = { + .visitor = { + .visit_node_fn = (VisitNodeFn) search_for_callsites + }, + .graph = graph, + .root = NULL, + }; + search_for_callsites(&v, decl->payload.global_variable.init); + } else if (decl->tag == Constant_TAG && decl->payload.constant.value) { + CGVisitor v = { + .visitor = { + .visit_node_fn = (VisitNodeFn) search_for_callsites + }, + .graph = graph, + .root = NULL, + }; + search_for_callsites(&v, decl->payload.constant.value); } } - debugv_print("CallGraph: done with CFG build, contains %d nodes\n", entries_count_dict(graph->fn2cgn)); + shd_destroy_uses_map(uses); + + shd_debugv_print("CallGraph: done with CFG build, contains %d nodes\n", shd_dict_count(graph->fn2cgn)); tarjan(graph->fn2cgn); return graph; } -void destroy_callgraph(CallGraph* graph) { +void shd_destroy_callgraph(CallGraph* graph) { size_t i = 0; CGNode* node; - while (dict_iter(graph->fn2cgn, &i, NULL, &node)) { - debugv_print("Freeing CG node: %s\n", node->fn->payload.fun.name); - destroy_dict(node->callers); - destroy_dict(node->callees); + while (shd_dict_iter(graph->fn2cgn, &i, NULL, &node)) { + shd_debugv_print("Freeing CG node: %s\n", node->fn->payload.fun.name); + shd_destroy_dict(node->callers); + shd_destroy_dict(node->callees); free(node); } - destroy_dict(graph->fn2cgn); + shd_destroy_dict(graph->fn2cgn); free(graph); } diff --git a/src/shady/analysis/callgraph.h b/src/shady/analysis/callgraph.h index 13fd9bfcb..bf4a1f437 100644 --- a/src/shady/analysis/callgraph.h +++ b/src/shady/analysis/callgraph.h @@ -24,13 +24,14 @@ struct CGNode_ { bool is_recursive; /// set to true if the address of this is captured by a FnAddr node that is not immediately consumed by a call bool is_address_captured; + bool calls_indirect; }; typedef struct Callgraph_ { struct Dict* fn2cgn; } CallGraph; -CallGraph* new_callgraph(Module*); -void destroy_callgraph(CallGraph*); +CallGraph* shd_new_callgraph(Module* mod); +void shd_destroy_callgraph(CallGraph* graph); #endif diff --git a/src/shady/analysis/cfg.c b/src/shady/analysis/cfg.c new file mode 100644 index 000000000..4bd123339 --- /dev/null +++ b/src/shady/analysis/cfg.c @@ -0,0 +1,580 @@ +#include "cfg.h" +#include "looptree.h" +#include "log.h" + +#include "list.h" +#include "dict.h" +#include "arena.h" +#include "util.h" + +#include "../ir_private.h" + +#include +#include + +#pragma GCC diagnostic error "-Wswitch" + +struct List* shd_build_cfgs(Module* mod, CFGBuildConfig config) { + struct List* cfgs = shd_new_list(CFG*); + + Nodes decls = shd_module_get_declarations(mod); + for (size_t i = 0; i < decls.count; i++) { + const Node* decl = decls.nodes[i]; + if (decl->tag != Function_TAG) continue; + CFG* cfg = shd_new_cfg(decl, decl, config); + shd_list_append(CFG*, cfgs, cfg); + } + + return cfgs; +} + +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node**, const Node**); + +typedef struct { + Arena* arena; + const Node* function; + const Node* entry; + struct Dict* nodes; + struct List* contents; + + CFGBuildConfig config; + + const Node* selection_construct_tail; + const Node* loop_construct_head; + const Node* loop_construct_tail; + + struct Dict* join_point_values; +} CfgBuildContext; + +static void process_cf_node(CfgBuildContext* ctx, CFNode* node); + +CFNode* shd_cfg_lookup(CFG* cfg, const Node* abs) { + CFNode** found = shd_dict_find_value(const Node*, CFNode*, cfg->map, abs); + if (found) { + CFNode* cfnode = *found; + assert(cfnode->node); + assert(cfnode->node == abs); + return cfnode; + } + assert(false); + return NULL; +} + +static CFNode* new_cfnode(Arena* a) { + CFNode* new = shd_arena_alloc(a, sizeof(CFNode)); + *new = (CFNode) { + .succ_edges = shd_new_list(CFEdge), + .pred_edges = shd_new_list(CFEdge), + .rpo_index = SIZE_MAX, + .idom = NULL, + .dominates = NULL, + .structurally_dominates = shd_new_set(const Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + return new; +} + +static CFNode* get_or_enqueue(CfgBuildContext* ctx, const Node* abs) { + assert(is_abstraction(abs)); + assert(!is_function(abs) || abs == ctx->function); + CFNode** found = shd_dict_find_value(const Node*, CFNode*, ctx->nodes, abs); + if (found) return *found; + + CFNode* new = new_cfnode(ctx->arena); + new->node = abs; + assert(abs && new->node); + shd_dict_insert(const Node*, CFNode*, ctx->nodes, abs, new); + process_cf_node(ctx, new); + shd_list_append(Node*, ctx->contents, new); + return new; +} + +static bool in_loop(LoopTree* lt, const Node* fn, const Node* loopentry, const Node* block) { + LTNode* lt_node = shd_loop_tree_lookup(lt, block); + assert(lt_node); + LTNode* parent = lt_node->parent; + assert(parent); + + while (parent) { + // we're not in a loop like we're expected to + if (shd_list_count(parent->cf_nodes) == 0 && loopentry == fn) + return true; + + // we are in the loop we were expected to + if (shd_list_count(parent->cf_nodes) == 1 && shd_read_list(CFNode*, parent->cf_nodes)[0]->node == loopentry) + return true; + + parent = parent->parent; + } + + return false; +} + +/// Adds an edge to somewhere inside a basic block +static void add_edge(CfgBuildContext* ctx, const Node* src, const Node* dst, CFEdgeType type, const Node* term) { + assert(is_abstraction(src) && is_abstraction(dst)); + assert(term && is_terminator(term)); + assert(!is_function(dst)); + if (ctx->config.lt && !in_loop(ctx->config.lt, ctx->function, ctx->entry, dst)) + return; + if (ctx->config.lt && dst == ctx->entry) { + return; + } + + const Node* j = term->tag == Jump_TAG ? term : NULL; + + CFNode* src_node = get_or_enqueue(ctx, src); + CFNode* dst_node = get_or_enqueue(ctx, dst); + CFEdge edge = { + .type = type, + .src = src_node, + .dst = dst_node, + .jump = j, + .terminator = term, + }; + shd_list_append(CFEdge, src_node->succ_edges, edge); + shd_list_append(CFEdge, dst_node->pred_edges, edge); +} + +static void add_structural_edge(CfgBuildContext* ctx, CFNode* parent, const Node* dst, CFEdgeType type, const Node* term) { + add_edge(ctx, parent->node, dst, type, term); +} + +static void add_structural_dominance_edge(CfgBuildContext* ctx, CFNode* parent, const Node* dst, CFEdgeType type, const Node* term) { + add_edge(ctx, parent->node, dst, type, term); + shd_set_insert_get_result(const Node*, parent->structurally_dominates, dst); +} + +static void add_jump_edge(CfgBuildContext* ctx, const Node* src, const Node* j) { + assert(j->tag == Jump_TAG); + const Node* target = j->payload.jump.target; + if (target->tag == BasicBlock_TAG) + add_edge(ctx, src, target, JumpEdge, j); +} + +#pragma GCC diagnostic error "-Wswitch" + +static void process_cf_node(CfgBuildContext* ctx, CFNode* node) { + const Node* const abs = node->node; + assert(is_abstraction(abs)); + assert(!is_function(abs) || abs == ctx->function); + const Node* terminator = get_abstraction_body(abs); + if (!terminator) + return; + while (true) { + switch (is_terminator(terminator)) { + case Jump_TAG: { + add_jump_edge(ctx, abs, terminator); + return; + } + case Branch_TAG: { + add_jump_edge(ctx, abs, terminator->payload.branch.true_jump); + add_jump_edge(ctx, abs, terminator->payload.branch.false_jump); + return; + } + case Switch_TAG: { + for (size_t i = 0; i < terminator->payload.br_switch.case_jumps.count; i++) + add_jump_edge(ctx, abs, terminator->payload.br_switch.case_jumps.nodes[i]); + add_jump_edge(ctx, abs, terminator->payload.br_switch.default_jump); + return; + } + case If_TAG: { + if (ctx->config.include_structured_tails) + add_structural_dominance_edge(ctx, node, get_structured_construct_tail(terminator), StructuredTailEdge, terminator); + CfgBuildContext if_ctx = *ctx; + if_ctx.selection_construct_tail = get_structured_construct_tail(terminator); + add_structural_edge(&if_ctx, node, terminator->payload.if_instr.if_true, StructuredEnterBodyEdge, terminator); + if (terminator->payload.if_instr.if_false) + add_structural_edge(&if_ctx, node, terminator->payload.if_instr.if_false, StructuredEnterBodyEdge, terminator); + else + add_structural_edge(ctx, node, get_structured_construct_tail(terminator), StructuredLeaveBodyEdge, terminator); + + return; + } case Match_TAG: { + if (ctx->config.include_structured_tails) + add_structural_dominance_edge(ctx, node, get_structured_construct_tail(terminator), StructuredTailEdge, terminator); + CfgBuildContext match_ctx = *ctx; + match_ctx.selection_construct_tail = get_structured_construct_tail(terminator); + for (size_t i = 0; i < terminator->payload.match_instr.cases.count; i++) + add_structural_edge(&match_ctx, node, terminator->payload.match_instr.cases.nodes[i], StructuredEnterBodyEdge, terminator); + add_structural_edge(&match_ctx, node, terminator->payload.match_instr.default_case, StructuredEnterBodyEdge, terminator); + return; + } case Loop_TAG: { + if (ctx->config.include_structured_tails) + add_structural_dominance_edge(ctx, node, get_structured_construct_tail(terminator), StructuredTailEdge, terminator); + CfgBuildContext loop_ctx = *ctx; + loop_ctx.loop_construct_head = terminator->payload.loop_instr.body; + loop_ctx.loop_construct_tail = get_structured_construct_tail(terminator); + add_structural_edge(&loop_ctx, node, terminator->payload.loop_instr.body, StructuredEnterBodyEdge, terminator); + return; + } case Control_TAG: { + const Node* param = shd_first(get_abstraction_params(terminator->payload.control.inside)); + //CFNode* let_tail_cfnode = get_or_enqueue(ctx, get_structured_construct_tail(terminator)); + const Node* tail = get_structured_construct_tail(terminator); + shd_dict_insert(const Node*, const Node*, ctx->join_point_values, param, tail); + add_structural_dominance_edge(ctx, node, terminator->payload.control.inside, StructuredEnterBodyEdge, terminator); + if (ctx->config.include_structured_tails) + add_structural_dominance_edge(ctx, node, get_structured_construct_tail(terminator), StructuredTailEdge, terminator); + return; + } case Join_TAG: { + if (ctx->config.include_structured_exits) { + const Node** dst = shd_dict_find_value(const Node*, const Node*, ctx->join_point_values, terminator->payload.join.join_point); + if (dst) + add_edge(ctx, node->node, *dst, StructuredLeaveBodyEdge, terminator); + } + return; + } case MergeSelection_TAG: { + assert(ctx->selection_construct_tail); + if (ctx->config.include_structured_exits) + add_structural_edge(ctx, node, ctx->selection_construct_tail, StructuredLeaveBodyEdge, terminator); + return; + } case MergeContinue_TAG:{ + assert(ctx->loop_construct_head); + if (ctx->config.include_structured_exits) + add_structural_edge(ctx, node, ctx->loop_construct_head, StructuredLoopContinue, terminator); + return; + } case MergeBreak_TAG: { + assert(ctx->loop_construct_tail); + if (ctx->config.include_structured_exits) + add_structural_edge(ctx, node, ctx->loop_construct_tail, StructuredLeaveBodyEdge, terminator); + return; + } + case TailCall_TAG: + case Return_TAG: + case Unreachable_TAG: + return; + case NotATerminator: + shd_error("Grammar problem"); + return; + } + SHADY_UNREACHABLE; + } +} + +/** + * Invert all edges in this cfg. Used to compute a post dominance tree. + */ +static void flip_cfg(CFG* cfg) { + cfg->entry = NULL; + + for (size_t i = 0; i < cfg->size; i++) { + CFNode* cur = shd_read_list(CFNode*, cfg->contents)[i]; + + struct List* tmp = cur->succ_edges; + cur->succ_edges = cur->pred_edges; + cur->pred_edges = tmp; + + for (size_t j = 0; j < shd_list_count(cur->succ_edges); j++) { + CFEdge* edge = &shd_read_list(CFEdge, cur->succ_edges)[j]; + + CFNode* tmp2 = edge->dst; + edge->dst = edge->src; + edge->src = tmp2; + } + + for (size_t j = 0; j < shd_list_count(cur->pred_edges); j++) { + CFEdge* edge = &shd_read_list(CFEdge, cur->pred_edges)[j]; + + CFNode* tmp2 = edge->dst; + edge->dst = edge->src; + edge->src = tmp2; + } + + if (shd_list_count(cur->pred_edges) == 0) { + if (cfg->entry != NULL) { + if (cfg->entry->node) { + CFNode* new_entry = new_cfnode(cfg->arena); + CFEdge prev_entry_edge = { + .type = JumpEdge, + .src = new_entry, + .dst = cfg->entry + }; + shd_list_append(CFEdge, new_entry->succ_edges, prev_entry_edge); + shd_list_append(CFEdge, cfg->entry->pred_edges, prev_entry_edge); + cfg->entry = new_entry; + } + + CFEdge new_edge = { + .type = JumpEdge, + .src = cfg->entry, + .dst = cur + }; + shd_list_append(CFEdge, cfg->entry->succ_edges, new_edge); + shd_list_append(CFEdge, cur->pred_edges, new_edge); + } else { + cfg->entry = cur; + } + } + } + + assert(cfg->entry); + if (!cfg->entry->node) { + cfg->size += 1; + shd_list_append(Node*, cfg->contents, cfg->entry); + } +} + +static void validate_cfg(CFG* cfg) { + for (size_t i = 0; i < cfg->size; i++) { + CFNode* node = shd_read_list(CFNode*, cfg->contents)[i]; + size_t structured_body_uses = 0; + size_t num_jumps = 0; + size_t num_exits = 0; + bool is_tail = false; + for (size_t j = 0; j < shd_list_count(node->pred_edges); j++) { + CFEdge edge = shd_read_list(CFEdge, node->pred_edges)[j]; + switch (edge.type) { + case JumpEdge: + num_jumps++; + break; + case StructuredLoopContinue: + break; + case StructuredEnterBodyEdge: + structured_body_uses += 1; + break; + case StructuredTailEdge: + structured_body_uses += 1; + is_tail = true; + break; + case StructuredLeaveBodyEdge: + num_exits += 1; + break; + } + } + if (node != cfg->entry /* this exception exists since we might build CFGs rooted in cases */) { + if (structured_body_uses > 0) { + if (structured_body_uses > 1) { + shd_error_print("Basic block %s is used as a structural target more than once (structured_body_uses: %zu)", shd_get_abstraction_name_safe(node->node), structured_body_uses); + shd_error_die(); + } + if (num_jumps > 0) { + shd_error_print("Basic block %s is used as structural target, but is also jumped into (num_jumps: %zu)", shd_get_abstraction_name_safe(node->node), num_jumps); + shd_error_die(); + } + if (!is_tail && num_exits > 0) { + shd_error_print("Basic block %s is not a merge target yet is used as once (num_exits: %zu)", shd_get_abstraction_name_safe(node->node), num_exits); + shd_error_die(); + } + } + } + } +} + +static void mark_reachable(CFNode* n) { + if (!n->reachable) { + n->reachable = true; + for (size_t i = 0; i < shd_list_count(n->succ_edges); i++) { + CFEdge e = shd_read_list(CFEdge, n->succ_edges)[i]; + if (e.type == StructuredTailEdge) + continue; + mark_reachable(e.dst); + } + } +} + +CFG* shd_new_cfg(const Node* function, const Node* entry, CFGBuildConfig config) { + assert(function && function->tag == Function_TAG); + assert(is_abstraction(entry)); + Arena* arena = shd_new_arena(); + + CfgBuildContext context = { + .arena = arena, + .function = function, + .entry = entry, + .nodes = shd_new_dict(const Node*, CFNode*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .join_point_values = shd_new_dict(const Node*, const Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .contents = shd_new_list(CFNode*), + .config = config, + }; + + CFNode* entry_node = get_or_enqueue(&context, entry); + mark_reachable(entry_node); + //process_cf_node(&context, entry_node); + + //while (entries_count_list(context.queue) > 0) { + // CFNode* this = pop_last_list(CFNode*, context.queue); + // process_cf_node(&context, this); + //} + + shd_destroy_dict(context.join_point_values); + + CFG* cfg = calloc(sizeof(CFG), 1); + *cfg = (CFG) { + .arena = arena, + .config = config, + .entry = entry_node, + .size = shd_list_count(context.contents), + .flipped = config.flipped, + .contents = context.contents, + .map = context.nodes, + .rpo = NULL + }; + + validate_cfg(cfg); + + if (config.flipped) + flip_cfg(cfg); + + shd_cfg_compute_rpo(cfg); + shd_cfg_compute_domtree(cfg); + + return cfg; +} + +void shd_destroy_cfg(CFG* cfg) { + bool entry_destroyed = false; + for (size_t i = 0; i < cfg->size; i++) { + CFNode* node = shd_read_list(CFNode*, cfg->contents)[i]; + entry_destroyed |= node == cfg->entry; + shd_destroy_list(node->pred_edges); + shd_destroy_list(node->succ_edges); + if (node->dominates) + shd_destroy_list(node->dominates); + if (node->structurally_dominates) + shd_destroy_dict(node->structurally_dominates); + } + if (!entry_destroyed) { + shd_destroy_list(cfg->entry->pred_edges); + shd_destroy_list(cfg->entry->succ_edges); + if (cfg->entry->dominates) + shd_destroy_list(cfg->entry->dominates); + } + shd_destroy_dict(cfg->map); + shd_destroy_arena(cfg->arena); + free(cfg->rpo); + shd_destroy_list(cfg->contents); + free(cfg); +} + +static size_t post_order_visit(CFG* cfg, CFNode* n, size_t i) { + n->rpo_index = -2; + + for (int phase = 0; phase < 2; phase++) { + for (size_t j = 0; j < shd_list_count(n->succ_edges); j++) { + CFEdge edge = shd_read_list(CFEdge, n->succ_edges)[j]; + // always visit structured tail edges last + if ((edge.type == StructuredTailEdge) == (phase == 0)) + continue; + if (edge.dst->rpo_index == SIZE_MAX) + i = post_order_visit(cfg, edge.dst, i); + } + } + + n->rpo_index = i - 1; + cfg->rpo[n->rpo_index] = n; + return n->rpo_index; +} + +void shd_cfg_compute_rpo(CFG* cfg) { + /*cfg->reachable_size = 0; + for (size_t i = 0; i < entries_count_list(cfg->contents); i++) { + CFNode* n = read_list(CFNode*, cfg->contents)[i]; + if (n->reachable) + cfg->reachable_size++; + }*/ + cfg->reachable_size = cfg->size; + + cfg->rpo = malloc(sizeof(const CFNode*) * cfg->size); + size_t index = post_order_visit(cfg, cfg->entry, cfg->reachable_size); + assert(index == 0); + + // debug_print("RPO: "); + // for (size_t i = 0; i < cfg->size; i++) { + // debug_print("%s, ", cfg->rpo[i]->node->payload.lam.name); + // } + // debug_print("\n"); +} + +bool shd_cfg_is_node_structural_target(CFNode* cfn) { + for (size_t i = 0; i < shd_list_count(cfn->pred_edges); i++) { + if (shd_read_list(CFEdge, cfn->pred_edges)[i].type != JumpEdge) + return true; + } + return false; +} + +CFNode* shd_cfg_least_common_ancestor(CFNode* i, CFNode* j) { + assert(i && j); + while (i->rpo_index != j->rpo_index) { + while (i->rpo_index < j->rpo_index) j = j->idom; + while (i->rpo_index > j->rpo_index) i = i->idom; + } + return i; +} + +bool shd_cfg_is_dominated(CFNode* a, CFNode* b) { + while (a) { + if (a == b) + return true; + if (a->idom) + a = a->idom; + else if (a->structured_idom) + a = a->structured_idom; + else + break; + } + return false; +} + +void shd_cfg_compute_domtree(CFG* cfg) { + for (size_t i = 0; i < cfg->size; i++) { + CFNode* n = shd_read_list(CFNode*, cfg->contents)[i]; + if (n == cfg->entry/* || !n->reachable*/) + continue; + CFNode* structured_idom = NULL; + for (size_t j = 0; j < shd_list_count(n->pred_edges); j++) { + CFEdge e = shd_read_list(CFEdge, n->pred_edges)[j]; + if (e.type == StructuredTailEdge) { + structured_idom = n->structured_idom = e.src; + n->structured_idom_edge = e; + continue; + } + } + for (size_t j = 0; j < shd_list_count(n->pred_edges); j++) { + CFEdge e = shd_read_list(CFEdge, n->pred_edges)[j]; + if (e.src->rpo_index < n->rpo_index) { + n->idom = e.src; + goto outer_loop; + } + } + if (structured_idom) { + continue; + } + shd_error("no idom found"); + outer_loop:; + } + + bool todo = true; + while (todo) { + todo = false; + for (size_t i = 0; i < cfg->size; i++) { + CFNode* n = shd_read_list(CFNode*, cfg->contents)[i]; + if (n == cfg->entry || n->structured_idom) + continue; + CFNode* new_idom = NULL; + for (size_t j = 0; j < shd_list_count(n->pred_edges); j++) { + CFEdge e = shd_read_list(CFEdge, n->pred_edges)[j]; + if (e.type == StructuredTailEdge) + continue; + CFNode* p = e.src; + new_idom = new_idom ? shd_cfg_least_common_ancestor(new_idom, p) : p; + } + assert(new_idom); + if (n->idom != new_idom) { + n->idom = new_idom; + todo = true; + } + } + } + + for (size_t i = 0; i < cfg->size; i++) { + CFNode* n = cfg->rpo[i]; + n->dominates = shd_new_list(CFNode*); + } + for (size_t i = 0; i < cfg->size; i++) { + CFNode* n = cfg->rpo[i]; + if (!n->idom) + continue; + shd_list_append(CFNode*, n->idom->dominates, n); + } +} diff --git a/src/shady/analysis/cfg.h b/src/shady/analysis/cfg.h new file mode 100644 index 000000000..29d9af834 --- /dev/null +++ b/src/shady/analysis/cfg.h @@ -0,0 +1,151 @@ +#ifndef SHADY_CFG_H +#define SHADY_CFG_H + +#include "shady/ir.h" + +#include + +typedef struct CFNode_ CFNode; + +typedef enum { + JumpEdge, + StructuredEnterBodyEdge, + StructuredLoopContinue, + StructuredLeaveBodyEdge, + /// Join points might leak, and as a consequence, there might be no static edge to the + /// tail of the enclosing let, which would make it look like dead code. + /// This edge type accounts for that risk, they can be ignored where more precise info is available + /// (see shd_is_control_static for example) + StructuredTailEdge, +} CFEdgeType; + +typedef struct { + CFEdgeType type; + CFNode* src; + CFNode* dst; + const Node* jump; + const Node* terminator; +} CFEdge; + +struct CFNode_ { + const Node* node; + + bool reachable; + + /** @brief Edges where this node is the source + * + * @ref List of @ref CFEdge + */ + struct List* succ_edges; + + /** @brief Edges where this node is the destination + * + * @ref List of @ref CFEdge + */ + struct List* pred_edges; + + // set by compute_rpo + size_t rpo_index; + + // set by compute_domtree + CFNode* idom; + CFNode* structured_idom; + CFEdge structured_idom_edge; + + /** @brief All Nodes directly dominated by this CFNode. + * + * @ref List of @ref CFNode* + */ + struct List* dominates; + struct Dict* structurally_dominates; +}; + +typedef struct Arena_ Arena; +typedef struct LoopTree_ LoopTree; + +typedef struct { + bool include_structured_exits; + bool include_structured_tails; + LoopTree* lt; + bool flipped; +} CFGBuildConfig; + +typedef struct CFG_ { + Arena* arena; + CFGBuildConfig config; + size_t size; + + bool flipped; + + /** + * @ref List of @ref CFNode* + */ + struct List* contents; + + /** + * @ref Dict from const @ref Node* to @ref CFNode* + */ + struct Dict* map; + + CFNode* entry; + // set by compute_rpo + size_t reachable_size; + CFNode** rpo; +} CFG; + +/** + * @returns @ref List of @ref CFG* + */ +struct List* shd_build_cfgs(Module* mod, CFGBuildConfig config); + +/** Construct the CFG starting in node. + */ +CFG* shd_new_cfg(const Node* fn, const Node* entry, CFGBuildConfig); + +/** Construct the CFG starting in node. + * Dominance will only be computed with respect to the nodes reachable by @p entry. + */ + +static inline CFGBuildConfig default_forward_cfg_build(void) { + return (CFGBuildConfig) { + .include_structured_exits = true, + .include_structured_tails = true, + }; +} + +static inline CFGBuildConfig structured_scope_cfg_build(void) { + return (CFGBuildConfig) { + .include_structured_exits = false, + .include_structured_tails = true, + }; +} + +static inline CFGBuildConfig flipped_cfg_build(void) { + return (CFGBuildConfig) { + //.include_structured_tails = include_structured_tails, + .lt = NULL, + .flipped = true, + }; +} + +#define build_fn_cfg(node) shd_new_cfg(node, node, default_forward_cfg_build()) + +/** Construct the CFG stating in Node. + * Dominance will only be computed with respect to the nodes reachable by @p entry. + * This CFG will contain post dominance information instead of regular dominance! + */ +#define build_fn_cfg_flipped(node) shd_new_cfg(node, node, flipped_cfg_build()) + +CFNode* shd_cfg_lookup(CFG* cfg, const Node* abs); +void shd_cfg_compute_rpo(CFG* cfg); +void shd_cfg_compute_domtree(CFG* cfg); + +bool shd_cfg_is_dominated(CFNode* a, CFNode* b); + +bool shd_cfg_is_node_structural_target(CFNode* cfn); + +CFNode* shd_cfg_least_common_ancestor(CFNode* i, CFNode* j); + +void shd_destroy_cfg(CFG* cfg); + +#endif diff --git a/src/shady/analysis/cfg_dump.c b/src/shady/analysis/cfg_dump.c new file mode 100644 index 000000000..e681b9985 --- /dev/null +++ b/src/shady/analysis/cfg_dump.c @@ -0,0 +1,210 @@ +#include "cfg.h" + +#include "ir_private.h" +#include "shady/print.h" + +#include "list.h" +#include "dict.h" +#include "util.h" +#include "printer.h" + +#include +#include + +static int extra_uniqueness = 0; + +static void print_node_helper(Printer* p, const Node* n) { + Growy* tmp_g = shd_new_growy(); + Printer* tmp_p = shd_new_printer_from_growy(tmp_g); + + NodePrintConfig config = { + .color = false, + .in_cfg = true, + }; + + shd_print_node(tmp_p, config, n); + + String label = shd_printer_growy_unwrap(tmp_p); + char* escaped_label = calloc(strlen(label) * 2, 1); + shd_unapply_escape_codes(label, strlen(label), escaped_label); + + shd_print(p, "%s", escaped_label); + free(escaped_label); + free((void*)label); +} + +static const Nodes* find_scope_info(const Node* abs) { + assert(is_abstraction(abs)); + const Node* terminator = get_abstraction_body(abs); + const Node* mem = get_terminator_mem(terminator); + Nodes* info = NULL; + while (mem) { + if (mem->tag == ExtInstr_TAG && strcmp(mem->payload.ext_instr.set, "shady.scope") == 0) { + if (!info || info->count > mem->payload.ext_instr.operands.count) + info = &mem->payload.ext_instr.operands; + } + mem = shd_get_parent_mem(mem); + } + return info; +} + +static void dump_cf_node(FILE* output, const CFNode* n) { + const Node* bb = n->node; + const Node* body = get_abstraction_body(bb); + if (!body) + return; + + String color = "black"; + switch (body->tag) { + case If_TAG: color = "blue"; break; + case Loop_TAG: color = "red"; break; + case Control_TAG: color = "orange"; break; + case Return_TAG: color = "teal"; break; + case Unreachable_TAG: color = "teal"; break; + default: break; + } + + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + + String abs_name = shd_get_abstraction_name_safe(bb); + + shd_print(p, "%s: \n%d: ", abs_name, bb->id); + + if (getenv("SHADY_CFG_SCOPE_ONLY")) { + const Nodes* scope = find_scope_info(bb); + if (scope) { + for (size_t i = 0; i < scope->count; i++) { + shd_print(p, "%d", scope->nodes[i]->id); + if (i + 1 < scope->count) + shd_print(p, ", "); + } + } + } else { + print_node_helper(p, body); + shd_print(p, "\\l"); + } + shd_print(p, "rpo: %d, idom: %s, sdom: %s", n->rpo_index, n->idom ? shd_get_abstraction_name_safe(n->idom->node) : "null", n->structured_idom ? shd_get_abstraction_name_safe(n->structured_idom->node) : "null"); + + String label = shd_printer_growy_unwrap(p); + fprintf(output, "bb_%zu [nojustify=true, label=\"%s\", color=\"%s\", shape=box];\n", (size_t) n, label, color); + free((void*) label); + + //for (size_t i = 0; i < entries_count_list(n->dominates); i++) { + // CFNode* d = read_list(CFNode*, n->dominates)[i]; + // if (!find_key_dict(const Node*, n->structurally_dominates, d->node)) + // dump_cf_node(output, d); + //} +} + +static void dump_cfg(FILE* output, CFG* cfg) { + extra_uniqueness++; + + const Node* entry = cfg->entry->node; + fprintf(output, "subgraph cluster_%d {\n", entry->id); + fprintf(output, "label = \"%s\";\n", shd_get_abstraction_name_safe(entry)); + for (size_t i = 0; i < shd_list_count(cfg->contents); i++) { + const CFNode* n = shd_read_list(const CFNode*, cfg->contents)[i]; + dump_cf_node(output, n); + } + for (size_t i = 0; i < shd_list_count(cfg->contents); i++) { + const CFNode* bb_node = shd_read_list(const CFNode*, cfg->contents)[i]; + const CFNode* src_node = bb_node; + + for (size_t j = 0; j < shd_list_count(bb_node->succ_edges); j++) { + CFEdge edge = shd_read_list(CFEdge, bb_node->succ_edges)[j]; + const CFNode* target_node = edge.dst; + String edge_color = "black"; + String edge_style = "solid"; + switch (edge.type) { + case StructuredEnterBodyEdge: edge_color = "blue"; break; + case StructuredLeaveBodyEdge: edge_color = "red"; break; + case StructuredTailEdge: edge_style = "dashed"; break; + case StructuredLoopContinue: edge_style = "dotted"; edge_color = "orange"; break; + default: break; + } + + fprintf(output, "bb_%zu -> bb_%zu [color=\"%s\", style=\"%s\"];\n", (size_t) (src_node), (size_t) (target_node), edge_color, edge_style); + } + } + fprintf(output, "}\n"); +} + +void shd_dump_existing_cfg_auto(CFG* cfg) { + FILE* f = fopen("cfg.dot", "wb"); + fprintf(f, "digraph G {\n"); + dump_cfg(f, cfg); + fprintf(f, "}\n"); + fclose(f); +} + +void shd_dump_cfg_auto(const Node* fn) { + FILE* f = fopen("cfg.dot", "wb"); + fprintf(f, "digraph G {\n"); + CFG* cfg = build_fn_cfg(fn); + dump_cfg(f, cfg); + shd_destroy_cfg(cfg); + fprintf(f, "}\n"); + fclose(f); +} + +void shd_dump_cfgs(FILE* output, Module* mod) { + if (output == NULL) + output = stderr; + + fprintf(output, "digraph G {\n"); + struct List* cfgs = shd_build_cfgs(mod, default_forward_cfg_build()); + for (size_t i = 0; i < shd_list_count(cfgs); i++) { + CFG* cfg = shd_read_list(CFG*, cfgs)[i]; + dump_cfg(output, cfg); + shd_destroy_cfg(cfg); + } + shd_destroy_list(cfgs); + fprintf(output, "}\n"); +} + +void shd_dump_cfgs_auto(Module* mod) { + FILE* f = fopen("cfg.dot", "wb"); + shd_dump_cfgs(f, mod); + fclose(f); +} + +static void dump_domtree_cfnode(Printer* p, CFNode* idom) { + String name = shd_get_abstraction_name_safe(idom->node); + if (name) + shd_print(p, "bb_%zu [label=\"%s\", shape=box];\n", (size_t) idom, name); + else + shd_print(p, "bb_%zu [label=\"%%%d\", shape=box];\n", (size_t) idom, idom->node->id); + + for (size_t i = 0; i < shd_list_count(idom->dominates); i++) { + CFNode* child = shd_read_list(CFNode*, idom->dominates)[i]; + dump_domtree_cfnode(p, child); + shd_print(p, "bb_%zu -> bb_%zu;\n", (size_t) (idom), (size_t) (child)); + } +} + +void shd_dump_domtree_cfg(Printer* p, CFG* s) { + shd_print(p, "subgraph cluster_%s {\n", shd_get_abstraction_name_safe(s->entry->node)); + dump_domtree_cfnode(p, s->entry); + shd_print(p, "}\n"); +} + +void shd_dump_domtree_module(Printer* p, Module* mod) { + shd_print(p, "digraph G {\n"); + struct List* cfgs = shd_build_cfgs(mod, default_forward_cfg_build()); + for (size_t i = 0; i < shd_list_count(cfgs); i++) { + CFG* cfg = shd_read_list(CFG*, cfgs)[i]; + shd_dump_domtree_cfg(p, cfg); + shd_destroy_cfg(cfg); + } + shd_destroy_list(cfgs); + shd_print(p, "}\n"); +} + +void shd_dump_domtree_auto(Module* mod) { + FILE* f = fopen("domtree.dot", "wb"); + Printer* p = shd_new_printer_from_file(f); + shd_dump_domtree_module(p, mod); + shd_destroy_printer(p); + fclose(f); +} diff --git a/src/shady/analysis/free_frontier.c b/src/shady/analysis/free_frontier.c new file mode 100644 index 000000000..8df0aa952 --- /dev/null +++ b/src/shady/analysis/free_frontier.c @@ -0,0 +1,59 @@ +#include "free_frontier.h" + +#include "shady/visit.h" +#include "dict.h" + +typedef struct { + Visitor v; + Scheduler* scheduler; + CFG* cfg; + CFNode* start; + struct Dict* seen; + struct Dict* frontier; +} FreeFrontierVisitor; + +static void visit_free_frontier(FreeFrontierVisitor* v, const Node* node) { + if (shd_dict_find_key(const Node*, v->seen, node)) + return; + shd_set_insert_get_result(const Node*, v->seen, node); + CFNode* where = shd_schedule_instruction(v->scheduler, node); + if (where) { + FreeFrontierVisitor vv = *v; + if (shd_cfg_is_dominated(where, v->start)) { + shd_visit_node_operands(&vv.v, NcAbstraction | NcDeclaration | NcType, node); + } else { + if (is_abstraction(node)) { + struct Dict* other_ff = shd_free_frontier(v->scheduler, v->cfg, node); + size_t i = 0; + const Node* f; + while (shd_dict_iter(other_ff, &i, &f, NULL)) { + shd_set_insert_get_result(const Node*, v->frontier, f); + } + shd_destroy_dict(other_ff); + } + if (is_value(node)) { + shd_set_insert_get_result(const Node*, v->frontier, node); + } + } + } +} + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +struct Dict* shd_free_frontier(Scheduler* scheduler, CFG* cfg, const Node* abs) { + FreeFrontierVisitor ffv = { + .v = { + .visit_node_fn = (VisitNodeFn) visit_free_frontier, + }, + .scheduler = scheduler, + .cfg = cfg, + .start = shd_cfg_lookup(cfg, abs), + .frontier = shd_new_set(const Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .seen = shd_new_set(const Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + if (get_abstraction_body(abs)) + visit_free_frontier(&ffv, get_abstraction_body(abs)); + shd_destroy_dict(ffv.seen); + return ffv.frontier; +} \ No newline at end of file diff --git a/src/shady/analysis/free_frontier.h b/src/shady/analysis/free_frontier.h new file mode 100644 index 000000000..ad389df50 --- /dev/null +++ b/src/shady/analysis/free_frontier.h @@ -0,0 +1,10 @@ +#ifndef SHADY_FREE_FRONTIER_H +#define SHADY_FREE_FRONTIER_H + +#include "shady/ir.h" +#include "cfg.h" +#include "scheduler.h" + +struct Dict* shd_free_frontier(Scheduler* scheduler, CFG* cfg, const Node* abs); + +#endif diff --git a/src/shady/analysis/free_variables.c b/src/shady/analysis/free_variables.c deleted file mode 100644 index c8412a6b4..000000000 --- a/src/shady/analysis/free_variables.c +++ /dev/null @@ -1,100 +0,0 @@ -#include "free_variables.h" - -#include "log.h" -#include "../visit.h" - -#include "../analysis/scope.h" - -#include "list.h" -#include "dict.h" - -#include - -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); - -typedef struct { - Visitor visitor; - struct Dict* bound_set; - struct Dict* set; - struct List* free_list; -} Context; - -static void search_op_for_free_variables(Context* visitor, NodeClass class, String op_name, const Node* node) { - assert(node); - switch (node->tag) { - case Variable_TAG: { - if (find_key_dict(const Node*, visitor->bound_set, node)) - return; - if (insert_set_get_result(const Node*, visitor->set, node)) { - append_list(const Node*, visitor->free_list, node); - } - break; - } - case Function_TAG: - case Case_TAG: - case BasicBlock_TAG: assert(false); - default: visit_node_operands(&visitor->visitor, IGNORE_ABSTRACTIONS_MASK, node); break; - } -} - -static void visit_domtree(Context* ctx, CFNode* cfnode, int depth) { - const Node* abs = cfnode->node; - - bool is_named = abs->tag != Case_TAG; - - if (is_named) { - for (int i = 0; i < depth; i++) - debugvv_print(" "); - debugvv_print("%s\n", get_abstraction_name(abs)); - } - - // Bind parameters - Nodes params = get_abstraction_params(abs); - for (size_t j = 0; j < params.count; j++) { - const Node* param = params.nodes[j]; - bool r = insert_set_get_result(const Node*, ctx->bound_set, param); - // assert(r); - // this can happen if you visit the domtree of a CFG starting _inside_ a loop - // we will meet some unbound params but eventually we'll enter their definition after the fact - // those params should still be considered free in this case. - } - - const Node* body = get_abstraction_body(abs); - if (body) - visit_op(&ctx->visitor, NcTerminator, "body", body); - - for (size_t i = 0; i < entries_count_list(cfnode->dominates); i++) { - CFNode* child = read_list(CFNode*, cfnode->dominates)[i]; - visit_domtree(ctx, child, depth + (is_named ? 1 : 0)); - } - - // Unbind parameters - for (size_t j = 0; j < params.count; j++) { - const Node* param = params.nodes[j]; - bool r = remove_dict(const Node*, ctx->bound_set, param); - assert(r); - } -} - -struct List* compute_free_variables(const Scope* scope, const Node* at) { - struct Dict* bound_set = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node); - struct Dict* set = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node); - struct List* free_list = new_list(const Node*); - - Context ctx = { - .visitor = { - .visit_op_fn = (VisitOpFn) search_op_for_free_variables, - }, - .bound_set = bound_set, - .set = set, - .free_list = free_list, - }; - - debugv_print("Computing free variables...\n"); - visit_domtree(&ctx, scope_lookup(scope, at), 0); - - destroy_dict(bound_set); - destroy_dict(set); - return free_list; -} diff --git a/src/shady/analysis/free_variables.h b/src/shady/analysis/free_variables.h deleted file mode 100644 index f6f92f4be..000000000 --- a/src/shady/analysis/free_variables.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef SHADY_FREE_VARIABLES_H -#define SHADY_FREE_VARIABLES_H - -#include "shady/ir.h" - -typedef struct Scope_ Scope; - -struct List* compute_free_variables(const Scope* scope, const Node*); - -#endif diff --git a/src/shady/analysis/leak.c b/src/shady/analysis/leak.c index 2c8dea13a..db2c52858 100644 --- a/src/shady/analysis/leak.c +++ b/src/shady/analysis/leak.c @@ -3,10 +3,10 @@ #include #include -#include "../visit.h" +#include "shady/visit.h" -void visit_enclosing_abstractions(UsesMap* map, const Node* n, void* uptr, VisitEnclosingAbsCallback fn) { - const Use* use = get_first_use(map, n); +void shd_visit_enclosing_abstractions(UsesMap* map, const Node* n, void* uptr, VisitEnclosingAbsCallback fn) { + const Use* use = shd_get_first_use(map, n); for (;use; use = use->next_use) { if (is_abstraction(use->user)) { fn(uptr, use); @@ -16,33 +16,17 @@ void visit_enclosing_abstractions(UsesMap* map, const Node* n, void* uptr, Visit if (is_declaration(use->user)) continue; - visit_enclosing_abstractions(map, n, uptr, fn); + shd_visit_enclosing_abstractions(map, n, uptr, fn); } } -const Node* get_binding_abstraction(const UsesMap* map, const Node* var) { - assert(var->tag == Variable_TAG); - const Use* use = get_first_use(map, var); - assert(use); - const Use* binding_use = NULL; - for (;use; use = use->next_use) { - if (is_abstraction(use->user) && use->operand_class == NcVariable) { - assert(!binding_use); - binding_use = use; - } - } - assert(binding_use && "Failed to find the binding abstraction in the uses map"); - return binding_use->user; -} - -bool is_control_static(const UsesMap* map, const Node* control) { +bool shd_is_control_static(const UsesMap* map, const Node* control) { assert(control->tag == Control_TAG); const Node* inside = control->payload.control.inside; - assert(is_case(inside)); - const Node* jp = first(get_abstraction_params(inside)); + const Node* jp = shd_first(get_abstraction_params(inside)); bool found_binding_abs = false; - const Use* use = get_first_use(map, jp); + const Use* use = shd_get_first_use(map, jp); assert(use && "we expected at least one use ... "); for (;use; use = use->next_use) { if (use->user == control->payload.control.inside) { @@ -57,3 +41,18 @@ bool is_control_static(const UsesMap* map, const Node* control) { assert(found_binding_abs); return true; } + +const Node* shd_get_control_for_jp(const UsesMap* map, const Node* jp) { + if (!is_param(jp)) + return NULL; + const Node* abs = jp->payload.param.abs; + assert(is_abstraction(abs)); + + const Use* use = shd_get_first_use(map, abs); + for (;use; use = use->next_use) { + if (use->user->tag == Control_TAG && use->operand_class == NcBasic_block && strcmp(use->operand_name, "inside") == 0) + return use->user; + } + + return NULL; +} diff --git a/src/shady/analysis/leak.h b/src/shady/analysis/leak.h index 21ffe7656..616891352 100644 --- a/src/shady/analysis/leak.h +++ b/src/shady/analysis/leak.h @@ -4,13 +4,13 @@ #include #include "uses.h" -#include "scope.h" +#include "cfg.h" typedef void (VisitEnclosingAbsCallback)(void*, const Use*); -void visit_enclosing_abstractions(UsesMap*, const Node*, void* uptr, VisitEnclosingAbsCallback fn); +void shd_visit_enclosing_abstractions(UsesMap* map, const Node* n, void* uptr, VisitEnclosingAbsCallback fn); -const Node* get_binding_abstraction(const UsesMap*, const Node* var); - -bool is_control_static(const UsesMap*, const Node* control); +bool shd_is_control_static(const UsesMap* map, const Node* control); +/// Returns the Control node that defines the join point, or NULL if it's defined by something else +const Node* shd_get_control_for_jp(const UsesMap* map, const Node* jp); #endif diff --git a/src/shady/analysis/literal.c b/src/shady/analysis/literal.c new file mode 100644 index 000000000..632e662e9 --- /dev/null +++ b/src/shady/analysis/literal.c @@ -0,0 +1,183 @@ +#include "shady/analysis/literal.h" + +#include "shady/ir/int.h" +#include "shady/ir/type.h" + +#include "portability.h" + +#include + +static bool is_zero(const Node* node) { + const IntLiteral* lit = shd_resolve_to_int_literal(node); + if (lit && shd_get_int_literal_value(*lit, false) == 0) + return true; + return false; +} + +const Node* shd_chase_ptr_to_source(const Node* ptr, NodeResolveConfig config) { + while (true) { + ptr = shd_resolve_node_to_definition(ptr, config); + switch (ptr->tag) { + case PtrArrayElementOffset_TAG: break; + case PtrCompositeElement_TAG: { + PtrCompositeElement payload = ptr->payload.ptr_composite_element; + if (!is_zero(payload.index)) + break; + ptr = payload.ptr; + continue; + } + case PrimOp_TAG: { + switch (ptr->payload.prim_op.op) { + case convert_op: { + // chase generic pointers to their source + if (shd_first(ptr->payload.prim_op.type_arguments)->tag == PtrType_TAG) { + ptr = shd_first(ptr->payload.prim_op.operands); + continue; + } + break; + } + case reinterpret_op: { + // chase ptr casts to their source + // TODO: figure out round-trips through integer casts? + if (shd_first(ptr->payload.prim_op.type_arguments)->tag == PtrType_TAG) { + ptr = shd_first(ptr->payload.prim_op.operands); + continue; + } + break; + } + default: break; + } + break; + } + default: break; + } + break; + } + return ptr; +} + +const Node* shd_resolve_ptr_to_value(const Node* ptr, NodeResolveConfig config) { + while (ptr) { + ptr = shd_resolve_node_to_definition(ptr, config); + switch (ptr->tag) { + case PrimOp_TAG: { + switch (ptr->payload.prim_op.op) { + case convert_op: { // allow address space conversions + ptr = shd_first(ptr->payload.prim_op.operands); + continue; + } + default: break; + } + } + case GlobalVariable_TAG: + if (config.assume_globals_immutability) + return ptr->payload.global_variable.init; + break; + default: break; + } + ptr = NULL; + } + return NULL; +} + +NodeResolveConfig shd_default_node_resolve_config(void) { + return (NodeResolveConfig) { + .enter_loads = true, + .allow_incompatible_types = false, + .assume_globals_immutability = false, + }; +} + +const Node* shd_resolve_node_to_definition(const Node* node, NodeResolveConfig config) { + while (node) { + switch (node->tag) { + case Constant_TAG: + node = node->payload.constant.value; + continue; + case RefDecl_TAG: + node = node->payload.ref_decl.decl; + continue; + case Load_TAG: { + if (config.enter_loads) { + const Node* source = node->payload.load.ptr; + const Node* result = shd_resolve_ptr_to_value(source, config); + if (!result) + break; + node = result; + continue; + } + } + case PrimOp_TAG: { + switch (node->payload.prim_op.op) { + case convert_op: + case reinterpret_op: { + if (config.allow_incompatible_types) { + node = shd_first(node->payload.prim_op.operands); + continue; + } + } + default: break; + } + break; + } + default: break; + } + break; + } + return node; +} + +const char* shd_get_string_literal(IrArena* arena, const Node* node) { + if (!node) + return NULL; + if (node->type && shd_get_unqualified_type(node->type)->tag == PtrType_TAG) { + NodeResolveConfig nrc = shd_default_node_resolve_config(); + const Node* ptr = shd_chase_ptr_to_source(node, nrc); + const Node* value = shd_resolve_ptr_to_value(ptr, nrc); + if (value) + return shd_get_string_literal(arena, value); + } + switch (node->tag) { + case Declaration_GlobalVariable_TAG: { + const Node* init = node->payload.global_variable.init; + if (init) { + return shd_get_string_literal(arena, init); + } + break; + } + case Declaration_Constant_TAG: { + return shd_get_string_literal(arena, node->payload.constant.value); + } + case RefDecl_TAG: { + const Node* decl = node->payload.ref_decl.decl; + return shd_get_string_literal(arena, decl); + } + /*case Lea_TAG: { + Lea lea = node->payload.lea; + if (lea.indices.count == 3 && is_zero(lea.offset) && is_zero(first(lea.indices))) { + const Node* ref = lea.ptr; + if (ref->tag != RefDecl_TAG) + return NULL; + const Node* decl = ref->payload.ref_decl.decl; + if (decl->tag != GlobalVariable_TAG || !decl->payload.global_variable.init) + return NULL; + return get_string_literal(arena, decl->payload.global_variable.init); + } + break; + }*/ + case StringLiteral_TAG: return node->payload.string_lit.string; + case Composite_TAG: { + Nodes contents = node->payload.composite.contents; + LARRAY(char, chars, contents.count); + for (size_t i = 0; i < contents.count; i++) { + const Node* value = contents.nodes[i]; + assert(value->tag == IntLiteral_TAG && value->payload.int_literal.width == IntTy8); + chars[i] = (unsigned char) shd_get_int_literal_value(*shd_resolve_to_int_literal(value), false); + } + assert(chars[contents.count - 1] == 0); + return shd_string(arena, chars); + } + default: break; + } + return NULL; +} diff --git a/src/shady/analysis/looptree.c b/src/shady/analysis/looptree.c index c29500602..785375a4d 100644 --- a/src/shady/analysis/looptree.c +++ b/src/shady/analysis/looptree.c @@ -14,7 +14,7 @@ typedef struct { } State; typedef struct { - Scope* s; + CFG* s; State* states; /** @@ -23,18 +23,18 @@ typedef struct { struct List* stack; } LoopTreeBuilder; -KeyHash hash_node(const Node**); -bool compare_node(const Node**, const Node**); +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node**, const Node**); -LTNode* new_lf_node(int type, LTNode* parent, int depth, struct List* cf_nodes) { +static LTNode* new_lf_node(int type, LTNode* parent, int depth, struct List* cf_nodes) { LTNode* n = calloc(sizeof(LTNode), 1); n->parent = parent; n->type = type; - n->lf_children = new_list(LTNode*); + n->lf_children = shd_new_list(LTNode*); n->cf_nodes = cf_nodes; n->depth = depth; if (parent) { - append_list(LTNode*, parent->lf_children, n); + shd_list_append(LTNode*, parent->lf_children, n); } return n; } @@ -51,8 +51,8 @@ LTNode* new_lf_node(int type, LTNode* parent, int depth, struct List* cf_nodes) static bool is_leaf(LoopTreeBuilder* ltb, const CFNode* n, size_t num) { if (num == 1) { struct List* succ_edges = n->succ_edges; - for (size_t i = 0; i < entries_count_list(succ_edges); i++) { - CFEdge e = read_list(CFEdge, succ_edges)[i]; + for (size_t i = 0; i < shd_list_count(succ_edges); i++) { + CFEdge e = shd_read_list(CFEdge, succ_edges)[i]; CFNode* succ = e.dst; if (!is_head(ltb, succ) && n == succ) return false; @@ -70,7 +70,7 @@ static int visit(LoopTreeBuilder* ltb, const CFNode* n, int counter) { ltb->states[n->rpo_index].dfs = counter; ltb->states[n->rpo_index].low_link = counter; // push it - append_list(const CFNode*, ltb->stack, n); + shd_list_append(const CFNode*, ltb->stack, n); on_stack(ltb, n) = true; return counter + 1; } @@ -78,8 +78,8 @@ static int visit(LoopTreeBuilder* ltb, const CFNode* n, int counter) { static int walk_scc(LoopTreeBuilder* ltb, const CFNode* cur, LTNode* parent, int depth, int scc_counter) { scc_counter = visit(ltb, cur, scc_counter); - for (size_t succi = 0; succi < entries_count_list(cur->succ_edges); succi++) { - CFEdge succe = read_list(CFEdge, cur->succ_edges)[succi]; + for (size_t succi = 0; succi < shd_list_count(cur->succ_edges); succi++) { + CFEdge succe = shd_read_list(CFEdge, cur->succ_edges)[succi]; CFNode* succ = succe.dst; if (is_head(ltb, succ)) continue; // this is a backedge @@ -93,30 +93,30 @@ static int walk_scc(LoopTreeBuilder* ltb, const CFNode* cur, LTNode* parent, int // root of SCC if (lowlink(ltb, cur) == dfs(ltb, cur)) { - struct List* heads = new_list(const CFNode*); + struct List* heads = shd_new_list(const CFNode*); // mark all cf_nodes in current SCC (all cf_nodes from back to cur on the stack) as 'in_scc' - size_t num = 0, e = entries_count_list(ltb->stack); + size_t num = 0, e = shd_list_count(ltb->stack); size_t b = e - 1; do { - in_scc(ltb, read_list(const CFNode*, ltb->stack)[b]) = true; + in_scc(ltb, shd_read_list(const CFNode*, ltb->stack)[b]) = true; ++num; - } while (read_list(const CFNode*, ltb->stack)[b--] != cur); + } while (shd_read_list(const CFNode*, ltb->stack)[b--] != cur); // for all cf_nodes in current SCC for (size_t i = ++b; i != e; i++) { - const CFNode* n = read_list(const CFNode*, ltb->stack)[i]; + const CFNode* n = shd_read_list(const CFNode*, ltb->stack)[i]; if (ltb->s->entry == n) { - append_list(const CFNode*, heads, n); // entries are axiomatically heads + shd_list_append(const CFNode*, heads, n); // entries are axiomatically heads } else { - for (size_t j = 0; j < entries_count_list(n->pred_edges); j++) { - assert(n == read_list(CFEdge, n->pred_edges)[j].dst); - const CFNode* pred = read_list(CFEdge, n->pred_edges)[j].src; + for (size_t j = 0; j < shd_list_count(n->pred_edges); j++) { + assert(n == shd_read_list(CFEdge, n->pred_edges)[j].dst); + const CFNode* pred = shd_read_list(CFEdge, n->pred_edges)[j].src; // all backedges are also inducing heads // but do not yet mark them globally as head -- we are still running through the SCC if (!in_scc(ltb, pred)) { - append_list(const CFNode*, heads, n); + shd_list_append(const CFNode*, heads, n); break; } } @@ -124,7 +124,7 @@ static int walk_scc(LoopTreeBuilder* ltb, const CFNode* cur, LTNode* parent, int } if (is_leaf(ltb, cur, num)) { - assert(entries_count_list(heads) == 1); + assert(shd_list_count(heads) == 1); new_lf_node(LF_LEAF, parent, depth, heads); } else { new_lf_node(LF_HEAD, parent, depth, heads); @@ -132,13 +132,13 @@ static int walk_scc(LoopTreeBuilder* ltb, const CFNode* cur, LTNode* parent, int // reset in_scc and on_stack flags for (size_t i = b; i != e; ++i) { - in_scc(ltb, read_list(const CFNode*, ltb->stack)[i]) = false; - on_stack(ltb, read_list(const CFNode*, ltb->stack)[i]) = false; + in_scc(ltb, shd_read_list(const CFNode*, ltb->stack)[i]) = false; + on_stack(ltb, shd_read_list(const CFNode*, ltb->stack)[i]) = false; } // pop whole SCC - while (entries_count_list(ltb->stack) != b) { - pop_last_list(const CFNode*, ltb->stack); + while (shd_list_count(ltb->stack) != b) { + shd_list_pop(const CFNode*, ltb->stack); } } @@ -153,22 +153,22 @@ static void clear_set(LoopTreeBuilder* ltb) { static void recurse(LoopTreeBuilder* ltb, LTNode* parent, struct List* heads, int depth) { assert(parent->type == LF_HEAD); size_t cur_new_child = 0; - for (size_t i = 0; i < entries_count_list(heads); i++) { - const CFNode* head = read_list(const CFNode*, heads)[i]; + for (size_t i = 0; i < shd_list_count(heads); i++) { + const CFNode* head = shd_read_list(const CFNode*, heads)[i]; clear_set(ltb); walk_scc(ltb, head, parent, depth, 0); - for (size_t e = entries_count_list(parent->lf_children); cur_new_child != e; ++cur_new_child) { - struct List* new_child_nodes = read_list(LTNode*, parent->lf_children)[cur_new_child]->cf_nodes; - for (size_t j = 0; j < entries_count_list(new_child_nodes); j++) { - CFNode* head2 = read_list(CFNode*, new_child_nodes)[j]; + for (size_t e = shd_list_count(parent->lf_children); cur_new_child != e; ++cur_new_child) { + struct List* new_child_nodes = shd_read_list(LTNode*, parent->lf_children)[cur_new_child]->cf_nodes; + for (size_t j = 0; j < shd_list_count(new_child_nodes); j++) { + CFNode* head2 = shd_read_list(CFNode*, new_child_nodes)[j]; is_head(ltb, head2) = true; } } } - for (size_t i = 0; i < entries_count_list(parent->lf_children); i++) { - LTNode* node = read_list(LTNode*, parent->lf_children)[i]; + for (size_t i = 0; i < shd_list_count(parent->lf_children); i++) { + LTNode* node = shd_read_list(LTNode*, parent->lf_children)[i]; if (node->type == LF_HEAD) recurse(ltb, node, node->cf_nodes, depth + 1); } @@ -176,24 +176,24 @@ static void recurse(LoopTreeBuilder* ltb, LTNode* parent, struct List* heads, in static void build_map_recursive(struct Dict* map, LTNode* n) { if (n->type == LF_LEAF) { - assert(entries_count_list(n->cf_nodes) == 1); - const Node* node = read_list(CFNode*, n->cf_nodes)[0]->node; - insert_dict(const Node*, LTNode*, map, node, n); + assert(shd_list_count(n->cf_nodes) == 1); + const Node* node = shd_read_list(CFNode*, n->cf_nodes)[0]->node; + shd_dict_insert(const Node*, LTNode*, map, node, n); } else { - for (size_t i = 0; i < entries_count_list(n->lf_children); i++) { - LTNode* child = read_list(LTNode*, n->lf_children)[i]; + for (size_t i = 0; i < shd_list_count(n->lf_children); i++) { + LTNode* child = shd_read_list(LTNode*, n->lf_children)[i]; build_map_recursive(map, child); } } } -LTNode* looptree_lookup(LoopTree* lt, const Node* block) { - LTNode** found = find_value_dict(const Node*, LTNode*, lt->map, block); +LTNode* shd_loop_tree_lookup(LoopTree* lt, const Node* block) { + LTNode** found = shd_dict_find_value(const Node*, LTNode*, lt->map, block); if (found) return *found; assert(false); } -LoopTree* build_loop_tree(Scope* s) { +LoopTree* shd_new_loop_tree(CFG* s) { LARRAY(State, states, s->size); for (size_t i = 0; i < s->size; i++) { states[i] = (State) { @@ -208,37 +208,37 @@ LoopTree* build_loop_tree(Scope* s) { LoopTreeBuilder ltb = { .states = states, .s = s, - .stack = new_list(const CFNode*), + .stack = shd_new_list(const CFNode*), }; LoopTree* lt = calloc(sizeof(LoopTree), 1); - struct List* empty_list = new_list(CFNode*); + struct List* empty_list = shd_new_list(CFNode*); lt->root = new_lf_node(LF_HEAD, NULL, 0, empty_list); const CFNode* entry = s->entry; - struct List* global_heads = new_list(const CFNode*); - append_list(const CFNode*, global_heads, entry); + struct List* global_heads = shd_new_list(const CFNode*); + shd_list_append(const CFNode*, global_heads, entry); recurse(<b, lt->root, global_heads, 1); - destroy_list(global_heads); - destroy_list(ltb.stack); + shd_destroy_list(global_heads); + shd_destroy_list(ltb.stack); - lt->map = new_dict(const Node*, LTNode*, (HashFn) hash_node, (CmpFn) compare_node); + lt->map = shd_new_dict(const Node*, LTNode*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); build_map_recursive(lt->map, lt->root); return lt; } static void destroy_lt_node(LTNode* n) { - for (size_t i = 0; i < entries_count_list(n->lf_children); i++) { - destroy_lt_node(read_list(LTNode*, n->lf_children)[i]); + for (size_t i = 0; i < shd_list_count(n->lf_children); i++) { + destroy_lt_node(shd_read_list(LTNode*, n->lf_children)[i]); } - destroy_list(n->lf_children); - destroy_list(n->cf_nodes); + shd_destroy_list(n->lf_children); + shd_destroy_list(n->cf_nodes); free(n); } -void destroy_loop_tree(LoopTree* lt) { +void shd_destroy_loop_tree(LoopTree* lt) { destroy_lt_node(lt->root); - destroy_dict(lt->map); + shd_destroy_dict(lt->map); free(lt); } @@ -247,7 +247,7 @@ static int extra_uniqueness = 0; static void dump_lt_node(FILE* f, const LTNode* n) { if (n->type == LF_HEAD) { fprintf(f, "subgraph cluster_%d {\n", extra_uniqueness++); - if (entries_count_list(n->cf_nodes) == 0) { + if (shd_list_count(n->cf_nodes) == 0) { fprintf(f, "label = \"%s\";\n", "Entry"); } else { fprintf(f, "label = \"%s\";\n", "LoopHead"); @@ -257,13 +257,13 @@ static void dump_lt_node(FILE* f, const LTNode* n) { fprintf(f, "label = \"%s\";\n", "Leaf"); } - for (size_t i = 0; i < entries_count_list(n->cf_nodes); i++) { - const Node* bb = read_list(const CFNode*, n->cf_nodes)[i]->node; - fprintf(f, "%s_%d;\n", get_abstraction_name(bb), extra_uniqueness++); + for (size_t i = 0; i < shd_list_count(n->cf_nodes); i++) { + const Node* bb = shd_read_list(const CFNode*, n->cf_nodes)[i]->node; + fprintf(f, "bb_%d[label=\"%s\"];\n", extra_uniqueness++, shd_get_abstraction_name_safe(bb)); } - for (size_t i = 0; i < entries_count_list(n->lf_children); i++) { - const LTNode* child = read_list(const LTNode*, n->lf_children)[i]; + for (size_t i = 0; i < shd_list_count(n->lf_children); i++) { + const LTNode* child = shd_read_list(const LTNode*, n->lf_children)[i]; dump_lt_node(f, child); } @@ -271,7 +271,7 @@ static void dump_lt_node(FILE* f, const LTNode* n) { fprintf(f, "}\n"); } -void dump_loop_tree(FILE* f, LoopTree* lt) { +void shd_dump_loop_tree(FILE* f, LoopTree* lt) { //fprintf(f, "digraph G {\n"); fprintf(f, "subgraph cluster_%d {\n", extra_uniqueness++); dump_lt_node(f, lt->root); @@ -279,19 +279,26 @@ void dump_loop_tree(FILE* f, LoopTree* lt) { //fprintf(f, "}\n"); } -void dump_loop_trees(FILE* output, Module* mod) { +void shd_dump_loop_trees(FILE* output, Module* mod) { if (output == NULL) output = stderr; fprintf(output, "digraph G {\n"); - struct List* scopes = build_scopes(mod); - for (size_t i = 0; i < entries_count_list(scopes); i++) { - Scope* scope = read_list(Scope*, scopes)[i]; - LoopTree* lt = build_loop_tree(scope); - dump_loop_tree(output, lt); - destroy_loop_tree(lt); - destroy_scope(scope); + struct List* cfgs = shd_build_cfgs(mod, default_forward_cfg_build()); + for (size_t i = 0; i < shd_list_count(cfgs); i++) { + CFG* cfg = shd_read_list(CFG*, cfgs)[i]; + LoopTree* lt = shd_new_loop_tree(cfg); + shd_dump_loop_tree(output, lt); + shd_destroy_loop_tree(lt); + shd_destroy_cfg(cfg); } - destroy_list(scopes); + shd_destroy_list(cfgs); fprintf(output, "}\n"); } + + +void shd_dump_loop_trees_auto(Module* mod) { + FILE* f = fopen("loop_trees.dot", "wb"); + shd_dump_loop_trees(f, mod); + fclose(f); +} diff --git a/src/shady/analysis/looptree.h b/src/shady/analysis/looptree.h index 838ae0c66..6ec6db055 100644 --- a/src/shady/analysis/looptree.h +++ b/src/shady/analysis/looptree.h @@ -1,7 +1,7 @@ #ifndef SHADY_LOOPTREE_H #define SHADY_LOOPTREE_H -#include "scope.h" +#include "cfg.h" // Loop tree implementation based on Thorin, translated to C and somewhat simplified // https://github.com/AnyDSL/thorin @@ -39,12 +39,11 @@ struct LoopTree_ { /** * Returns the leaf for this node. */ -LTNode* looptree_lookup(LoopTree*, const Node* block); +LTNode* shd_loop_tree_lookup(LoopTree* lt, const Node* block); -static void destroy_lt_node(LTNode* n); -void destroy_loop_tree(LoopTree* lt); +void shd_destroy_loop_tree(LoopTree* lt); -LoopTree* build_loop_tree(Scope* s); -void dump_loop_trees(FILE* output, Module* mod); +LoopTree* shd_new_loop_tree(CFG* s); +void shd_dump_loop_trees(FILE* output, Module* mod); #endif // SHADY_LOOPTREE_H diff --git a/src/shady/analysis/scheduler.c b/src/shady/analysis/scheduler.c new file mode 100644 index 000000000..7b4955a81 --- /dev/null +++ b/src/shady/analysis/scheduler.c @@ -0,0 +1,84 @@ +#include "scheduler.h" + +#include "shady/visit.h" + +#include "dict.h" +#include + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +struct Scheduler_ { + Visitor v; + CFNode* result; + CFG* cfg; + struct Dict* scheduled; +}; + +static void schedule_after(CFNode** scheduled, CFNode* req) { + if (!req) + return; + CFNode* old = *scheduled; + if (!old) + *scheduled = req; + else { + // TODO: validate that old post-dominates req + if (req->rpo_index > old->rpo_index) { + assert(shd_cfg_is_dominated(req, old)); + *scheduled = req; + } else { + assert(shd_cfg_is_dominated(old, req)); + } + } +} + +static void visit_operand(Scheduler* s, NodeClass nc, String opname, const Node* op, size_t i) { + switch (nc) { + // We only care about mem and value dependencies + case NcMem: + case NcValue: + schedule_after(&s->result, shd_schedule_instruction(s, op)); + break; + default: + break; + } +} + +Scheduler* shd_new_scheduler(CFG* cfg) { + Scheduler* s = calloc(sizeof(Scheduler), 1); + *s = (Scheduler) { + .v = { + .visit_op_fn = (VisitOpFn) visit_operand, + }, + .cfg = cfg, + .scheduled = shd_new_dict(const Node*, CFNode*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + return s; +} + +CFNode* shd_schedule_instruction(Scheduler* s, const Node* n) { + //assert(n && is_instruction(n)); + CFNode** found = shd_dict_find_value(const Node*, CFNode*, s->scheduled, n); + if (found) + return *found; + + Scheduler s2 = *s; + s2.result = NULL; + + if (n->tag == Param_TAG) { + schedule_after(&s2.result, shd_cfg_lookup(s->cfg, n->payload.param.abs)); + } else if (n->tag == BasicBlock_TAG) { + schedule_after(&s2.result, shd_cfg_lookup(s->cfg, n)); + } else if (n->tag == AbsMem_TAG) { + schedule_after(&s2.result, shd_cfg_lookup(s->cfg, n->payload.abs_mem.abs)); + } + + shd_visit_node_operands(&s2.v, 0, n); + shd_dict_insert(const Node*, CFNode*, s->scheduled, n, s2.result); + return s2.result; +} + +void shd_destroy_scheduler(Scheduler* s) { + shd_destroy_dict(s->scheduled); + free(s); +} diff --git a/src/shady/analysis/scheduler.h b/src/shady/analysis/scheduler.h new file mode 100644 index 000000000..f3fedfff4 --- /dev/null +++ b/src/shady/analysis/scheduler.h @@ -0,0 +1,15 @@ +#ifndef SHADY_SCHEDULER_H +#define SHADY_SCHEDULER_H + +#include "shady/ir.h" +#include "cfg.h" + +typedef struct Scheduler_ Scheduler; + +Scheduler* shd_new_scheduler(CFG* cfg); +void shd_destroy_scheduler(Scheduler* s); + +/// Returns the CFNode where that instruction should be placed, or NULL if it can be computed at the top-level +CFNode* shd_schedule_instruction(Scheduler* s, const Node* n); + +#endif diff --git a/src/shady/analysis/scope.c b/src/shady/analysis/scope.c deleted file mode 100644 index 3f1721f80..000000000 --- a/src/shady/analysis/scope.c +++ /dev/null @@ -1,629 +0,0 @@ -#include "scope.h" -#include "looptree.h" -#include "log.h" - -#include "list.h" -#include "dict.h" -#include "arena.h" -#include "util.h" - -#include "../ir_private.h" - -#include -#include - -struct List* build_scopes(Module* mod) { - struct List* scopes = new_list(Scope*); - - Nodes decls = get_module_declarations(mod); - for (size_t i = 0; i < decls.count; i++) { - const Node* decl = decls.nodes[i]; - if (decl->tag != Function_TAG) continue; - Scope* scope = new_scope(decl); - append_list(Scope*, scopes, scope); - } - - return scopes; -} - -KeyHash hash_node(const Node**); -bool compare_node(const Node**, const Node**); - -typedef struct { - Arena* arena; - const Node* entry; - LoopTree* lt; - struct Dict* nodes; - struct List* queue; - struct List* contents; - - struct Dict* join_point_values; -} ScopeBuildContext; - -CFNode* scope_lookup(Scope* scope, const Node* block) { - CFNode** found = find_value_dict(const Node*, CFNode*, scope->map, block); - if (found) { - assert((*found)->node); - return *found; - } - assert(false); - return NULL; -} - -static CFNode* get_or_enqueue(ScopeBuildContext* ctx, const Node* abs) { - assert(is_abstraction(abs)); - assert(!is_function(abs) || abs == ctx->entry); - CFNode** found = find_value_dict(const Node*, CFNode*, ctx->nodes, abs); - if (found) return *found; - - CFNode* new = arena_alloc(ctx->arena, sizeof(CFNode)); - *new = (CFNode) { - .node = abs, - .succ_edges = new_list(CFEdge), - .pred_edges = new_list(CFEdge), - .rpo_index = SIZE_MAX, - .idom = NULL, - .dominates = NULL, - .structurally_dominates = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node), - }; - assert(abs && new->node); - insert_dict(const Node*, CFNode*, ctx->nodes, abs, new); - append_list(Node*, ctx->queue, new); - append_list(Node*, ctx->contents, new); - return new; -} - -static bool in_loop(LoopTree* lt, const Node* entry, const Node* block) { - LTNode* lt_node = looptree_lookup(lt, block); - assert(lt_node); - LTNode* parent = lt_node->parent; - assert(parent); - - while (parent) { - if (entries_count_list(parent->cf_nodes) != 1) - return false; - - if (read_list(CFNode*, parent->cf_nodes)[0]->node == entry) - return true; - - parent = parent->parent; - } - - return false; -} - -static bool is_structural_edge(CFEdgeType edge_type) { return edge_type != JumpEdge; } - -/// Adds an edge to somewhere inside a basic block -static void add_edge(ScopeBuildContext* ctx, const Node* src, const Node* dst, CFEdgeType type) { - assert(is_abstraction(src) && is_abstraction(dst)); - assert(!is_function(dst)); - assert(is_structural_edge(type) == (bool) is_case(dst)); - if (ctx->lt && !in_loop(ctx->lt, ctx->entry, dst)) - return; - if (ctx->lt && dst == ctx->entry) - return; - - CFNode* src_node = get_or_enqueue(ctx, src); - CFNode* dst_node = get_or_enqueue(ctx, dst); - CFEdge edge = { - .type = type, - .src = src_node, - .dst = dst_node, - }; - append_list(CFEdge, src_node->succ_edges, edge); - append_list(CFEdge, dst_node->pred_edges, edge); -} - -static void add_structural_dominance_edge(ScopeBuildContext* ctx, CFNode* parent, const Node* dst, CFEdgeType type) { - add_edge(ctx, parent->node, dst, type); - insert_set_get_result(const Node*, parent->structurally_dominates, dst); -} - -static void add_jump_edge(ScopeBuildContext* ctx, const Node* src, const Node* j) { - assert(j->tag == Jump_TAG); - const Node* target = j->payload.jump.target; - add_edge(ctx, src, target, JumpEdge); -} - -static void process_instruction(ScopeBuildContext* ctx, CFNode* parent, const Node* instruction, const Node* let_tail) { - switch (is_instruction(instruction)) { - case NotAnInstruction: error("Grammar problem"); - case Instruction_Call_TAG: - case Instruction_PrimOp_TAG: - case Instruction_Comment_TAG: - add_structural_dominance_edge(ctx, parent, let_tail, LetTailEdge); - return; - case Instruction_Block_TAG: - add_structural_dominance_edge(ctx, parent, instruction->payload.block.inside, StructuredEnterBodyEdge); - add_structural_dominance_edge(ctx, parent, let_tail, LetTailEdge); - return; - case Instruction_If_TAG: - add_structural_dominance_edge(ctx, parent, instruction->payload.if_instr.if_true, StructuredEnterBodyEdge); - if(instruction->payload.if_instr.if_false) - add_structural_dominance_edge(ctx, parent, instruction->payload.if_instr.if_false, StructuredEnterBodyEdge); - break; - case Instruction_Match_TAG: - for (size_t i = 0; i < instruction->payload.match_instr.cases.count; i++) - add_structural_dominance_edge(ctx, parent, instruction->payload.match_instr.cases.nodes[i], StructuredEnterBodyEdge); - add_structural_dominance_edge(ctx, parent, instruction->payload.match_instr.default_case, StructuredEnterBodyEdge); - break; - case Instruction_Loop_TAG: - add_structural_dominance_edge(ctx, parent, instruction->payload.loop_instr.body, StructuredEnterBodyEdge); - break; - case Instruction_Control_TAG: - add_structural_dominance_edge(ctx, parent, instruction->payload.control.inside, StructuredEnterBodyEdge); - const Node* param = first(get_abstraction_params(instruction->payload.control.inside)); - CFNode* let_tail_cfnode = get_or_enqueue(ctx, let_tail); - insert_dict(const Node*, CFNode*, ctx->join_point_values, param, let_tail_cfnode); - break; - } - add_structural_dominance_edge(ctx, parent, let_tail, StructuredPseudoExitEdge); -} - -static void process_cf_node(ScopeBuildContext* ctx, CFNode* node) { - const Node* const abs = node->node; - assert(is_abstraction(abs)); - assert(!is_function(abs) || abs == ctx->entry); - const Node* terminator = get_abstraction_body(abs); - if (!terminator) - return; - switch (is_terminator(terminator)) { - case LetMut_TAG: - case Let_TAG: { - const Node* target = get_let_tail(terminator); - process_instruction(ctx, node, get_let_instruction(terminator), target); - break; - } - case Jump_TAG: { - add_jump_edge(ctx, abs, terminator); - break; - } - case Branch_TAG: { - add_jump_edge(ctx, abs, terminator->payload.branch.true_jump); - add_jump_edge(ctx, abs, terminator->payload.branch.false_jump); - break; - } - case Switch_TAG: { - for (size_t i = 0; i < terminator->payload.br_switch.case_jumps.count; i++) - add_jump_edge(ctx, abs, terminator->payload.br_switch.case_jumps.nodes[i]); - add_jump_edge(ctx, abs, terminator->payload.br_switch.default_jump); - break; - } - case Join_TAG: { - CFNode** dst = find_value_dict(const Node*, CFNode*, ctx->join_point_values, terminator->payload.join.join_point); - if (dst) - add_edge(ctx, node->node, (*dst)->node, StructuredLeaveBodyEdge); - break; - } - case Yield_TAG: - case MergeContinue_TAG: - case MergeBreak_TAG: { - break; // TODO i guess - } - case TailCall_TAG: - case Return_TAG: - case Unreachable_TAG: break; - case NotATerminator: if (terminator->arena->config.check_types) { error("Grammar problem"); } break; - } -} - -/** - * Invert all edges in this scope. Used to compute a post dominance tree. - */ -static void flip_scope(Scope* scope) { - scope->entry = NULL; - - for (size_t i = 0; i < scope->size; i++) { - CFNode * cur = read_list(CFNode*, scope->contents)[i]; - - struct List* tmp = cur->succ_edges; - cur->succ_edges = cur->pred_edges; - cur->pred_edges = tmp; - - for (size_t j = 0; j < entries_count_list(cur->succ_edges); j++) { - CFEdge* edge = &read_list(CFEdge, cur->succ_edges)[j]; - - CFNode* tmp = edge->dst; - edge->dst = edge->src; - edge->src = tmp; - } - - for (size_t j = 0; j < entries_count_list(cur->pred_edges); j++) { - CFEdge* edge = &read_list(CFEdge, cur->pred_edges)[j]; - - CFNode* tmp = edge->dst; - edge->dst = edge->src; - edge->src = tmp; - } - - if (entries_count_list(cur->pred_edges) == 0) { - if (scope->entry != NULL) { - if (scope->entry->node) { - CFNode* new_entry = arena_alloc(scope->arena, sizeof(CFNode)); - *new_entry = (CFNode) { - .node = NULL, - .succ_edges = new_list(CFEdge), - .pred_edges = new_list(CFEdge), - .rpo_index = SIZE_MAX, - .idom = NULL, - .dominates = NULL, - }; - - CFEdge prev_entry_edge = { - .type = JumpEdge, - .src = new_entry, - .dst = scope->entry - }; - append_list(CFEdge, new_entry->succ_edges, prev_entry_edge); - append_list(CFEdge, scope->entry->pred_edges, prev_entry_edge); - scope->entry = new_entry; - } - - CFEdge new_edge = { - .type = JumpEdge, - .src = scope->entry, - .dst = cur - }; - append_list(CFEdge, scope->entry->succ_edges, new_edge); - append_list(CFEdge, cur->pred_edges, new_edge); - } else { - scope->entry = cur; - } - } - } - - if (!scope->entry->node) { - scope->size += 1; - append_list(Node*, scope->contents, scope->entry); - } -} - -static void validate_scope(Scope* scope) { - for (size_t i = 0; i < scope->size; i++) { - CFNode* node = read_list(CFNode*, scope->contents)[i]; - if (is_case(node->node)) { - size_t structured_body_uses = 0; - for (size_t j = 0; j < entries_count_list(node->pred_edges); j++) { - CFEdge edge = read_list(CFEdge, node->pred_edges)[j]; - switch (edge.type) { - case JumpEdge: - error_print("Error: cases cannot be jumped to directly."); - error_die(); - case LetTailEdge: - structured_body_uses += 1; - break; - case StructuredEnterBodyEdge: - structured_body_uses += 1; - break; - case StructuredPseudoExitEdge: - structured_body_uses += 1; - case StructuredLeaveBodyEdge: - break; - } - } - if (structured_body_uses != 1 && node != scope->entry /* this exception exists since we might build scopes rooted in cases */) { - error_print("reachable cases must be used be as bodies exactly once (actual uses: %zu)", structured_body_uses); - error_die(); - } - } - } -} - -Scope* new_scope_impl(const Node* entry, LoopTree* lt, bool flipped) { - assert(is_abstraction(entry)); - Arena* arena = new_arena(); - - ScopeBuildContext context = { - .arena = arena, - .entry = entry, - .lt = lt, - .nodes = new_dict(const Node*, CFNode*, (HashFn) hash_node, (CmpFn) compare_node), - .join_point_values = new_dict(const Node*, CFNode*, (HashFn) hash_node, (CmpFn) compare_node), - .queue = new_list(CFNode*), - .contents = new_list(CFNode*), - }; - - CFNode* entry_node = get_or_enqueue(&context, entry); - - while (entries_count_list(context.queue) > 0) { - CFNode* this = pop_last_list(CFNode*, context.queue); - process_cf_node(&context, this); - } - - destroy_list(context.queue); - destroy_dict(context.join_point_values); - - Scope* scope = calloc(sizeof(Scope), 1); - *scope = (Scope) { - .arena = arena, - .entry = entry_node, - .size = entries_count_list(context.contents), - .flipped = flipped, - .contents = context.contents, - .map = context.nodes, - .rpo = NULL - }; - - validate_scope(scope); - - if (flipped) - flip_scope(scope); - - compute_rpo(scope); - compute_domtree(scope); - - return scope; -} - -void destroy_scope(Scope* scope) { - bool entry_destroyed = false; - for (size_t i = 0; i < scope->size; i++) { - CFNode* node = read_list(CFNode*, scope->contents)[i]; - entry_destroyed |= node == scope->entry; - destroy_list(node->pred_edges); - destroy_list(node->succ_edges); - if (node->dominates) - destroy_list(node->dominates); - if (node->structurally_dominates) - destroy_dict(node->structurally_dominates); - } - if (!entry_destroyed) { - destroy_list(scope->entry->pred_edges); - destroy_list(scope->entry->succ_edges); - if (scope->entry->dominates) - destroy_list(scope->entry->dominates); - } - destroy_dict(scope->map); - destroy_arena(scope->arena); - free(scope->rpo); - destroy_list(scope->contents); - free(scope); -} - -static size_t post_order_visit(Scope* scope, CFNode* n, size_t i) { - n->rpo_index = -2; - - for (size_t j = 0; j < entries_count_list(n->succ_edges); j++) { - CFEdge edge = read_list(CFEdge, n->succ_edges)[j]; - if (edge.dst->rpo_index == SIZE_MAX) - i = post_order_visit(scope, edge.dst, i); - } - - n->rpo_index = i - 1; - scope->rpo[n->rpo_index] = n; - return n->rpo_index; -} - -void compute_rpo(Scope* scope) { - scope->rpo = malloc(sizeof(const CFNode*) * scope->size); - size_t index = post_order_visit(scope, scope->entry, scope->size); - assert(index == 0); - - // debug_print("RPO: "); - // for (size_t i = 0; i < scope->size; i++) { - // debug_print("%s, ", scope->rpo[i]->node->payload.lam.name); - // } - // debug_print("\n"); -} - -CFNode* least_common_ancestor(CFNode* i, CFNode* j) { - assert(i && j); - while (i->rpo_index != j->rpo_index) { - while (i->rpo_index < j->rpo_index) j = j->idom; - while (i->rpo_index > j->rpo_index) i = i->idom; - } - return i; -} - -void compute_domtree(Scope* scope) { - for (size_t i = 0; i < scope->size; i++) { - CFNode* n = read_list(CFNode*, scope->contents)[i]; - if (n == scope->entry) - continue; - for (size_t j = 0; j < entries_count_list(n->pred_edges); j++) { - CFEdge e = read_list(CFEdge, n->pred_edges)[j]; - CFNode* p = e.src; - if (p->rpo_index < n->rpo_index) { - n->idom = p; - goto outer_loop; - } - } - error("no idom found"); - outer_loop:; - } - - bool todo = true; - while (todo) { - todo = false; - for (size_t i = 0; i < scope->size; i++) { - CFNode* n = read_list(CFNode*, scope->contents)[i]; - if (n == scope->entry) - continue; - CFNode* new_idom = NULL; - for (size_t j = 0; j < entries_count_list(n->pred_edges); j++) { - CFEdge e = read_list(CFEdge, n->pred_edges)[j]; - CFNode* p = e.src; - new_idom = new_idom ? least_common_ancestor(new_idom, p) : p; - } - assert(new_idom); - if (n->idom != new_idom) { - n->idom = new_idom; - todo = true; - } - } - } - - for (size_t i = 0; i < scope->size; i++) { - CFNode* n = read_list(CFNode*, scope->contents)[i]; - n->dominates = new_list(CFNode*); - } - for (size_t i = 0; i < scope->size; i++) { - CFNode* n = read_list(CFNode*, scope->contents)[i]; - if (n == scope->entry) - continue; - append_list(CFNode*, n->idom->dominates, n); - } -} - -/** - * @param node: Start node. - * @param target: List to extend. @ref List of @ref CFNode* - */ -static void get_undominated_children(const CFNode* node, struct List* target) { - for (size_t i = 0; i < entries_count_list(node->succ_edges); i++) { - CFEdge edge = read_list(CFEdge, node->succ_edges)[i]; - - bool contained = false; - for (size_t j = 0; j < entries_count_list(node->dominates); j++) { - CFNode* dominated = read_list(CFNode*, node->dominates)[j]; - if (edge.dst == dominated) { - contained = true; - break; - } - } - if (!contained) - append_list(CFNode*, target, edge.dst); - } -} - -//TODO: this function can produce duplicates. -struct List* scope_get_dom_frontier(Scope* scope, const CFNode* node) { - struct List* dom_frontier = new_list(CFNode*); - - get_undominated_children(node, dom_frontier); - for (size_t i = 0; i < entries_count_list(node->dominates); i++) { - CFNode* dom = read_list(CFNode*, node->dominates)[i]; - get_undominated_children(dom, dom_frontier); - } - - return dom_frontier; -} - -static int extra_uniqueness = 0; - -static CFNode* get_let_pred(const CFNode* n) { - if (entries_count_list(n->pred_edges) == 1) { - CFEdge pred = read_list(CFEdge, n->pred_edges)[0]; - assert(pred.dst == n); - if (pred.type == LetTailEdge && entries_count_list(pred.src->succ_edges) == 1) { - assert(is_case(n->node)); - return pred.src; - } - } - return NULL; -} - -static void dump_cf_node(FILE* output, const CFNode* n) { - const Node* bb = n->node; - const Node* body = get_abstraction_body(bb); - if (!body) - return; - if (get_let_pred(n)) - return; - - String color = "black"; - if (is_case(bb)) - color = "green"; - else if (is_basic_block(bb)) - color = "blue"; - - String label = ""; - - const CFNode* let_chain_end = n; - while (body->tag == Let_TAG) { - const Node* instr = body->payload.let.instruction; - // label = ""; - if (instr->tag == PrimOp_TAG) - label = format_string_arena(bb->arena->arena, "%slet ... = %s (...)\n", label, get_primop_name(instr->payload.prim_op.op)); - else - label = format_string_arena(bb->arena->arena, "%slet ... = %s (...)\n", label, node_tags[instr->tag]); - - if (entries_count_list(let_chain_end->succ_edges) != 1 || read_list(CFEdge, let_chain_end->succ_edges)[0].type != LetTailEdge) - break; - - let_chain_end = read_list(CFEdge, let_chain_end->succ_edges)[0].dst; - const Node* abs = body->payload.let.tail; - assert(let_chain_end->node == abs); - assert(is_case(abs)); - body = get_abstraction_body(abs); - } - - label = format_string_arena(bb->arena->arena, "%s%s", label, node_tags[body->tag]); - - if (is_basic_block(bb)) { - label = format_string_arena(bb->arena->arena, "%s\n%s", get_abstraction_name(bb), label); - } - - fprintf(output, "bb_%zu [label=\"%s\", color=\"%s\", shape=box];\n", (size_t) n, label, color); - - for (size_t i = 0; i < entries_count_list(n->dominates); i++) { - CFNode* d = read_list(CFNode*, n->dominates)[i]; - if (!find_key_dict(const Node*, n->structurally_dominates, d->node)) - dump_cf_node(output, d); - } -} - -static void dump_cfg_scope(FILE* output, Scope* scope) { - extra_uniqueness++; - - const Node* entry = scope->entry->node; - fprintf(output, "subgraph cluster_%s {\n", get_abstraction_name(entry)); - fprintf(output, "label = \"%s\";\n", get_abstraction_name(entry)); - for (size_t i = 0; i < entries_count_list(scope->contents); i++) { - const CFNode* n = read_list(const CFNode*, scope->contents)[i]; - dump_cf_node(output, n); - } - for (size_t i = 0; i < entries_count_list(scope->contents); i++) { - const CFNode* bb_node = read_list(const CFNode*, scope->contents)[i]; - const CFNode* src_node = bb_node; - while (true) { - const CFNode* let_parent = get_let_pred(src_node); - if (let_parent) - src_node = let_parent; - else - break; - } - - for (size_t j = 0; j < entries_count_list(bb_node->succ_edges); j++) { - CFEdge edge = read_list(CFEdge, bb_node->succ_edges)[j]; - const CFNode* target_node = edge.dst; - - if (edge.type == LetTailEdge && get_let_pred(target_node) == bb_node) - continue; - - String edge_color = "black"; - switch (edge.type) { - case LetTailEdge: edge_color = "green"; break; - case StructuredEnterBodyEdge: edge_color = "blue"; break; - case StructuredLeaveBodyEdge: edge_color = "red"; break; - case StructuredPseudoExitEdge: edge_color = "darkred"; break; - default: break; - } - - fprintf(output, "bb_%zu -> bb_%zu [color=\"%s\"];\n", (size_t) (src_node), (size_t) (target_node), edge_color); - } - } - fprintf(output, "}\n"); -} - -void dump_cfg(FILE* output, Module* mod) { - if (output == NULL) - output = stderr; - - fprintf(output, "digraph G {\n"); - struct List* scopes = build_scopes(mod); - for (size_t i = 0; i < entries_count_list(scopes); i++) { - Scope* scope = read_list(Scope*, scopes)[i]; - dump_cfg_scope(output, scope); - destroy_scope(scope); - } - destroy_list(scopes); - fprintf(output, "}\n"); -} - -void dump_cfg_auto(Module* mod) { - FILE* f = fopen("cfg.dot", "wb"); - dump_cfg(f, mod); - fclose(f); -} diff --git a/src/shady/analysis/scope.h b/src/shady/analysis/scope.h deleted file mode 100644 index a4c87546d..000000000 --- a/src/shady/analysis/scope.h +++ /dev/null @@ -1,117 +0,0 @@ -#ifndef SHADY_SCOPE_H - -#include "shady/ir.h" - -typedef struct CFNode_ CFNode; - -typedef enum { - JumpEdge, - LetTailEdge, - StructuredEnterBodyEdge, - StructuredLeaveBodyEdge, - /// Join points might leak, and as a consequence, there might be no static edge to the - /// tail of the enclosing let, which would make it look like dead code. - /// This edge type accounts for that risk, they can be ignored where more precise info is available - /// (see is_control_static for example) - StructuredPseudoExitEdge, -} CFEdgeType; - -typedef struct { - CFEdgeType type; - CFNode* src; - CFNode* dst; -} CFEdge; - -struct CFNode_ { - const Node* node; - - /** @brief Edges where this node is the source - * - * @ref List of @ref CFEdge - */ - struct List* succ_edges; - - /** @brief Edges where this node is the destination - * - * @ref List of @ref CFEdge - */ - struct List* pred_edges; - - // set by compute_rpo - size_t rpo_index; - - // set by compute_domtree - CFNode* idom; - - /** @brief All Nodes directly dominated by this CFNode. - * - * @ref List of @ref CFNode* - */ - struct List* dominates; - struct Dict* structurally_dominates; -}; - -typedef struct Arena_ Arena; -typedef struct Scope_ { - Arena* arena; - size_t size; - bool flipped; - - /** - * @ref List of @ref CFNode* - */ - struct List* contents; - - /** - * @ref Dict from const @ref Node* to @ref CFNode* - */ - struct Dict* map; - - CFNode* entry; - // set by compute_rpo - CFNode** rpo; -} Scope; - -/** - * @returns @ref List of @ref Scope* - */ -struct List* build_scopes(Module*); - -typedef struct LoopTree_ LoopTree; - -/** Construct the scope stating in Node. - */ -Scope* new_scope_impl(const Node* entry, LoopTree* lt, bool flipped); - -#define new_scope_lt(node, lt) new_scope_impl(node, lt, false); -#define new_scope_lt_flipped(node, lt) new_scope_impl(node, lt, true); - -Scope* new_scope_lt_impl(const Node* entry, LoopTree* lt, bool flipped); - -/** Construct the scope starting in Node. - * Dominance will only be computed with respect to the nodes reachable by @p entry. - */ -#define new_scope(node) new_scope_impl(node, NULL, false); - -/** Construct the scope stating in Node. - * Dominance will only be computed with respect to the nodes reachable by @p entry. - * This scope will contain post dominance information instead of regular dominance! - */ -#define new_scope_flipped(node) new_scope_impl(node, NULL, true); - -CFNode* scope_lookup(Scope*, const Node* block); -void compute_rpo(Scope*); -void compute_domtree(Scope*); - -CFNode* least_common_ancestor(CFNode* i, CFNode* j); - -void destroy_scope(Scope*); - -/** - * @returns @ref List of @ref CFNode* - */ -struct List* scope_get_dom_frontier(Scope*, const CFNode* node); - -#define SHADY_SCOPE_H - -#endif diff --git a/src/shady/analysis/uses.c b/src/shady/analysis/uses.c index dfbdbbe65..8b5f953bb 100644 --- a/src/shady/analysis/uses.c +++ b/src/shady/analysis/uses.c @@ -2,14 +2,14 @@ #include "log.h" -#include "../visit.h" +#include "shady/visit.h" #include #include #include -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); struct UsesMap_ { struct Dict* map; @@ -25,7 +25,7 @@ typedef struct { } UsesMapVisitor; static Use* get_last_use(UsesMap* map, const Node* n) { - Use* use = (Use*) get_first_use(map, n); + Use* use = (Use*) shd_get_first_use(map, n); if (!use) return NULL; while (use->next_use) @@ -34,58 +34,75 @@ static Use* get_last_use(UsesMap* map, const Node* n) { return use; } -static void uses_visit_op(UsesMapVisitor* v, NodeClass class, String op_name, const Node* op) { - Use* use = arena_alloc(v->map->a, sizeof(Use)); +static void uses_visit_node(UsesMapVisitor* v, const Node* n) { + if (!shd_dict_find_key(const Node*, v->seen, n)) { + shd_set_insert_get_result(const Node*, v->seen, n); + UsesMapVisitor nv = *v; + nv.user = n; + shd_visit_node_operands(&nv.v, v->exclude, n); + } +} + +static void uses_visit_op(UsesMapVisitor* v, NodeClass class, String op_name, const Node* op, size_t i) { + Use* use = shd_arena_alloc(v->map->a, sizeof(Use)); memset(use, 0, sizeof(Use)); *use = (Use) { .user = v->user, .operand_class = class, .operand_name = op_name, - .next_use = NULL + .operand_index = i, + .next_use = NULL, }; Use* last_use = get_last_use(v->map, op); if (last_use) last_use->next_use = use; else - insert_dict(const Node*, const Use*, v->map->map, op, use); + shd_dict_insert(const Node*, const Use*, v->map->map, op, use); - if (!find_key_dict(const Node*, v->seen, op)) { - insert_set_get_result(const Node*, v->seen, op); - UsesMapVisitor nv = *v; - nv.user = op; - visit_node_operands(&nv.v, v->exclude, op); - } + uses_visit_node(v, op); } -const UsesMap* create_uses_map(const Node* root, NodeClass exclude) { +static const UsesMap* create_uses_map_(const Node* root, const Module* m, NodeClass exclude) { UsesMap* uses = calloc(sizeof(UsesMap), 1); *uses = (UsesMap) { - .map = new_dict(const Node*, Use*, (HashFn) hash_node, (CmpFn) compare_node), - .a = new_arena(), + .map = shd_new_dict(const Node*, Use*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .a = shd_new_arena(), }; UsesMapVisitor v = { .v = { .visit_op_fn = (VisitOpFn) uses_visit_op }, .map = uses, .exclude = exclude, - .seen = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node), - .user = root, + .seen = shd_new_set(const Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), }; - insert_set_get_result(const Node*, v.seen, root); - visit_node_operands(&v.v, exclude, root); - destroy_dict(v.seen); + if (root) + uses_visit_node(&v, root); + if (m) { + Nodes nodes = shd_module_get_declarations(m); + for (size_t i = 0; i < nodes.count; i++) + uses_visit_node(&v, nodes.nodes[i]); + } + shd_destroy_dict(v.seen); return uses; } -void destroy_uses_map(const UsesMap* map) { - destroy_arena(map->a); - destroy_dict(map->map); +const UsesMap* shd_new_uses_map_fn(const Node* root, NodeClass exclude) { + return create_uses_map_(root, NULL, exclude); +} + +const UsesMap* shd_new_uses_map_module(const Module* m, NodeClass exclude) { + return create_uses_map_(NULL, m, exclude); +} + +void shd_destroy_uses_map(const UsesMap* map) { + shd_destroy_arena(map->a); + shd_destroy_dict(map->map); free((void*) map); } -const Use* get_first_use(const UsesMap* map, const Node* n) { - const Use** found = find_value_dict(const Node*, const Use*, map->map, n); +const Use* shd_get_first_use(const UsesMap* map, const Node* n) { + const Use** found = shd_dict_find_value(const Node*, const Use*, map->map, n); if (found) return *found; return NULL; diff --git a/src/shady/analysis/uses.h b/src/shady/analysis/uses.h index 3ebd1424d..e6fc21140 100644 --- a/src/shady/analysis/uses.h +++ b/src/shady/analysis/uses.h @@ -2,24 +2,26 @@ #define SHADY_USAGES #include "shady/ir.h" -#include "scope.h" +#include "cfg.h" #include "list.h" #include "dict.h" #include "arena.h" typedef struct UsesMap_ UsesMap; -const UsesMap* create_uses_map(const Node* root, NodeClass exclude); -void destroy_uses_map(const UsesMap*); +const UsesMap* shd_new_uses_map_fn(const Node* root, NodeClass exclude); +const UsesMap* shd_new_uses_map_module(const Module* m, NodeClass exclude); +void shd_destroy_uses_map(const UsesMap* map); typedef struct Use_ Use; struct Use_ { const Node* user; NodeClass operand_class; String operand_name; + size_t operand_index; const Use* next_use; }; -const Use* get_first_use(const UsesMap*, const Node*); +const Use* shd_get_first_use(const UsesMap* map, const Node* n); #endif diff --git a/src/shady/analysis/verify.c b/src/shady/analysis/verify.c index d4b085510..cd5ec63ad 100644 --- a/src/shady/analysis/verify.c +++ b/src/shady/analysis/verify.c @@ -1,12 +1,13 @@ #include "verify.h" -#include "free_variables.h" -#include "scope.h" -#include "log.h" -#include "../visit.h" +#include "shady/visit.h" + +#include "free_frontier.h" +#include "cfg.h" #include "../ir_private.h" -#include "../type.h" +#include "../check.h" +#include "log.h" #include "dict.h" #include "list.h" @@ -20,52 +21,65 @@ typedef struct { static void visit_verify_same_arena(ArenaVerifyVisitor* visitor, const Node* node) { assert(visitor->arena == node->arena); - if (find_key_dict(const Node*, visitor->once, node)) + if (shd_dict_find_key(const Node*, visitor->once, node)) return; - insert_set_get_result(const Node*, visitor->once, node); - visit_node_operands(&visitor->visitor, 0, node); + shd_set_insert_get_result(const Node*, visitor->once, node); + shd_visit_node_operands(&visitor->visitor, 0, node); } -KeyHash hash_node(const Node**); -bool compare_node(const Node**, const Node**); +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node**, const Node**); static void verify_same_arena(Module* mod) { - const IrArena* arena = get_module_arena(mod); + const IrArena* arena = shd_module_get_arena(mod); ArenaVerifyVisitor visitor = { .visitor = { .visit_node_fn = (VisitNodeFn) visit_verify_same_arena, }, .arena = arena, - .once = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node) + .once = shd_new_set(const Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node) }; - visit_module(&visitor.visitor, mod); - destroy_dict(visitor.once); + shd_visit_module(&visitor.visitor, mod); + shd_destroy_dict(visitor.once); } -static void verify_scoping(Module* mod) { - struct List* scopes = build_scopes(mod); - for (size_t i = 0; i < entries_count_list(scopes); i++) { - Scope* scope = read_list(Scope*, scopes)[i]; - struct List* leaking = compute_free_variables(scope, scope->entry->node); - for (size_t j = 0; j < entries_count_list(leaking); j++) { - log_node(ERROR, read_list(const Node*, leaking)[j]); - error_print("\n"); +static void verify_scoping(const CompilerConfig* config, Module* mod) { + struct List* cfgs = shd_build_cfgs(mod, structured_scope_cfg_build()); + for (size_t i = 0; i < shd_list_count(cfgs); i++) { + CFG* cfg = shd_read_list(CFG*, cfgs)[i]; + Scheduler* scheduler = shd_new_scheduler(cfg); + struct Dict* set = shd_free_frontier(scheduler, cfg, cfg->entry->node); + if (shd_dict_count(set) > 0) { + shd_log_fmt(ERROR, "Leaking variables in "); + shd_log_node(ERROR, cfg->entry->node); + shd_log_fmt(ERROR, ":\n"); + + size_t j = 0; + const Node* leaking; + while (shd_dict_iter(set, &j, &leaking, NULL)) { + shd_log_node(ERROR, leaking); + shd_error_print("\n"); + } + + shd_log_fmt(ERROR, "Problematic module:\n"); + shd_log_module(ERROR, config, mod); + shd_error_die(); } - assert(entries_count_list(leaking) == 0); - destroy_list(leaking); - destroy_scope(scope); + shd_destroy_dict(set); + shd_destroy_scheduler(scheduler); + shd_destroy_cfg(cfg); } - destroy_list(scopes); + shd_destroy_list(cfgs); } static void verify_nominal_node(const Node* fn, const Node* n) { switch (n->tag) { case Function_TAG: { - assert(!fn && "functions cannot be part of a scope, except as the entry"); + assert(!fn && "functions cannot be part of a CFG, except as the entry"); break; } case BasicBlock_TAG: { - assert(is_subtype(noret_type(n->arena), n->payload.basic_block.body->type)); + assert(shd_is_subtype(noret_type(n->arena), n->payload.basic_block.body->type)); break; } case NominalType_TAG: { @@ -73,18 +87,20 @@ static void verify_nominal_node(const Node* fn, const Node* n) { break; } case Constant_TAG: { - const Type* t = n->payload.constant.instruction->type; - bool u = deconstruct_qualified_type(&t); - assert(u); - assert(is_subtype(n->payload.constant.type_hint, t)); + if (n->payload.constant.value) { + const Type* t = n->payload.constant.value->type; + bool u = shd_deconstruct_qualified_type(&t); + assert(u); + assert(shd_is_subtype(n->payload.constant.type_hint, t)); + } break; } case GlobalVariable_TAG: { if (n->payload.global_variable.init) { const Type* t = n->payload.global_variable.init->type; - bool u = deconstruct_qualified_type(&t); + bool u = shd_deconstruct_qualified_type(&t); assert(u); - assert(is_subtype(n->payload.global_variable.type, t)); + assert(shd_is_subtype(n->payload.global_variable.type, t)); } break; } @@ -92,35 +108,63 @@ static void verify_nominal_node(const Node* fn, const Node* n) { } } -static void verify_bodies(Module* mod) { - struct List* scopes = build_scopes(mod); - for (size_t i = 0; i < entries_count_list(scopes); i++) { - Scope* scope = read_list(Scope*, scopes)[i]; +typedef struct ScheduleContext_ { + Visitor visitor; + struct Dict* bound; + struct ScheduleContext_* parent; + CompilerConfig* config; + Module* mod; +} ScheduleContext; + +static void verify_schedule_visitor(ScheduleContext* ctx, const Node* node) { + if (is_instruction(node)) { + ScheduleContext* search = ctx; + while (search) { + if (shd_dict_find_key(const Node*, search->bound, node)) + break; + search = search->parent; + } + if (!search) { + shd_log_fmt(ERROR, "Scheduling problem: "); + shd_log_node(ERROR, node); + shd_log_fmt(ERROR, "was encountered before we saw it be bound by a let!\n"); + shd_log_fmt(ERROR, "Problematic module:\n"); + shd_log_module(ERROR, ctx->config, ctx->mod); + shd_error_die(); + } + } + shd_visit_node_operands(&ctx->visitor, NcTerminator | NcDeclaration, node); +} + +static void verify_bodies(const CompilerConfig* config, Module* mod) { + struct List* cfgs = shd_build_cfgs(mod, structured_scope_cfg_build()); + for (size_t i = 0; i < shd_list_count(cfgs); i++) { + CFG* cfg = shd_read_list(CFG*, cfgs)[i]; - for (size_t j = 0; j < scope->size; j++) { - CFNode* n = scope->rpo[j]; + for (size_t j = 0; j < cfg->size; j++) { + CFNode* n = cfg->rpo[j]; if (n->node->tag == BasicBlock_TAG) { - verify_nominal_node(scope->entry->node, n->node); + verify_nominal_node(cfg->entry->node, n->node); } } - destroy_scope(scope); + shd_destroy_cfg(cfg); } - destroy_list(scopes); + shd_destroy_list(cfgs); - Nodes decls = get_module_declarations(mod); + Nodes decls = shd_module_get_declarations(mod); for (size_t i = 0; i < decls.count; i++) { const Node* decl = decls.nodes[i]; verify_nominal_node(NULL, decl); } } -void verify_module(Module* mod) { +void shd_verify_module(const CompilerConfig* config, Module* mod) { verify_same_arena(mod); // before we normalize the IR, scopes are broken because decls appear where they should not // TODO add a normalized flag to the IR and check grammar is adhered to strictly - if (get_module_arena(mod)->config.check_types) { - verify_scoping(mod); - verify_bodies(mod); + if (shd_module_get_arena(mod)->config.check_types) { + verify_scoping(config, mod); + verify_bodies(config, mod); } } diff --git a/src/shady/analysis/verify.h b/src/shady/analysis/verify.h index aedd20142..85fa53682 100644 --- a/src/shady/analysis/verify.h +++ b/src/shady/analysis/verify.h @@ -3,6 +3,7 @@ #include "shady/ir.h" -void verify_module(Module*); +typedef struct CompilerConfig_ CompilerConfig; +void shd_verify_module(const CompilerConfig* config, Module* mod); #endif diff --git a/src/shady/annotation.c b/src/shady/annotation.c deleted file mode 100644 index abec2141e..000000000 --- a/src/shady/annotation.c +++ /dev/null @@ -1,104 +0,0 @@ -#include "ir_private.h" -#include "log.h" -#include "portability.h" - -#include -#include - -String get_annotation_name(const Node* node) { - assert(is_annotation(node)); - switch (node->tag) { - case Annotation_TAG: return node->payload.annotation.name; - case AnnotationValue_TAG: return node->payload.annotation_value.name; - case AnnotationValues_TAG: return node->payload.annotation_values.name; - case AnnotationCompound_TAG: return node->payload.annotation_compound.name; - default: return false; - } -} - -static const Node* search_annotations(const Node* decl, const char* name, size_t* i) { - assert(decl); - const Nodes* annotations = NULL; - switch (decl->tag) { - case Function_TAG: annotations = &decl->payload.fun.annotations; break; - case GlobalVariable_TAG: annotations = &decl->payload.global_variable.annotations; break; - case Constant_TAG: annotations = &decl->payload.constant.annotations; break; - case NominalType_TAG: annotations = &decl->payload.nom_type.annotations; break; - default: error("Not a declaration") - } - - while (*i < annotations->count) { - const Node* annotation = annotations->nodes[*i]; - (*i)++; - if (strcmp(get_annotation_name(annotation), name) == 0) { - return annotation; - } - } - - return NULL; -} - -const Node* lookup_annotation(const Node* decl, const char* name) { - size_t i = 0; - return search_annotations(decl, name, &i); -} - -const Node* lookup_annotation_list(Nodes annotations, const char* name) { - for (size_t i = 0; i < annotations.count; i++) { - if (strcmp(get_annotation_name(annotations.nodes[i]), name) == 0) { - return annotations.nodes[i]; - } - } - return NULL; -} - -const Node* get_annotation_value(const Node* annotation) { - assert(annotation); - if (annotation->tag != AnnotationValue_TAG) - error("This annotation does not have a single payload"); - return annotation->payload.annotation_value.value; -} - -Nodes get_annotation_values(const Node* annotation) { - assert(annotation); - if (annotation->tag != AnnotationValues_TAG) - error("This annotation does not have multiple payloads"); - return annotation->payload.annotation_values.values; -} - -/// Gets the string literal attached to an annotation, if present. -const char* get_annotation_string_payload(const Node* annotation) { - const Node* payload = get_annotation_value(annotation); - if (!payload) return NULL; - if (payload->tag != StringLiteral_TAG) - error("Wrong annotation payload tag, expected a string literal") - return payload->payload.string_lit.string; -} - -bool lookup_annotation_with_string_payload(const Node* decl, const char* annotation_name, const char* expected_payload) { - size_t i = 0; - while (true) { - const Node* next = search_annotations(decl, annotation_name, &i); - if (!next) return false; - if (strcmp(get_annotation_string_payload(next), expected_payload) == 0) - return true; - } -} - -Nodes filter_out_annotation(IrArena* arena, Nodes annotations, const char* name) { - LARRAY(const Node*, new_annotations, annotations.count); - size_t new_count = 0; - for (size_t i = 0; i < annotations.count; i++) { - if (strcmp(get_annotation_name(annotations.nodes[i]), name) != 0) { - new_annotations[new_count++] = annotations.nodes[i]; - } - } - return nodes(arena, new_count, new_annotations); -} - -ExecutionModel execution_model_from_string(const char* string) { -#define EM(n, _) if (strcmp(string, #n) == 0) return Em##n; - EXECUTION_MODELS(EM) -#undef EM - return EmNone; -} diff --git a/src/shady/api/CMakeLists.txt b/src/shady/api/CMakeLists.txt index d993d4dcf..ed82a1e99 100644 --- a/src/shady/api/CMakeLists.txt +++ b/src/shady/api/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(api INTERFACE) -target_include_directories(api INTERFACE "$" "$" "$") +target_include_directories(api INTERFACE "$" "$" "$") target_include_directories(api INTERFACE "$") +target_link_libraries(api INTERFACE "$") get_target_property(SPIRV_HEADERS_INCLUDE_DIRS SPIRV-Headers::SPIRV-Headers INTERFACE_INCLUDE_DIRECTORIES) @@ -11,3 +12,20 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grammar_generated.h DESTINATION includ add_generated_file(FILE_NAME primops_generated.h TARGET_NAME generate-primops-headers SOURCES generator_primops.c) add_dependencies(api INTERFACE generate-primops-headers) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/primops_generated.h DESTINATION include) + +add_generated_file(FILE_NAME type_generated.h TARGET_NAME generate-type-headers SOURCES generator_type.c) +add_dependencies(api INTERFACE generate-type-headers) +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/type_generated.h DESTINATION include) + +install(TARGETS api EXPORT shady_export_set) + +find_package(Python COMPONENTS Interpreter REQUIRED) +function(generate_extinst_headers NAME SRC) + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.h COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/SPIRV-Headers/tools/buildHeaders/bin/generate_language_headers.py --extinst-name=${NAME} --extinst-grammar=${CMAKE_CURRENT_SOURCE_DIR}/${SRC} --extinst-output-base=${NAME} DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${SRC} VERBATIM) + set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/${NAME}.h PROPERTIES GENERATED TRUE) + add_custom_target("${NAME}_h" DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.h) + + add_library(${NAME} INTERFACE) + add_dependencies(${NAME} "${NAME}_h") + target_include_directories(${NAME} INTERFACE ${CMAKE_CURRENT_BINARY_DIR}) +endfunction() diff --git a/src/shady/api/generator_grammar.c b/src/shady/api/generator_grammar.c index e2783a4df..454627aba 100644 --- a/src/shady/api/generator_grammar.c +++ b/src/shady/api/generator_grammar.c @@ -1,73 +1,21 @@ #include "generator.h" -static json_object* lookup_node_class(Data data, String name) { - json_object* node_classes = json_object_object_get(data.shd, "node-classes"); - for (size_t i = 0; i < json_object_array_length(node_classes); i++) { - json_object* class = json_object_array_get_idx(node_classes, i); - String class_name = json_object_get_string(json_object_object_get(class, "name")); - assert(class_name); - if (strcmp(name, class_name) == 0) - return class; - } - return NULL; -} - -static String class_to_type(Data data, String class, bool list) { - assert(class); - if (strcmp(class, "string") == 0) { - if (list) - return "Strings"; - else - return "String"; - } - // check the class is valid - if (!lookup_node_class(data, class)) { - error_print("invalid node class '%s'\n", class); - error_die(); - } - return list ? "Nodes" : "const Node*"; -} - -static String get_type_for_operand(Data data, json_object* op) { - String op_type = json_object_get_string(json_object_object_get(op, "type")); - bool list = json_object_get_boolean(json_object_object_get(op, "list")); - String op_class = NULL; - if (!op_type) { - op_class = json_object_get_string(json_object_object_get(op, "class")); - op_type = class_to_type(data, op_class, list); - } - assert(op_type); - return op_type; -} - static void generate_address_spaces(Growy* g, json_object* address_spaces) { - growy_append_formatted(g, "typedef enum AddressSpace_ {\n"); + shd_growy_append_formatted(g, "typedef enum AddressSpace_ {\n"); for (size_t i = 0; i < json_object_array_length(address_spaces); i++) { json_object* as = json_object_array_get_idx(address_spaces, i); String name = json_object_get_string(json_object_object_get(as, "name")); add_comments(g, "\t", json_object_object_get(as, "description")); - growy_append_formatted(g, "\tAs%s,\n", name); + shd_growy_append_formatted(g, "\tAs%s,\n", name); } - growy_append_formatted(g, "\tNumAddressSpaces,\n"); - growy_append_formatted(g, "} AddressSpace;\n\n"); - - growy_append_formatted(g, "static inline bool is_physical_as(AddressSpace as) {\n"); - growy_append_formatted(g, "\tswitch(as) {\n"); - for (size_t i = 0; i < json_object_array_length(address_spaces); i++) { - json_object* as = json_object_array_get_idx(address_spaces, i); - String name = json_object_get_string(json_object_object_get(as, "name")); - if (json_object_get_boolean(json_object_object_get(as, "physical"))) - growy_append_formatted(g, "\t\tcase As%s: return true;\n", name); - } - growy_append_formatted(g, "\t\tdefault: return false;\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "}\n\n"); + shd_growy_append_formatted(g, "\tNumAddressSpaces,\n"); + shd_growy_append_formatted(g, "} AddressSpace;\n\n"); } static void generate_node_tags(Growy* g, json_object* nodes) { assert(json_object_get_type(nodes) == json_type_array); - growy_append_formatted(g, "typedef enum {\n"); - growy_append_formatted(g, "\tInvalidNode_TAG,\n"); + shd_growy_append_formatted(g, "typedef enum {\n"); + shd_growy_append_formatted(g, "\tInvalidNode_TAG,\n"); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -77,12 +25,12 @@ static void generate_node_tags(Growy* g, json_object* nodes) { if (!ops) add_comments(g, "\t", json_object_object_get(node, "description")); - growy_append_formatted(g, "\t%s_TAG,\n", name); + shd_growy_append_formatted(g, "\t%s_TAG,\n", name); } - growy_append_formatted(g, "} NodeTag;\n\n"); + shd_growy_append_formatted(g, "} NodeTag;\n\n"); } -static void generate_node_payloads(Growy* g, Data data, json_object* nodes) { +static void generate_node_payloads(Growy* g, json_object* src, json_object* nodes) { for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -93,23 +41,24 @@ static void generate_node_payloads(Growy* g, Data data, json_object* nodes) { if (ops) { assert(json_object_get_type(ops) == json_type_array); add_comments(g, "", json_object_object_get(node, "description")); - growy_append_formatted(g, "typedef struct {\n"); + shd_growy_append_formatted(g, "typedef struct SHADY_DESIGNATED_INIT {\n"); for (size_t j = 0; j < json_object_array_length(ops); j++) { json_object* op = json_object_array_get_idx(ops, j); String op_name = json_object_get_string(json_object_object_get(op, "name")); - growy_append_formatted(g, "\t%s %s;\n", get_type_for_operand(data, op), op_name); + shd_growy_append_formatted(g, "\t%s %s;\n", get_type_for_operand(src, op), op_name); } - growy_append_formatted(g, "} %s;\n\n", name); + shd_growy_append_formatted(g, "} %s;\n\n", name); } } } static void generate_node_type(Growy* g, json_object* nodes) { - growy_append_formatted(g, "struct Node_ {\n"); - growy_append_formatted(g, "\tIrArena* arena;\n"); - growy_append_formatted(g, "\tconst Type* type;\n"); - growy_append_formatted(g, "\tNodeTag tag;\n"); - growy_append_formatted(g, "\tunion NodesUnion {\n"); + shd_growy_append_formatted(g, "struct Node_ {\n"); + shd_growy_append_formatted(g, "\tIrArena* arena;\n"); + shd_growy_append_formatted(g, "\tNodeId id;\n"); + shd_growy_append_formatted(g, "\tconst Type* type;\n"); + shd_growy_append_formatted(g, "\tNodeTag tag;\n"); + shd_growy_append_formatted(g, "\tunion NodesUnion {\n"); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -118,30 +67,24 @@ static void generate_node_type(Growy* g, json_object* nodes) { assert(name); String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); - void* alloc = NULL; - if (!snake_name) { - alloc = snake_name = to_snake_case(name); - } + assert(snake_name); json_object* ops = json_object_object_get(node, "ops"); if (ops) - growy_append_formatted(g, "\t\t%s %s;\n", name, snake_name); - - if (alloc) - free(alloc); + shd_growy_append_formatted(g, "\t\t%s %s;\n", name, snake_name); } - growy_append_formatted(g, "\t} payload;\n"); - growy_append_formatted(g, "};\n\n"); + shd_growy_append_formatted(g, "\t} payload;\n"); + shd_growy_append_formatted(g, "};\n\n"); } static void generate_node_tags_for_class(Growy* g, json_object* nodes, String class, String capitalized_class) { assert(json_object_get_type(nodes) == json_type_array); - growy_append_formatted(g, "typedef enum {\n"); + shd_growy_append_formatted(g, "typedef enum {\n"); if (starts_with_vowel(class)) - growy_append_formatted(g, "\tNotAn%s = 0,\n", capitalized_class); + shd_growy_append_formatted(g, "\tNotAn%s = 0,\n", capitalized_class); else - growy_append_formatted(g, "\tNotA%s = 0,\n", capitalized_class); + shd_growy_append_formatted(g, "\tNotA%s = 0,\n", capitalized_class); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -153,12 +96,12 @@ static void generate_node_tags_for_class(Growy* g, json_object* nodes, String cl break; case json_type_string: if (nclass && strcmp(json_object_get_string(nclass), class) == 0) - growy_append_formatted(g, "\t%s_%s_TAG = %s_TAG,\n", capitalized_class, name, name); + shd_growy_append_formatted(g, "\t%s_%s_TAG = %s_TAG,\n", capitalized_class, name, name); break; case json_type_array: { for (size_t j = 0; j < json_object_array_length(nclass); j++) { if (nclass && strcmp(json_object_get_string(json_object_array_get_idx(nclass, j)), class) == 0) { - growy_append_formatted(g, "\t%s_%s_TAG = %s_TAG,\n", capitalized_class, name, name); + shd_growy_append_formatted(g, "\t%s_%s_TAG = %s_TAG,\n", capitalized_class, name, name); break; } } @@ -168,40 +111,173 @@ static void generate_node_tags_for_class(Growy* g, json_object* nodes, String cl case json_type_double: case json_type_int: case json_type_object: - error_print("Invalid datatype for a node's 'class' attribute"); + shd_error_print("Invalid datatype for a node's 'class' attribute"); + } + + } + shd_growy_append_formatted(g, "} %sTag;\n\n", capitalized_class); +} + +static void generate_isa_for_class(Growy* g, json_object* nodes, String class, String capitalized_class, bool use_enum) { + assert(json_object_get_type(nodes) == json_type_array); + if (use_enum) + shd_growy_append_formatted(g, "static inline %sTag is_%s(const Node* node) {\n", capitalized_class, class); + else + shd_growy_append_formatted(g, "static inline bool is_%s(const Node* node) {\n", class); + shd_growy_append_formatted(g, "\tif (shd_get_node_class_from_tag(node->tag) & Nc%s)\n", capitalized_class); + if (use_enum) { + shd_growy_append_formatted(g, "\t\treturn (%sTag) node->tag;\n", capitalized_class); + shd_growy_append_formatted(g, "\treturn (%sTag) 0;\n", capitalized_class); + } else { + shd_growy_append_formatted(g, "\t\treturn true;\n", capitalized_class); + shd_growy_append_formatted(g, "\treturn false;\n", capitalized_class); + } + shd_growy_append_formatted(g, "}\n\n"); +} + +static void generate_header_getters_for_class(Growy* g, json_object* src, json_object* node_class) { + String class_name = json_object_get_string(json_object_object_get(node_class, "name")); + json_object* class_ops = json_object_object_get(node_class, "ops"); + if (!class_ops) + return; + assert(json_object_get_type(class_ops) == json_type_array); + for (size_t i = 0; i < json_object_array_length(class_ops); i++) { + json_object* operand = json_object_array_get_idx(class_ops, i); + String operand_name = json_object_get_string(json_object_object_get(operand, "name")); + assert(operand_name); + shd_growy_append_formatted(g, "%s get_%s_%s(const Node* node);\n", get_type_for_operand(src, operand), class_name, operand_name); + } +} + +void generate_node_ctor(Growy* g, json_object* src, json_object* nodes) { + for (size_t i = 0; i < json_object_array_length(nodes); i++) { + json_object* node = json_object_array_get_idx(nodes, i); + + String name = json_object_get_string(json_object_object_get(node, "name")); + assert(name); + + if (has_custom_ctor(node)) + continue; + + if (i > 0) + shd_growy_append_formatted(g, "\n"); + + String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); + const void* alloc = NULL; + if (!snake_name) { + alloc = snake_name = to_snake_case(name); + } + + json_object* ops = json_object_object_get(node, "ops"); + if (ops) + shd_growy_append_formatted(g, "static inline const Node* %s(IrArena* arena, %s payload)", snake_name, name); + else + shd_growy_append_formatted(g, "static inline const Node* %s(IrArena* arena)", snake_name); + + shd_growy_append_formatted(g, " {\n"); + shd_growy_append_formatted(g, "\tNode node;\n"); + shd_growy_append_formatted(g, "\tmemset((void*) &node, 0, sizeof(Node));\n"); + shd_growy_append_formatted(g, "\tnode = (Node) {\n"); + shd_growy_append_formatted(g, "\t\t.arena = arena,\n"); + shd_growy_append_formatted(g, "\t\t.tag = %s_TAG,\n", name); + if (ops) + shd_growy_append_formatted(g, "\t\t.payload.%s = payload,\n", snake_name); + shd_growy_append_formatted(g, "\t\t.type = NULL,\n"); + shd_growy_append_formatted(g, "\t};\n"); + shd_growy_append_formatted(g, "\treturn _shd_create_node_helper(arena, node, NULL);\n"); + shd_growy_append_formatted(g, "}\n"); + + // Generate helper variant + if (ops) { + shd_growy_append_formatted(g, "static inline const Node* %s_helper(IrArena* arena, ", snake_name); + for (size_t j = 0; j < json_object_array_length(ops); j++) { + json_object* op = json_object_array_get_idx(ops, j); + String op_name = json_object_get_string(json_object_object_get(op, "name")); + shd_growy_append_formatted(g, "\t%s %s", get_type_for_operand(src, op), op_name); + if (j + 1 < json_object_array_length(ops)) + shd_growy_append_formatted(g, ", "); + } + shd_growy_append_formatted(g, ") {\n"); + shd_growy_append_formatted(g, "\treturn %s(arena, (%s) {", snake_name, name); + for (size_t j = 0; j < json_object_array_length(ops); j++) { + json_object* op = json_object_array_get_idx(ops, j); + String op_name = json_object_get_string(json_object_object_get(op, "name")); + shd_growy_append_formatted(g, ".%s = %s", op_name, op_name); + if (j + 1 < json_object_array_length(ops)) + shd_growy_append_formatted(g, ", "); + } + shd_growy_append_formatted(g, "});\n"); + shd_growy_append_formatted(g, "}\n"); } + if (alloc) + free((void*) alloc); } - growy_append_formatted(g, "} %sTag;\n\n", capitalized_class); + shd_growy_append_formatted(g, "\n"); } -void generate(Growy* g, Data data) { - generate_header(g, data); +static void generate_getters_for_class(Growy* g, json_object* src, json_object* nodes, json_object* node_class) { + String class_name = json_object_get_string(json_object_object_get(node_class, "name")); + json_object* class_ops = json_object_object_get(node_class, "ops"); + if (!class_ops) + return; + assert(json_object_get_type(class_ops) == json_type_array); + for (size_t i = 0; i < json_object_array_length(class_ops); i++) { + json_object* operand = json_object_array_get_idx(class_ops, i); + String operand_name = json_object_get_string(json_object_object_get(operand, "name")); + assert(operand_name); + shd_growy_append_formatted(g, "static inline %s get_%s_%s(const Node* node) {\n", get_type_for_operand(src, operand), class_name, operand_name); + shd_growy_append_formatted(g, "\tswitch(node->tag) {\n"); + for (size_t j = 0; j < json_object_array_length(nodes); j++) { + json_object* node = json_object_array_get_idx(nodes, j); + if (find_in_set(json_object_object_get(node, "class"), class_name)) { + String node_name = json_object_get_string(json_object_object_get(node, "name")); + shd_growy_append_formatted(g, "\t\tcase %s_TAG: ", node_name); + String node_snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); + assert(node_snake_name); + shd_growy_append_formatted(g, "return node->payload.%s.%s;\n", node_snake_name, operand_name); + } + } + shd_growy_append_formatted(g, "\t\tdefault: break;\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "\tassert(false);\n"); + shd_growy_append_formatted(g, "}\n\n"); + } +} - generate_address_spaces(g, json_object_object_get(data.shd, "address-spaces")); +void generate(Growy* g, json_object* src) { + generate_header(g, src); - json_object* node_classes = json_object_object_get(data.shd, "node-classes"); + generate_address_spaces(g, json_object_object_get(src, "address-spaces")); + + json_object* node_classes = json_object_object_get(src, "node-classes"); generate_bit_enum(g, "NodeClass", "Nc", node_classes); - json_object* nodes = json_object_object_get(data.shd, "nodes"); + json_object* nodes = json_object_object_get(src, "nodes"); generate_node_tags(g, nodes); - growy_append_formatted(g, "NodeClass get_node_class_from_tag(NodeTag tag);\n\n"); - generate_node_payloads(g, data, nodes); + shd_growy_append_formatted(g, "NodeClass shd_get_node_class_from_tag(NodeTag tag);\n\n"); + generate_node_payloads(g, src, nodes); generate_node_type(g, nodes); - generate_node_ctor(g, nodes, false); + + shd_growy_append_formatted(g, "#include \n"); + shd_growy_append_formatted(g, "#include \n"); + shd_growy_append_formatted(g, "Node* _shd_create_node_helper(IrArena* arena, Node node, bool* pfresh);\n"); + generate_node_ctor(g, src, nodes); for (size_t i = 0; i < json_object_array_length(node_classes); i++) { json_object* node_class = json_object_array_get_idx(node_classes, i); String name = json_object_get_string(json_object_object_get(node_class, "name")); assert(name); json_object* generate_enum = json_object_object_get(node_class, "generate-enum"); + String capitalized = capitalize(name); + if (!generate_enum || json_object_get_boolean(generate_enum)) { - String capitalized = capitalize(name); generate_node_tags_for_class(g, nodes, name, capitalized); - growy_append_formatted(g, "%sTag is_%s(const Node*);\n", capitalized, name); - free(capitalized); - } else { - growy_append_formatted(g, "bool is_%s(const Node*);\n", name); } + + //generate_header_getters_for_class(g, src, node_class); + generate_getters_for_class(g, src, nodes, node_class); + generate_isa_for_class(g, nodes, name, capitalized, !generate_enum || json_object_get_boolean(generate_enum)); + free((void*) capitalized); } } diff --git a/src/shady/api/generator_primops.c b/src/shady/api/generator_primops.c index a2910f1d3..7805ba4b4 100644 --- a/src/shady/api/generator_primops.c +++ b/src/shady/api/generator_primops.c @@ -1,10 +1,10 @@ #include "generator.h" -void generate(Growy* g, Data data) { - generate_header(g, data); +void generate(Growy* g, json_object* src) { + generate_header(g, src); - json_object* nodes = json_object_object_get(data.shd, "prim-ops"); - growy_append_formatted(g, "typedef enum Op_ {\n"); + json_object* nodes = json_object_object_get(src, "prim-ops"); + shd_growy_append_formatted(g, "typedef enum Op_ {\n"); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -12,13 +12,12 @@ void generate(Growy* g, Data data) { String name = json_object_get_string(json_object_object_get(node, "name")); assert(name); - growy_append_formatted(g, "\t%s_op,\n", name); + shd_growy_append_formatted(g, "\t%s_op,\n", name); } - growy_append_formatted(g, "\tPRIMOPS_COUNT,\n"); - growy_append_formatted(g, "} Op;\n"); + shd_growy_append_formatted(g, "\tPRIMOPS_COUNT,\n"); + shd_growy_append_formatted(g, "} Op;\n"); - json_object* op_classes = json_object_object_get(data.shd, "prim-ops-classes"); + json_object* op_classes = json_object_object_get(src, "prim-ops-classes"); generate_bit_enum(g, "OpClass", "Oc", op_classes); - growy_append_formatted(g, "OpClass get_primop_class(Op);\n\n"); } diff --git a/src/shady/api/generator_type.c b/src/shady/api/generator_type.c new file mode 100644 index 000000000..8445c591b --- /dev/null +++ b/src/shady/api/generator_type.c @@ -0,0 +1,34 @@ +#include "generator.h" + +void generate(Growy* g, json_object* src) { + generate_header(g, src); + + json_object* nodes = json_object_object_get(src, "nodes"); + for (size_t i = 0; i < json_object_array_length(nodes); i++) { + json_object* node = json_object_array_get_idx(nodes, i); + + String name = json_object_get_string(json_object_object_get(node, "name")); + assert(name); + + String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); + void* alloc = NULL; + if (!snake_name) { + snake_name = to_snake_case(name); + alloc = (void*) snake_name; + } + + json_object* t = json_object_object_get(node, "type"); + if (!t || json_object_get_boolean(t)) { + json_object* ops = json_object_object_get(node, "ops"); + if (ops) + shd_growy_append_formatted(g, "const Type* _shd_check_type_%s(IrArena*, %s);\n", snake_name, name); + else + shd_growy_append_formatted(g, "const Type* _shd_check_type_%s(IrArena*);\n", snake_name); + } + + if (alloc) + free(alloc); + } + + shd_growy_append_formatted(g, "const Type* _shd_check_type_generated(IrArena* a, const Node* node);\n"); +} diff --git a/src/shady/body_builder.c b/src/shady/body_builder.c index 890a4ce62..0625ee994 100644 --- a/src/shady/body_builder.c +++ b/src/shady/body_builder.c @@ -1,152 +1,379 @@ -#include "ir_private.h" -#include "log.h" -#include "portability.h" -#include "type.h" +#include "shady/ir/builder.h" +#include "shady/ir/grammar.h" +#include "shady/ir/function.h" +#include "shady/ir/composite.h" +#include "shady/ir/arena.h" +#include "shady/ir/mem.h" #include "list.h" #include "dict.h" +#include "log.h" +#include "portability.h" #include #include +#pragma GCC diagnostic error "-Wswitch" + +struct BodyBuilder_ { + IrArena* arena; + struct List* stack; + const Node* block_entry_block; + const Node* block_entry_mem; + const Node* mem; + Node* tail_block; +}; + typedef struct { - const Node* instr; + Structured_constructTag tag; + union NodesUnion payload; +} BlockEntry; + +typedef struct { + BlockEntry structured; Nodes vars; - bool mut; } StackEntry; -BodyBuilder* begin_body(IrArena* a) { +BodyBuilder* shd_bld_begin(IrArena* a, const Node* mem) { BodyBuilder* bb = malloc(sizeof(BodyBuilder)); *bb = (BodyBuilder) { .arena = a, - .stack = new_list(StackEntry), + .stack = shd_new_list(StackEntry), + .mem = mem, }; return bb; } -static Nodes create_output_variables(IrArena* a, const Node* value, size_t outputs_count, const Node** output_types, String const output_names[]) { - Nodes types; - if (a->config.check_types) { - types = unwrap_multiple_yield_types(a, value->type); - // outputs count has to match or not be given - assert(outputs_count == types.count || outputs_count == SIZE_MAX); - if (output_types) { - // Check that the types we got are subtypes of what we care about - for (size_t i = 0; i < types.count; i++) - assert(is_subtype(output_types[i], types.nodes[i])); - types = nodes(a, outputs_count, output_types); - } - outputs_count = types.count; - } else { - assert(outputs_count != SIZE_MAX); - if (output_types) { - types = nodes(a, outputs_count, output_types); - } else { - LARRAY(const Type*, nulls, outputs_count); - for (size_t i = 0; i < outputs_count; i++) - nulls[i] = NULL; - types = nodes(a, outputs_count, nulls); - } - } +BodyBuilder* shd_bld_begin_pseudo_instr(IrArena* a, const Node* mem) { + Node* block = basic_block(a, shd_empty(a), NULL); + BodyBuilder* builder = shd_bld_begin(a, shd_get_abstraction_mem(block)); + builder->tail_block = block; + builder->block_entry_block = block; + builder->block_entry_mem = mem; + return builder; +} - LARRAY(Node*, vars, types.count); - for (size_t i = 0; i < types.count; i++) { - String var_name = output_names ? output_names[i] : NULL; - vars[i] = (Node*) var(a, types.nodes[i], var_name); - } +BodyBuilder* shd_bld_begin_pure(IrArena* a) { + BodyBuilder* builder = shd_bld_begin(a, NULL); + return builder; +} - // for (size_t i = 0; i < outputs_count; i++) { - // vars[i]->payload.var.instruction = value; - // vars[i]->payload.var.output = i; - // } - return nodes(a, outputs_count, (const Node**) vars); +IrArena* shd_get_bb_arena(BodyBuilder* bb) { + return bb->arena; } -static Nodes bind_internal(BodyBuilder* bb, const Node* instruction, bool mut, size_t outputs_count, const Node** provided_types, String const output_names[]) { - if (bb->arena->config.check_types) { - assert(is_instruction(instruction)); - } - Nodes params = create_output_variables(bb->arena, instruction, outputs_count, provided_types, output_names); - StackEntry entry = { - .instr = instruction, - .vars = params, - .mut = mut, - }; - append_list(StackEntry, bb->stack, entry); - return params; +const Node* _shd_bb_insert_mem(BodyBuilder* bb) { + return bb->block_entry_mem; } -Nodes bind_instruction(BodyBuilder* bb, const Node* instruction) { - assert(bb->arena->config.check_types); - return bind_internal(bb, instruction, false, SIZE_MAX, NULL, NULL); +const Node* _shd_bb_insert_block(BodyBuilder* bb) { + return bb->block_entry_block; } -Nodes bind_instruction_named(BodyBuilder* bb, const Node* instruction, String const output_names[]) { - assert(bb->arena->config.check_types); - assert(output_names); - return bind_internal(bb, instruction, false, SIZE_MAX, NULL, output_names); +const Node* shd_bb_mem(BodyBuilder* bb) { + return bb->mem; } -Nodes bind_instruction_explicit_result_types(BodyBuilder* bb, const Node* instruction, Nodes provided_types, String const output_names[], bool mut) { - return bind_internal(bb, instruction, mut, provided_types.count, provided_types.nodes, output_names); +static Nodes bind_internal(BodyBuilder* bb, const Node* instruction, size_t outputs_count) { + if (shd_get_arena_config(bb->arena)->check_types) { + assert(is_mem(instruction)); + } + if (is_mem(instruction) && /* avoid things like ExtInstr with null mem input! */ shd_get_parent_mem(instruction)) + bb->mem = instruction; + return shd_deconstruct_composite(bb->arena, instruction, outputs_count); } -Nodes bind_instruction_outputs_count(BodyBuilder* bb, const Node* instruction, size_t outputs_count, String const output_names[], bool mut) { - return bind_internal(bb, instruction, mut, outputs_count, NULL, output_names); +const Node* shd_bld_add_instruction(BodyBuilder* bb, const Node* instr) { + return shd_first(shd_bld_add_instruction_extract_count(bb, instr, 1)); } -void bind_variables(BodyBuilder* bb, Nodes vars, Nodes values) { - StackEntry entry = { - .instr = quote_helper(bb->arena, values), - .vars = vars, - .mut = false, - }; - append_list(StackEntry, bb->stack, entry); +Nodes shd_bld_add_instruction_extract(BodyBuilder* bb, const Node* instruction) { + assert(shd_get_arena_config(bb->arena)->check_types); + return bind_internal(bb, instruction, shd_singleton(instruction->type).count); } -const Node* finish_body(BodyBuilder* bb, const Node* terminator) { - size_t stack_size = entries_count_list(bb->stack); +Nodes shd_bld_add_instruction_extract_count(BodyBuilder* bb, const Node* instruction, size_t outputs_count) { + return bind_internal(bb, instruction, outputs_count); +} + +static const Node* build_body(BodyBuilder* bb, const Node* terminator) { + IrArena* a = bb->arena; + size_t stack_size = shd_list_count(bb->stack); for (size_t i = stack_size - 1; i < stack_size; i--) { - StackEntry entry = read_list(StackEntry, bb->stack)[i]; - const Node* lam = case_(bb->arena, entry.vars, terminator); - terminator = (entry.mut ? let_mut : let)(bb->arena, entry.instr, lam); + StackEntry entry = shd_read_list(StackEntry, bb->stack)[i]; + const Node* t2 = terminator; + switch (entry.structured.tag) { + case NotAStructured_construct: shd_error("") + case Structured_construct_If_TAG: { + terminator = if_instr(a, entry.structured.payload.if_instr); + break; + } + case Structured_construct_Match_TAG: { + terminator = match_instr(a, entry.structured.payload.match_instr); + break; + } + case Structured_construct_Loop_TAG: { + terminator = loop_instr(a, entry.structured.payload.loop_instr); + break; + } + case Structured_construct_Control_TAG: { + terminator = control(a, entry.structured.payload.control); + break; + } + } + shd_set_abstraction_body((Node*) get_structured_construct_tail(terminator), t2); } + return terminator; +} - destroy_list(bb->stack); +const Node* shd_bld_finish(BodyBuilder* bb, const Node* terminator) { + assert(bb->mem && !bb->block_entry_mem); + terminator = build_body(bb, terminator); + shd_destroy_list(bb->stack); free(bb); return terminator; } -const Node* yield_values_and_wrap_in_block_explicit_return_types(BodyBuilder* bb, Nodes values, const Nodes* types) { - IrArena* arena = bb->arena; - assert(arena->config.check_types || types); - const Node* terminator = yield(arena, (Yield) { .args = values }); - const Node* lam = case_(arena, empty(arena), finish_body(bb, terminator)); - return block(arena, (Block) { - .yield_types = arena->config.check_types ? get_values_types(arena, values) : *types, - .inside = lam, +const Node* shd_bld_return(BodyBuilder* bb, Nodes args) { + return shd_bld_finish(bb, fn_ret(bb->arena, (Return) { + .args = args, + .mem = shd_bb_mem(bb) + })); +} + +const Node* shd_bld_unreachable(BodyBuilder* bb) { + return shd_bld_finish(bb, unreachable(bb->arena, (Unreachable) { + .mem = shd_bb_mem(bb) + })); +} + +const Node* shd_bld_selection_merge(BodyBuilder* bb, Nodes args) { + return shd_bld_finish(bb, merge_selection(bb->arena, (MergeSelection) { + .args = args, + .mem = shd_bb_mem(bb), + })); +} + +const Node* shd_bld_loop_continue(BodyBuilder* bb, Nodes args) { + return shd_bld_finish(bb, merge_continue(bb->arena, (MergeContinue) { + .args = args, + .mem = shd_bb_mem(bb), + })); +} + +const Node* shd_bld_loop_break(BodyBuilder* bb, Nodes args) { + return shd_bld_finish(bb, merge_break(bb->arena, (MergeBreak) { + .args = args, + .mem = shd_bb_mem(bb), + })); +} + +const Node* shd_bld_join(BodyBuilder* bb, const Node* jp, Nodes args) { + return shd_bld_finish(bb, join(bb->arena, (Join) { + .join_point = jp, + .args = args, + .mem = shd_bb_mem(bb), + })); +} + +const Node* shd_bld_jump(BodyBuilder* bb, const Node* target, Nodes args) { + return shd_bld_finish(bb, jump(bb->arena, (Jump) { + .target = target, + .args = args, + .mem = shd_bb_mem(bb), + })); +} + +const Node* shd_bld_to_instr_yield_value(BodyBuilder* bb, const Node* value) { + IrArena* a = bb->arena; + if (!bb->tail_block && shd_list_count(bb->stack) == 0) { + const Node* last_mem = shd_bb_mem(bb); + shd_bld_cancel(bb); + if (last_mem) + return mem_and_value(a, (MemAndValue) { + .mem = last_mem, + .value = value + }); + return value; + } + assert(bb->block_entry_mem && "This builder wasn't started with 'shd_bld_begin_pure' or 'shd_bld_begin_pseudo_instr'"); + bb->tail_block->payload.basic_block.insert = bb; + const Node* r = mem_and_value(bb->arena, (MemAndValue) { + .mem = shd_bb_mem(bb), + .value = value }); + return r; +} + +const Node* shd_bld_to_instr_yield_values(BodyBuilder* bb, Nodes values) { + return shd_bld_to_instr_yield_value(bb, shd_maybe_tuple_helper(bb->arena, values)); } -const Node* yield_values_and_wrap_in_block(BodyBuilder* bb, Nodes values) { - return yield_values_and_wrap_in_block_explicit_return_types(bb, values, NULL); +const Node* _shd_bld_finish_pseudo_instr(BodyBuilder* bb, const Node* terminator) { + assert(bb->block_entry_mem); + terminator = build_body(bb, terminator); + shd_destroy_list(bb->stack); + free(bb); + return terminator; } -const Node* bind_last_instruction_and_wrap_in_block_explicit_return_types(BodyBuilder* bb, const Node* instruction, const Nodes* types) { - size_t stack_size = entries_count_list(bb->stack); +const Node* shd_bld_to_instr_with_last_instr(BodyBuilder* bb, const Node* instruction) { + size_t stack_size = shd_list_count(bb->stack); if (stack_size == 0) { - cancel_body(bb); + shd_bld_cancel(bb); return instruction; } - Nodes bound = bind_internal(bb, instruction, false, types ? types->count : SIZE_MAX, types ? types->nodes : NULL, NULL); - return yield_values_and_wrap_in_block_explicit_return_types(bb, bound, types); + bind_internal(bb, instruction, 0); + return shd_bld_to_instr_yield_value(bb, instruction); +} + +const Node* shd_bld_to_instr_pure_with_values(BodyBuilder* bb, Nodes values) { + IrArena* arena = bb->arena; + assert(!bb->mem && !bb->block_entry_mem && shd_list_count(bb->stack) == 0); + shd_bld_cancel(bb); + return shd_maybe_tuple_helper(arena, values); +} + +static Nodes gen_variables(BodyBuilder* bb, Nodes yield_types) { + IrArena* a = bb->arena; + + Nodes qyield_types = shd_add_qualifiers(a, yield_types, false); + LARRAY(const Node*, tail_params, yield_types.count); + for (size_t i = 0; i < yield_types.count; i++) + tail_params[i] = param(a, qyield_types.nodes[i], NULL); + return shd_nodes(a, yield_types.count, tail_params); +} + +static Nodes add_structured_construct(BodyBuilder* bb, Nodes params, Structured_constructTag tag, union NodesUnion payload) { + Node* tail = basic_block(bb->arena, params, NULL); + StackEntry entry = { + .structured = { + .tag = tag, + .payload = payload, + }, + .vars = params, + }; + switch (entry.structured.tag) { + case NotAStructured_construct: shd_error("") + case Structured_construct_If_TAG: { + entry.structured.payload.if_instr.tail = tail; + entry.structured.payload.if_instr.mem = shd_bb_mem(bb); + break; + } + case Structured_construct_Match_TAG: { + entry.structured.payload.match_instr.tail = tail; + entry.structured.payload.match_instr.mem = shd_bb_mem(bb); + break; + } + case Structured_construct_Loop_TAG: { + entry.structured.payload.loop_instr.tail = tail; + entry.structured.payload.loop_instr.mem = shd_bb_mem(bb); + break; + } + case Structured_construct_Control_TAG: { + entry.structured.payload.control.tail = tail; + entry.structured.payload.control.mem = shd_bb_mem(bb); + break; + } + } + bb->mem = shd_get_abstraction_mem(tail); + shd_list_append(StackEntry , bb->stack, entry); + bb->tail_block = tail; + return entry.vars; } -const Node* bind_last_instruction_and_wrap_in_block(BodyBuilder* bb, const Node* instruction) { - return bind_last_instruction_and_wrap_in_block_explicit_return_types(bb, instruction, NULL); +static Nodes gen_structured_construct(BodyBuilder* bb, Nodes yield_types, Structured_constructTag tag, union NodesUnion payload) { + return add_structured_construct(bb, gen_variables(bb, yield_types), tag, payload); } -void cancel_body(BodyBuilder* bb) { - destroy_list(bb->stack); +Nodes shd_bld_if(BodyBuilder* bb, Nodes yield_types, const Node* condition, const Node* true_case, Node* false_case) { + return gen_structured_construct(bb, yield_types, Structured_construct_If_TAG, (union NodesUnion) { + .if_instr = { + .condition = condition, + .if_true = true_case, + .if_false = false_case, + .yield_types = yield_types, + } + }); +} + +Nodes shd_bld_match(BodyBuilder* bb, Nodes yield_types, const Node* inspectee, Nodes literals, Nodes cases, Node* default_case) { + return gen_structured_construct(bb, yield_types, Structured_construct_Match_TAG, (union NodesUnion) { + .match_instr = { + .yield_types = yield_types, + .inspect = inspectee, + .literals = literals, + .cases = cases, + .default_case = default_case + } + }); +} + +Nodes shd_bld_loop(BodyBuilder* bb, Nodes yield_types, Nodes initial_args, Node* body) { + return gen_structured_construct(bb, yield_types, Structured_construct_Loop_TAG, (union NodesUnion) { + .loop_instr = { + .yield_types = yield_types, + .initial_args = initial_args, + .body = body + }, + }); +} + +Nodes shd_bld_control(BodyBuilder* bb, Nodes yield_types, Node* body) { + return gen_structured_construct(bb, yield_types, Structured_construct_Control_TAG, (union NodesUnion) { + .control = { + .yield_types = yield_types, + .inside = body + }, + }); +} + +begin_control_t shd_bld_begin_control(BodyBuilder* bb, Nodes yield_types) { + IrArena* a = bb->arena; + const Type* jp_type = qualified_type(a, (QualifiedType) { + .type = join_point_type(a, (JoinPointType) { .yield_types = yield_types }), + .is_uniform = true + }); + const Node* jp = param(a, jp_type, NULL); + Node* c = case_(a, shd_singleton(jp)); + return (begin_control_t) { + .results = shd_bld_control(bb, yield_types, c), + .case_ = c, + .jp = jp + }; +} + +begin_loop_helper_t shd_bld_begin_loop_helper(BodyBuilder* bb, Nodes yield_types, Nodes arg_types, Nodes initial_values) { + assert(arg_types.count == initial_values.count); + IrArena* a = bb->arena; + begin_control_t outer_control = shd_bld_begin_control(bb, yield_types); + BodyBuilder* outer_control_case_builder = shd_bld_begin(a, shd_get_abstraction_mem(outer_control.case_)); + LARRAY(const Node*, params, arg_types.count); + for (size_t i = 0; i < arg_types.count; i++) { + params[i] = param(a, shd_as_qualified_type(arg_types.nodes[i], false), NULL); + } + Node* loop_header = case_(a, shd_nodes(a, arg_types.count, params)); + shd_set_abstraction_body(outer_control.case_, shd_bld_jump(outer_control_case_builder, loop_header, initial_values)); + BodyBuilder* loop_header_builder = shd_bld_begin(a, shd_get_abstraction_mem(loop_header)); + begin_control_t inner_control = shd_bld_begin_control(loop_header_builder, arg_types); + shd_set_abstraction_body(loop_header, shd_bld_jump(loop_header_builder, loop_header, inner_control.results)); + + return (begin_loop_helper_t) { + .results = outer_control.results, + .params = shd_nodes(a, arg_types.count, params), + .loop_body = inner_control.case_, + .break_jp = outer_control.jp, + .continue_jp = inner_control.jp, + }; +} + +void shd_bld_cancel(BodyBuilder* bb) { + for (size_t i = 0; i < shd_list_count(bb->stack); i++) { + StackEntry entry = shd_read_list(StackEntry, bb->stack)[i]; + // if (entry.structured.tag != NotAStructured_construct) + // destroy_list(entry.structured.stack); + } + shd_destroy_list(bb->stack); + //destroy_list(bb->stack_stack); free(bb); } diff --git a/src/shady/builtins.c b/src/shady/builtins.c deleted file mode 100644 index 750ba12dd..000000000 --- a/src/shady/builtins.c +++ /dev/null @@ -1,71 +0,0 @@ -#include "shady/builtins.h" -#include "spirv/unified1/spirv.h" - -#include "log.h" -#include "portability.h" -#include - -AddressSpace builtin_as[] = { -#define BUILTIN(_, as, _2) as, -SHADY_BUILTINS() -#undef BUILTIN -}; - -AddressSpace get_builtin_as(Builtin builtin) { return builtin_as[builtin]; } - -String builtin_names[] = { -#define BUILTIN(name, _, _2) #name, -SHADY_BUILTINS() -#undef BUILTIN -}; - -String get_builtin_name(Builtin builtin) { return builtin_names[builtin]; } - -const Type* get_builtin_type(IrArena* arena, Builtin builtin) { - switch (builtin) { -#define BUILTIN(name, _, datatype) case Builtin##name: return datatype; -SHADY_BUILTINS() -#undef BUILTIN - default: error("Unhandled builtin") - } -} - -// What's the decoration for the builtin -SpvBuiltIn spv_builtins[] = { -#define BUILTIN(name, _, _2) SpvBuiltIn##name, -SHADY_BUILTINS() -#undef BUILTIN -}; - -Builtin get_builtin_by_name(String s) { - for (size_t i = 0; i < BuiltinsCount; i++) { - if (strcmp(s, builtin_names[i]) == 0) { - return i; - } - } - return BuiltinsCount; -} - -Builtin get_builtin_by_spv_id(SpvBuiltIn id) { - Builtin b = BuiltinsCount; - for (size_t i = 0; i < BuiltinsCount; i++) { - if (id == spv_builtins[i]) { - b = i; - break; - } - } - return b; -} - -Builtin get_decl_builtin(const Node* decl) { - const Node* a = lookup_annotation(decl, "Builtin"); - if (!a) - return BuiltinsCount; - String payload = get_annotation_string_payload(a); - return get_builtin_by_name(payload); -} - - -bool is_decl_builtin(const Node* decl) { - return get_decl_builtin(decl) != BuiltinsCount; -} diff --git a/src/shady/check.c b/src/shady/check.c new file mode 100644 index 000000000..fe4cce8ce --- /dev/null +++ b/src/shady/check.c @@ -0,0 +1,964 @@ +#include "check.h" +#include "shady/ir/cast.h" + +#include "log.h" +#include "ir_private.h" +#include "portability.h" +#include "dict.h" +#include "util.h" + +#include "shady/ir/builtin.h" + +#include +#include + +static bool are_types_identical(size_t num_types, const Type* types[]) { + for (size_t i = 0; i < num_types; i++) { + assert(types[i]); + if (types[0] != types[i]) + return false; + } + return true; +} + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" + +const Type* _shd_check_type_join_point_type(IrArena* arena, JoinPointType type) { + for (size_t i = 0; i < type.yield_types.count; i++) { + assert(shd_is_data_type(type.yield_types.nodes[i])); + } + return NULL; +} + +const Type* _shd_check_type_record_type(IrArena* arena, RecordType type) { + assert(type.names.count == 0 || type.names.count == type.members.count); + for (size_t i = 0; i < type.members.count; i++) { + // member types are value types iff this is a return tuple + if (type.special == MultipleReturn) + assert(shd_is_value_type(type.members.nodes[i])); + else + assert(shd_is_data_type(type.members.nodes[i])); + } + return NULL; +} + +const Type* _shd_check_type_qualified_type(IrArena* arena, QualifiedType qualified_type) { + assert(shd_is_data_type(qualified_type.type)); + assert(arena->config.is_simt || qualified_type.is_uniform); + return NULL; +} + +const Type* _shd_check_type_arr_type(IrArena* arena, ArrType type) { + assert(shd_is_data_type(type.element_type)); + return NULL; +} + +const Type* _shd_check_type_pack_type(IrArena* arena, PackType pack_type) { + assert(shd_is_data_type(pack_type.element_type)); + return NULL; +} + +const Type* _shd_check_type_ptr_type(IrArena* arena, PtrType ptr_type) { + if (!arena->config.address_spaces[ptr_type.address_space].allowed) { + shd_error_print("Address space %s is not allowed in this arena\n", shd_get_address_space_name(ptr_type.address_space)); + shd_error_die(); + } + assert(ptr_type.pointed_type && "Shady does not support untyped pointers, but can infer them, see infer.c"); + if (ptr_type.pointed_type) { + if (ptr_type.pointed_type->tag == ArrType_TAG) { + assert(shd_is_data_type(ptr_type.pointed_type->payload.arr_type.element_type)); + return NULL; + } + if (ptr_type.pointed_type->tag == FnType_TAG || ptr_type.pointed_type == unit_type(arena)) { + // no diagnostic required, we just allow these + return NULL; + } + const Node* maybe_record_type = ptr_type.pointed_type; + if (maybe_record_type->tag == TypeDeclRef_TAG) + maybe_record_type = shd_get_nominal_type_body(maybe_record_type); + if (maybe_record_type && maybe_record_type->tag == RecordType_TAG && maybe_record_type->payload.record_type.special == DecorateBlock) { + return NULL; + } + assert(shd_is_data_type(ptr_type.pointed_type)); + } + return NULL; +} + +const Type* _shd_check_type_param(IrArena* arena, Param variable) { + assert(shd_is_value_type(variable.type)); + return variable.type; +} + +const Type* _shd_check_type_untyped_number(IrArena* arena, UntypedNumber untyped) { + shd_error("should never happen"); +} + +const Type* _shd_check_type_int_literal(IrArena* arena, IntLiteral lit) { + return qualified_type(arena, (QualifiedType) { + .is_uniform = true, + .type = int_type(arena, (Int) { .width = lit.width, .is_signed = lit.is_signed }) + }); +} + +const Type* _shd_check_type_float_literal(IrArena* arena, FloatLiteral lit) { + return qualified_type(arena, (QualifiedType) { + .is_uniform = true, + .type = float_type(arena, (Float) { .width = lit.width }) + }); +} + +const Type* _shd_check_type_true_lit(IrArena* arena) { return qualified_type(arena, (QualifiedType) { .type = bool_type(arena), .is_uniform = true }); } +const Type* _shd_check_type_false_lit(IrArena* arena) { return qualified_type(arena, (QualifiedType) { .type = bool_type(arena), .is_uniform = true }); } + +const Type* _shd_check_type_string_lit(IrArena* arena, StringLiteral str_lit) { + const Type* t = arr_type(arena, (ArrType) { + .element_type = shd_int8_type(arena), + .size = shd_int32_literal(arena, strlen(str_lit.string)) + }); + return qualified_type(arena, (QualifiedType) { + .type = t, + .is_uniform = true, + }); +} + +const Type* _shd_check_type_null_ptr(IrArena* a, NullPtr payload) { + assert(shd_is_data_type(payload.ptr_type) && payload.ptr_type->tag == PtrType_TAG); + return shd_as_qualified_type(payload.ptr_type, true); +} + +const Type* _shd_check_type_composite(IrArena* arena, Composite composite) { + if (composite.type) { + assert(shd_is_data_type(composite.type)); + Nodes expected_member_types = shd_get_composite_type_element_types(composite.type); + bool is_uniform = true; + assert(composite.contents.count == expected_member_types.count); + for (size_t i = 0; i < composite.contents.count; i++) { + const Type* element_type = composite.contents.nodes[i]->type; + is_uniform &= shd_deconstruct_qualified_type(&element_type); + assert(shd_is_subtype(expected_member_types.nodes[i], element_type)); + } + return qualified_type(arena, (QualifiedType) { + .is_uniform = is_uniform, + .type = composite.type + }); + } + bool is_uniform = true; + LARRAY(const Type*, member_ts, composite.contents.count); + for (size_t i = 0; i < composite.contents.count; i++) { + const Type* element_type = composite.contents.nodes[i]->type; + is_uniform &= shd_deconstruct_qualified_type(&element_type); + member_ts[i] = element_type; + } + return qualified_type(arena, (QualifiedType) { + .is_uniform = is_uniform, + .type = record_type(arena, (RecordType) { + .members = shd_nodes(arena, composite.contents.count, member_ts) + }) + }); +} + +const Type* _shd_check_type_fill(IrArena* arena, Fill payload) { + assert(shd_is_data_type(payload.type)); + const Node* element_t = shd_get_fill_type_element_type(payload.type); + const Node* value_t = payload.value->type; + bool u = shd_deconstruct_qualified_type(&value_t); + assert(shd_is_subtype(element_t, value_t)); + return qualified_type(arena, (QualifiedType) { + .is_uniform = u, + .type = payload.type + }); +} + +const Type* _shd_check_type_undef(IrArena* arena, Undef payload) { + assert(shd_is_data_type(payload.type)); + return qualified_type(arena, (QualifiedType) { + .is_uniform = true, + .type = payload.type + }); +} + +const Type* _shd_check_type_mem_and_value(IrArena* arena, MemAndValue mav) { + return mav.value->type; +} + +const Type* _shd_check_type_fn_addr(IrArena* arena, FnAddr fn_addr) { + assert(fn_addr.fn->type->tag == FnType_TAG); + assert(fn_addr.fn->tag == Function_TAG); + return qualified_type(arena, (QualifiedType) { + .is_uniform = true, + .type = ptr_type(arena, (PtrType) { + .pointed_type = fn_addr.fn->type, + .address_space = AsGeneric /* the actual AS does not matter because these are opaque anyways */, + }) + }); +} + +const Type* _shd_check_type_ref_decl(IrArena* arena, RefDecl ref_decl) { + const Type* t = ref_decl.decl->type; + assert(t && "RefDecl needs to be applied on a decl with a non-null type. Did you forget to set 'type' on a constant ?"); + switch (ref_decl.decl->tag) { + case GlobalVariable_TAG: + case Constant_TAG: break; + default: shd_error("You can only use RefDecl on a global or a constant. See FnAddr for taking addresses of functions.") + } + assert(t->tag != QualifiedType_TAG && "decl types may not be qualified"); + return qualified_type(arena, (QualifiedType) { + .type = t, + .is_uniform = true, + }); +} + +const Type* _shd_check_type_prim_op(IrArena* arena, PrimOp prim_op) { + for (size_t i = 0; i < prim_op.type_arguments.count; i++) { + const Node* ta = prim_op.type_arguments.nodes[i]; + assert(ta && is_type(ta)); + } + for (size_t i = 0; i < prim_op.operands.count; i++) { + const Node* operand = prim_op.operands.nodes[i]; + assert(operand && is_value(operand)); + } + + bool extended = false; + bool ordered = false; + AddressSpace as; + switch (prim_op.op) { + case neg_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 1); + + const Type* type = shd_first(prim_op.operands)->type; + assert(shd_is_arithm_type(shd_get_maybe_packed_type_element(shd_get_unqualified_type(type)))); + return type; + } + case rshift_arithm_op: + case rshift_logical_op: + case lshift_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 2); + const Type* first_operand_type = shd_first(prim_op.operands)->type; + const Type* second_operand_type = prim_op.operands.nodes[1]->type; + + bool uniform_result = shd_deconstruct_qualified_type(&first_operand_type); + uniform_result &= shd_deconstruct_qualified_type(&second_operand_type); + + size_t value_simd_width = shd_deconstruct_maybe_packed_type(&first_operand_type); + size_t shift_simd_width = shd_deconstruct_maybe_packed_type(&second_operand_type); + assert(value_simd_width == shift_simd_width); + + assert(first_operand_type->tag == Int_TAG); + assert(second_operand_type->tag == Int_TAG); + + return shd_as_qualified_type(shd_maybe_packed_type_helper(first_operand_type, value_simd_width), uniform_result); + } + case add_carry_op: + case sub_borrow_op: + case mul_extended_op: extended = true; SHADY_FALLTHROUGH; + case min_op: + case max_op: + case add_op: + case sub_op: + case mul_op: + case div_op: + case mod_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 2); + const Type* first_operand_type = shd_get_unqualified_type(shd_first(prim_op.operands)->type); + + bool result_uniform = true; + for (size_t i = 0; i < prim_op.operands.count; i++) { + const Node* arg = prim_op.operands.nodes[i]; + const Type* operand_type = arg->type; + bool operand_uniform = shd_deconstruct_qualified_type(&operand_type); + + assert(shd_is_arithm_type(shd_get_maybe_packed_type_element(operand_type))); + assert(first_operand_type == operand_type && "operand type mismatch"); + + result_uniform &= operand_uniform; + } + + const Type* result_t = first_operand_type; + if (extended) { + // TODO: assert unsigned + result_t = record_type(arena, (RecordType) {.members = mk_nodes(arena, result_t, result_t)}); + } + return shd_as_qualified_type(result_t, result_uniform); + } + + case not_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 1); + + const Type* type = shd_first(prim_op.operands)->type; + assert(shd_has_boolean_ops(shd_get_maybe_packed_type_element(shd_get_unqualified_type(type)))); + return type; + } + case or_op: + case xor_op: + case and_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 2); + const Type* first_operand_type = shd_get_unqualified_type(shd_first(prim_op.operands)->type); + + bool result_uniform = true; + for (size_t i = 0; i < prim_op.operands.count; i++) { + const Node* arg = prim_op.operands.nodes[i]; + const Type* operand_type = arg->type; + bool operand_uniform = shd_deconstruct_qualified_type(&operand_type); + + assert(shd_has_boolean_ops(shd_get_maybe_packed_type_element(operand_type))); + assert(first_operand_type == operand_type && "operand type mismatch"); + + result_uniform &= operand_uniform; + } + + return shd_as_qualified_type(first_operand_type, result_uniform); + } + case lt_op: + case lte_op: + case gt_op: + case gte_op: ordered = true; SHADY_FALLTHROUGH + case eq_op: + case neq_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 2); + const Type* first_operand_type = shd_get_unqualified_type(shd_first(prim_op.operands)->type); + size_t first_operand_width = shd_get_maybe_packed_type_width(first_operand_type); + + bool result_uniform = true; + for (size_t i = 0; i < prim_op.operands.count; i++) { + const Node* arg = prim_op.operands.nodes[i]; + const Type* operand_type = arg->type; + bool operand_uniform = shd_deconstruct_qualified_type(&operand_type); + + assert((ordered ? shd_is_ordered_type : shd_is_comparable_type)(shd_get_maybe_packed_type_element(operand_type))); + assert(first_operand_type == operand_type && "operand type mismatch"); + + result_uniform &= operand_uniform; + } + + return shd_as_qualified_type(shd_maybe_packed_type_helper(bool_type(arena), first_operand_width), + result_uniform); + } + case sqrt_op: + case inv_sqrt_op: + case floor_op: + case ceil_op: + case round_op: + case fract_op: + case sin_op: + case cos_op: + case exp_op: + { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 1); + const Node* src_type = shd_first(prim_op.operands)->type; + bool uniform = shd_deconstruct_qualified_type(&src_type); + size_t width = shd_deconstruct_maybe_packed_type(&src_type); + assert(src_type->tag == Float_TAG); + return shd_as_qualified_type(shd_maybe_packed_type_helper(src_type, width), uniform); + } + case pow_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 2); + const Type* first_operand_type = shd_get_unqualified_type(shd_first(prim_op.operands)->type); + + bool result_uniform = true; + for (size_t i = 0; i < prim_op.operands.count; i++) { + const Node* arg = prim_op.operands.nodes[i]; + const Type* operand_type = arg->type; + bool operand_uniform = shd_deconstruct_qualified_type(&operand_type); + + assert(shd_get_maybe_packed_type_element(operand_type)->tag == Float_TAG); + assert(first_operand_type == operand_type && "operand type mismatch"); + + result_uniform &= operand_uniform; + } + + return shd_as_qualified_type(first_operand_type, result_uniform); + } + case fma_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 3); + const Type* first_operand_type = shd_get_unqualified_type(shd_first(prim_op.operands)->type); + + bool result_uniform = true; + for (size_t i = 0; i < prim_op.operands.count; i++) { + const Node* arg = prim_op.operands.nodes[i]; + const Type* operand_type = arg->type; + bool operand_uniform = shd_deconstruct_qualified_type(&operand_type); + + assert(shd_get_maybe_packed_type_element(operand_type)->tag == Float_TAG); + assert(first_operand_type == operand_type && "operand type mismatch"); + + result_uniform &= operand_uniform; + } + + return shd_as_qualified_type(first_operand_type, result_uniform); + } + case abs_op: + case sign_op: + { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 1); + const Node* src_type = shd_first(prim_op.operands)->type; + bool uniform = shd_deconstruct_qualified_type(&src_type); + size_t width = shd_deconstruct_maybe_packed_type(&src_type); + assert(src_type->tag == Float_TAG || src_type->tag == Int_TAG && src_type->payload.int_type.is_signed); + return shd_as_qualified_type(shd_maybe_packed_type_helper(src_type, width), uniform); + } + case align_of_op: + case size_of_op: { + assert(prim_op.type_arguments.count == 1); + assert(prim_op.operands.count == 0); + return qualified_type(arena, (QualifiedType) { + .is_uniform = true, + .type = int_type(arena, (Int) { .width = arena->config.memory.ptr_size, .is_signed = false }) + }); + } + case offset_of_op: { + assert(prim_op.type_arguments.count == 1); + assert(prim_op.operands.count == 1); + const Type* optype = shd_first(prim_op.operands)->type; + bool uniform = shd_deconstruct_qualified_type(&optype); + assert(uniform && optype->tag == Int_TAG); + return qualified_type(arena, (QualifiedType) { + .is_uniform = true, + .type = int_type(arena, (Int) { .width = arena->config.memory.ptr_size, .is_signed = false }) + }); + } + case select_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 3); + const Type* condition_type = prim_op.operands.nodes[0]->type; + bool condition_uniform = shd_deconstruct_qualified_type(&condition_type); + size_t width = shd_deconstruct_maybe_packed_type(&condition_type); + + const Type* alternatives_types[2]; + bool alternatives_all_uniform = true; + for (size_t i = 0; i < 2; i++) { + alternatives_types[i] = prim_op.operands.nodes[1 + i]->type; + alternatives_all_uniform &= shd_deconstruct_qualified_type(&alternatives_types[i]); + size_t alternative_width = shd_deconstruct_maybe_packed_type(&alternatives_types[i]); + assert(alternative_width == width); + } + + assert(shd_is_subtype(bool_type(arena), condition_type)); + // todo find true supertype + assert(are_types_identical(2, alternatives_types)); + + return shd_as_qualified_type(shd_maybe_packed_type_helper(alternatives_types[0], width), + alternatives_all_uniform && condition_uniform); + } + case insert_op: + case extract_dynamic_op: + case extract_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count >= 2); + const Node* source = shd_first(prim_op.operands); + + size_t indices_start = prim_op.op == insert_op ? 2 : 1; + Nodes indices = shd_nodes(arena, prim_op.operands.count - indices_start, &prim_op.operands.nodes[indices_start]); + + const Type* t = source->type; + bool uniform = shd_deconstruct_qualified_type(&t); + shd_enter_composite_type_indices(&t, &uniform, indices, true); + + if (prim_op.op == insert_op) { + const Node* inserted_data = prim_op.operands.nodes[1]; + const Type* inserted_data_type = inserted_data->type; + bool is_uniform = uniform & shd_deconstruct_qualified_type(&inserted_data_type); + assert(shd_is_subtype(t, inserted_data_type) && "inserting data into a composite, but it doesn't match the target and indices"); + return qualified_type(arena, (QualifiedType) { + .is_uniform = is_uniform, + .type = shd_get_unqualified_type(source->type), + }); + } + + return shd_as_qualified_type(t, uniform); + } + case shuffle_op: { + assert(prim_op.operands.count >= 2); + assert(prim_op.type_arguments.count == 0); + const Node* lhs = prim_op.operands.nodes[0]; + const Node* rhs = prim_op.operands.nodes[1]; + const Type* lhs_t = lhs->type; + const Type* rhs_t = rhs->type; + bool lhs_u = shd_deconstruct_qualified_type(&lhs_t); + bool rhs_u = shd_deconstruct_qualified_type(&rhs_t); + assert(lhs_t->tag == PackType_TAG && rhs_t->tag == PackType_TAG); + size_t total_size = lhs_t->payload.pack_type.width + rhs_t->payload.pack_type.width; + const Type* element_t = lhs_t->payload.pack_type.element_type; + assert(element_t == rhs_t->payload.pack_type.element_type); + + size_t indices_count = prim_op.operands.count - 2; + const Node** indices = &prim_op.operands.nodes[2]; + bool u = lhs_u & rhs_u; + for (size_t i = 0; i < indices_count; i++) { + u &= shd_is_qualified_type_uniform(indices[i]->type); + int64_t index = shd_get_int_literal_value(*shd_resolve_to_int_literal(indices[i]), true); + assert(index < 0 /* poison */ || (index >= 0 && index < total_size && "shuffle element out of range")); + } + return shd_as_qualified_type( + pack_type(arena, (PackType) {.element_type = element_t, .width = indices_count}), u); + } + case reinterpret_op: { + assert(prim_op.type_arguments.count == 1); + assert(prim_op.operands.count == 1); + const Node* source = shd_first(prim_op.operands); + const Type* src_type = source->type; + bool src_uniform = shd_deconstruct_qualified_type(&src_type); + + const Type* dst_type = shd_first(prim_op.type_arguments); + assert(shd_is_data_type(dst_type)); + assert(shd_is_reinterpret_cast_legal(src_type, dst_type)); + + return qualified_type(arena, (QualifiedType) { + .is_uniform = src_uniform, + .type = dst_type + }); + } + case convert_op: { + assert(prim_op.type_arguments.count == 1); + assert(prim_op.operands.count == 1); + const Node* source = shd_first(prim_op.operands); + const Type* src_type = source->type; + bool src_uniform = shd_deconstruct_qualified_type(&src_type); + + const Type* dst_type = shd_first(prim_op.type_arguments); + assert(shd_is_data_type(dst_type)); + assert(shd_is_conversion_legal(src_type, dst_type)); + + // TODO check the conversion is legal + return qualified_type(arena, (QualifiedType) { + .is_uniform = src_uniform, + .type = dst_type + }); + } + // Mask management + case empty_mask_op: { + assert(prim_op.type_arguments.count == 0 && prim_op.operands.count == 0); + return shd_as_qualified_type(shd_get_actual_mask_type(arena), true); + } + case mask_is_thread_active_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 2); + return qualified_type(arena, (QualifiedType) { + .is_uniform = shd_is_qualified_type_uniform(prim_op.operands.nodes[0]->type) && shd_is_qualified_type_uniform(prim_op.operands.nodes[1]->type), + .type = bool_type(arena) + }); + } + // Subgroup ops + case subgroup_assume_uniform_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 1); + const Type* operand_type = shd_get_unqualified_type(prim_op.operands.nodes[0]->type); + return qualified_type(arena, (QualifiedType) { + .is_uniform = true, + .type = operand_type + }); + } + // Intermediary ops + case sample_texture_op: { + assert(prim_op.type_arguments.count == 0); + assert(prim_op.operands.count == 2); + const Type* sampled_image_t = shd_first(prim_op.operands)->type; + bool uniform_src = shd_deconstruct_qualified_type(&sampled_image_t); + const Type* coords_t = prim_op.operands.nodes[1]->type; + shd_deconstruct_qualified_type(&coords_t); + assert(sampled_image_t->tag == SampledImageType_TAG); + const Type* image_t = sampled_image_t->payload.sampled_image_type.image_type; + assert(image_t->tag == ImageType_TAG); + size_t coords_dim = shd_deconstruct_packed_type(&coords_t); + return qualified_type(arena, (QualifiedType) { .is_uniform = false, .type = shd_maybe_packed_type_helper(image_t->payload.image_type.sampled_type, 4) }); + } + case PRIMOPS_COUNT: assert(false); + } +} + +const Type* _shd_check_type_ext_instr(IrArena* arena, ExtInstr payload) { + return payload.result_t; +} + +static void check_arguments_types_against_parameters_helper(Nodes param_types, Nodes arg_types) { + if (param_types.count != arg_types.count) + shd_error("Mismatched number of arguments/parameters"); + for (size_t i = 0; i < param_types.count; i++) + shd_check_subtype(param_types.nodes[i], arg_types.nodes[i]); +} + +/// Shared logic between indirect calls and tailcalls +static Nodes check_value_call(const Node* callee, Nodes argument_types) { + assert(is_value(callee)); + + const Type* callee_type = callee->type; + SHADY_UNUSED bool callee_uniform = shd_deconstruct_qualified_type(&callee_type); + AddressSpace as = shd_deconstruct_pointer_type(&callee_type); + assert(as == AsGeneric); + + assert(callee_type->tag == FnType_TAG); + + const FnType* fn_type = &callee_type->payload.fn_type; + check_arguments_types_against_parameters_helper(fn_type->param_types, argument_types); + // TODO force the return types to be varying if the callee is not uniform + return fn_type->return_types; +} + +const Type* _shd_check_type_call(IrArena* arena, Call call) { + Nodes args = call.args; + for (size_t i = 0; i < args.count; i++) { + const Node* argument = args.nodes[i]; + assert(is_value(argument)); + } + Nodes argument_types = shd_get_values_types(arena, args); + return shd_maybe_multiple_return(arena, check_value_call(call.callee, argument_types)); +} + +static void ensure_types_are_data_types(const Nodes* yield_types) { + for (size_t i = 0; i < yield_types->count; i++) { + assert(shd_is_data_type(yield_types->nodes[i])); + } +} + +static void ensure_types_are_value_types(const Nodes* yield_types) { + for (size_t i = 0; i < yield_types->count; i++) { + assert(shd_is_value_type(yield_types->nodes[i])); + } +} + +const Type* _shd_check_type_if_instr(IrArena* arena, If if_instr) { + assert(if_instr.tail && is_abstraction(if_instr.tail)); + ensure_types_are_data_types(&if_instr.yield_types); + if (shd_get_unqualified_type(if_instr.condition->type) != bool_type(arena)) + shd_error("condition of an if should be bool"); + // TODO check the contained Merge instrs + if (if_instr.yield_types.count > 0) + assert(if_instr.if_false); + + check_arguments_types_against_parameters_helper(shd_get_param_types(arena, get_abstraction_params(if_instr.tail)), shd_add_qualifiers(arena, if_instr.yield_types, false)); + return noret_type(arena); +} + +const Type* _shd_check_type_match_instr(IrArena* arena, Match match_instr) { + ensure_types_are_data_types(&match_instr.yield_types); + // TODO check param against initial_args + // TODO check the contained Merge instrs + return noret_type(arena); +} + +const Type* _shd_check_type_loop_instr(IrArena* arena, Loop loop_instr) { + ensure_types_are_data_types(&loop_instr.yield_types); + // TODO check param against initial_args + // TODO check the contained Merge instrs + return noret_type(arena); +} + +const Type* _shd_check_type_control(IrArena* arena, Control control) { + ensure_types_are_data_types(&control.yield_types); + // TODO check it then ! + const Node* join_point = shd_first(get_abstraction_params(control.inside)); + + const Type* join_point_type = join_point->type; + shd_deconstruct_qualified_type(&join_point_type); + assert(join_point_type->tag == JoinPointType_TAG); + + Nodes join_point_yield_types = join_point_type->payload.join_point_type.yield_types; + assert(join_point_yield_types.count == control.yield_types.count); + for (size_t i = 0; i < control.yield_types.count; i++) { + assert(shd_is_subtype(control.yield_types.nodes[i], join_point_yield_types.nodes[i])); + } + + assert(get_abstraction_params(control.tail).count == control.yield_types.count); + + return noret_type(arena); +} + +const Type* _shd_check_type_comment(IrArena* arena, SHADY_UNUSED Comment payload) { + return empty_multiple_return_type(arena); +} + +const Type* _shd_check_type_stack_alloc(IrArena* a, StackAlloc alloc) { + assert(is_type(alloc.type)); + return qualified_type(a, (QualifiedType) { + .is_uniform = shd_is_addr_space_uniform(a, AsPrivate), + .type = ptr_type(a, (PtrType) { + .pointed_type = alloc.type, + .address_space = AsPrivate, + .is_reference = false + }) + }); +} + +const Type* _shd_check_type_local_alloc(IrArena* a, LocalAlloc alloc) { + assert(is_type(alloc.type)); + return qualified_type(a, (QualifiedType) { + .is_uniform = shd_is_addr_space_uniform(a, AsFunction), + .type = ptr_type(a, (PtrType) { + .pointed_type = alloc.type, + .address_space = AsFunction, + .is_reference = true + }) + }); +} + +const Type* _shd_check_type_load(IrArena* a, Load load) { + const Node* ptr_type = load.ptr->type; + bool ptr_uniform = shd_deconstruct_qualified_type(&ptr_type); + size_t width = shd_deconstruct_maybe_packed_type(&ptr_type); + + assert(ptr_type->tag == PtrType_TAG); + const PtrType* node_ptr_type_ = &ptr_type->payload.ptr_type; + const Type* elem_type = node_ptr_type_->pointed_type; + elem_type = shd_maybe_packed_type_helper(elem_type, width); + return shd_as_qualified_type(elem_type, + ptr_uniform && shd_is_addr_space_uniform(a, ptr_type->payload.ptr_type.address_space)); +} + +const Type* _shd_check_type_store(IrArena* a, Store store) { + const Node* ptr_type = store.ptr->type; + bool ptr_uniform = shd_deconstruct_qualified_type(&ptr_type); + size_t width = shd_deconstruct_maybe_packed_type(&ptr_type); + assert(ptr_type->tag == PtrType_TAG); + const PtrType* ptr_type_payload = &ptr_type->payload.ptr_type; + const Type* elem_type = ptr_type_payload->pointed_type; + assert(elem_type); + elem_type = shd_maybe_packed_type_helper(elem_type, width); + // we don't enforce uniform stores - but we care about storing the right thing :) + const Type* val_expected_type = qualified_type(a, (QualifiedType) { + .is_uniform = !a->config.is_simt, + .type = elem_type + }); + + assert(shd_is_subtype(val_expected_type, store.value->type)); + return empty_multiple_return_type(a); +} + +const Type* _shd_check_type_ptr_array_element_offset(IrArena* a, PtrArrayElementOffset lea) { + const Type* base_ptr_type = lea.ptr->type; + bool uniform = shd_deconstruct_qualified_type(&base_ptr_type); + assert(base_ptr_type->tag == PtrType_TAG && "lea expects a ptr or ref as a base"); + const Type* pointee_type = base_ptr_type->payload.ptr_type.pointed_type; + + assert(lea.offset); + const Type* offset_type = lea.offset->type; + bool offset_uniform = shd_deconstruct_qualified_type(&offset_type); + assert(offset_type->tag == Int_TAG && "lea expects an integer offset"); + + const IntLiteral* lit = shd_resolve_to_int_literal(lea.offset); + bool offset_is_zero = lit && lit->value == 0; + assert(offset_is_zero || !base_ptr_type->payload.ptr_type.is_reference && "if an offset is used, the base cannot be a reference"); + assert(offset_is_zero || shd_is_data_type(pointee_type) && "if an offset is used, the base must point to a data type"); + uniform &= offset_uniform; + + return qualified_type(a, (QualifiedType) { + .is_uniform = uniform, + .type = ptr_type(a, (PtrType) { + .pointed_type = pointee_type, + .address_space = base_ptr_type->payload.ptr_type.address_space, + .is_reference = base_ptr_type->payload.ptr_type.is_reference + }) + }); +} + +const Type* _shd_check_type_ptr_composite_element(IrArena* a, PtrCompositeElement lea) { + const Type* base_ptr_type = lea.ptr->type; + bool uniform = shd_deconstruct_qualified_type(&base_ptr_type); + assert(base_ptr_type->tag == PtrType_TAG && "lea expects a ptr or ref as a base"); + const Type* pointee_type = base_ptr_type->payload.ptr_type.pointed_type; + + shd_enter_composite_type(&pointee_type, &uniform, lea.index, true); + + return qualified_type(a, (QualifiedType) { + .is_uniform = uniform, + .type = ptr_type(a, (PtrType) { + .pointed_type = pointee_type, + .address_space = base_ptr_type->payload.ptr_type.address_space, + .is_reference = base_ptr_type->payload.ptr_type.is_reference + }) + }); +} + +const Type* _shd_check_type_copy_bytes(IrArena* a, CopyBytes copy_bytes) { + const Type* dst_t = copy_bytes.dst->type; + shd_deconstruct_qualified_type(&dst_t); + assert(dst_t->tag == PtrType_TAG); + const Type* src_t = copy_bytes.src->type; + shd_deconstruct_qualified_type(&src_t); + assert(src_t); + const Type* cnt_t = copy_bytes.count->type; + shd_deconstruct_qualified_type(&cnt_t); + assert(cnt_t->tag == Int_TAG); + return empty_multiple_return_type(a); +} + +const Type* _shd_check_type_fill_bytes(IrArena* a, FillBytes fill_bytes) { + const Type* dst_t = fill_bytes.dst->type; + shd_deconstruct_qualified_type(&dst_t); + assert(dst_t->tag == PtrType_TAG); + const Type* src_t = fill_bytes.src->type; + shd_deconstruct_qualified_type(&src_t); + assert(src_t); + const Type* cnt_t = fill_bytes.count->type; + shd_deconstruct_qualified_type(&cnt_t); + assert(cnt_t->tag == Int_TAG); + return empty_multiple_return_type(a); +} + +const Type* _shd_check_type_push_stack(IrArena* a, PushStack payload) { + assert(payload.value); + return empty_multiple_return_type(a); +} + +const Type* _shd_check_type_pop_stack(IrArena* a, PopStack payload) { + return shd_as_qualified_type(payload.type, false); +} + +const Type* _shd_check_type_set_stack_size(IrArena* a, SetStackSize payload) { + assert(shd_get_unqualified_type(payload.value->type) == shd_uint32_type(a)); + return shd_as_qualified_type(unit_type(a), true); +} + +const Type* _shd_check_type_get_stack_size(IrArena* a, SHADY_UNUSED GetStackSize ss) { + return qualified_type(a, (QualifiedType) { .is_uniform = false, .type = shd_uint32_type(a) }); +} + +const Type* _shd_check_type_get_stack_base_addr(IrArena* a, SHADY_UNUSED GetStackBaseAddr gsba) { + const Node* ptr = ptr_type(a, (PtrType) { .pointed_type = shd_uint8_type(a), .address_space = AsPrivate}); + return qualified_type(a, (QualifiedType) { .is_uniform = false, .type = ptr }); +} + +const Type* _shd_check_type_debug_printf(IrArena* a, DebugPrintf payload) { + return empty_multiple_return_type(a); +} + +const Type* _shd_check_type_tail_call(IrArena* arena, TailCall tail_call) { + Nodes args = tail_call.args; + for (size_t i = 0; i < args.count; i++) { + const Node* argument = args.nodes[i]; + assert(is_value(argument)); + } + assert(check_value_call(tail_call.callee, shd_get_values_types(arena, tail_call.args)).count == 0); + return noret_type(arena); +} + +static void check_basic_block_call(const Node* block, Nodes argument_types) { + assert(is_basic_block(block)); + assert(block->type->tag == BBType_TAG); + BBType bb_type = block->type->payload.bb_type; + check_arguments_types_against_parameters_helper(bb_type.param_types, argument_types); +} + +const Type* _shd_check_type_jump(IrArena* arena, Jump jump) { + for (size_t i = 0; i < jump.args.count; i++) { + const Node* argument = jump.args.nodes[i]; + assert(is_value(argument)); + } + + check_basic_block_call(jump.target, shd_get_values_types(arena, jump.args)); + return noret_type(arena); +} + +const Type* _shd_check_type_branch(IrArena* arena, Branch payload) { + assert(payload.true_jump->tag == Jump_TAG); + assert(payload.false_jump->tag == Jump_TAG); + return noret_type(arena); +} + +const Type* _shd_check_type_br_switch(IrArena* arena, Switch payload) { + for (size_t i = 0; i < payload.case_jumps.count; i++) + assert(payload.case_jumps.nodes[i]->tag == Jump_TAG); + assert(payload.case_values.count == payload.case_jumps.count); + assert(payload.default_jump->tag == Jump_TAG); + return noret_type(arena); +} + +const Type* _shd_check_type_join(IrArena* arena, Join join) { + for (size_t i = 0; i < join.args.count; i++) { + const Node* argument = join.args.nodes[i]; + assert(is_value(argument)); + } + + const Type* join_target_type = join.join_point->type; + + shd_deconstruct_qualified_type(&join_target_type); + assert(join_target_type->tag == JoinPointType_TAG); + + Nodes join_point_param_types = join_target_type->payload.join_point_type.yield_types; + join_point_param_types = shd_add_qualifiers(arena, join_point_param_types, !arena->config.is_simt); + + check_arguments_types_against_parameters_helper(join_point_param_types, shd_get_values_types(arena, join.args)); + + return noret_type(arena); +} + +const Type* _shd_check_type_unreachable(IrArena* arena, SHADY_UNUSED Unreachable u) { + return noret_type(arena); +} + +const Type* _shd_check_type_merge_continue(IrArena* arena, MergeContinue mc) { + // TODO check it + return noret_type(arena); +} + +const Type* _shd_check_type_merge_break(IrArena* arena, MergeBreak mc) { + // TODO check it + return noret_type(arena); +} + +const Type* _shd_check_type_merge_selection(IrArena* arena, SHADY_UNUSED MergeSelection payload) { + // TODO check it + return noret_type(arena); +} + +const Type* _shd_check_type_fn_ret(IrArena* arena, Return ret) { + // assert(ret.fn); + // TODO check it then ! + return noret_type(arena); +} + +const Type* _shd_check_type_fun(IrArena* arena, Function fn) { + for (size_t i = 0; i < fn.return_types.count; i++) { + assert(shd_is_value_type(fn.return_types.nodes[i])); + } + return fn_type(arena, (FnType) { .param_types = shd_get_param_types(arena, (&fn)->params), .return_types = (&fn)->return_types }); +} + +const Type* _shd_check_type_basic_block(IrArena* arena, BasicBlock bb) { + return bb_type(arena, (BBType) { .param_types = shd_get_param_types(arena, (&bb)->params) }); +} + +const Type* _shd_check_type_global_variable(IrArena* arena, GlobalVariable global_variable) { + assert(is_type(global_variable.type)); + + const Node* ba = shd_lookup_annotation_list(global_variable.annotations, "Builtin"); + if (ba && arena->config.validate_builtin_types) { + Builtin b = shd_get_builtin_by_name(shd_get_annotation_string_payload(ba)); + assert(b != BuiltinsCount); + const Type* t = shd_get_builtin_type(arena, b); + if (t != global_variable.type) { + shd_error_print("Creating a @Builtin global variable '%s' with the incorrect type: ", global_variable.name); + shd_log_node(ERROR, global_variable.type); + shd_error_print(" instead of the expected "); + shd_log_node(ERROR, t); + shd_error_print(".\n"); + shd_error_die(); + } + } + + assert(global_variable.address_space < NumAddressSpaces); + + return ptr_type(arena, (PtrType) { + .pointed_type = global_variable.type, + .address_space = global_variable.address_space, + .is_reference = shd_lookup_annotation_list(global_variable.annotations, "Logical"), + }); +} + +const Type* _shd_check_type_constant(IrArena* arena, Constant cnst) { + assert(shd_is_data_type(cnst.type_hint)); + return cnst.type_hint; +} + +#include "type_generated.c" + +#pragma GCC diagnostic pop diff --git a/src/shady/check.h b/src/shady/check.h new file mode 100644 index 000000000..057a23cd3 --- /dev/null +++ b/src/shady/check.h @@ -0,0 +1,8 @@ +#ifndef SHADY_TYPE_H +#define SHADY_TYPE_H + +#include "shady/ir.h" + +#include "type_generated.h" + +#endif diff --git a/src/shady/compile.c b/src/shady/compile.c index 2713aded2..3ce720f97 100644 --- a/src/shady/compile.c +++ b/src/shady/compile.c @@ -1,136 +1,142 @@ +#include "ir_private.h" #include "shady/driver.h" -#include "compile.h" +#include "shady/ir.h" + +#include "passes/passes.h" +#include "analysis/verify.h" -#include "frontends/slim/parser.h" +#include "../frontend/slim/parser.h" #include "shady_scheduler_src.h" #include "transform/internal_constants.h" -#include "portability.h" -#include "ir_private.h" + #include "util.h" +#include "log.h" #include -#define KiB * 1024 -#define MiB * 1024 KiB - -CompilerConfig default_compiler_config() { - return (CompilerConfig) { - .dynamic_scheduling = true, - .per_thread_stack_size = 4 KiB, - - .target_spirv_version = { - .major = 1, - .minor = 4 - }, - - .logging = { - // most of the time, we are not interested in seeing generated & internal code in the debug output - .skip_internal = true, - .skip_generated = true, - }, - - .optimisations = { - .cleanup = { - .after_every_pass = true, - .delete_unused_instructions = true, - } - }, - - .specialization = { - .subgroup_size = 8, - .entry_point = NULL - } +static void add_scheduler_source(const CompilerConfig* config, Module* dst) { + SlimParserConfig pconfig = { + .front_end = true, }; + Module* builtin_scheduler_mod = shd_parse_slim_module(config, &pconfig, shady_scheduler_src, "builtin_scheduler"); + shd_debug_print("Adding builtin scheduler code"); + shd_module_link(dst, builtin_scheduler_mod); + shd_destroy_ir_arena(shd_module_get_arena(builtin_scheduler_mod)); } -ArenaConfig default_arena_config() { - return (ArenaConfig) { - .is_simt = true, - .validate_builtin_types = false, - .allow_subgroup_memory = true, - .allow_shared_memory = true, - - .memory = { - .word_size = IntTy8, - .ptr_size = IntTy64, - }, - - .optimisations = { - .delete_unreachable_structured_cases = true, - }, - }; +#ifdef NDEBUG +#define SHADY_RUN_VERIFY 0 +#else +#define SHADY_RUN_VERIFY 1 +#endif + +void shd_run_pass_impl(const CompilerConfig* config, Module** pmod, IrArena* initial_arena, RewritePass pass, String pass_name) { + Module* old_mod = NULL; + old_mod = *pmod; + *pmod = pass(config, *pmod); + (*pmod)->sealed = true; + shd_debugvv_print("After pass %s: \n", pass_name); + if (SHADY_RUN_VERIFY) + shd_verify_module(config, *pmod); + if (shd_module_get_arena(old_mod) != shd_module_get_arena(*pmod) && shd_module_get_arena(old_mod) != initial_arena) + shd_destroy_ir_arena(shd_module_get_arena(old_mod)); + old_mod = *pmod; + if (config->optimisations.cleanup.after_every_pass) + *pmod = shd_cleanup(config, *pmod); + shd_log_module(DEBUGVV, config, *pmod); + if (SHADY_RUN_VERIFY) + shd_verify_module(config, *pmod); + if (shd_module_get_arena(old_mod) != shd_module_get_arena(*pmod) && shd_module_get_arena(old_mod) != initial_arena) + shd_destroy_ir_arena(shd_module_get_arena(old_mod)); + if (config->hooks.after_pass.fn) + config->hooks.after_pass.fn(config->hooks.after_pass.uptr, pass_name, *pmod); } -CompilationResult run_compiler_passes(CompilerConfig* config, Module** pmod) { - if (config->dynamic_scheduling) { - debugv_print("Parsing builtin scheduler code"); - ParserConfig pconfig = { - .front_end = true, - }; - parse_shady_ir(pconfig, shady_scheduler_src, *pmod); +void shd_apply_opt_impl(const CompilerConfig* config, bool* todo, Module** m, OptPass pass, String pass_name) { + bool changed = pass(config, m); + *todo |= changed; + + if (getenv("SHADY_DUMP_CLEAN_ROUNDS") && changed) { + shd_log_fmt(DEBUGVV, "%s changed something:\n", pass_name); + shd_log_module(DEBUGVV, config, *m); } +} +CompilationResult shd_run_compiler_passes(CompilerConfig* config, Module** pmod) { IrArena* initial_arena = (*pmod)->arena; - Module* old_mod = NULL; - generate_dummy_constants(config, *pmod); - - if (!get_module_arena(*pmod)->config.name_bound) - RUN_PASS(bind_program) - RUN_PASS(normalize) - - RUN_PASS(normalize_builtins); - RUN_PASS(infer_program) + // we don't want to mess with the original module + *pmod = shd_import(config, *pmod); + shd_log_fmt(DEBUG, "After import:\n"); + shd_log_module(DEBUG, config, *pmod); + + if (config->input_cf.has_scope_annotations) { + // RUN_PASS(shd_pass_scope_heuristic) + RUN_PASS(shd_pass_lift_everything) + RUN_PASS(shd_pass_scope2control) + } else if (config->input_cf.restructure_with_heuristics) { + RUN_PASS(shd_pass_remove_critical_edges) + // RUN_PASS(shd_pass_lcssa) + RUN_PASS(shd_pass_lift_everything) + RUN_PASS(shd_pass_reconvergence_heuristics) + } - RUN_PASS(opt_inline_jumps) + if (config->dynamic_scheduling) { + add_scheduler_source(config, *pmod); + } - RUN_PASS(lcssa) - RUN_PASS(reconvergence_heuristics) + RUN_PASS(shd_pass_eliminate_inlineable_constants) - RUN_PASS(lower_cf_instrs) - RUN_PASS(opt_mem2reg) - RUN_PASS(setup_stack_frames) + RUN_PASS(shd_pass_setup_stack_frames) if (!config->hacks.force_join_point_lifting) - RUN_PASS(mark_leaf_functions) + RUN_PASS(shd_pass_mark_leaf_functions) - RUN_PASS(lower_callf) - RUN_PASS(opt_inline) + RUN_PASS(shd_pass_lower_callf) + RUN_PASS(shd_pass_inline) - RUN_PASS(lift_indirect_targets) + RUN_PASS(shd_pass_lift_indirect_targets) - if (config->specialization.execution_model != EmNone) - RUN_PASS(specialize_execution_model) + RUN_PASS(shd_pass_specialize_execution_model) - RUN_PASS(opt_stack) + //RUN_PASS(shd_pass_opt_stack) - RUN_PASS(lower_tailcalls) - RUN_PASS(lower_switch_btree) - RUN_PASS(opt_restructurize) - RUN_PASS(opt_inline_jumps) + RUN_PASS(shd_pass_lower_tailcalls) + //RUN_PASS(shd_pass_lower_switch_btree) + //RUN_PASS(shd_pass_opt_mem2reg) - RUN_PASS(lower_mask) - RUN_PASS(lower_memcpy) - RUN_PASS(lower_subgroup_ops) - RUN_PASS(lower_stack) + if (config->specialization.entry_point) + RUN_PASS(shd_pass_specialize_entry_point) - RUN_PASS(lower_lea) - RUN_PASS(lower_generic_globals) - RUN_PASS(lower_generic_ptrs) - RUN_PASS(lower_physical_ptrs) - RUN_PASS(lower_subgroup_vars) - RUN_PASS(lower_memory_layout) + RUN_PASS(shd_pass_lower_logical_pointers) + + RUN_PASS(shd_pass_lower_mask) + RUN_PASS(shd_pass_lower_subgroup_ops) + if (config->lower.emulate_physical_memory) { + RUN_PASS(shd_pass_lower_alloca) + } + RUN_PASS(shd_pass_lower_stack) + RUN_PASS(shd_pass_lower_memcpy) + RUN_PASS(shd_pass_lower_lea) + RUN_PASS(shd_pass_lower_generic_globals) + if (config->lower.emulate_generic_ptrs) { + RUN_PASS(shd_pass_lower_generic_ptrs) + } + if (config->lower.emulate_physical_memory) { + RUN_PASS(shd_pass_lower_physical_ptrs) + } + RUN_PASS(shd_pass_lower_subgroup_vars) + RUN_PASS(shd_pass_lower_memory_layout) if (config->lower.decay_ptrs) - RUN_PASS(lower_decay_ptrs) + RUN_PASS(shd_pass_lower_decay_ptrs) - RUN_PASS(lower_int) + RUN_PASS(shd_pass_lower_int) - if (config->lower.simt_to_explicit_simd) - RUN_PASS(simt2d) + RUN_PASS(shd_pass_lower_fill) + RUN_PASS(shd_pass_lower_nullptr) + RUN_PASS(shd_pass_normalize_builtins) - if (config->specialization.entry_point) - RUN_PASS(specialize_entry_point) - RUN_PASS(lower_fill) + RUN_PASS(shd_pass_restructurize) return CompilationNoError; } diff --git a/src/shady/compile.h b/src/shady/compile.h deleted file mode 100644 index b6e147b21..000000000 --- a/src/shady/compile.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SHADY_COMPILE_H -#define SHADY_COMPILE_H - -#include "shady/ir.h" -#include "passes/passes.h" -#include "log.h" -#include "analysis/verify.h" - -#ifdef NDEBUG -#define SHADY_RUN_VERIFY 0 -#else -#define SHADY_RUN_VERIFY 1 -#endif - -#define RUN_PASS(pass_name) { \ -old_mod = *pmod; \ -*pmod = pass_name(config, *pmod); \ -(*pmod)->sealed = true; \ -debugvv_print("After "#pass_name" pass: \n"); \ -log_module(DEBUGVV, config, *pmod); \ -if (SHADY_RUN_VERIFY) \ - verify_module(*pmod); \ -if (get_module_arena(old_mod) != get_module_arena(*pmod) && get_module_arena(old_mod) != initial_arena) \ - destroy_ir_arena(get_module_arena(old_mod)); \ -old_mod = *pmod; \ -if (config->optimisations.cleanup.after_every_pass) \ - *pmod = cleanup(config, *pmod); \ -if (SHADY_RUN_VERIFY) \ - verify_module(*pmod); \ -if (get_module_arena(old_mod) != get_module_arena(*pmod) && get_module_arena(old_mod) != initial_arena) \ - destroy_ir_arena(get_module_arena(old_mod)); \ -if (config->hooks.after_pass.fn) \ - config->hooks.after_pass.fn(config->hooks.after_pass.uptr, #pass_name, *pmod); \ -} \ - -#endif diff --git a/src/shady/config.c b/src/shady/config.c new file mode 100644 index 000000000..e67850152 --- /dev/null +++ b/src/shady/config.c @@ -0,0 +1,87 @@ +#include "shady/ir.h" +#include "shady/config.h" + +#define KiB * 1024 +#define MiB * 1024 KiB + +CompilerConfig shd_default_compiler_config(void) { + return (CompilerConfig) { + .dynamic_scheduling = true, + .per_thread_stack_size = 4 KiB, + + .target_spirv_version = { + .major = 1, + .minor = 4 + }, + + .lower = { + .emulate_physical_memory = true, + .emulate_generic_ptrs = true, + }, + + .logging = { + // most of the time, we are not interested in seeing generated & internal code in the debug output + //.print_internal = true, + //.print_generated = true, + .print_builtin = true, + }, + + .optimisations = { + .cleanup = { + .after_every_pass = true, + .delete_unused_instructions = true, + } + }, + + /*.shader_diagnostics = { + .max_top_iterations = 10, + }, + + .printf_trace = { + .god_function = true, + },*/ + + .target = shd_default_target_config(), + + .specialization = { + .subgroup_size = 8, + .entry_point = NULL + } + }; +} + +TargetConfig shd_default_target_config(void) { + return (TargetConfig) { + .memory = { + .word_size = IntTy32, + .ptr_size = IntTy64, + }, + }; +} + +ArenaConfig shd_default_arena_config(const TargetConfig* target) { + ArenaConfig config = { + .is_simt = true, + .name_bound = true, + .allow_fold = true, + .check_types = true, + .validate_builtin_types = true, + .check_op_classes = true, + + .optimisations = { + .inline_single_use_bbs = true, + .fold_static_control_flow = true, + .delete_unreachable_structured_cases = true, + }, + + .memory = target->memory + }; + + for (size_t i = 0; i < NumAddressSpaces; i++) { + // by default, all address spaces are physical ! + config.address_spaces[i].physical = true; + config.address_spaces[i].allowed = true; + } + + return config; +} diff --git a/src/shady/constructors.c b/src/shady/constructors.c deleted file mode 100644 index af9466740..000000000 --- a/src/shady/constructors.c +++ /dev/null @@ -1,414 +0,0 @@ -#include "ir_private.h" -#include "type.h" -#include "log.h" -#include "fold.h" -#include "portability.h" - -#include "dict.h" -#include "visit.h" - -#include -#include - -Strings import_strings(IrArena*, Strings); -bool compare_nodes(Nodes* a, Nodes* b); - -typedef struct { Visitor visitor; const Node* parent; } VisitorPCV; - -static void post_construction_validation_visit_op(VisitorPCV* v, NodeClass class, SHADY_UNUSED String op_name, const Node* node) { - if (class == NcCase) - ((Node*) node)->payload.case_.structured_construct = v->parent; -} - -static void post_construction_validation(IrArena* arena, Node* node) { - VisitorPCV v = { - .visitor = { - .visit_op_fn = (VisitOpFn) post_construction_validation_visit_op - }, - .parent = node, - }; - visit_node_operands(&v.visitor, 0, node); -} - -static void pre_construction_validation(IrArena* arena, Node* node); - -static Node* create_node_helper(IrArena* arena, Node node, bool* pfresh) { - pre_construction_validation(arena, &node); - - if (pfresh) - *pfresh = false; - - Node* ptr = &node; - Node** found = find_key_dict(Node*, arena->node_set, ptr); - // sanity check nominal nodes to be unique, check for duplicates in structural nodes - if (is_nominal(&node)) - assert(!found); - else if (found) - return *found; - - if (pfresh) - *pfresh = true; - - if (arena->config.allow_fold) { - Node* folded = (Node*) fold_node(arena, ptr); - if (folded != ptr) { - // The folding process simplified the node, we store a mapping to that simplified node and bail out ! - insert_set_get_result(Node*, arena->node_set, folded); - post_construction_validation(arena, folded); - return folded; - } - } - - if (arena->config.check_types && node.type) - assert(is_type(node.type)); - - // place the node in the arena and return it - Node* alloc = (Node*) arena_alloc(arena->arena, sizeof(Node)); - *alloc = node; - insert_set_get_result(const Node*, arena->node_set, alloc); - - post_construction_validation(arena, alloc); - return alloc; -} - -#include "constructors_generated.c" - -const Node* let(IrArena* arena, const Node* instruction, const Node* tail) { - Let payload = { - .instruction = instruction, - .tail = tail, - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = arena->config.check_types ? check_type_let(arena, payload) : NULL, - .tag = Let_TAG, - .payload.let = payload - }; - return create_node_helper(arena, node, NULL); -} - -Node* var(IrArena* arena, const Type* type, const char* name) { - Variable variable = { - .type = type, - .name = string(arena, name), - .id = fresh_id(arena) - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = arena->config.check_types ? check_type_var(arena, variable) : NULL, - .tag = Variable_TAG, - .payload.var = variable - }; - return create_node_helper(arena, node, NULL); -} - -const Node* let_mut(IrArena* arena, const Node* instruction, const Node* tail) { - LetMut payload = { - .instruction = instruction, - .tail = tail, - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = NULL, - .tag = LetMut_TAG, - .payload.let_mut = payload - }; - return create_node_helper(arena, node, NULL); -} - -const Node* composite_helper(IrArena* a, const Type* t, Nodes contents) { - return composite(a, (Composite) { .type = t, .contents = contents }); -} - -const Node* tuple_helper(IrArena* a, Nodes contents) { - const Type* t = NULL; - if (a->config.check_types) { - // infer the type of the tuple - Nodes member_types = get_values_types(a, contents); - t = record_type(a, (RecordType) {.members = strip_qualifiers(a, member_types)}); - } - - return composite_helper(a, t, contents); -} - -const Node* fn_addr_helper(IrArena* a, const Node* fn) { - return fn_addr(a, (FnAddr) { .fn = fn }); -} - -const Node* ref_decl_helper(IrArena* a, const Node* decl) { - return ref_decl(a, (RefDecl) { .decl = decl }); -} - -const Node* type_decl_ref_helper(IrArena* a, const Node* decl) { - return type_decl_ref(a, (TypeDeclRef) { .decl = decl }); -} - -Node* function(Module* mod, Nodes params, const char* name, Nodes annotations, Nodes return_types) { - assert(!mod->sealed); - IrArena* arena = mod->arena; - Function payload = { - .module = mod, - .params = params, - .body = NULL, - .name = name, - .annotations = annotations, - .return_types = return_types, - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = arena->config.check_types ? check_type_fun(arena, payload) : NULL, - .tag = Function_TAG, - .payload.fun = payload - }; - Node* fn = create_node_helper(arena, node, NULL); - register_decl_module(mod, fn); - - for (size_t i = 0; i < params.count; i++) { - Node* param = (Node*) params.nodes[i]; - assert(param->tag == Variable_TAG); - assert(!param->payload.var.abs); - param->payload.var.abs = fn; - param->payload.var.pindex = i; - } - - return fn; -} - -Node* basic_block(IrArena* arena, Node* fn, Nodes params, const char* name) { - assert(!fn->payload.fun.module->sealed); - BasicBlock payload = { - .params = params, - .body = NULL, - .fn = fn, - .name = name, - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = arena->config.check_types ? check_type_basic_block(arena, payload) : NULL, - .tag = BasicBlock_TAG, - .payload.basic_block = payload - }; - - Node* bb = create_node_helper(arena, node, NULL); - - for (size_t i = 0; i < params.count; i++) { - Node* param = (Node*) params.nodes[i]; - assert(param->tag == Variable_TAG); - assert(!param->payload.var.abs); - param->payload.var.abs = bb; - param->payload.var.pindex = i; - } - - return bb; -} - -const Node* case_(IrArena* a, Nodes params, const Node* body) { - Case payload = { - .params = params, - .body = body, - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = a, - .type = a->config.check_types ? check_type_case_(a, payload) : NULL, - .tag = Case_TAG, - .payload.case_ = payload - }; - - bool fresh; - const Node* lam = create_node_helper(a, node, &fresh); - - if (fresh || true) { - for (size_t i = 0; i < params.count; i++) { - Node* param = (Node*) params.nodes[i]; - assert(param->tag == Variable_TAG); - assert(!param->payload.var.abs); - param->payload.var.abs = lam; - param->payload.var.pindex = i; - } - } - - return lam; -} - -Node* constant(Module* mod, Nodes annotations, const Type* hint, String name) { - IrArena* arena = mod->arena; - Constant cnst = { - .annotations = annotations, - .name = string(arena, name), - .type_hint = hint, - .instruction = NULL, - }; - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = arena->config.check_types ? check_type_constant(arena, cnst) : NULL, - .tag = Constant_TAG, - .payload.constant = cnst - }; - Node* decl = create_node_helper(arena, node, NULL); - register_decl_module(mod, decl); - return decl; -} - -Node* global_var(Module* mod, Nodes annotations, const Type* type, const char* name, AddressSpace as) { - const Node* existing = get_declaration(mod, name); - if (existing) { - assert(existing->tag == GlobalVariable_TAG); - assert(existing->payload.global_variable.type == type); - assert(existing->payload.global_variable.address_space == as); - assert(!mod->arena->config.check_types || compare_nodes((Nodes*) &existing->payload.global_variable.annotations, &annotations)); - return (Node*) existing; - } - - IrArena* arena = mod->arena; - GlobalVariable gvar = { - .annotations = annotations, - .name = string(arena, name), - .type = type, - .address_space = as, - .init = NULL, - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = arena->config.check_types ? check_type_global_variable(arena, gvar) : NULL, - .tag = GlobalVariable_TAG, - .payload.global_variable = gvar - }; - Node* decl = create_node_helper(arena, node, NULL); - register_decl_module(mod, decl); - return decl; -} - -Type* nominal_type(Module* mod, Nodes annotations, String name) { - IrArena* arena = mod->arena; - NominalType payload = { - .name = string(arena, name), - .module = mod, - .annotations = annotations, - .body = NULL, - }; - - Node node; - memset((void*) &node, 0, sizeof(Node)); - node = (Node) { - .arena = arena, - .type = NULL, - .tag = NominalType_TAG, - .payload.nom_type = payload - }; - Node* decl = create_node_helper(arena, node, NULL); - register_decl_module(mod, decl); - return decl; -} - -const Node* quote_helper(IrArena* a, Nodes values) { - for (size_t i = 0; i < values.count; i++) - assert(is_value(values.nodes[i])); - - return prim_op(a, (PrimOp) { - .op = quote_op, - .type_arguments = nodes(a, 0, NULL), - .operands = values - }); -} - -const Node* prim_op_helper(IrArena* a, Op op, Nodes types, Nodes operands) { - return prim_op(a, (PrimOp) { - .op = op, - .type_arguments = types, - .operands = operands - }); -} - -const Node* jump_helper(IrArena* a, const Node* dst, Nodes args) { - return jump(a, (Jump) { - .target = dst, - .args = args, - }); -} - -const Node* unit_type(IrArena* arena) { - return record_type(arena, (RecordType) { - .members = empty(arena), - }); -} - -const Node* empty_multiple_return_type(IrArena* arena) { - return record_type(arena, (RecordType) { - .members = empty(arena), - .special = MultipleReturn, - }); -} - -const Node* annotation_value_helper(IrArena* a, String n, const Node* v) { - return annotation_value(a, (AnnotationValue) { .name = n, .value = v}); -} - -const Node* string_lit_helper(IrArena* a, String s) { - return string_lit(a, (StringLiteral) { .string = s }); -} - -const Type* int_type_helper(IrArena* a, bool s, IntSizes w) { return int_type(a, (Int) { .width = w, .is_signed = s }); } - -const Type* int8_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy8 , .is_signed = true }); } -const Type* int16_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy16, .is_signed = true }); } -const Type* int32_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy32, .is_signed = true }); } -const Type* int64_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy64, .is_signed = true }); } - -const Type* uint8_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy8 , .is_signed = false }); } -const Type* uint16_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy16, .is_signed = false }); } -const Type* uint32_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy32, .is_signed = false }); } -const Type* uint64_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy64, .is_signed = false }); } - -const Type* int8_literal (IrArena* arena, int8_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy8, .value = (uint64_t) (uint8_t) i, .is_signed = true }); } -const Type* int16_literal(IrArena* arena, int16_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy16, .value = (uint64_t) (uint16_t) i, .is_signed = true }); } -const Type* int32_literal(IrArena* arena, int32_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy32, .value = (uint64_t) (uint32_t) i, .is_signed = true }); } -const Type* int64_literal(IrArena* arena, int64_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy64, .value = (uint64_t) i, .is_signed = true }); } - -const Type* uint8_literal (IrArena* arena, uint8_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy8, .value = (int64_t) i }); } -const Type* uint16_literal(IrArena* arena, uint16_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy16, .value = (int64_t) i }); } -const Type* uint32_literal(IrArena* arena, uint32_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy32, .value = (int64_t) i }); } -const Type* uint64_literal(IrArena* arena, uint64_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy64, .value = i }); } - -const Type* fp16_type(IrArena* arena) { return float_type(arena, (Float) { .width = FloatTy16 }); } -const Type* fp32_type(IrArena* arena) { return float_type(arena, (Float) { .width = FloatTy32 }); } -const Type* fp64_type(IrArena* arena) { return float_type(arena, (Float) { .width = FloatTy64 }); } - -const Node* fp_literal_helper(IrArena* a, FloatSizes size, double value) { - switch (size) { - case FloatTy16: assert(false); break; - case FloatTy32: { - float f = value; - uint64_t bits = 0; - memcpy(&bits, &f, sizeof(f)); - return float_literal(a, (FloatLiteral) { .width = size, .value = bits }); - } - case FloatTy64: { - uint64_t bits = 0; - memcpy(&bits, &value, sizeof(value)); - return float_literal(a, (FloatLiteral) { .width = size, .value = bits }); - } - } -} diff --git a/src/shady/emit/CMakeLists.txt b/src/shady/emit/CMakeLists.txt deleted file mode 100644 index 66d85f8e0..000000000 --- a/src/shady/emit/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_subdirectory(spirv) -add_subdirectory(c) - -target_link_libraries(shady PRIVATE "$") -target_link_libraries(shady PRIVATE "$") diff --git a/src/shady/emit/c/CMakeLists.txt b/src/shady/emit/c/CMakeLists.txt deleted file mode 100644 index 21dd62a62..000000000 --- a/src/shady/emit/c/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_library(shady_c OBJECT - emit_c.c - emit_c_instructions.c - emit_c_signatures.c - emit_c_builtins.c -) -set_property(TARGET shady_c PROPERTY POSITION_INDEPENDENT_CODE ON) - -target_link_libraries(shady_c PUBLIC "api") -target_link_libraries(shady_c PRIVATE "$") -target_link_libraries(shady_c PUBLIC "$") diff --git a/src/shady/emit/c/emit_c.c b/src/shady/emit/c/emit_c.c deleted file mode 100644 index 03c0952ac..000000000 --- a/src/shady/emit/c/emit_c.c +++ /dev/null @@ -1,776 +0,0 @@ -#include "emit_c.h" - -#include "portability.h" -#include "dict.h" -#include "log.h" -#include "util.h" - -#include "../../type.h" -#include "../../ir_private.h" -#include "../../compile.h" - -#include "../../transform/ir_gen_helpers.h" - -#include -#include -#include -#include - -#pragma GCC diagnostic error "-Wswitch" - -static void emit_terminator(Emitter* emitter, Printer* block_printer, const Node* terminator); - -CValue to_cvalue(SHADY_UNUSED Emitter* e, CTerm term) { - if (term.value) - return term.value; - if (term.var) - return format_string_arena(e->arena->arena, "(&%s)", term.var); - assert(false); -} - -CAddr deref_term(Emitter* e, CTerm term) { - if (term.value) - return format_string_arena(e->arena->arena, "(*%s)", term.value); - if (term.var) - return term.var; - assert(false); -} - -// TODO: utf8 -static bool is_legal_c_identifier_char(char c) { - if (c >= '0' && c <= '9') - return true; - if (c >= 'a' && c <= 'z') - return true; - if (c >= 'A' && c <= 'Z') - return true; - if (c == '_') - return true; - return false; -} - -String legalize_c_identifier(Emitter* e, String src) { - size_t len = strlen(src); - LARRAY(char, dst, len + 1); - size_t i; - for (i = 0; i < len; i++) { - char c = src[i]; - if (is_legal_c_identifier_char(c)) - dst[i] = c; - else - dst[i] = '_'; - } - dst[i] = '\0'; - // TODO: collision handling using a dict - return string(e->arena, dst); -} - -#include - -static enum { ObjectsList, StringLit, CharsLit } array_insides_helper(Emitter* e, Printer* block_printer, Printer* p, Growy* g, const Node* t, Nodes c) { - if (t->tag == Int_TAG && t->payload.int_type.width == 8) { - uint8_t* tmp = malloc(sizeof(uint8_t) * c.count); - bool ends_zero = false; - for (size_t i = 0; i < c.count; i++) { - tmp[i] = get_int_literal_value(*resolve_to_int_literal(c.nodes[i]), false); - if (tmp[i] == 0) { - if (i == c.count - 1) - ends_zero = true; - } - } - bool is_stringy = ends_zero; - for (size_t i = 0; i < c.count; i++) { - // ignore the last char in a string - if (is_stringy && i == c.count - 1) - break; - if (isprint(tmp[i])) - print(p, "%c", tmp[i]); - else - print(p, "\\x%02x", tmp[i]); - } - free(tmp); - return is_stringy ? StringLit : CharsLit; - } else { - for (size_t i = 0; i < c.count; i++) { - print(p, to_cvalue(e, emit_value(e, block_printer, c.nodes[i]))); - if (i + 1 < c.count) - print(p, ", "); - } - growy_append_bytes(g, 1, "\0"); - return ObjectsList; - } -} - -static bool has_forward_declarations(CDialect dialect) { - switch (dialect) { - case C: return true; - case GLSL: // no global variable forward declarations in GLSL - case ISPC: // ISPC seems to share this quirk - return false; - } -} - -static void emit_global_variable_definition(Emitter* emitter, String prefix, String decl_center, const Type* type, bool uniform, bool constant, String init) { - // GLSL wants 'const' to go on the left to start the declaration, but in C const should go on the right (east const convention) - switch (emitter->config.dialect) { - // ISPC defaults to varying, even for constants... yuck - case ISPC: - if (uniform) - decl_center = format_string_arena(emitter->arena->arena, "uniform %s", decl_center); - else - decl_center = format_string_arena(emitter->arena->arena, "varying %s", decl_center); - break; - case C: - if (constant) - decl_center = format_string_arena(emitter->arena->arena, "const %s", decl_center); - break; - case GLSL: - if (constant) - prefix = format_string_arena(emitter->arena->arena, "%s %s", "const", prefix); - break; - } - - if (init) - print(emitter->fn_defs, "\n%s%s = %s;", prefix, emit_type(emitter, type, decl_center), init); - else - print(emitter->fn_defs, "\n%s%s;", prefix, emit_type(emitter, type, decl_center)); - - if (!has_forward_declarations(emitter->config.dialect) || !init) - return; - - String declaration = emit_type(emitter, type, decl_center); - print(emitter->fn_decls, "\n%s;", declaration); -} - -CTerm emit_value(Emitter* emitter, Printer* block_printer, const Node* value) { - CTerm* found = lookup_existing_term(emitter, value); - if (found) return *found; - - String emitted = NULL; - - switch (is_value(value)) { - case NotAValue: assert(false); - case Value_ConstrainedValue_TAG: - case Value_UntypedNumber_TAG: error("lower me"); - case Value_Variable_TAG: error("variables need to be emitted beforehand"); - case Value_IntLiteral_TAG: { - if (value->payload.int_literal.is_signed) - emitted = format_string_arena(emitter->arena->arena, "%" PRIi64, value->payload.int_literal.value); - else - emitted = format_string_arena(emitter->arena->arena, "%" PRIu64, value->payload.int_literal.value); - - bool is_long = value->payload.int_literal.width == IntTy64; - bool is_signed = value->payload.int_literal.is_signed; - if (emitter->config.dialect == GLSL) { - if (!is_signed) - emitted = format_string_arena(emitter->arena->arena, "%sU", emitted); - if (is_long) - emitted = format_string_arena(emitter->arena->arena, "%sL", emitted); - } - - break; - } - case Value_FloatLiteral_TAG: { - uint64_t v = value->payload.float_literal.value; - switch (value->payload.float_literal.width) { - case FloatTy16: - assert(false); - case FloatTy32: { - float f; - memcpy(&f, &v, sizeof(uint32_t)); - double d = (double) f; - emitted = format_string_arena(emitter->arena->arena, "%.9g", d); break; - } - case FloatTy64: { - double d; - memcpy(&d, &v, sizeof(uint64_t)); - emitted = format_string_arena(emitter->arena->arena, "%.17g", d); break; - } - } - break; - } - case Value_True_TAG: return term_from_cvalue("true"); - case Value_False_TAG: return term_from_cvalue("false"); - case Value_Undef_TAG: { - if (emitter->config.dialect == GLSL) - return emit_value(emitter, block_printer, get_default_zero_value(emitter->arena, value->payload.undef.type)); - String name = unique_name(emitter->arena, "undef"); - emit_global_variable_definition(emitter, "", name, value->payload.undef.type, true, true, NULL); - emitted = name; - break; - } - case Value_NullPtr_TAG: return term_from_cvalue("NULL"); - case Value_Composite_TAG: { - const Type* type = value->payload.composite.type; - Nodes elements = value->payload.composite.contents; - - Growy* g = new_growy(); - Printer* p = open_growy_as_printer(g); - - if (type->tag == ArrType_TAG) { - switch (array_insides_helper(emitter, block_printer, p, g, type, elements)) { - case ObjectsList: - emitted = growy_data(g); - break; - case StringLit: - emitted = format_string_arena(emitter->arena->arena, "\"%s\"", growy_data(g)); - break; - case CharsLit: - emitted = format_string_arena(emitter->arena->arena, "'%s'", growy_data(g)); - break; - } - } else { - for (size_t i = 0; i < elements.count; i++) { - print(p, "%s", to_cvalue(emitter, emit_value(emitter, block_printer, elements.nodes[i]))); - if (i + 1 < elements.count) - print(p, ", "); - } - emitted = growy_data(g); - } - growy_append_bytes(g, 1, "\0"); - - switch (emitter->config.dialect) { - no_compound_literals: - case ISPC: { - // arrays need double the brackets - if (type->tag == ArrType_TAG) - emitted = format_string_arena(emitter->arena->arena, "{ %s }", emitted); - - if (block_printer) { - String tmp = unique_name(emitter->arena, "composite"); - print(block_printer, "\n%s = { %s };", emit_type(emitter, value->type, tmp), emitted); - emitted = tmp; - } else { - // this requires us to end up in the initialisation side of a declaration - emitted = format_string_arena(emitter->arena->arena, "{ %s }", emitted); - } - break; - } - case C: - // If we're C89 (ew) - if (!emitter->config.allow_compound_literals) - goto no_compound_literals; - emitted = format_string_arena(emitter->arena->arena, "((%s) { %s })", emit_type(emitter, value->type, NULL), emitted); - break; - case GLSL: - if (type->tag != PackType_TAG) - goto no_compound_literals; - // GLSL doesn't have compound literals, but it does have constructor syntax for vectors - emitted = format_string_arena(emitter->arena->arena, "%s(%s)", emit_type(emitter, value->type, NULL), emitted); - break; - } - - destroy_growy(g); - destroy_printer(p); - break; - } - case Value_Fill_TAG: error("lower me") - case Value_StringLiteral_TAG: { - Growy* g = new_growy(); - Printer* p = open_growy_as_printer(g); - - String str = value->payload.string_lit.string; - size_t len = strlen(str); - for (size_t i = 0; i < len; i++) { - char c = str[i]; - switch (c) { - case '\n': print(p, "\\n"); - break; - default: - growy_append_bytes(g, 1, &c); - } - } - growy_append_bytes(g, 1, "\0"); - - emitted = format_string_arena(emitter->arena->arena, "\"%s\"", growy_data(g)); - destroy_growy(g); - destroy_printer(p); - break; - } - case Value_FnAddr_TAG: { - emitted = legalize_c_identifier(emitter, get_decl_name(value->payload.fn_addr.fn)); - emitted = format_string_arena(emitter->arena->arena, "(&%s)", emitted); - break; - } - case Value_RefDecl_TAG: { - const Node* decl = value->payload.ref_decl.decl; - emit_decl(emitter, decl); - - if (emitter->config.dialect == ISPC && decl->tag == GlobalVariable_TAG) { - if (!is_addr_space_uniform(emitter->arena, decl->payload.global_variable.address_space) && !is_decl_builtin(decl)) { - assert(block_printer && "ISPC backend cannot statically refer to a varying variable"); - return ispc_varying_ptr_helper(emitter, block_printer, decl->type, *lookup_existing_term(emitter, decl)); - } - } - - return *lookup_existing_term(emitter, decl); - } - } - - assert(emitted); - return term_from_cvalue(emitted); -} - -/// hack for ISPC: there is no nice way to get a set of varying pointers (instead of a "pointer to a varying") pointing to a varying global -CTerm ispc_varying_ptr_helper(Emitter* emitter, Printer* block_printer, const Type* ptr_type, CTerm term) { - String interm = unique_name(emitter->arena, "intermediary_ptr_value"); - const Type* ut = qualified_type_helper(ptr_type, true); - const Type* vt = qualified_type_helper(ptr_type, false); - String lhs = emit_type(emitter, vt, interm); - print(block_printer, "\n%s = ((%s) %s) + programIndex;", lhs, emit_type(emitter, ut, NULL), to_cvalue(emitter, term)); - return term_from_cvalue(interm); -} - -void emit_variable_declaration(Emitter* emitter, Printer* block_printer, const Type* t, String variable_name, bool mut, const CTerm* initializer) { - assert((mut || initializer != NULL) && "unbound results are only allowed when creating a mutable local variable"); - - String prefix = ""; - String center = variable_name; - - // add extra qualifiers if immutable - if (!mut) switch (emitter->config.dialect) { - case ISPC: - center = format_string_arena(emitter->arena->arena, "const %s", center); - break; - case C: - prefix = "register "; - center = format_string_arena(emitter->arena->arena, "const %s", center); - break; - case GLSL: - prefix = "const "; - break; - } - - String decl = c_emit_type(emitter, t, center); - if (initializer) - print(block_printer, "\n%s%s = %s;", prefix, decl, to_cvalue(emitter, *initializer)); - else - print(block_printer, "\n%s%s;", prefix, decl); -} - -static void emit_terminator(Emitter* emitter, Printer* block_printer, const Node* terminator) { - switch (is_terminator(terminator)) { - case NotATerminator: assert(false); - case LetMut_TAG: - case Join_TAG: error("this must be lowered away!"); - case Jump_TAG: - case Branch_TAG: - case Switch_TAG: - case TailCall_TAG: error("TODO"); - case Let_TAG: { - const Node* instruction = get_let_instruction(terminator); - - // we declare N local variables in order to store the result of the instruction - Nodes yield_types = unwrap_multiple_yield_types(emitter->arena, instruction->type); - - LARRAY(CTerm, results, yield_types.count); - LARRAY(InstrResultBinding, bindings, yield_types.count); - InstructionOutputs ioutputs = { - .count = yield_types.count, - .results = results, - .binding = bindings, - }; - emit_instruction(emitter, block_printer, instruction, ioutputs); - - const Node* tail = get_let_tail(terminator); - assert(tail->tag == Case_TAG); - - const Nodes tail_params = tail->payload.case_.params; - assert(tail_params.count == yield_types.count); - for (size_t i = 0; i < yield_types.count; i++) { - bool has_result = results[i].value || results[i].var; - switch (bindings[i]) { - case NoBinding: { - assert(has_result && "unbound results can't be empty"); - register_emitted(emitter, tail_params.nodes[i], results[i]); - break; - } - case LetBinding: { - String variable_name = get_value_name(tail_params.nodes[i]); - - String bind_to; - if (variable_name) - bind_to = format_string_arena(emitter->arena->arena, "%s_%d", legalize_c_identifier(emitter, variable_name), fresh_id(emitter->arena)); - else - bind_to = format_string_arena(emitter->arena->arena, "v%d", fresh_id(emitter->arena)); - - const Type* t = yield_types.nodes[i]; - - if (has_result) - emit_variable_declaration(emitter, block_printer, t, bind_to, false, &results[i]); - else - emit_variable_declaration(emitter, block_printer, t, bind_to, false, NULL); - - register_emitted(emitter, tail_params.nodes[i], term_from_cvalue(bind_to)); - break; - } - default: assert(false); - } - } - emit_terminator(emitter, block_printer, tail->payload.case_.body); - - break; - } - case Terminator_Return_TAG: { - Nodes args = terminator->payload.fn_ret.args; - if (args.count == 0) { - print(block_printer, "\nreturn;"); - } else if (args.count == 1) { - print(block_printer, "\nreturn %s;", to_cvalue(emitter, emit_value(emitter, block_printer, args.nodes[0]))); - } else { - String packed = unique_name(emitter->arena, "pack_return"); - LARRAY(CValue, values, args.count); - for (size_t i = 0; i < args.count; i++) - values[i] = to_cvalue(emitter, emit_value(emitter, block_printer, args.nodes[i])); - emit_pack_code(block_printer, strings(emitter->arena, args.count, values), packed); - print(block_printer, "\nreturn %s;", packed); - } - break; - } - case Yield_TAG: { - Nodes args = terminator->payload.yield.args; - Phis phis = emitter->phis.selection; - assert(phis.count == args.count); - for (size_t i = 0; i < phis.count; i++) - print(block_printer, "\n%s = %s;", phis.strings[i], to_cvalue(emitter, emit_value(emitter, block_printer, args.nodes[i]))); - - break; - } - case MergeContinue_TAG: { - Nodes args = terminator->payload.merge_continue.args; - Phis phis = emitter->phis.loop_continue; - assert(phis.count == args.count); - for (size_t i = 0; i < phis.count; i++) - print(block_printer, "\n%s = %s;", phis.strings[i], to_cvalue(emitter, emit_value(emitter, block_printer, args.nodes[i]))); - print(block_printer, "\ncontinue;"); - break; - } - case MergeBreak_TAG: { - Nodes args = terminator->payload.merge_break.args; - Phis phis = emitter->phis.loop_break; - assert(phis.count == args.count); - for (size_t i = 0; i < phis.count; i++) - print(block_printer, "\n%s = %s;", phis.strings[i], to_cvalue(emitter, emit_value(emitter, block_printer, args.nodes[i]))); - print(block_printer, "\nbreak;"); - break; - } - case Terminator_Unreachable_TAG: { - switch (emitter->config.dialect) { - case C: - print(block_printer, "\n__builtin_unreachable();"); - break; - case ISPC: - print(block_printer, "\nassert(false);"); - break; - case GLSL: - print(block_printer, "\n//unreachable"); - break; - } - break; - } - } -} - -void emit_lambda_body_at(Emitter* emitter, Printer* p, const Node* body, const Nodes* bbs) { - assert(is_terminator(body)); - //print(p, "{"); - indent(p); - - emit_terminator(emitter, p, body); - - if (bbs && bbs->count > 0) { - assert(emitter->config.dialect != GLSL); - error("TODO"); - } - - deindent(p); - print(p, "\n"); -} - -String emit_lambda_body(Emitter* emitter, const Node* body, const Nodes* bbs) { - Growy* g = new_growy(); - Printer* p = open_growy_as_printer(g); - emit_lambda_body_at(emitter, p, body, bbs); - growy_append_bytes(g, 1, (char[]) { 0 }); - return printer_growy_unwrap(p); -} - -void emit_decl(Emitter* emitter, const Node* decl) { - assert(is_declaration(decl)); - - CTerm* found = lookup_existing_term(emitter, decl); - if (found) return; - - CType* found2 = lookup_existing_type(emitter, decl); - if (found2) return; - - const char* name = legalize_c_identifier(emitter, get_decl_name(decl)); - const Type* decl_type = decl->type; - const char* decl_center = name; - CTerm emit_as; - - switch (decl->tag) { - case GlobalVariable_TAG: { - String init = NULL; - if (decl->payload.global_variable.init) - init = to_cvalue(emitter, emit_value(emitter, NULL, decl->payload.global_variable.init)); - - const GlobalVariable* gvar = &decl->payload.global_variable; - if (is_decl_builtin(decl)) { - Builtin b = get_decl_builtin(decl); - register_emitted(emitter, decl, emit_c_builtin(emitter, b)); - return; - } - - decl_type = decl->payload.global_variable.type; - // we emit the global variable as a CVar, so we can refer to it's 'address' without explicit ptrs - emit_as = term_from_cvar(name); - - bool uniform = is_addr_space_uniform(emitter->arena, decl->payload.global_variable.address_space); - - String address_space_prefix = NULL; - switch (decl->payload.global_variable.address_space) { - case AsGeneric: - break; - case AsSubgroupLogical: - case AsSubgroupPhysical: - switch (emitter->config.dialect) { - case C: - case GLSL: - warn_print("C and GLSL do not have a 'subgroup' level addressing space, using shared instead"); - address_space_prefix = "shared "; - break; - case ISPC: - address_space_prefix = ""; - break; - } - break; - case AsPrivatePhysical: - case AsPrivateLogical: - address_space_prefix = ""; - case AsGlobalLogical: - case AsGlobalPhysical: - address_space_prefix = ""; - break; - case AsSharedPhysical: - case AsSharedLogical: - switch (emitter->config.dialect) { - case C: - break; - case GLSL: - address_space_prefix = "shared "; - break; - case ISPC: - // ISPC doesn't really know what "shared" is - break; - } - break; - case AsExternal: - address_space_prefix = "extern "; - break; - case AsInput: - case AsUInput: - address_space_prefix = "in "; - break; - case AsOutput: - address_space_prefix = "out "; - break; - case AsUniform: - case AsImage: - case AsUniformConstant: - case AsShaderStorageBufferObject: - case AsFunctionLogical: - case AsPushConstant: - break; // error("These only make sense for SPIR-V !") - default: error("Unhandled address space"); - } - - if (!address_space_prefix) { - warn_print("No known address space prefix for as %d, this might produce broken code\n", decl->payload.global_variable.address_space); - address_space_prefix = ""; - } - - register_emitted(emitter, decl, emit_as); - - emit_global_variable_definition(emitter, address_space_prefix, decl_center, decl_type, uniform, false, init); - return; - } - case Function_TAG: { - emit_as = term_from_cvalue(name); - register_emitted(emitter, decl, emit_as); - String head = emit_fn_head(emitter, decl->type, name, decl); - const Node* body = decl->payload.fun.body; - if (body) { - for (size_t i = 0; i < decl->payload.fun.params.count; i++) { - String param_name; - String variable_name = get_value_name(decl->payload.fun.params.nodes[i]); - if (variable_name) - param_name = format_string_arena(emitter->arena->arena, "%s_%d", legalize_c_identifier(emitter, variable_name), decl->payload.fun.params.nodes[i]->payload.var.id); - else - param_name = format_string_arena(emitter->arena->arena, "p%d", decl->payload.fun.params.nodes[i]->payload.var.id); - register_emitted(emitter, decl->payload.fun.params.nodes[i], term_from_cvalue(param_name)); - } - - String fn_body = emit_lambda_body(emitter, body, NULL); - String free_me = fn_body; - if (emitter->config.dialect == ISPC) { - // ISPC hack: This compiler (like seemingly all LLVM-based compilers) has broken handling of the execution mask - it fails to generated masked stores for the entry BB of a function that may be called non-uniformingly - // therefore we must tell ISPC to please, pretty please, mask everything by branching on what the mask should be - fn_body = format_string_arena(emitter->arena->arena, "if ((lanemask() >> programIndex) & 1u) { %s}", fn_body); - // I hate everything about this too. - } - print(emitter->fn_defs, "\n%s { %s }", head, fn_body); - free_tmp_str(free_me); - } - - print(emitter->fn_decls, "\n%s;", head); - return; - } - case Constant_TAG: { - emit_as = term_from_cvalue(name); - register_emitted(emitter, decl, emit_as); - - const Node* init_value = get_quoted_value(decl->payload.constant.instruction); - assert(init_value && "TODO: support some measure of constant expressions"); - String init = to_cvalue(emitter, emit_value(emitter, NULL, init_value)); - emit_global_variable_definition(emitter, "", decl_center, decl->type, true, true, init); - return; - } - case NominalType_TAG: { - CType emitted = name; - register_emitted_type(emitter, decl, emitted); - switch (emitter->config.dialect) { - case ISPC: - case C: print(emitter->type_decls, "\ntypedef %s;", emit_type(emitter, decl->payload.nom_type.body, emitted)); break; - case GLSL: emit_nominal_type_body(emitter, format_string_arena(emitter->arena->arena, "struct %s /* nominal */", emitted), decl->payload.nom_type.body); break; - } - return; - } - default: error("not a decl"); - } -} - -void register_emitted(Emitter* emitter, const Node* node, CTerm as) { - assert(as.value || as.var); - insert_dict(const Node*, CTerm, emitter->emitted_terms, node, as); -} - -void register_emitted_type(Emitter* emitter, const Node* node, String as) { - insert_dict(const Node*, String, emitter->emitted_types, node, as); -} - -CTerm* lookup_existing_term(Emitter* emitter, const Node* node) { - CTerm* found = find_value_dict(const Node*, CTerm, emitter->emitted_terms, node); - return found; -} - -CType* lookup_existing_type(Emitter* emitter, const Type* node) { - CType* found = find_value_dict(const Node*, CType, emitter->emitted_types, node); - return found; -} - -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); - -static Module* run_backend_specific_passes(CompilerConfig* config, CEmitterConfig* econfig, Module* initial_mod) { - IrArena* initial_arena = initial_mod->arena; - Module* old_mod = NULL; - Module** pmod = &initial_mod; - - if (econfig->dialect == ISPC) { - RUN_PASS(lower_workgroups) - } - if (econfig->dialect != GLSL) { - RUN_PASS(lower_vec_arr) - } - if (config->lower.simt_to_explicit_simd) { - RUN_PASS(simt2d) - } - // C lacks a nice way to express constants that can be used in type definitions afterwards, so let's just inline them all. - RUN_PASS(eliminate_constants) - return *pmod; -} - -void emit_c(CompilerConfig compiler_config, CEmitterConfig config, Module* mod, size_t* output_size, char** output, Module** new_mod) { - IrArena* initial_arena = get_module_arena(mod); - mod = run_backend_specific_passes(&compiler_config, &config, mod); - IrArena* arena = get_module_arena(mod); - - Growy* type_decls_g = new_growy(); - Growy* fn_decls_g = new_growy(); - Growy* fn_defs_g = new_growy(); - - Emitter emitter = { - .config = config, - .arena = arena, - .type_decls = open_growy_as_printer(type_decls_g), - .fn_decls = open_growy_as_printer(fn_decls_g), - .fn_defs = open_growy_as_printer(fn_defs_g), - .emitted_terms = new_dict(Node*, CTerm, (HashFn) hash_node, (CmpFn) compare_node), - .emitted_types = new_dict(Node*, String, (HashFn) hash_node, (CmpFn) compare_node), - }; - - Nodes decls = get_module_declarations(mod); - for (size_t i = 0; i < decls.count; i++) - emit_decl(&emitter, decls.nodes[i]); - - destroy_printer(emitter.type_decls); - destroy_printer(emitter.fn_decls); - destroy_printer(emitter.fn_defs); - - Growy* final = new_growy(); - Printer* finalp = open_growy_as_printer(final); - - if (emitter.config.dialect == GLSL) { - print(finalp, "#version 420\n"); - } - - print(finalp, "/* file generated by shady */\n"); - - switch (emitter.config.dialect) { - case ISPC: - break; - case C: - print(finalp, "\n#include "); - print(finalp, "\n#include "); - print(finalp, "\n#include "); - print(finalp, "\n#include "); - print(finalp, "\n#include "); - break; - case GLSL: - print(finalp, "#extension GL_ARB_gpu_shader_int64: require\n"); - print(finalp, "#define ubyte uint\n"); - print(finalp, "#define uchar uint\n"); - print(finalp, "#define ulong uint\n"); - break; - } - - print(finalp, "\n/* types: */\n"); - growy_append_bytes(final, growy_size(type_decls_g), growy_data(type_decls_g)); - - print(finalp, "\n/* declarations: */\n"); - growy_append_bytes(final, growy_size(fn_decls_g), growy_data(fn_decls_g)); - - print(finalp, "\n/* definitions: */\n"); - growy_append_bytes(final, growy_size(fn_defs_g), growy_data(fn_defs_g)); - - print(finalp, "\n"); - print(finalp, "\n"); - print(finalp, "\n"); - - destroy_growy(type_decls_g); - destroy_growy(fn_decls_g); - destroy_growy(fn_defs_g); - - destroy_dict(emitter.emitted_types); - destroy_dict(emitter.emitted_terms); - - *output_size = growy_size(final); - *output = growy_deconstruct(final); - destroy_printer(finalp); - - if (new_mod) - *new_mod = mod; - else if (initial_arena != arena) - destroy_ir_arena(arena); -} diff --git a/src/shady/emit/c/emit_c.h b/src/shady/emit/c/emit_c.h deleted file mode 100644 index 48e37725c..000000000 --- a/src/shady/emit/c/emit_c.h +++ /dev/null @@ -1,90 +0,0 @@ -#ifndef SHADY_EMIT_C -#define SHADY_EMIT_C - -#include "shady/ir.h" -#include "shady/builtins.h" -#include "growy.h" -#include "arena.h" -#include "printer.h" - -#define emit_type c_emit_type -#define emit_value c_emit_value -#define emit_instruction c_emit_instruction -#define emit_lambda_body c_emit_lambda_body -#define emit_decl c_emit_decl -#define emit_nominal_type_body c_emit_nominal_type_body - -/// SSA-like things, you can read them -typedef String CValue; -/// non-SSA like things, they represent addresses -typedef String CAddr; - -typedef String CType; - -typedef struct { - CValue value; - CAddr var; -} CTerm; - -#define term_from_cvalue(t) (CTerm) { .value = t } -#define term_from_cvar(t) (CTerm) { .var = t } - -typedef Strings Phis; - -typedef struct { - CEmitterConfig config; - IrArena* arena; - Printer *type_decls, *fn_decls, *fn_defs; - struct { - Phis selection, loop_continue, loop_break; - } phis; - - struct Dict* emitted_terms; - struct Dict* emitted_types; -} Emitter; - -void register_emitted(Emitter*, const Node*, CTerm); -void register_emitted_type(Emitter*, const Type*, String); - -CTerm* lookup_existing_term(Emitter* emitter, const Node*); -CType* lookup_existing_type(Emitter* emitter, const Type*); - -CValue to_cvalue(Emitter*, CTerm); -CAddr deref_term(Emitter*, CTerm); - -void emit_decl(Emitter* emitter, const Node* decl); -CType emit_type(Emitter* emitter, const Type*, const char* identifier); -String emit_fn_head(Emitter* emitter, const Node* fn_type, String center, const Node* fn); -void emit_nominal_type_body(Emitter* emitter, String name, const Type* type); -void emit_variable_declaration(Emitter* emitter, Printer* block_printer, const Type* t, String variable_name, bool mut, const CTerm* initializer); - -CTerm emit_value(Emitter* emitter, Printer*, const Node* value); -CTerm emit_c_builtin(Emitter*, Builtin); - -String legalize_c_identifier(Emitter*, String); -String get_record_field_name(const Type* t, size_t i); -CTerm ispc_varying_ptr_helper(Emitter* emitter, Printer* block_printer, const Type* ptr_type, CTerm term); - -typedef enum { NoBinding, LetBinding } InstrResultBinding; - -typedef struct { - size_t count; - CTerm* results; - /// What to do with the results at the call site - InstrResultBinding* binding; -} InstructionOutputs; - -void emit_instruction(Emitter* emitter, Printer* p, const Node* instruction, InstructionOutputs); -String emit_lambda_body (Emitter*, const Node*, const Nodes* nested_basic_blocks); -void emit_lambda_body_at(Emitter*, Printer*, const Node*, const Nodes* nested_basic_blocks); - -void emit_pack_code(Printer*, Strings, String dst); -void emit_unpack_code(Printer*, String src, Strings dst); - -#define free_tmp_str(s) free((char*) (s)) - -inline static bool is_glsl_scalar_type(const Type* t) { - return t->tag == Bool_TAG || t->tag == Int_TAG || t->tag == Float_TAG; -} - -#endif diff --git a/src/shady/emit/c/emit_c_instructions.c b/src/shady/emit/c/emit_c_instructions.c deleted file mode 100644 index 21d7cb9f0..000000000 --- a/src/shady/emit/c/emit_c_instructions.c +++ /dev/null @@ -1,769 +0,0 @@ -#include "emit_c.h" - -#include "portability.h" -#include "log.h" -#include "dict.h" -#include "util.h" - -#include "../../type.h" -#include "../../ir_private.h" - -#include -#include - -#pragma GCC diagnostic error "-Wswitch" - -void emit_pack_code(Printer* p, Strings src, String dst) { - for (size_t i = 0; i < src.count; i++) { - print(p, "\n%s->_%d = %s", dst, src.strings[i], i); - } -} - -void emit_unpack_code(Printer* p, String src, Strings dst) { - for (size_t i = 0; i < dst.count; i++) { - print(p, "\n%s = %s->_%d", dst.strings[i], src, i); - } -} - -static Strings emit_variable_declarations(Emitter* emitter, Printer* p, String given_name, Strings* given_names, Nodes types, bool mut, const Nodes* init_values) { - if (given_names) - assert(given_names->count == types.count); - if (init_values) - assert(init_values->count == types.count); - LARRAY(String, names, types.count); - for (size_t i = 0; i < types.count; i++) { - VarId id = fresh_id(emitter->arena); - String name = given_names ? given_names->strings[i] : given_name; - assert(name); - names[i] = format_string_arena(emitter->arena->arena, "%s_%d", name, id); - if (init_values) { - CTerm initializer = emit_value(emitter, p, init_values->nodes[i]); - emit_variable_declaration(emitter, p, types.nodes[i], names[i], mut, &initializer); - } else - emit_variable_declaration(emitter, p, types.nodes[i], names[i], mut, NULL); - } - return strings(emitter->arena, types.count, names); -} - -static const Type* get_first_op_scalar_type(Nodes ops) { - const Type* t = first(ops)->type; - deconstruct_qualified_type(&t); - deconstruct_maybe_packed_type(&t); - return t; -} - -typedef enum { - OsInfix, OsPrefix, OsCall, -} OpStyle; - -typedef enum { - IsNone, // empty entry - IsMono, - IsPoly -} ISelMechanism; - -typedef struct { - ISelMechanism isel_mechanism; - OpStyle style; - String op; - String u_ops[4]; - String s_ops[4]; - String f_ops[3]; -} ISelTableEntry; - -static const ISelTableEntry isel_table[PRIMOPS_COUNT] = { - [add_op] = { IsMono, OsInfix, "+" }, - [sub_op] = { IsMono, OsInfix, "-" }, - [mul_op] = { IsMono, OsInfix, "*" }, - [div_op] = { IsMono, OsInfix, "/" }, - [mod_op] = { IsMono, OsInfix, "%" }, - [neg_op] = { IsMono, OsPrefix, "-" }, - [gt_op] = { IsMono, OsInfix, ">" }, - [gte_op] = { IsMono, OsInfix, ">=" }, - [lt_op] = { IsMono, OsInfix, "<" }, - [lte_op] = { IsMono, OsInfix, "<=" }, - [eq_op] = { IsMono, OsInfix, "==" }, - [neq_op] = { IsMono, OsInfix, "!=" }, - [and_op] = { IsMono, OsInfix, "&" }, - [or_op] = { IsMono, OsInfix, "|" }, - [xor_op] = { IsMono, OsInfix, "^" }, - [not_op] = { IsMono, OsPrefix, "!" }, - [rshift_arithm_op] = { IsMono, OsInfix, ">>" }, - [rshift_logical_op] = { IsMono, OsInfix, ">>" }, // TODO achieve desired right shift semantics through unsigned/signed casts - [lshift_op] = { IsMono, OsInfix, "<<" }, -}; - -static const ISelTableEntry isel_table_c[PRIMOPS_COUNT] = { - [abs_op] = { IsPoly, OsCall, .s_ops = { "abs", "abs", "abs", "llabs" }, .f_ops = {"fabsf", "fabsf", "fabs"}}, - - [sin_op] = { IsPoly, OsCall, .f_ops = {"sinf", "sinf", "sin"}}, - [cos_op] = { IsPoly, OsCall, .f_ops = {"cosf", "cosf", "cos"}}, - [floor_op] = { IsPoly, OsCall, .f_ops = {"floorf", "floorf", "floor"}}, - [ceil_op] = { IsPoly, OsCall, .f_ops = {"ceilf", "ceilf", "ceil"}}, - [round_op] = { IsPoly, OsCall, .f_ops = {"roundf", "roundf", "round"}}, - - [sqrt_op] = { IsPoly, OsCall, .f_ops = {"sqrtf", "sqrtf", "sqrt"}}, - [exp_op] = { IsPoly, OsCall, .f_ops = {"expf", "expf", "exp"}}, - [pow_op] = { IsPoly, OsCall, .f_ops = {"powf", "powf", "pow"}}, -}; - -static const ISelTableEntry isel_table_glsl[PRIMOPS_COUNT] = { 0 }; - -static const ISelTableEntry isel_table_ispc[PRIMOPS_COUNT] = { - [abs_op] = { IsMono, OsCall, "abs" }, - - [sin_op] = { IsMono, OsCall, "sin" }, - [cos_op] = { IsMono, OsCall, "cos" }, - [floor_op] = { IsMono, OsCall, "floor" }, - [ceil_op] = { IsMono, OsCall, "ceil" }, - [round_op] = { IsMono, OsCall, "round" }, - - [sqrt_op] = { IsMono, OsCall, "sqrt" }, - [exp_op] = { IsMono, OsCall, "exp" }, - [pow_op] = { IsMono, OsCall, "pow" }, - - [subgroup_active_mask_op] = { IsMono, OsCall, "lanemask" }, - [subgroup_ballot_op] = { IsMono, OsCall, "packmask" }, - [subgroup_reduce_sum_op] = { IsMono, OsCall, "reduce_add" }, -}; - -static bool emit_using_entry(CTerm* out, Emitter* emitter, Printer* p, const ISelTableEntry* entry, Nodes operands) { - String operator_str = NULL; - switch (entry->isel_mechanism) { - case IsNone: return false; - case IsMono: operator_str = entry->op; break; - case IsPoly: { - const Type* t = get_first_op_scalar_type(operands); - if (t->tag == Float_TAG) - operator_str = entry->f_ops[t->payload.float_type.width]; - else if (t->tag == Int_TAG && t->payload.int_type.is_signed) - operator_str = entry->s_ops[t->payload.int_type.width]; - else if (t->tag == Int_TAG) - operator_str = entry->u_ops[t->payload.int_type.width]; - break; - } - } - - if (!operator_str) - return false; - - switch (entry->style) { - case OsInfix: { - CTerm a = emit_value(emitter, p, operands.nodes[0]); - CTerm b = emit_value(emitter, p, operands.nodes[1]); - *out = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s %s %s", to_cvalue(emitter, a), operator_str, to_cvalue(emitter, b))); - break; - } - case OsPrefix: { - CTerm operand = emit_value(emitter, p, operands.nodes[0]); - *out = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s%s", operator_str, to_cvalue(emitter, operand))); - break; - } - case OsCall: { - LARRAY(CTerm, cops, operands.count); - for (size_t i = 0; i < operands.count; i++) - cops[i] = emit_value(emitter, p, operands.nodes[i]); - if (operands.count == 1) - *out = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s(%s)", operator_str, to_cvalue(emitter, cops[0]))); - else { - Growy* g = new_growy(); - growy_append_string(g, operator_str); - growy_append_string_literal(g, "("); - for (size_t i = 0; i < operands.count; i++) { - growy_append_string(g, to_cvalue(emitter, cops[i])); - if (i + 1 < operands.count) - growy_append_string_literal(g, ", "); - } - growy_append_string_literal(g, ")"); - *out = term_from_cvalue(growy_deconstruct(g)); - } - break; - } - } - return true; -} - -static const ISelTableEntry* lookup_entry(Emitter* emitter, Op op) { - const ISelTableEntry* isel_entry = NULL; - switch (emitter->config.dialect) { - case C: isel_entry = &isel_table_c[op]; break; - case GLSL: isel_entry = &isel_table_glsl[op]; break; - case ISPC: isel_entry = &isel_table_ispc[op]; break; - } - if (isel_entry->isel_mechanism == IsNone) - isel_entry = &isel_table[op]; - return isel_entry; -} - -static void emit_primop(Emitter* emitter, Printer* p, const Node* node, InstructionOutputs outputs) { - assert(node->tag == PrimOp_TAG); - IrArena* arena = emitter->arena; - const PrimOp* prim_op = &node->payload.prim_op; - CTerm term = term_from_cvalue(format_string_interned(emitter->arena, "/* todo %s */", get_primop_name(prim_op->op))); - const ISelTableEntry* isel_entry = lookup_entry(emitter, prim_op->op); - switch (prim_op->op) { - case deref_op: - case assign_op: - case subscript_op: assert(false); - case quote_op: { - assert(outputs.count == 1); - for (size_t i = 0; i < prim_op->operands.count; i++) { - outputs.results[i] = emit_value(emitter, p, prim_op->operands.nodes[i]); - outputs.binding[i] = NoBinding; - } - break; - } - case add_carry_op: - case sub_borrow_op: - case mul_extended_op: - error("TODO: implement extended arithm ops in C"); - break; - // MATH OPS - case fract_op: { - CTerm floored; - emit_using_entry(&floored, emitter, p, lookup_entry(emitter, floor_op), prim_op->operands); - term = term_from_cvalue(format_string_arena(arena->arena, "1 - %s", to_cvalue(emitter, floored))); - break; - } - case inv_sqrt_op: { - CTerm floored; - emit_using_entry(&floored, emitter, p, lookup_entry(emitter, sqrt_op), prim_op->operands); - term = term_from_cvalue(format_string_arena(arena->arena, "1.0f / %s", to_cvalue(emitter, floored))); - break; - } - case min_op: { - CValue a = to_cvalue(emitter, emit_value(emitter, p, first(prim_op->operands))); - CValue b = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[1])); - term = term_from_cvalue(format_string_arena(arena->arena, "(%s > %s ? %s : %s)", a, b, b, a)); - break; - } - case max_op: { - CValue a = to_cvalue(emitter, emit_value(emitter, p, first(prim_op->operands))); - CValue b = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[1])); - term = term_from_cvalue(format_string_arena(arena->arena, "(%s > %s ? %s : %s)", a, b, a, b)); - break; - } - case sign_op: { - CValue src = to_cvalue(emitter, emit_value(emitter, p, first(prim_op->operands))); - term = term_from_cvalue(format_string_arena(arena->arena, "(%s > 0 ? 1 : -1)", src)); - break; - } - case alloca_subgroup_op: error("Lower me"); - case alloca_op: - case alloca_logical_op: { - assert(outputs.count == 1); - String variable_name = unique_name(emitter->arena, "alloca"); - CTerm variable = (CTerm) { .value = NULL, .var = variable_name }; - emit_variable_declaration(emitter, p, first(prim_op->type_arguments), variable_name, true, NULL); - outputs.results[0] = variable; - if (emitter->config.dialect == ISPC) { - outputs.results[0] = ispc_varying_ptr_helper(emitter, p, get_unqualified_type(node->type), variable); - } - outputs.binding[0] = NoBinding; - return; - } - case load_op: { - CAddr dereferenced = deref_term(emitter, emit_value(emitter, p, first(prim_op->operands))); - outputs.results[0] = term_from_cvalue(dereferenced); - outputs.binding[0] = LetBinding; - return; - } - case store_op: { - const Node* addr = first(prim_op->operands); - const Node* value = prim_op->operands.nodes[1]; - const Type* addr_type = addr->type; - bool addr_uniform = deconstruct_qualified_type(&addr_type); - bool value_uniform = is_qualified_type_uniform(value->type); - assert(addr_type->tag == PtrType_TAG); - CAddr dereferenced = deref_term(emitter, emit_value(emitter, p, addr)); - CValue cvalue = to_cvalue(emitter, emit_value(emitter, p, value)); - // ISPC lets you broadcast to a uniform address space iff the address is non-uniform, otherwise we need to do this - if (emitter->config.dialect == ISPC && addr_uniform && is_addr_space_uniform(arena, addr_type->payload.ptr_type.address_space) && !value_uniform) - cvalue = format_string_arena(emitter->arena->arena, "extract(%s, count_trailing_zeros(lanemask()))", cvalue); - - print(p, "\n%s = %s;", dereferenced, cvalue); - return; - } case lea_op: { - CTerm acc = emit_value(emitter, p, prim_op->operands.nodes[0]); - - const Type* src_qtype = prim_op->operands.nodes[0]->type; - bool uniform = is_qualified_type_uniform(src_qtype); - const Type* curr_ptr_type = get_unqualified_type(src_qtype); - assert(curr_ptr_type->tag == PtrType_TAG); - - const IntLiteral* offset_static_value = resolve_to_int_literal(prim_op->operands.nodes[1]); - if (!offset_static_value || offset_static_value->value != 0) { - CTerm offset = emit_value(emitter, p, prim_op->operands.nodes[1]); - // we sadly need to drop to the value level (aka explicit pointer arithmetic) to do this - // this means such code is never going to be legal in GLSL - // also the cast is to account for our arrays-in-structs hack - acc = term_from_cvalue(format_string_arena(arena->arena, "((%s) &(%s.arr[%s]))", emit_type(emitter, curr_ptr_type, NULL), deref_term(emitter, acc), to_cvalue(emitter, offset))); - uniform &= is_qualified_type_uniform(prim_op->operands.nodes[1]->type); - } - - //t = t->payload.ptr_type.pointed_type; - for (size_t i = 2; i < prim_op->operands.count; i++) { - const Type* pointee_type = get_pointee_type(arena, curr_ptr_type); - const Node* selector = prim_op->operands.nodes[i]; - uniform &= is_qualified_type_uniform(selector->type); - switch (is_type(pointee_type)) { - case ArrType_TAG: { - CTerm index = emit_value(emitter, p, selector); - if (emitter->config.dialect == GLSL) - acc = term_from_cvar(format_string_arena(arena->arena, "(%s.arr[int(%s)])", deref_term(emitter, acc), to_cvalue(emitter, index))); - else - acc = term_from_cvar(format_string_arena(arena->arena, "(%s.arr[%s])", deref_term(emitter, acc), to_cvalue(emitter, index))); - curr_ptr_type = ptr_type(arena, (PtrType) { - .pointed_type = pointee_type->payload.arr_type.element_type, - .address_space = curr_ptr_type->payload.ptr_type.address_space - }); - break; - } - case TypeDeclRef_TAG: { - pointee_type = get_nominal_type_body(pointee_type); - SHADY_FALLTHROUGH - } - case RecordType_TAG: { - // yet another ISPC bug and workaround - // ISPC cannot deal with subscripting if you've done pointer arithmetic (!) inside the expression - // so hum we just need to introduce a temporary variable to hold the pointer expression so far, and go again from there - // See https://github.com/ispc/ispc/issues/2496 - if (emitter->config.dialect == ISPC) { - String interm = unique_name(arena, "lea_intermediary_ptr_value"); - print(p, "\n%s = %s;", emit_type(emitter, qualified_type_helper(curr_ptr_type, uniform), interm), to_cvalue(emitter, acc)); - acc = term_from_cvalue(interm); - } - - assert(selector->tag == IntLiteral_TAG && "selectors when indexing into a record need to be constant"); - size_t static_index = get_int_literal_value(*resolve_to_int_literal(selector), false); - String field_name = get_record_field_name(pointee_type, static_index); - acc = term_from_cvar(format_string_arena(arena->arena, "(%s.%s)", deref_term(emitter, acc), field_name)); - curr_ptr_type = ptr_type(arena, (PtrType) { - .pointed_type = pointee_type->payload.record_type.members.nodes[static_index], - .address_space = curr_ptr_type->payload.ptr_type.address_space - }); - break; - } - default: error("lea can't work on this"); - } - } - assert(outputs.count == 1); - outputs.results[0] = acc; - outputs.binding[0] = emitter->config.dialect == ISPC ? LetBinding : NoBinding; - outputs.binding[0] = NoBinding; - return; - } - case memcpy_op: { - print(p, "\nmemcpy(%s, %s, %s);", to_cvalue(emitter, c_emit_value(emitter, p, prim_op->operands.nodes[0])), to_cvalue(emitter, c_emit_value(emitter, p, prim_op->operands.nodes[1])), to_cvalue(emitter, c_emit_value(emitter, p, prim_op->operands.nodes[2]))); - return; - } - case size_of_op: - term = term_from_cvalue(format_string_arena(emitter->arena->arena, "sizeof(%s)", c_emit_type(emitter, first(prim_op->type_arguments), NULL))); - break; - case align_of_op: - term = term_from_cvalue(format_string_arena(emitter->arena->arena, "alignof(%s)", c_emit_type(emitter, first(prim_op->type_arguments), NULL))); - break; - case offset_of_op: { - const Type* t = first(prim_op->type_arguments); - while (t->tag == TypeDeclRef_TAG) { - t = get_nominal_type_body(t); - } - const Node* index = first(prim_op->operands); - uint64_t index_literal = get_int_literal_value(*resolve_to_int_literal(index), false); - String member_name = get_record_field_name(t, index_literal); - term = term_from_cvalue(format_string_arena(emitter->arena->arena, "offsetof(%s, %s)", c_emit_type(emitter, t, NULL), member_name)); - break; - } case select_op: { - assert(prim_op->operands.count == 3); - CValue condition = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[0])); - CValue l = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[1])); - CValue r = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[2])); - term = term_from_cvalue(format_string_arena(emitter->arena->arena, "(%s) ? (%s) : (%s)", condition, l, r)); - break; - } - case convert_op: { - assert(outputs.count == 1); - CTerm src = emit_value(emitter, p, first(prim_op->operands)); - const Type* src_type = get_unqualified_type(first(prim_op->operands)->type); - const Type* dst_type = first(prim_op->type_arguments); - if (emitter->config.dialect == GLSL) { - if (is_glsl_scalar_type(src_type) && is_glsl_scalar_type(dst_type)) { - CType t = emit_type(emitter, dst_type, NULL); - term = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s(%s)", t, to_cvalue(emitter, src))); - } else - assert(false); - } else { - CType t = emit_type(emitter, dst_type, NULL); - term = term_from_cvalue(format_string_arena(emitter->arena->arena, "((%s) %s)", t, to_cvalue(emitter, src))); - } - break; - } - case reinterpret_op: { - assert(outputs.count == 1); - CTerm src_value = emit_value(emitter, p, first(prim_op->operands)); - const Type* src_type = get_unqualified_type(first(prim_op->operands)->type); - const Type* dst_type = first(prim_op->type_arguments); - switch (emitter->config.dialect) { - case C: { - String src = unique_name(arena, "bitcast_src"); - String dst = unique_name(arena, "bitcast_result"); - print(p, "\n%s = %s;", emit_type(emitter, src_type, src), to_cvalue(emitter, src_value)); - print(p, "\n%s;", emit_type(emitter, dst_type, dst)); - print(p, "\nmemcpy(&%s, &s, sizeof(%s));", dst, src, src); - outputs.results[0] = term_from_cvalue(dst); - outputs.binding[0] = NoBinding; - break; - } - case GLSL: { - String n = NULL; - if (dst_type->tag == Float_TAG) { - assert(src_type->tag == Int_TAG); - switch (dst_type->payload.float_type.width) { - case FloatTy16: break; - case FloatTy32: n = src_type->payload.int_type.is_signed ? "intBitsToFloat" : "uintBitsToFloat"; - break; - case FloatTy64: break; - } - } else if (dst_type->tag == Int_TAG) { - if (src_type->tag == Int_TAG) { - outputs.results[0] = src_value; - outputs.binding[0] = NoBinding; - break; - } - assert(src_type->tag == Float_TAG); - switch (src_type->payload.float_type.width) { - case FloatTy16: break; - case FloatTy32: n = dst_type->payload.int_type.is_signed ? "floatBitsToInt" : "floatBitsToUint"; - break; - case FloatTy64: break; - } - } - if (n) { - outputs.results[0] = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s(%s)", n, to_cvalue(emitter, src_value))); - outputs.binding[0] = LetBinding; - break; - } - error_print("glsl: unsupported bit cast from "); - log_node(ERROR, src_type); - error_print(" to "); - log_node(ERROR, dst_type); - error_print(".\n"); - error_die(); - } - case ISPC: { - if (dst_type->tag == Float_TAG) { - assert(src_type->tag == Int_TAG); - String n; - switch (dst_type->payload.float_type.width) { - case FloatTy16: n = "float16bits"; - break; - case FloatTy32: n = "floatbits"; - break; - case FloatTy64: n = "doublebits"; - break; - } - outputs.results[0] = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s(%s)", n, to_cvalue(emitter, src_value))); - outputs.binding[0] = LetBinding; - break; - } else if (src_type->tag == Float_TAG) { - assert(dst_type->tag == Int_TAG); - outputs.results[0] = term_from_cvalue(format_string_arena(emitter->arena->arena, "intbits(%s)", to_cvalue(emitter, src_value))); - outputs.binding[0] = LetBinding; - break; - } - - CType t = emit_type(emitter, dst_type, NULL); - outputs.results[0] = term_from_cvalue(format_string_arena(emitter->arena->arena, "((%s) %s)", t, to_cvalue(emitter, src_value))); - outputs.binding[0] = NoBinding; - break; - } - } - return; - } - case insert_op: - case extract_dynamic_op: - case extract_op: { - CValue acc = to_cvalue(emitter, emit_value(emitter, p, first(prim_op->operands))); - bool insert = prim_op->op == insert_op; - - if (insert) { - String dst = unique_name(arena, "modified"); - print(p, "\n%s = %s;", c_emit_type(emitter, node->type, dst), acc); - acc = dst; - term = term_from_cvalue(dst); - } - - const Type* t = get_unqualified_type(first(prim_op->operands)->type); - for (size_t i = (insert ? 2 : 1); i < prim_op->operands.count; i++) { - const Node* index = prim_op->operands.nodes[i]; - const IntLiteral* static_index = resolve_to_int_literal(index); - - switch (is_type(t)) { - case Type_TypeDeclRef_TAG: { - const Node* decl = t->payload.type_decl_ref.decl; - assert(decl && decl->tag == NominalType_TAG); - t = decl->payload.nom_type.body; - SHADY_FALLTHROUGH - } - case Type_RecordType_TAG: { - assert(static_index); - Strings names = t->payload.record_type.names; - if (names.count == 0) - acc = format_string_arena(emitter->arena->arena, "(%s._%d)", acc, static_index->value); - else - acc = format_string_arena(emitter->arena->arena, "(%s.%s)", acc, names.strings[static_index->value]); - break; - } - case Type_PackType_TAG: { - assert(static_index); - assert(static_index->value < 4 && static_index->value < t->payload.pack_type.width); - String suffixes = "xyzw"; - acc = format_string_arena(emitter->arena->arena, "(%s.%c)", acc, suffixes[static_index->value]); - break; - } - case Type_ArrType_TAG: { - if (emitter->config.dialect == GLSL) - acc = format_string_arena(emitter->arena->arena, "(%s.arr[int(%s)])", acc, to_cvalue(emitter, emit_value(emitter, p, index))); - else - acc = format_string_arena(emitter->arena->arena, "(%s.arr[%s])", acc, to_cvalue(emitter, emit_value(emitter, p, index))); - break; - } - default: - case NotAType: error("Must be a type"); - } - } - - if (insert) { - print(p, "\n%s = %s;", acc, to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[1]))); - break; - } - - term = term_from_cvalue(acc); - break; - } - case get_stack_base_op: - case push_stack_op: - case pop_stack_op: - case get_stack_pointer_op: - case set_stack_pointer_op: error("Stack operations need to be lowered."); - case default_join_point_op: - case create_joint_point_op: error("lowered in lower_tailcalls.c"); - case subgroup_elect_first_op: { - switch (emitter->config.dialect) { - case ISPC: term = term_from_cvalue(format_string_arena(emitter->arena->arena, "(programIndex == count_trailing_zeros(lanemask()))")); break; - case C: - case GLSL: error("TODO") - } - break; - } - case subgroup_broadcast_first_op: { - CValue value = to_cvalue(emitter, emit_value(emitter, p, first(prim_op->operands))); - switch (emitter->config.dialect) { - case ISPC: term = term_from_cvalue(format_string_arena(emitter->arena->arena, "extract(%s, count_trailing_zeros(lanemask()))", value)); break; - case C: - case GLSL: error("TODO") - } - break; - } - case empty_mask_op: - case mask_is_thread_active_op: error("lower_me"); - case debug_printf_op: { - String args_list = ""; - for (size_t i = 0; i < prim_op->operands.count; i++) { - CValue str = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[i])); - - if (emitter->config.dialect == ISPC && i > 0) - str = format_string_arena(emitter->arena->arena, "extract(%s, printf_thread_index)", str); - - if (i > 0) - args_list = format_string_arena(emitter->arena->arena, "%s, %s", args_list, str); - else - args_list = str; - } - switch (emitter->config.dialect) { - case ISPC: - print(p, "\nforeach_active(printf_thread_index) { print(%s); }", args_list); - break; - case C: - print(p, "\nprintf(%s);", args_list); - break; - case GLSL: warn_print("printf is not supported in GLSL"); - break; - } - - return; - } - default: break; - case PRIMOPS_COUNT: assert(false); break; - } - - if (isel_entry->isel_mechanism != IsNone) - emit_using_entry(&term, emitter, p, isel_entry, prim_op->operands); - - assert(outputs.count == 1); - outputs.binding[0] = LetBinding; - outputs.results[0] = term; - return; -} - -static void emit_call(Emitter* emitter, Printer* p, const Node* call, InstructionOutputs outputs) { - Nodes args; - if (call->tag == Call_TAG) - args = call->payload.call.args; - else - assert(false); - - Growy* g = new_growy(); - Printer* paramsp = open_growy_as_printer(g); - for (size_t i = 0; i < args.count; i++) { - print(paramsp, to_cvalue(emitter, emit_value(emitter, p, args.nodes[i]))); - if (i + 1 < args.count) - print(paramsp, ", "); - } - - CValue e_callee; - const Node* callee = call->payload.call.callee; - if (callee->tag == FnAddr_TAG) - e_callee = get_decl_name(callee->payload.fn_addr.fn); - else - e_callee = to_cvalue(emitter, emit_value(emitter, p, callee)); - - String params = printer_growy_unwrap(paramsp); - - Nodes yield_types = unwrap_multiple_yield_types(emitter->arena, call->type); - assert(yield_types.count == outputs.count); - if (yield_types.count > 1) { - String named = unique_name(emitter->arena, "result"); - print(p, "\n%s = %s(%s);", emit_type(emitter, call->type, named), e_callee, params); - for (size_t i = 0; i < yield_types.count; i++) { - outputs.results[i] = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s->_%d", named, i)); - // we have let-bound the actual result already, and extracting their components can be done inline - outputs.binding[i] = NoBinding; - } - } else if (yield_types.count == 1) { - outputs.results[0] = term_from_cvalue(format_string_arena(emitter->arena->arena, "%s(%s)", e_callee, params)); - outputs.binding[0] = LetBinding; - } else { - print(p, "\n%s(%s);", e_callee, params); - } - free_tmp_str(params); -} - -static void emit_if(Emitter* emitter, Printer* p, const Node* if_instr, InstructionOutputs outputs) { - assert(if_instr->tag == If_TAG); - const If* if_ = &if_instr->payload.if_instr; - Emitter sub_emiter = *emitter; - Strings ephis = emit_variable_declarations(emitter, p, "if_phi", NULL, if_->yield_types, true, NULL); - sub_emiter.phis.selection = ephis; - - assert(get_abstraction_params(if_->if_true).count == 0); - String true_body = emit_lambda_body(&sub_emiter, get_abstraction_body(if_->if_true), NULL); - CValue condition = to_cvalue(emitter, emit_value(emitter, p, if_->condition)); - print(p, "\nif (%s) { %s}", condition, true_body); - free_tmp_str(true_body); - if (if_->if_false) { - assert(get_abstraction_params(if_->if_false).count == 0); - String false_body = emit_lambda_body(&sub_emiter, get_abstraction_body(if_->if_false), NULL); - print(p, " else {%s}", false_body); - free_tmp_str(false_body); - } - - assert(outputs.count == ephis.count); - for (size_t i = 0; i < outputs.count; i++) { - outputs.results[i] = term_from_cvalue(ephis.strings[i]); - outputs.binding[i] = NoBinding; - } -} - -static void emit_match(Emitter* emitter, Printer* p, const Node* match_instr, InstructionOutputs outputs) { - assert(match_instr->tag == Match_TAG); - const Match* match = &match_instr->payload.match_instr; - Emitter sub_emiter = *emitter; - Strings ephis = emit_variable_declarations(emitter, p, "match_phi", NULL, match->yield_types, true, NULL); - sub_emiter.phis.selection = ephis; - - // Of course, the sensible thing to do here would be to emit a switch statement. - // ... - // Except that doesn't work, because C/GLSL have a baffling design wart: the `break` statement is overloaded, - // meaning that if you enter a switch statement, which should be orthogonal to loops, you can't actually break - // out of the outer loop anymore. Brilliant. So we do this terrible if-chain instead. - // - // We could do GOTO for C, but at the cost of arguably even more noise in the output, and two different codepaths. - // I don't think it's quite worth it, just like it's not worth doing some data-flow based solution either. - - CValue inspectee = to_cvalue(emitter, emit_value(emitter, p, match->inspect)); - bool first = true; - LARRAY(CValue, literals, match->cases.count); - for (size_t i = 0; i < match->cases.count; i++) { - literals[i] = to_cvalue(emitter, emit_value(emitter, p, match->literals.nodes[i])); - } - for (size_t i = 0; i < match->cases.count; i++) { - String case_body = emit_lambda_body(&sub_emiter, get_abstraction_body(match->cases.nodes[i]), NULL); - print(p, "\n"); - if (!first) - print(p, "else "); - print(p, "if (%s == %s) { %s}", inspectee, literals[i], case_body); - free_tmp_str(case_body); - first = false; - } - if (match->default_case) { - String default_case_body = emit_lambda_body(&sub_emiter, get_abstraction_body(match->default_case), NULL); - print(p, "\nelse { %s}", default_case_body); - free_tmp_str(default_case_body); - } - - assert(outputs.count == ephis.count); - for (size_t i = 0; i < outputs.count; i++) { - outputs.results[i] = term_from_cvalue(ephis.strings[i]); - outputs.binding[i] = NoBinding; - } -} - -static void emit_loop(Emitter* emitter, Printer* p, const Node* loop_instr, InstructionOutputs outputs) { - assert(loop_instr->tag == Loop_TAG); - const Loop* loop = &loop_instr->payload.loop_instr; - - Emitter sub_emiter = *emitter; - Nodes params = get_abstraction_params(loop->body); - Nodes variables = params; - LARRAY(String, arr, variables.count); - for (size_t i = 0; i < variables.count; i++) { - arr[i] = get_value_name(variables.nodes[i]); - if (!arr[i]) - arr[i] = unique_name(emitter->arena, "phi"); - } - Strings param_names = strings(emitter->arena, variables.count, arr); - Strings eparams = emit_variable_declarations(emitter, p, NULL, ¶m_names, get_variables_types(emitter->arena, params), true, &loop_instr->payload.loop_instr.initial_args); - for (size_t i = 0; i < params.count; i++) - register_emitted(&sub_emiter, params.nodes[i], term_from_cvalue(eparams.strings[i])); - - sub_emiter.phis.loop_continue = eparams; - Strings ephis = emit_variable_declarations(emitter, p, "loop_break_phi", NULL, loop->yield_types, true, NULL); - sub_emiter.phis.loop_break = ephis; - - String body = emit_lambda_body(&sub_emiter, get_abstraction_body(loop->body), NULL); - print(p, "\nwhile(true) { %s}", body); - free_tmp_str(body); - - assert(outputs.count == ephis.count); - for (size_t i = 0; i < outputs.count; i++) { - outputs.results[i] = term_from_cvalue(ephis.strings[i]); - outputs.binding[i] = NoBinding; - } -} - -void emit_instruction(Emitter* emitter, Printer* p, const Node* instruction, InstructionOutputs outputs) { - assert(is_instruction(instruction)); - - switch (is_instruction(instruction)) { - case NotAnInstruction: assert(false); - case Instruction_PrimOp_TAG: emit_primop(emitter, p, instruction, outputs); break; - case Instruction_Call_TAG: emit_call (emitter, p, instruction, outputs); break; - case Instruction_If_TAG: emit_if (emitter, p, instruction, outputs); break; - case Instruction_Match_TAG: emit_match (emitter, p, instruction, outputs); break; - case Instruction_Loop_TAG: emit_loop (emitter, p, instruction, outputs); break; - case Instruction_Control_TAG: error("TODO") - case Instruction_Block_TAG: error("Should be eliminated by the compiler") - case Instruction_Comment_TAG: print(p, "/* %s */", instruction->payload.comment.string); break; - } -} diff --git a/src/shady/emit/c/emit_c_signatures.c b/src/shady/emit/c/emit_c_signatures.c deleted file mode 100644 index 306dffd90..000000000 --- a/src/shady/emit/c/emit_c_signatures.c +++ /dev/null @@ -1,302 +0,0 @@ -#include "emit_c.h" - -#include "dict.h" -#include "log.h" -#include "util.h" - -#include "../../type.h" -#include "../../ir_private.h" - -#include -#include -#include - -#pragma GCC diagnostic error "-Wswitch" - -String get_record_field_name(const Type* t, size_t i) { - assert(t->tag == RecordType_TAG); - RecordType r = t->payload.record_type; - assert(i < r.members.count); - if (i >= r.names.count) - return format_string_interned(t->arena, "_%d", i); - else - return r.names.strings[i]; -} - -void emit_nominal_type_body(Emitter* emitter, String name, const Type* type) { - assert(type->tag == RecordType_TAG); - Growy* g = new_growy(); - Printer* p = open_growy_as_printer(g); - - print(p, "\n%s {", name); - indent(p); - for (size_t i = 0; i < type->payload.record_type.members.count; i++) { - String member_identifier = get_record_field_name(type, i); - print(p, "\n%s;", emit_type(emitter, type->payload.record_type.members.nodes[i], member_identifier)); - } - deindent(p); - print(p, "\n};\n"); - growy_append_bytes(g, 1, (char[]) { '\0' }); - - print(emitter->type_decls, growy_data(g)); - destroy_growy(g); - destroy_printer(p); -} - -String emit_fn_head(Emitter* emitter, const Node* fn_type, String center, const Node* fn) { - assert(fn_type->tag == FnType_TAG); - assert(!fn || fn->type == fn_type); - Nodes codom = fn_type->payload.fn_type.return_types; - - Growy* paramg = new_growy(); - Printer* paramp = open_growy_as_printer(paramg); - Nodes dom = fn_type->payload.fn_type.param_types; - if (dom.count == 0 && emitter->config.dialect == C) - print(paramp, "void"); - else if (fn) { - Nodes params = fn->payload.fun.params; - assert(params.count == dom.count); - for (size_t i = 0; i < dom.count; i++) { - String param_name; - String variable_name = get_value_name(fn->payload.fun.params.nodes[i]); - if (variable_name) - param_name = format_string_arena(emitter->arena->arena, "%s_%d", legalize_c_identifier(emitter, variable_name), fn->payload.fun.params.nodes[i]->payload.var.id); - else - param_name = format_string_arena(emitter->arena->arena, "p%d", fn->payload.fun.params.nodes[i]->payload.var.id); - print(paramp, emit_type(emitter, params.nodes[i]->type, param_name)); - if (i + 1 < dom.count) { - print(paramp, ", "); - } - } - } else { - for (size_t i = 0; i < dom.count; i++) { - print(paramp, emit_type(emitter, dom.nodes[i], "")); - if (i + 1 < dom.count) { - print(paramp, ", "); - } - } - } - growy_append_bytes(paramg, 1, (char[]) { 0 }); - const char* parameters = printer_growy_unwrap(paramp); - switch (emitter->config.dialect) { - case ISPC: - case C: - center = format_string_arena(emitter->arena->arena, "(%s)(%s)", center, parameters); - break; - case GLSL: - // GLSL does not accept functions declared like void (foo)(int); - // it also does not support higher-order functions and/or function pointers, so we drop the parentheses - center = format_string_arena(emitter->arena->arena, "%s(%s)", center, parameters); - break; - } - free_tmp_str(parameters); - - String c_decl = emit_type(emitter, wrap_multiple_yield_types(emitter->arena, codom), center); - - const Node* entry_point = fn ? lookup_annotation(fn, "EntryPoint") : NULL; - if (entry_point) switch (emitter->config.dialect) { - case C: - break; - case GLSL: - break; - case ISPC: - c_decl = format_string_arena(emitter->arena->arena, "export %s", c_decl); - break; - } - - return c_decl; -} - -String emit_type(Emitter* emitter, const Type* type, const char* center) { - if (center == NULL) - center = ""; - - String emitted = NULL; - CType* found = lookup_existing_type(emitter, type); - if (found) { - emitted = *found; - goto type_goes_on_left; - } - - switch (is_type(type)) { - case NotAType: assert(false); break; - case LamType_TAG: - case BBType_TAG: error("these types do not exist in C"); - case MaskType_TAG: error("should be lowered away"); - case Type_CombinedImageSamplerType_TAG: - case Type_SamplerType_TAG: - case Type_ImageType_TAG: - case JoinPointType_TAG: error("TODO") - case NoRet_TAG: - case Bool_TAG: emitted = "bool"; break; - case Int_TAG: { - switch (emitter->config.dialect) { - case ISPC: { - const char* ispc_int_types[4][2] = { - { "uint8" , "int8" }, - { "uint16", "int16" }, - { "uint32", "int32" }, - { "uint64", "int64" }, - }; - emitted = ispc_int_types[type->payload.int_type.width][type->payload.int_type.is_signed]; - break; - } - case C: { - const char* c_classic_int_types[4][2] = { - { "unsigned char" , "char" }, - { "unsigned short", "short" }, - { "unsigned int" , "int" }, - { "unsigned long" , "long" }, - }; - const char* c_explicit_int_sizes[4][2] = { - { "uint8_t" , "int8_t" }, - { "uint16_t", "int16_t" }, - { "uint32_t", "int32_t" }, - { "uint64_t", "int64_t" }, - }; - emitted = (emitter->config.explicitly_sized_types ? c_explicit_int_sizes : c_classic_int_types)[type->payload.int_type.width][type->payload.int_type.is_signed]; - break; - } - case GLSL: - switch (type->payload.int_type.width) { - case IntTy8: warn_print("vanilla GLSL does not support 8-bit integers\n"); - emitted = "ubyte"; - break; - case IntTy16: warn_print("vanilla GLSL does not support 16-bit integers\n"); - emitted = "ushort"; - break; - case IntTy32: emitted = "uint"; break; - case IntTy64: warn_print("vanilla GLSL does not support 64-bit integers\n"); - emitted = "uint64_t"; - break; - } - break; - } - break; - } - case Float_TAG: - switch (type->payload.float_type.width) { - case FloatTy16: - assert(false); - break; - case FloatTy32: - emitted = "float"; - break; - case FloatTy64: - emitted = "double"; - break; - } - break; - case Type_RecordType_TAG: { - if (type->payload.record_type.members.count == 0) { - emitted = "void"; - break; - } - - emitted = unique_name(emitter->arena, "Record"); - String prefixed = format_string_arena(emitter->arena->arena, "struct %s", emitted); - emit_nominal_type_body(emitter, prefixed, type); - // C puts structs in their own namespace so we always need the prefix - if (emitter->config.dialect == C) - emitted = prefixed; - - break; - } - case Type_QualifiedType_TAG: - switch (emitter->config.dialect) { - case C: - case GLSL: - return emit_type(emitter, type->payload.qualified_type.type, center); - case ISPC: - if (type->payload.qualified_type.is_uniform) - return emit_type(emitter, type->payload.qualified_type.type, format_string_arena(emitter->arena->arena, "uniform %s", center)); - else - return emit_type(emitter, type->payload.qualified_type.type, format_string_arena(emitter->arena->arena, "varying %s", center)); - } - case Type_PtrType_TAG: { - CType t = emit_type(emitter, type->payload.ptr_type.pointed_type, format_string_arena(emitter->arena->arena, "* %s", center)); - // we always emit pointers to _uniform_ data, no exceptions - if (emitter->config.dialect == ISPC) - t = format_string_arena(emitter->arena->arena, "uniform %s", t); - return t; - } - case Type_FnType_TAG: { - return emit_fn_head(emitter, type, center, NULL); - } - case Type_ArrType_TAG: { - emitted = unique_name(emitter->arena, "Array"); - String prefixed = format_string_arena(emitter->arena->arena, "struct %s", emitted); - Growy* g = new_growy(); - Printer* p = open_growy_as_printer(g); - - print(p, "\n%s {", prefixed); - indent(p); - const Node* size = type->payload.arr_type.size; - String inner_decl_rhs; - if (size) - inner_decl_rhs = format_string_arena(emitter->arena->arena, "arr[%zu]", get_int_literal_value(*resolve_to_int_literal(size), false)); - else - inner_decl_rhs = format_string_arena(emitter->arena->arena, "arr[0]"); - print(p, "\n%s;", emit_type(emitter, type->payload.arr_type.element_type, inner_decl_rhs)); - deindent(p); - print(p, "\n};\n"); - growy_append_bytes(g, 1, (char[]) { '\0' }); - - String subdecl = printer_growy_unwrap(p); - print(emitter->type_decls, subdecl); - free_tmp_str(subdecl); - - // ditto from RecordType - switch (emitter->config.dialect) { - case C: - case ISPC: - emitted = prefixed; - break; - case GLSL: - break; - } - break; - } - case Type_PackType_TAG: { - int width = type->payload.pack_type.width; - const Type* element_type = type->payload.pack_type.element_type; - switch (emitter->config.dialect) { - case GLSL: { - assert(is_glsl_scalar_type(element_type)); - assert(width > 1); - String base; - switch (element_type->tag) { - case Bool_TAG: base = "bvec"; break; - case Int_TAG: base = "uvec"; break; // TODO not every int is 32-bit - case Float_TAG: base = "vec"; break; - default: error("not a valid GLSL vector type"); - } - emitted = format_string_arena(emitter->arena->arena, "%s%d", base, width); - break; - } - case ISPC: error("Please lower to something else") - case C: { - emitted = emit_type(emitter, element_type, NULL); - emitted = format_string_arena(emitter->arena->arena, "__attribute__ ((vector_size (%d * sizeof(%s) ))) %s", width, emitted, emitted); - break; - } - } - break; - } - case Type_TypeDeclRef_TAG: { - emit_decl(emitter, type->payload.type_decl_ref.decl); - emitted = *lookup_existing_type(emitter, type->payload.type_decl_ref.decl); - goto type_goes_on_left; - } - } - assert(emitted != NULL); - register_emitted_type(emitter, type, emitted); - - type_goes_on_left: - assert(emitted != NULL); - - if (strlen(center) > 0) - emitted = format_string_arena(emitter->arena->arena, "%s %s", emitted, center); - - return emitted; -} diff --git a/src/shady/emit/spirv/CMakeLists.txt b/src/shady/emit/spirv/CMakeLists.txt deleted file mode 100644 index 97f6d1ace..000000000 --- a/src/shady/emit/spirv/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_library(shady_spirv OBJECT - emit_spv.c - emit_spv_type.c - emit_spv_instructions.c - spirv_builder.c -) -set_property(TARGET shady_spirv PROPERTY POSITION_INDEPENDENT_CODE ON) - -target_link_libraries(shady_spirv PUBLIC "api") -target_link_libraries(shady_spirv PRIVATE "$") -target_link_libraries(shady_spirv PRIVATE "$") -target_link_libraries(shady_spirv PUBLIC "$") \ No newline at end of file diff --git a/src/shady/emit/spirv/emit_spv.c b/src/shady/emit/spirv/emit_spv.c deleted file mode 100644 index 5e6cb45a2..000000000 --- a/src/shady/emit/spirv/emit_spv.c +++ /dev/null @@ -1,581 +0,0 @@ -#include "list.h" -#include "dict.h" -#include "log.h" -#include "portability.h" -#include "growy.h" -#include "util.h" - -#include "shady/builtins.h" -#include "../../ir_private.h" -#include "../../analysis/scope.h" -#include "../../type.h" -#include "../../compile.h" - -#include "emit_spv.h" - -#include -#include -#include - -extern SpvBuiltIn spv_builtins[]; - -#pragma GCC diagnostic error "-Wswitch" - -void register_result(Emitter* emitter, const Node* node, SpvId id) { - if (is_value(node)) { - String name = get_value_name(node); - if (name) - spvb_name(emitter->file_builder, id, name); - } - insert_dict_and_get_result(struct Node*, SpvId, emitter->node_ids, node, id); -} - -SpvId emit_value(Emitter* emitter, BBBuilder bb_builder, const Node* node) { - SpvId* existing = find_value_dict(const Node*, SpvId, emitter->node_ids, node); - if (existing) - return *existing; - - SpvId new; - switch (is_value(node)) { - case NotAValue: error(""); - case Variable_TAG: error("tried to emit a variable: but all variables should be emitted by enclosing scope or preceding instructions !"); - case Value_ConstrainedValue_TAG: - case Value_UntypedNumber_TAG: - case Value_FnAddr_TAG: error("Should be lowered away earlier!"); - case IntLiteral_TAG: { - new = spvb_fresh_id(emitter->file_builder); - SpvId ty = emit_type(emitter, node->type); - // 64-bit constants take two spirv words, anything else fits in one - if (node->payload.int_literal.width == IntTy64) { - uint32_t arr[] = { node->payload.int_literal.value & 0xFFFFFFFF, node->payload.int_literal.value >> 32 }; - spvb_constant(emitter->file_builder, new, ty, 2, arr); - } else { - uint32_t arr[] = { node->payload.int_literal.value }; - spvb_constant(emitter->file_builder, new, ty, 1, arr); - } - break; - } - case FloatLiteral_TAG: { - new = spvb_fresh_id(emitter->file_builder); - SpvId ty = emit_type(emitter, node->type); - switch (node->payload.float_literal.width) { - case FloatTy16: { - uint32_t arr[] = { node->payload.float_literal.value & 0xFFFF }; - spvb_constant(emitter->file_builder, new, ty, 1, arr); - break; - } - case FloatTy32: { - uint32_t arr[] = { node->payload.float_literal.value }; - spvb_constant(emitter->file_builder, new, ty, 1, arr); - break; - } - case FloatTy64: { - uint32_t arr[] = { node->payload.float_literal.value & 0xFFFFFFFF, node->payload.float_literal.value >> 32 }; - spvb_constant(emitter->file_builder, new, ty, 2, arr); - break; - } - } - break; - } - case True_TAG: { - new = spvb_fresh_id(emitter->file_builder); - spvb_bool_constant(emitter->file_builder, new, emit_type(emitter, bool_type(emitter->arena)), true); - break; - } - case False_TAG: { - new = spvb_fresh_id(emitter->file_builder); - spvb_bool_constant(emitter->file_builder, new, emit_type(emitter, bool_type(emitter->arena)), false); - break; - } - case Value_StringLiteral_TAG: { - new = spvb_debug_string(emitter->file_builder, node->payload.string_lit.string); - break; - } - case Value_NullPtr_TAG: { - new = spvb_constant_null(emitter->file_builder, emit_type(emitter, node->payload.null_ptr.ptr_type)); - break; - } - case Composite_TAG: { - Nodes contents = node->payload.composite.contents; - LARRAY(SpvId, ids, contents.count); - for (size_t i = 0; i < contents.count; i++) { - ids[i] = emit_value(emitter, bb_builder, contents.nodes[i]); - } - if (bb_builder) { - new = spvb_composite(bb_builder, emit_type(emitter, node->type), contents.count, ids); - return new; - } else { - new = spvb_constant_composite(emitter->file_builder, emit_type(emitter, node->type), contents.count, ids); - break; - } - } - case Value_Undef_TAG: { - new = spvb_undef(emitter->file_builder, emit_type(emitter, node->payload.undef.type)); - break; - } - case Value_Fill_TAG: error("lower me") - case RefDecl_TAG: { - const Node* decl = node->payload.ref_decl.decl; - switch (decl->tag) { - case GlobalVariable_TAG: { - new = emit_decl(emitter, decl); - break; - } - case Constant_TAG: { - const Node* init_value = get_quoted_value(decl->payload.constant.instruction); - if (!init_value && bb_builder) { - SpvId r; - emit_instruction(emitter, NULL, &bb_builder, NULL, decl->payload.constant.instruction, 1, &r); - return r; - } - assert(init_value && "TODO: support some measure of constant expressions"); - new = emit_value(emitter, NULL, init_value); - break; - } - default: error("RefDecl must reference a constant or global"); - } - break; - } - } - - insert_dict_and_get_result(struct Node*, SpvId, emitter->node_ids, node, new); - return new; -} - -SpvId spv_find_reserved_id(Emitter* emitter, const Node* node) { - SpvId* found = find_value_dict(const Node*, SpvId, emitter->node_ids, node); - assert(found); - return *found; -} - -static BBBuilder find_basic_block_builder(Emitter* emitter, SHADY_UNUSED FnBuilder fn_builder, const Node* bb) { - // assert(is_basic_block(bb)); - BBBuilder* found = find_value_dict(const Node*, BBBuilder, emitter->bb_builders, bb); - assert(found); - return *found; -} - -static void add_branch_phis(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_builder, const Node* jump) { - assert(jump->tag == Jump_TAG); - const Node* dst = jump->payload.jump.target; - Nodes args = jump->payload.jump.args; - // because it's forbidden to jump back into the entry block of a function - // (which is actually a Function in this IR, not a BasicBlock) - // we assert that the destination must be an actual BasicBlock - assert(is_basic_block(dst)); - BBBuilder dst_builder = find_basic_block_builder(emitter, fn_builder, dst); - struct List* phis = spbv_get_phis(dst_builder); - assert(entries_count_list(phis) == args.count); - for (size_t i = 0; i < args.count; i++) { - SpvbPhi* phi = read_list(SpvbPhi*, phis)[i]; - spvb_add_phi_source(phi, get_block_builder_id(bb_builder), emit_value(emitter, bb_builder, args.nodes[i])); - } -} - -void emit_terminator(Emitter* emitter, FnBuilder fn_builder, BBBuilder basic_block_builder, MergeTargets merge_targets, const Node* terminator) { - switch (is_terminator(terminator)) { - case Return_TAG: { - const Nodes* ret_values = &terminator->payload.fn_ret.args; - switch (ret_values->count) { - case 0: spvb_return_void(basic_block_builder); return; - case 1: spvb_return_value(basic_block_builder, emit_value(emitter, basic_block_builder, ret_values->nodes[0])); return; - default: { - LARRAY(SpvId, arr, ret_values->count); - for (size_t i = 0; i < ret_values->count; i++) - arr[i] = emit_value(emitter, basic_block_builder, ret_values->nodes[i]); - SpvId return_that = spvb_composite(basic_block_builder, fn_ret_type_id(fn_builder), ret_values->count, arr); - spvb_return_value(basic_block_builder, return_that); - return; - } - } - } - case Let_TAG: { - const Node* tail = get_let_tail(terminator); - Nodes params = tail->payload.case_.params; - LARRAY(SpvId, results, params.count); - emit_instruction(emitter, fn_builder, &basic_block_builder, &merge_targets, get_let_instruction(terminator), params.count, results); - assert(tail->tag == Case_TAG); - - for (size_t i = 0; i < params.count; i++) - register_result(emitter, params.nodes[i], results[i]); - emit_terminator(emitter, fn_builder, basic_block_builder, merge_targets, tail->payload.case_.body); - return; - } - case Jump_TAG: { - add_branch_phis(emitter, fn_builder, basic_block_builder, terminator); - spvb_branch(basic_block_builder, find_reserved_id(emitter, terminator->payload.jump.target)); - } - case Branch_TAG: { - SpvId condition = emit_value(emitter, basic_block_builder, terminator->payload.branch.branch_condition); - add_branch_phis(emitter, fn_builder, basic_block_builder, terminator->payload.branch.true_jump); - add_branch_phis(emitter, fn_builder, basic_block_builder, terminator->payload.branch.false_jump); - spvb_branch_conditional(basic_block_builder, condition, find_reserved_id(emitter, terminator->payload.branch.true_jump->payload.jump.target), find_reserved_id(emitter, terminator->payload.branch.false_jump->payload.jump.target)); - } - case Switch_TAG: { - SpvId inspectee = emit_value(emitter, basic_block_builder, terminator->payload.br_switch.switch_value); - LARRAY(SpvId, targets, terminator->payload.br_switch.case_jumps.count * 2); - for (size_t i = 0; i < terminator->payload.br_switch.case_jumps.count; i++) { - add_branch_phis(emitter, fn_builder, basic_block_builder, terminator->payload.br_switch.case_jumps.nodes[i]); - error("TODO finish") - } - add_branch_phis(emitter, fn_builder, basic_block_builder, terminator->payload.br_switch.default_jump); - SpvId default_tgt = find_reserved_id(emitter, terminator->payload.br_switch.default_jump->payload.jump.target); - - spvb_switch(basic_block_builder, inspectee, default_tgt, terminator->payload.br_switch.case_jumps.count, targets); - } - case LetMut_TAG: - case TailCall_TAG: - case Join_TAG: error("Lower me"); - case Terminator_Yield_TAG: { - Nodes args = terminator->payload.yield.args; - for (size_t i = 0; i < args.count; i++) - spvb_add_phi_source(merge_targets.join_phis[i], get_block_builder_id(basic_block_builder), emit_value(emitter, basic_block_builder, args.nodes[i])); - spvb_branch(basic_block_builder, merge_targets.join_target); - return; - } - case MergeContinue_TAG: { - Nodes args = terminator->payload.merge_continue.args; - for (size_t i = 0; i < args.count; i++) - spvb_add_phi_source(merge_targets.continue_phis[i], get_block_builder_id(basic_block_builder), emit_value(emitter, basic_block_builder, args.nodes[i])); - spvb_branch(basic_block_builder, merge_targets.continue_target); - return; - } - case MergeBreak_TAG: { - Nodes args = terminator->payload.merge_break.args; - for (size_t i = 0; i < args.count; i++) - spvb_add_phi_source(merge_targets.break_phis[i], get_block_builder_id(basic_block_builder), emit_value(emitter, basic_block_builder, args.nodes[i])); - spvb_branch(basic_block_builder, merge_targets.break_target); - return; - } - case Unreachable_TAG: { - spvb_unreachable(basic_block_builder); - return; - } - case NotATerminator: error("TODO: emit terminator %s", node_tags[terminator->tag]); - } - SHADY_UNREACHABLE; -} - -static void emit_basic_block(Emitter* emitter, FnBuilder fn_builder, const Scope* scope, const CFNode* cf_node) { - const Node* bb_node = cf_node->node; - assert(is_basic_block(bb_node) || cf_node == scope->entry); - - const Node* body = get_abstraction_body(bb_node); - - // Find the preassigned ID to this - BBBuilder bb_builder = find_basic_block_builder(emitter, fn_builder, bb_node); - SpvId bb_id = get_block_builder_id(bb_builder); - spvb_add_bb(fn_builder, bb_builder); - - if (is_basic_block(bb_node)) - spvb_name(emitter->file_builder, bb_id, bb_node->payload.basic_block.name); - - MergeTargets merge_targets = { - .continue_target = 0, - .break_target = 0, - .join_target = 0 - }; - - emit_terminator(emitter, fn_builder, bb_builder, merge_targets, body); -} - -static void emit_function(Emitter* emitter, const Node* node) { - assert(node->tag == Function_TAG); - - const Type* fn_type = node->type; - SpvId fn_id = find_reserved_id(emitter, node); - FnBuilder fn_builder = spvb_begin_fn(emitter->file_builder, fn_id, emit_type(emitter, fn_type), nodes_to_codom(emitter, node->payload.fun.return_types)); - - Nodes params = node->payload.fun.params; - for (size_t i = 0; i < params.count; i++) { - const Type* param_type = params.nodes[i]->payload.var.type; - SpvId param_id = spvb_parameter(fn_builder, emit_type(emitter, param_type)); - insert_dict_and_get_result(struct Node*, SpvId, emitter->node_ids, params.nodes[i], param_id); - deconstruct_qualified_type(¶m_type); - if (param_type->tag == PtrType_TAG && param_type->payload.ptr_type.address_space == AsGlobalPhysical) { - spvb_decorate(emitter->file_builder, param_id, SpvDecorationAliased, 0, NULL); - } - } - - if (node->payload.fun.body) { - Scope* scope = new_scope(node); - // reserve a bunch of identifiers for the basic blocks in the scope - for (size_t i = 0; i < scope->size; i++) { - CFNode* cfnode = read_list(CFNode*, scope->contents)[i]; - assert(cfnode); - const Node* bb = cfnode->node; - if (is_case(bb)) - continue; - assert(is_basic_block(bb) || bb == node); - SpvId bb_id = spvb_fresh_id(emitter->file_builder); - BBBuilder basic_block_builder = spvb_begin_bb(fn_builder, bb_id); - insert_dict(const Node*, BBBuilder, emitter->bb_builders, bb, basic_block_builder); - // add phis for every non-entry basic block - if (i > 0) { - assert(is_basic_block(bb) && bb != node); - Nodes bb_params = bb->payload.basic_block.params; - for (size_t j = 0; j < bb_params.count; j++) { - const Node* bb_param = bb_params.nodes[j]; - spvb_add_phi(basic_block_builder, emit_type(emitter, bb_param->type), spvb_fresh_id(emitter->file_builder)); - } - // also make sure to register the label for basic blocks - register_result(emitter, bb, bb_id); - } - } - // emit the blocks using the dominator tree - //emit_basic_block(emitter, fn_builder, &scope, scope.entry); - for (size_t i = 0; i < scope->size; i++) { - CFNode* cfnode = scope->rpo[i]; - if (i == 0) - assert(cfnode == scope->entry); - if (is_case(cfnode->node)) - continue; - emit_basic_block(emitter, fn_builder, scope, cfnode); - } - - destroy_scope(scope); - - spvb_define_function(emitter->file_builder, fn_builder); - } else { - Growy* g = new_growy(); - spvb_literal_name(g, get_abstraction_name(node)); - growy_append_bytes(g, 4, (char*) &(uint32_t) { SpvLinkageTypeImport }); - spvb_decorate(emitter->file_builder, fn_id, SpvDecorationLinkageAttributes, growy_size(g) / 4, (uint32_t*) growy_data(g)); - destroy_growy(g); - spvb_declare_function(emitter->file_builder, fn_builder); - } -} - -SpvId emit_decl(Emitter* emitter, const Node* decl) { - SpvId* existing = find_value_dict(const Node*, SpvId, emitter->node_ids, decl); - if (existing) - return *existing; - - switch (is_declaration(decl)) { - case GlobalVariable_TAG: { - const GlobalVariable* gvar = &decl->payload.global_variable; - SpvId given_id = spvb_fresh_id(emitter->file_builder); - register_result(emitter, decl, given_id); - spvb_name(emitter->file_builder, given_id, gvar->name); - SpvId init = 0; - if (gvar->init) - init = emit_value(emitter, NULL, gvar->init); - assert(!is_physical_as(gvar->address_space)); - SpvStorageClass storage_class = emit_addr_space(emitter, gvar->address_space); - spvb_global_variable(emitter->file_builder, given_id, emit_type(emitter, decl->type), storage_class, false, init); - - Builtin b = BuiltinsCount; - for (size_t i = 0; i < gvar->annotations.count; i++) { - const Node* a = gvar->annotations.nodes[i]; - assert(is_annotation(a)); - String name = get_annotation_name(a); - if (strcmp(name, "Builtin") == 0) { - String builtin_name = get_annotation_string_payload(a); - assert(builtin_name); - assert(b == BuiltinsCount && "Only one @Builtin annotation permitted."); - b = get_builtin_by_name(builtin_name); - assert(b != BuiltinsCount); - SpvBuiltIn d = spv_builtins[b]; - uint32_t decoration_payload[] = { d }; - spvb_decorate(emitter->file_builder, given_id, SpvDecorationBuiltIn, 1, decoration_payload); - } else if (strcmp(name, "Location") == 0) { - size_t loc = get_int_literal_value(*resolve_to_int_literal(get_annotation_value(a)), false); - assert(loc >= 0); - spvb_decorate(emitter->file_builder, given_id, SpvDecorationLocation, 1, (uint32_t[]) { loc }); - } else if (strcmp(name, "DescriptorSet") == 0) { - size_t loc = get_int_literal_value(*resolve_to_int_literal(get_annotation_value(a)), false); - assert(loc >= 0); - spvb_decorate(emitter->file_builder, given_id, SpvDecorationDescriptorSet, 1, (uint32_t[]) { loc }); - } else if (strcmp(name, "DescriptorBinding") == 0) { - size_t loc = get_int_literal_value(*resolve_to_int_literal(get_annotation_value(a)), false); - assert(loc >= 0); - spvb_decorate(emitter->file_builder, given_id, SpvDecorationBinding, 1, (uint32_t[]) { loc }); - } - } - - switch (storage_class) { - case SpvStorageClassPushConstant: { - break; - } - case SpvStorageClassStorageBuffer: - case SpvStorageClassUniform: - case SpvStorageClassUniformConstant: { - const Node* descriptor_set = lookup_annotation(decl, "DescriptorSet"); - const Node* descriptor_binding = lookup_annotation(decl, "DescriptorBinding"); - assert(descriptor_set && descriptor_binding && "DescriptorSet and/or DescriptorBinding annotations are missing"); - break; - } - default: break; - } - - return given_id; - } case Function_TAG: { - SpvId given_id = spvb_fresh_id(emitter->file_builder); - register_result(emitter, decl, given_id); - spvb_name(emitter->file_builder, given_id, decl->payload.fun.name); - emit_function(emitter, decl); - return given_id; - } case Constant_TAG: { - // We don't emit constants at all ! - // With RefDecl, we directly grab the underlying value and emit that there and then. - // Emitting constants as their own IDs would be nicer, but it's painful to do because decls need their ID to be reserved in advance, - // but we also desire to cache reused values instead of emitting them multiple times. This means we can't really "force" an ID for a given value. - // The ideal fix would be if SPIR-V offered a way to "alias" an ID under a new one. This would allow applying new debug information to the decl ID, separate from the other instances of that value. - return 0; - } case NominalType_TAG: { - SpvId given_id = spvb_fresh_id(emitter->file_builder); - register_result(emitter, decl, given_id); - spvb_name(emitter->file_builder, given_id, decl->payload.nom_type.name); - emit_nominal_type_body(emitter, decl->payload.nom_type.body, given_id); - return given_id; - } - case NotADeclaration: error(""); - } - error("unreachable"); -} - -static SpvExecutionModel emit_exec_model(ExecutionModel model) { - switch (model) { - case EmCompute: return SpvExecutionModelGLCompute; - case EmVertex: return SpvExecutionModelVertex; - case EmFragment: return SpvExecutionModelFragment; - case EmNone: error("No execution model but we were asked to emit it anyways"); - } -} - -static void emit_entry_points(Emitter* emitter, Nodes declarations) { - // First, collect all the global variables, they're needed for the interface section of OpEntryPoint - // it can be a superset of the ones actually used, so the easiest option is to just grab _all_ global variables and shove them in there - // my gut feeling says it's unlikely any drivers actually care, but validation needs to be happy so here we go... - LARRAY(SpvId, interface_arr, declarations.count); - size_t interface_size = 0; - for (size_t i = 0; i < declarations.count; i++) { - const Node* node = declarations.nodes[i]; - if (node->tag != GlobalVariable_TAG) continue; - // Prior to SPIRV 1.4, _only_ input and output variables should be found here. - if (emitter->configuration->target_spirv_version.major == 1 && - emitter->configuration->target_spirv_version.minor < 4) { - switch (node->payload.global_variable.address_space) { - case AsOutput: - case AsInput: break; - default: continue; - } - } - interface_arr[interface_size++] = find_reserved_id(emitter, node); - } - - for (size_t i = 0; i < declarations.count; i++) { - const Node* decl = declarations.nodes[i]; - if (decl->tag != Function_TAG) continue; - SpvId fn_id = find_reserved_id(emitter, decl); - - const Node* entry_point = lookup_annotation(decl, "EntryPoint"); - if (entry_point) { - ExecutionModel execution_model = execution_model_from_string(get_string_literal(emitter->arena, get_annotation_value(entry_point))); - assert(execution_model != EmNone); - - spvb_entry_point(emitter->file_builder, emit_exec_model(execution_model), fn_id, decl->payload.fun.name, interface_size, interface_arr); - emitter->num_entry_pts++; - - const Node* workgroup_size = lookup_annotation(decl, "WorkgroupSize"); - if (execution_model == EmCompute) - assert(workgroup_size); - if (workgroup_size) { - Nodes values = get_annotation_values(workgroup_size); - assert(values.count == 3); - uint32_t wg_x_dim = (uint32_t) get_int_literal_value(*resolve_to_int_literal(values.nodes[0]), false); - uint32_t wg_y_dim = (uint32_t) get_int_literal_value(*resolve_to_int_literal(values.nodes[1]), false); - uint32_t wg_z_dim = (uint32_t) get_int_literal_value(*resolve_to_int_literal(values.nodes[2]), false); - - spvb_execution_mode(emitter->file_builder, fn_id, SpvExecutionModeLocalSize, 3, (uint32_t[3]) { wg_x_dim, wg_y_dim, wg_z_dim }); - } - - if (execution_model == EmFragment) { - spvb_execution_mode(emitter->file_builder, fn_id, SpvExecutionModeOriginUpperLeft, 0, NULL); - } - } - } -} - -static void emit_decls(Emitter* emitter, Nodes declarations) { - for (size_t i = 0; i < declarations.count; i++) { - const Node* decl = declarations.nodes[i]; - emit_decl(emitter, decl); - } -} - -SpvId get_extended_instruction_set(Emitter* emitter, const char* name) { - SpvId* found = find_value_dict(const char*, SpvId, emitter->extended_instruction_sets, name); - if (found) - return *found; - - SpvId new = spvb_extended_import(emitter->file_builder, name); - insert_dict(const char*, SpvId, emitter->extended_instruction_sets, name, new); - return new; -} - -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); - -KeyHash hash_string(const char** string); -bool compare_string(const char** a, const char** b); - -static Module* run_backend_specific_passes(CompilerConfig* config, Module* initial_mod) { - IrArena* initial_arena = initial_mod->arena; - Module* old_mod = NULL; - Module** pmod = &initial_mod; - - RUN_PASS(lower_entrypoint_args) - RUN_PASS(spirv_map_entrypoint_args) - RUN_PASS(spirv_lift_globals_ssbo) - RUN_PASS(import) - - return *pmod; -} - -void emit_spirv(CompilerConfig* config, Module* mod, size_t* output_size, char** output, Module** new_mod) { - IrArena* initial_arena = get_module_arena(mod); - mod = run_backend_specific_passes(config, mod); - IrArena* arena = get_module_arena(mod); - - FileBuilder file_builder = spvb_begin(); - spvb_set_version(file_builder, config->target_spirv_version.major, config->target_spirv_version.minor); - spvb_set_addressing_model(file_builder, SpvAddressingModelLogical); - - Emitter emitter = { - .module = mod, - .arena = arena, - .configuration = config, - .file_builder = file_builder, - .node_ids = new_dict(Node*, SpvId, (HashFn) hash_node, (CmpFn) compare_node), - .bb_builders = new_dict(Node*, BBBuilder, (HashFn) hash_node, (CmpFn) compare_node), - .num_entry_pts = 0, - }; - - emitter.extended_instruction_sets = new_dict(const char*, SpvId, (HashFn) hash_string, (CmpFn) compare_string); - - emitter.void_t = spvb_void_type(emitter.file_builder); - - spvb_extension(file_builder, "SPV_KHR_non_semantic_info"); - - Nodes decls = get_module_declarations(mod); - emit_decls(&emitter, decls); - emit_entry_points(&emitter, decls); - - if (emitter.num_entry_pts == 0) - spvb_capability(file_builder, SpvCapabilityLinkage); - - spvb_capability(file_builder, SpvCapabilityShader); - - *output_size = spvb_finish(file_builder, output); - - // cleanup the emitter - destroy_dict(emitter.node_ids); - destroy_dict(emitter.bb_builders); - destroy_dict(emitter.extended_instruction_sets); - - if (new_mod) - *new_mod = mod; - else if (initial_arena != arena) - destroy_ir_arena(arena); -} diff --git a/src/shady/emit/spirv/emit_spv.h b/src/shady/emit/spirv/emit_spv.h deleted file mode 100644 index 3c0dca14f..000000000 --- a/src/shady/emit/spirv/emit_spv.h +++ /dev/null @@ -1,56 +0,0 @@ -#ifndef SHADY_EMIT_SPIRV_H -#define SHADY_EMIT_SPIRV_H - -#include "shady/ir.h" -#include "spirv_builder.h" - -typedef SpvbFileBuilder* FileBuilder; -typedef SpvbFnBuilder* FnBuilder; -typedef SpvbBasicBlockBuilder* BBBuilder; - -typedef struct Emitter_ { - Module* module; - IrArena* arena; - CompilerConfig* configuration; - FileBuilder file_builder; - SpvId void_t; - struct Dict* node_ids; - struct Dict* bb_builders; - size_t num_entry_pts; - - struct Dict* extended_instruction_sets; -} Emitter; - -typedef SpvbPhi** Phis; - -typedef struct { - SpvId continue_target, break_target, join_target; - Phis continue_phis, break_phis, join_phis; -} MergeTargets; - -#define emit_decl spv_emit_decl -#define emit_type spv_emit_type -#define emit_value spv_emit_value -#define emit_instruction spv_emit_instruction -#define emit_terminator spv_emit_terminator -#define find_reserved_id spv_find_reserved_id -#define emit_nominal_type_body spv_emit_nominal_type_body - -SpvId emit_decl(Emitter*, const Node*); -SpvId emit_type(Emitter*, const Type*); -SpvId emit_value(Emitter*, BBBuilder, const Node*); -void emit_instruction(Emitter*, FnBuilder, BBBuilder*, MergeTargets*, const Node* instruction, size_t results_count, SpvId results[]); -void emit_terminator(Emitter*, FnBuilder, BBBuilder, MergeTargets, const Node* terminator); - -SpvId find_reserved_id(Emitter* emitter, const Node* node); -void register_result(Emitter*, const Node*, SpvId id); - -SpvId get_extended_instruction_set(Emitter*, const char*); - -SpvStorageClass emit_addr_space(Emitter*, AddressSpace address_space); -// SPIR-V doesn't have multiple return types, this bridges the gap... -SpvId nodes_to_codom(Emitter* emitter, Nodes return_types); -const Type* normalize_type(Emitter* emitter, const Type* type); -void emit_nominal_type_body(Emitter* emitter, const Type* type, SpvId id); - -#endif diff --git a/src/shady/emit/spirv/emit_spv_instructions.c b/src/shady/emit/spirv/emit_spv_instructions.c deleted file mode 100644 index dd09c5ed1..000000000 --- a/src/shady/emit/spirv/emit_spv_instructions.c +++ /dev/null @@ -1,638 +0,0 @@ -#include "emit_spv.h" - -#include "log.h" -#include "portability.h" - -#include "../../type.h" -#include "../../transform/memory_layout.h" -#include "../../transform/ir_gen_helpers.h" - -#include - -#include "spirv/unified1/NonSemanticDebugPrintf.h" -#include "spirv/unified1/GLSL.std.450.h" - -typedef enum { - Custom, Plain, -} InstrClass; - -/// What is considered when searching for an instruction -typedef enum { - None, Monomorphic, FirstOp, FirstAndResult -} ISelMechanism; - -typedef enum { - Same, SameTuple, Bool, Void, TyOperand -} ResultClass; - -typedef enum { - Signed, Unsigned, FP, Logical, Ptr, OperandClassCount -} OperandClass; - -static OperandClass classify_operand_type(const Type* type) { - assert(is_type(type) && is_data_type(type)); - - if (type->tag == PackType_TAG) - return classify_operand_type(type->payload.pack_type.element_type); - - switch (type->tag) { - case Int_TAG: return type->payload.int_type.is_signed ? Signed : Unsigned; - case Bool_TAG: return Logical; - case PtrType_TAG: return Ptr; - case Float_TAG: return FP; - default: error("we don't know what to do with this") - } -} - -typedef struct { - InstrClass class; - ISelMechanism isel_mechanism; - ResultClass result_kind; - union { - SpvOp op; - // matches first operand - SpvOp fo[OperandClassCount]; - // matches first operand and return type [first operand][result type] - SpvOp foar[OperandClassCount][OperandClassCount]; - }; - const char* extended_set; -} IselTableEntry; - -#define ISEL_IDENTITY (SpvOpNop /* no-op, should be lowered to nothing beforehand */) -#define ISEL_LOWERME (SpvOpMax /* boolean conversions don't exist as a single instruction, a pass should lower them instead */) -#define ISEL_ILLEGAL (SpvOpMax /* doesn't make sense to support */) -#define ISEL_CUSTOM (SpvOpMax /* doesn't make sense to support */) - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wmissing-field-initializers" - -static const IselTableEntry isel_table[] = { - [add_op] = {Plain, FirstOp, Same, .fo = {SpvOpIAdd, SpvOpIAdd, SpvOpFAdd, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [sub_op] = {Plain, FirstOp, Same, .fo = {SpvOpISub, SpvOpISub, SpvOpFSub, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [mul_op] = {Plain, FirstOp, Same, .fo = {SpvOpIMul, SpvOpIMul, SpvOpFMul, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [div_op] = {Plain, FirstOp, Same, .fo = {SpvOpSDiv, SpvOpUDiv, SpvOpFDiv, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [mod_op] = {Plain, FirstOp, Same, .fo = {SpvOpSMod, SpvOpUMod, SpvOpFMod, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - - [add_carry_op] = {Plain, FirstOp, SameTuple, .fo = {SpvOpIAddCarry, SpvOpIAddCarry, ISEL_ILLEGAL }}, - [sub_borrow_op] = {Plain, FirstOp, SameTuple, .fo = {SpvOpISubBorrow, SpvOpISubBorrow, ISEL_ILLEGAL }}, - [mul_extended_op] = {Plain, FirstOp, SameTuple, .fo = {SpvOpSMulExtended, SpvOpUMulExtended, ISEL_ILLEGAL }}, - - [neg_op] = {Plain, FirstOp, Same, .fo = {SpvOpSNegate, SpvOpSNegate, SpvOpFNegate }}, - - [eq_op] = {Plain, FirstOp, Bool, .fo = {SpvOpIEqual, SpvOpIEqual, SpvOpFOrdEqual, SpvOpLogicalEqual }}, - [neq_op] = {Plain, FirstOp, Bool, .fo = {SpvOpINotEqual, SpvOpINotEqual, SpvOpFOrdNotEqual, SpvOpLogicalNotEqual }}, - [lt_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSLessThan, SpvOpULessThan, SpvOpFOrdLessThan, ISEL_IDENTITY }}, - [lte_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSLessThanEqual, SpvOpULessThanEqual, SpvOpFOrdLessThanEqual, ISEL_IDENTITY }}, - [gt_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSGreaterThan, SpvOpUGreaterThan, SpvOpFOrdGreaterThan, ISEL_IDENTITY }}, - [gte_op] = {Plain, FirstOp, Bool, .fo = {SpvOpSGreaterThanEqual, SpvOpUGreaterThanEqual, SpvOpFOrdGreaterThanEqual, ISEL_IDENTITY }}, - - [not_op] = {Plain, FirstOp, Same, .fo = {SpvOpNot, SpvOpNot, ISEL_ILLEGAL, SpvOpLogicalNot }}, - - [and_op] = {Plain, FirstOp, Same, .fo = {SpvOpBitwiseAnd, SpvOpBitwiseAnd, ISEL_ILLEGAL, SpvOpLogicalAnd }}, - [or_op] = {Plain, FirstOp, Same, .fo = {SpvOpBitwiseOr, SpvOpBitwiseOr, ISEL_ILLEGAL, SpvOpLogicalOr }}, - [xor_op] = {Plain, FirstOp, Same, .fo = {SpvOpBitwiseXor, SpvOpBitwiseXor, ISEL_ILLEGAL, SpvOpLogicalNotEqual }}, - - [lshift_op] = {Plain, FirstOp, Same, .fo = {SpvOpShiftLeftLogical, SpvOpShiftLeftLogical, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [rshift_arithm_op] = {Plain, FirstOp, Same, .fo = {SpvOpShiftRightArithmetic, SpvOpShiftRightArithmetic, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [rshift_logical_op] = {Plain, FirstOp, Same, .fo = {SpvOpShiftRightLogical, SpvOpShiftRightLogical, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - - [convert_op] = {Plain, FirstAndResult, TyOperand, .foar = { - { SpvOpSConvert, SpvOpUConvert, SpvOpConvertSToF, ISEL_LOWERME, ISEL_LOWERME }, - { SpvOpSConvert, SpvOpUConvert, SpvOpConvertUToF, ISEL_LOWERME, ISEL_LOWERME }, - { SpvOpConvertFToS, SpvOpConvertFToU, SpvOpFConvert, ISEL_ILLEGAL, ISEL_ILLEGAL }, - { ISEL_LOWERME, ISEL_LOWERME, ISEL_ILLEGAL, ISEL_IDENTITY, ISEL_ILLEGAL }, - { ISEL_LOWERME, ISEL_LOWERME, ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_IDENTITY } - }}, - - [reinterpret_op] = {Plain, FirstAndResult, TyOperand, .foar = { - { ISEL_ILLEGAL, SpvOpBitcast, SpvOpBitcast, ISEL_ILLEGAL, SpvOpConvertUToPtr }, - { SpvOpBitcast, ISEL_ILLEGAL, SpvOpBitcast, ISEL_ILLEGAL, SpvOpConvertUToPtr }, - { SpvOpBitcast, SpvOpBitcast, ISEL_IDENTITY, ISEL_ILLEGAL, ISEL_ILLEGAL /* no fp-ptr casts */ }, - { ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_IDENTITY, ISEL_ILLEGAL /* no bool reinterpret */ }, - { SpvOpConvertPtrToU, SpvOpConvertPtrToU, ISEL_ILLEGAL, ISEL_ILLEGAL, ISEL_IDENTITY } - }}, - - [sqrt_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Sqrt }, - [inv_sqrt_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450InverseSqrt}, - [floor_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Floor }, - [ceil_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Ceil }, - [round_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Round }, - [fract_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Fract }, - [sin_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Sin }, - [cos_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Cos }, - - [abs_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = { (SpvOp) GLSLstd450SAbs, ISEL_ILLEGAL, (SpvOp) GLSLstd450FAbs, ISEL_ILLEGAL }}, - [sign_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = { (SpvOp) GLSLstd450SSign, ISEL_ILLEGAL, (SpvOp) GLSLstd450FSign, ISEL_ILLEGAL }}, - - [min_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMin, (SpvOp) GLSLstd450UMin, (SpvOp) GLSLstd450FMin, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [max_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMax, (SpvOp) GLSLstd450UMax, (SpvOp) GLSLstd450FMax, ISEL_ILLEGAL, ISEL_ILLEGAL }}, - [exp_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Exp}, - [pow_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Pow}, - - [debug_printf_op] = {Plain, Monomorphic, Void, .extended_set = "NonSemantic.DebugPrintf", .op = (SpvOp) NonSemanticDebugPrintfDebugPrintf}, - - [sample_texture_op] = {Plain, Monomorphic, TyOperand, .op = SpvOpImageSampleImplicitLod }, - - [subgroup_assume_uniform_op] = {Plain, Monomorphic, Same, .op = ISEL_IDENTITY }, - - [PRIMOPS_COUNT] = { Custom } -}; - -#pragma GCC diagnostic pop -#pragma GCC diagnostic error "-Wswitch" - -static const Type* get_result_t(Emitter* emitter, IselTableEntry entry, Nodes args, Nodes type_arguments) { - switch (entry.result_kind) { - case Same: return get_unqualified_type(first(args)->type); - case SameTuple: return record_type(emitter->arena, (RecordType) { .members = mk_nodes(emitter->arena, get_unqualified_type(first(args)->type), get_unqualified_type(first(args)->type)) }); - case Bool: return bool_type(emitter->arena); - case TyOperand: return first(type_arguments); - case Void: return unit_type(emitter->arena); - } -} - -static SpvOp get_opcode(SHADY_UNUSED Emitter* emitter, IselTableEntry entry, Nodes args, Nodes type_arguments) { - switch (entry.isel_mechanism) { - case None: return SpvOpMax; - case Monomorphic: return entry.op; - case FirstOp: { - assert(args.count >= 1); - OperandClass op_class = classify_operand_type(get_unqualified_type(first(args)->type)); - return entry.fo[op_class]; - } - case FirstAndResult: { - assert(args.count >= 1); - assert(type_arguments.count == 1); - OperandClass op_class = classify_operand_type(get_unqualified_type(first(args)->type)); - OperandClass return_t_class = classify_operand_type(first(type_arguments)); - return entry.foar[op_class][return_t_class]; - } - } -} - -static void emit_primop(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_builder, const Node* instr, size_t results_count, SpvId results[]) { - PrimOp the_op = instr->payload.prim_op; - Nodes args = the_op.operands; - Nodes type_arguments = the_op.type_arguments; - - IselTableEntry entry = isel_table[the_op.op]; - if (entry.class != Custom) { - assert(results_count <= 1); - LARRAY(SpvId, emitted_args, args.count); - for (size_t i = 0; i < args.count; i++) - emitted_args[i] = emit_value(emitter, bb_builder, args.nodes[i]); - - switch (entry.class) { - case Plain: { - SpvOp opcode = get_opcode(emitter, entry, args, type_arguments); - if (opcode == SpvOpNop) { - assert(results_count == 1); - results[0] = emitted_args[0]; - return; - } - - Nodes results_ts = unwrap_multiple_yield_types(emitter->arena, instr->type); - SpvId result_t = results_ts.count >= 1 ? emit_type(emitter, instr->type) : emitter->void_t; - - assert(opcode != SpvOpMax); - - if (entry.extended_set) { - SpvId set_id = get_extended_instruction_set(emitter, entry.extended_set); - - SpvId result = spvb_ext_instruction(bb_builder, result_t, set_id, opcode, args.count, emitted_args); - if (results_count == 1) - results[0] = result; - } else { - SpvId result = spvb_op(bb_builder, opcode, result_t, args.count, emitted_args); - if (results_count == 1) - results[0] = result; - } - return; - } - case Custom: SHADY_UNREACHABLE; - } - - return; - } - switch (the_op.op) { - case subgroup_ballot_op: { - const Type* i32x4 = pack_type(emitter->arena, (PackType) { .width = 4, .element_type = uint32_type(emitter->arena) }); - SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); - SpvId raw_result = spvb_group_ballot(bb_builder, emit_type(emitter, i32x4), emit_value(emitter, bb_builder, first(args)), scope_subgroup); - // TODO: why are we doing this in SPIR-V and not the IR ? - SpvId low32 = spvb_extract(bb_builder, emit_type(emitter, uint32_type(emitter->arena)), raw_result, 1, (uint32_t[]) { 0 }); - SpvId hi32 = spvb_extract(bb_builder, emit_type(emitter, uint32_type(emitter->arena)), raw_result, 1, (uint32_t[]) { 1 }); - SpvId low64 = spvb_op(bb_builder, SpvOpUConvert, emit_type(emitter, uint64_type(emitter->arena)), 1, &low32); - SpvId hi64 = spvb_op(bb_builder, SpvOpUConvert, emit_type(emitter, uint64_type(emitter->arena)), 1, &hi32); - hi64 = spvb_op(bb_builder, SpvOpShiftLeftLogical, emit_type(emitter, uint64_type(emitter->arena)), 2, (SpvId []) { hi64, emit_value(emitter, bb_builder, int64_literal(emitter->arena, 32)) }); - SpvId final_result = spvb_op(bb_builder, SpvOpBitwiseOr, emit_type(emitter, uint64_type(emitter->arena)), 2, (SpvId []) { low64, hi64 }); - assert(results_count == 1); - results[0] = final_result; - spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformBallot); - return; - } - case subgroup_broadcast_first_op: { - SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); - SpvId result; - - if (emitter->configuration->hacks.spv_shuffle_instead_of_broadcast_first) { - SpvId local_id; - const Node* b = ref_decl_helper(emitter->arena, get_builtin(emitter->module, BuiltinSubgroupLocalInvocationId, NULL)); - emit_primop(emitter, fn_builder, bb_builder, prim_op(emitter->arena, (PrimOp) { .op = load_op, .operands = singleton(b) }), 1, &local_id); - result = spvb_group_shuffle(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), scope_subgroup, emit_value(emitter, bb_builder, first(args)), local_id); - spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformShuffle); - } else { - result = spvb_group_broadcast_first(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), emit_value(emitter, bb_builder, first(args)), scope_subgroup); - } - - assert(results_count == 1); - results[0] = result; - spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformBallot); - return; - } - case subgroup_reduce_sum_op: { - SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); - assert(results_count == 1); - results[0] = spvb_group_non_uniform_iadd(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), emit_value(emitter, bb_builder, first(args)), scope_subgroup, SpvGroupOperationReduce, NULL); - spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniformArithmetic); - return; - } - case subgroup_elect_first_op: { - SpvId result_t = emit_type(emitter, bool_type(emitter->arena)); - SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); - SpvId result = spvb_group_elect(bb_builder, result_t, scope_subgroup); - assert(results_count == 1); - results[0] = result; - spvb_capability(emitter->file_builder, SpvCapabilityGroupNonUniform); - return; - } - case insert_op: - case extract_dynamic_op: - case extract_op: { - assert(results_count == 1); - bool insert = the_op.op == insert_op; - - const Node* src_value = first(args); - const Type* result_t = instr->type; - size_t indices_start = insert ? 2 : 1; - size_t indices_count = args.count - indices_start; - assert(args.count > indices_start); - - bool dynamic = the_op.op == extract_dynamic_op; - - if (dynamic) { - LARRAY(SpvId, indices, indices_count); - for (size_t i = 0; i < indices_count; i++) { - indices[i] = emit_value(emitter, bb_builder, args.nodes[i + indices_start]); - } - assert(indices_count == 1); - results[0] = spvb_vector_extract_dynamic(bb_builder, emit_type(emitter, result_t), emit_value(emitter, bb_builder, src_value), indices[0]); - } else { - LARRAY(uint32_t, indices, indices_count); - for (size_t i = 0; i < indices_count; i++) { - // TODO: fallback to Dynamic variants transparently - indices[i] = get_int_literal_value(*resolve_to_int_literal(args.nodes[i + indices_start]), false); - } - - if (!insert) { - results[0] = spvb_extract(bb_builder, emit_type(emitter, result_t), emit_value(emitter, bb_builder, src_value), indices_count, indices); - } else - results[0] = spvb_insert(bb_builder, emit_type(emitter, result_t), emit_value(emitter, bb_builder, args.nodes[1]), emit_value(emitter, bb_builder, src_value), indices_count, indices); - } - return; - } - case shuffle_op: { - const Type* result_t = instr->type; - SpvId a = emit_value(emitter, bb_builder, args.nodes[0]); - SpvId b = emit_value(emitter, bb_builder, args.nodes[1]); - LARRAY(uint32_t, indices, args.count - 2); - for (size_t i = 0; i < args.count - 2; i++) { - int64_t indice = get_int_literal_value(*resolve_to_int_literal(args.nodes[i + 2]), true); - if (indice == -1) - indices[i] = 0xFFFFFFFF; - else - indices[i] = indice; - } - assert(results_count == 1); - results[0] = spvb_vecshuffle(bb_builder, emit_type(emitter, result_t), a, b, args.count - 2, indices); - return; - } - case load_op: { - const Type* ptr_type = first(args)->type; - deconstruct_qualified_type(&ptr_type); - assert(ptr_type->tag == PtrType_TAG); - const Type* elem_type = ptr_type->payload.ptr_type.pointed_type; - - size_t operands_count = 0; - uint32_t operands[2]; - if (ptr_type->payload.ptr_type.address_space == AsGlobalPhysical) { - // TODO only do this in VK mode ? - TypeMemLayout layout = get_mem_layout(emitter->arena, elem_type); - operands[operands_count + 0] = SpvMemoryAccessAlignedMask; - operands[operands_count + 1] = (uint32_t) layout.alignment_in_bytes; - operands_count += 2; - } - - SpvId eptr = emit_value(emitter, bb_builder, first(args)); - SpvId result = spvb_load(bb_builder, emit_type(emitter, elem_type), eptr, operands_count, operands); - assert(results_count == 1); - results[0] = result; - return; - } - case store_op: { - const Type* ptr_type = first(args)->type; - deconstruct_qualified_type(&ptr_type); - assert(ptr_type->tag == PtrType_TAG); - const Type* elem_type = ptr_type->payload.ptr_type.pointed_type; - - size_t operands_count = 0; - uint32_t operands[2]; - if (ptr_type->payload.ptr_type.address_space == AsGlobalPhysical) { - // TODO only do this in VK mode ? - TypeMemLayout layout = get_mem_layout(emitter->arena, elem_type); - operands[operands_count + 0] = SpvMemoryAccessAlignedMask; - operands[operands_count + 1] = (uint32_t) layout.alignment_in_bytes; - operands_count += 2; - } - - SpvId eptr = emit_value(emitter, bb_builder, first(args)); - SpvId eval = emit_value(emitter, bb_builder, args.nodes[1]); - spvb_store(bb_builder, eval, eptr, operands_count, operands); - assert(results_count == 0); - return; - } - case alloca_logical_op: { - const Type* elem_type = first(type_arguments); - SpvId result = spvb_local_variable(fn_builder, emit_type(emitter, ptr_type(emitter->arena, (PtrType) { - .address_space = AsFunctionLogical, - .pointed_type = elem_type - })), SpvStorageClassFunction); - assert(results_count == 1); - results[0] = result; - return; - } - case lea_op: { - SpvId base = emit_value(emitter, bb_builder, first(args)); - - LARRAY(SpvId, indices, args.count - 2); - for (size_t i = 2; i < args.count; i++) - indices[i - 2] = args.nodes[i] ? emit_value(emitter, bb_builder, args.nodes[i]) : 0; - - const IntLiteral* known_offset = resolve_to_int_literal(args.nodes[1]); - if (known_offset && known_offset->value == 0) { - const Type* target_type = instr->type; - SpvId result = spvb_access_chain(bb_builder, emit_type(emitter, target_type), base, args.count - 2, indices); - assert(results_count == 1); - results[0] = result; - } else { - error("TODO: OpPtrAccessChain") - } - return; - } - case select_op: { - SpvId cond = emit_value(emitter, bb_builder, first(args)); - SpvId truv = emit_value(emitter, bb_builder, args.nodes[1]); - SpvId flsv = emit_value(emitter, bb_builder, args.nodes[2]); - - SpvId result = spvb_select(bb_builder, emit_type(emitter, args.nodes[1]->type), cond, truv, flsv); - assert(results_count == 1); - results[0] = result; - return; - } - default: error("TODO: unhandled op"); - } - error("unreachable"); -} - -static void emit_leaf_call(Emitter* emitter, SHADY_UNUSED FnBuilder fn_builder, BBBuilder bb_builder, Call call, size_t results_count, SpvId results[]) { - const Node* fn = call.callee; - assert(fn->tag == FnAddr_TAG); - fn = fn->payload.fn_addr.fn; - SpvId callee = emit_decl(emitter, fn); - - const Type* callee_type = fn->type; - assert(callee_type->tag == FnType_TAG); - Nodes return_types = callee_type->payload.fn_type.return_types; - SpvId return_type = nodes_to_codom(emitter, return_types); - LARRAY(SpvId, args, call.args.count); - for (size_t i = 0; i < call.args.count; i++) - args[i] = emit_value(emitter, bb_builder, call.args.nodes[i]); - SpvId result = spvb_call(bb_builder, return_type, callee, call.args.count, args); - switch (results_count) { - case 0: break; - case 1: { - results[0] = result; - break; - } - default: { - assert(return_types.count == results_count); - for (size_t i = 0; i < results_count; i++) { - SpvId result_type = emit_type(emitter, return_types.nodes[i]->type); - SpvId extracted_component = spvb_extract(bb_builder, result_type, result, 1, (uint32_t []) { i }); - results[i] = extracted_component; - } - break; - } - } -} - -static void emit_if(Emitter* emitter, FnBuilder fn_builder, BBBuilder* bb_builder, MergeTargets* merge_targets, If if_instr, size_t results_count, SpvId results[]) { - Nodes yield_types = if_instr.yield_types; - assert(yield_types.count == results_count); - SpvId join_bb_id = spvb_fresh_id(emitter->file_builder); - - SpvId true_id = spvb_fresh_id(emitter->file_builder); - SpvId false_id = if_instr.if_false ? spvb_fresh_id(emitter->file_builder) : join_bb_id; - - spvb_selection_merge(*bb_builder, join_bb_id, 0); - SpvId condition = emit_value(emitter, *bb_builder, if_instr.condition); - spvb_branch_conditional(*bb_builder, condition, true_id, false_id); - - // When 'join' is codegen'd, these will be filled with the values given to it - BBBuilder join_bb = spvb_begin_bb(fn_builder, join_bb_id); - LARRAY(SpvbPhi*, join_phis, yield_types.count); - for (size_t i = 0; i < yield_types.count; i++) { - assert(if_instr.if_false && "Ifs with yield types need false branches !"); - SpvId phi_id = spvb_fresh_id(emitter->file_builder); - SpvId type = emit_type(emitter, yield_types.nodes[i]); - SpvbPhi* phi = spvb_add_phi(join_bb, type, phi_id); - join_phis[i] = phi; - results[i] = phi_id; - } - - MergeTargets merge_targets_branches = *merge_targets; - merge_targets_branches.join_target = join_bb_id; - merge_targets_branches.join_phis = join_phis; - - BBBuilder true_bb = spvb_begin_bb(fn_builder, true_id); - spvb_add_bb(fn_builder, true_bb); - assert(is_case(if_instr.if_true)); - emit_terminator(emitter, fn_builder, true_bb, merge_targets_branches, if_instr.if_true->payload.case_.body); - if (if_instr.if_false) { - BBBuilder false_bb = spvb_begin_bb(fn_builder, false_id); - spvb_add_bb(fn_builder, false_bb); - assert(is_case(if_instr.if_false)); - emit_terminator(emitter, fn_builder, false_bb, merge_targets_branches, if_instr.if_false->payload.case_.body); - } - - spvb_add_bb(fn_builder, join_bb); - *bb_builder = join_bb; -} - -static void emit_match(Emitter* emitter, FnBuilder fn_builder, BBBuilder* bb_builder, MergeTargets* merge_targets, Match match, size_t results_count, SHADY_UNUSED SpvId results[]) { - Nodes yield_types = match.yield_types; - assert(yield_types.count == results_count); - SpvId join_bb_id = spvb_fresh_id(emitter->file_builder); - - assert(get_unqualified_type(match.inspect->type)->tag == Int_TAG); - SpvId inspectee = emit_value(emitter, *bb_builder, match.inspect); - - SpvId default_id = spvb_fresh_id(emitter->file_builder); - - const Type* inspectee_t = match.inspect->type; - deconstruct_qualified_type(&inspectee_t); - assert(inspectee_t->tag == Int_TAG); - size_t literal_width = inspectee_t->payload.int_type.width == IntTy64 ? 2 : 1; - size_t literal_case_entry_size = literal_width + 1; - LARRAY(uint32_t, literals_and_cases, match.cases.count * literal_case_entry_size); - for (size_t i = 0; i < match.cases.count; i++) { - uint64_t value = (uint64_t) get_int_literal_value(*resolve_to_int_literal(match.literals.nodes[i]), false); - if (inspectee_t->payload.int_type.width == IntTy64) { - literals_and_cases[i * literal_case_entry_size + 0] = (SpvId) (uint32_t) (value & 0xFFFFFFFF); - literals_and_cases[i * literal_case_entry_size + 1] = (SpvId) (uint32_t) (value >> 32); - } else { - literals_and_cases[i * literal_case_entry_size + 0] = (SpvId) (uint32_t) value; - } - literals_and_cases[i * literal_case_entry_size + literal_width] = spvb_fresh_id(emitter->file_builder); - } - - spvb_selection_merge(*bb_builder, join_bb_id, 0); - spvb_switch(*bb_builder, inspectee, default_id, match.cases.count * literal_case_entry_size, literals_and_cases); - - // When 'join' is codegen'd, these will be filled with the values given to it - BBBuilder join_bb = spvb_begin_bb(fn_builder, join_bb_id); - LARRAY(SpvbPhi*, join_phis, yield_types.count); - for (size_t i = 0; i < yield_types.count; i++) { - SpvId phi_id = spvb_fresh_id(emitter->file_builder); - SpvId type = emit_type(emitter, yield_types.nodes[i]); - SpvbPhi* phi = spvb_add_phi(join_bb, type, phi_id); - join_phis[i] = phi; - results[i] = phi_id; - } - - MergeTargets merge_targets_branches = *merge_targets; - merge_targets_branches.join_target = join_bb_id; - merge_targets_branches.join_phis = join_phis; - - for (size_t i = 0; i < match.cases.count; i++) { - BBBuilder case_bb = spvb_begin_bb(fn_builder, literals_and_cases[i * literal_case_entry_size + literal_width]); - const Node* case_body = match.cases.nodes[i]; - assert(is_case(case_body)); - spvb_add_bb(fn_builder, case_bb); - emit_terminator(emitter, fn_builder, case_bb, merge_targets_branches, case_body->payload.case_.body); - } - BBBuilder default_bb = spvb_begin_bb(fn_builder, default_id); - assert(is_case(match.default_case)); - spvb_add_bb(fn_builder, default_bb); - emit_terminator(emitter, fn_builder, default_bb, merge_targets_branches, match.default_case->payload.case_.body); - - spvb_add_bb(fn_builder, join_bb); - *bb_builder = join_bb; -} - -static void emit_loop(Emitter* emitter, FnBuilder fn_builder, BBBuilder* bb_builder, MergeTargets* merge_targets, Loop loop_instr, size_t results_count, SpvId results[]) { - Nodes yield_types = loop_instr.yield_types; - assert(yield_types.count == results_count); - - const Node* body = loop_instr.body; - assert(is_case(body)); - Nodes body_params = body->payload.case_.params; - - // First we create all the basic blocks we'll need - SpvId header_id = spvb_fresh_id(emitter->file_builder); - BBBuilder header_builder = spvb_begin_bb(fn_builder, header_id); - spvb_name(emitter->file_builder, header_id, "loop_header"); - - SpvId body_id = spvb_fresh_id(emitter->file_builder); - BBBuilder body_builder = spvb_begin_bb(fn_builder, body_id); - spvb_name(emitter->file_builder, body_id, "loop_body"); - - SpvId continue_id = spvb_fresh_id(emitter->file_builder); - BBBuilder continue_builder = spvb_begin_bb(fn_builder, continue_id); - spvb_name(emitter->file_builder, continue_id, "loop_continue"); - - SpvId next_id = spvb_fresh_id(emitter->file_builder); - BBBuilder next = spvb_begin_bb(fn_builder, next_id); - spvb_name(emitter->file_builder, next_id, "loop_next"); - - // Wire up the phi nodes for loop exit - LARRAY(SpvbPhi*, loop_break_phis, yield_types.count); - for (size_t i = 0; i < yield_types.count; i++) { - SpvId yielded_type = emit_type(emitter, get_unqualified_type(yield_types.nodes[i])); - - SpvId break_phi_id = spvb_fresh_id(emitter->file_builder); - SpvbPhi* phi = spvb_add_phi(next, yielded_type, break_phi_id); - loop_break_phis[i] = phi; - results[i] = break_phi_id; - } - - // Wire up the phi nodes for the loop contents - LARRAY(SpvbPhi*, loop_continue_phis, body_params.count); - for (size_t i = 0; i < body_params.count; i++) { - SpvId loop_param_type = emit_type(emitter, get_unqualified_type(body_params.nodes[i]->type)); - - SpvId continue_phi_id = spvb_fresh_id(emitter->file_builder); - SpvbPhi* continue_phi = spvb_add_phi(continue_builder, loop_param_type, continue_phi_id); - loop_continue_phis[i] = continue_phi; - - // To get the actual loop parameter, we make a second phi for the nodes that go into the header - // We already know the two edges into the header so we immediately add the Phi sources for it. - SpvId loop_param_id = spvb_fresh_id(emitter->file_builder); - SpvbPhi* loop_param_phi = spvb_add_phi(header_builder, loop_param_type, loop_param_id); - SpvId param_initial_value = emit_value(emitter, *bb_builder, loop_instr.initial_args.nodes[i]); - spvb_add_phi_source(loop_param_phi, get_block_builder_id(*bb_builder), param_initial_value); - spvb_add_phi_source(loop_param_phi, get_block_builder_id(continue_builder), continue_phi_id); - register_result(emitter, body_params.nodes[i], loop_param_id); - } - - // The current block goes to the header (it can't be the header itself !) - spvb_branch(*bb_builder, header_id); - spvb_add_bb(fn_builder, header_builder); - - // the header block receives the loop merge annotation - spvb_loop_merge(header_builder, next_id, continue_id, 0, 0, NULL); - spvb_branch(header_builder, body_id); - spvb_add_bb(fn_builder, body_builder); - - // Emission of the body requires extra info for the break/continue merge terminators - MergeTargets merge_targets_branches = *merge_targets; - merge_targets_branches.continue_target = continue_id; - merge_targets_branches.continue_phis = loop_continue_phis; - merge_targets_branches.break_target = next_id; - merge_targets_branches.break_phis = loop_break_phis; - emit_terminator(emitter, fn_builder, body_builder, merge_targets_branches, body->payload.case_.body); - - // the continue block just jumps back into the header - spvb_branch(continue_builder, header_id); - spvb_add_bb(fn_builder, continue_builder); - - // We start the next block - spvb_add_bb(fn_builder, next); - *bb_builder = next; -} - -void emit_instruction(Emitter* emitter, FnBuilder fn_builder, BBBuilder* bb_builder, MergeTargets* merge_targets, const Node* instruction, size_t results_count, SpvId results[]) { - assert(is_instruction(instruction)); - - switch (is_instruction(instruction)) { - case NotAnInstruction: error(""); - case Instruction_Control_TAG: - case Instruction_Block_TAG: error("Should be lowered elsewhere") - case Instruction_Call_TAG: emit_leaf_call(emitter, fn_builder, *bb_builder, instruction->payload.call, results_count, results); break; - case PrimOp_TAG: emit_primop(emitter, fn_builder, *bb_builder, instruction, results_count, results); break; - case If_TAG: emit_if(emitter, fn_builder, bb_builder, merge_targets, instruction->payload.if_instr, results_count, results); break; - case Match_TAG: emit_match(emitter, fn_builder, bb_builder, merge_targets, instruction->payload.match_instr, results_count, results); break; - case Loop_TAG: emit_loop(emitter, fn_builder, bb_builder, merge_targets, instruction->payload.loop_instr, results_count, results); break; - case Comment_TAG: break; - } -} diff --git a/src/shady/fold.c b/src/shady/fold.c index 02b9d2b7f..98ef6a7b2 100644 --- a/src/shady/fold.c +++ b/src/shady/fold.c @@ -1,119 +1,36 @@ #include "fold.h" -#include "log.h" +#include "shady/ir/memory_layout.h" + +#include "check.h" -#include "type.h" #include "portability.h" -#include "rewrite.h" #include #include static const Node* quote_single(IrArena* a, const Node* value) { - return quote_helper(a, singleton(value)); + return value; } static bool is_zero(const Node* node) { - const IntLiteral* lit = resolve_to_int_literal(node); - if (lit && get_int_literal_value(*lit, false) == 0) + const IntLiteral* lit = shd_resolve_to_int_literal(node); + if (lit && shd_get_int_literal_value(*lit, false) == 0) return true; return false; } static bool is_one(const Node* node) { - const IntLiteral* lit = resolve_to_int_literal(node); - if (lit && get_int_literal_value(*lit, false) == 1) + const IntLiteral* lit = shd_resolve_to_int_literal(node); + if (lit && shd_get_int_literal_value(*lit, false) == 1) return true; return false; } -static const Node* fold_let(IrArena* arena, const Node* node) { - assert(node->tag == Let_TAG); - const Node* instruction = node->payload.let.instruction; - const Node* tail = node->payload.let.tail; - switch (instruction->tag) { - // eliminates blocks by "lifting" their contents out and replacing yield with the tail of the outer let - // In other words, we turn these patterns: - // - // let block { - // let I in case(x) => - // let J in case(y) => - // let K in case(z) => - // ... - // yield (x, y, z) } - // in case(a, b, c) => R - // - // into these: - // - // let I in case(x) => - // let J in case(y) => - // let K in case(z) => - // ... - // R[a->x, b->y, c->z] - case Block_TAG: { - // follow the terminator of the block until we hit a yield() - const Node* lam = instruction->payload.block.inside; - const Node* terminator = get_abstraction_body(lam); - size_t depth = 0; - bool dry_run = true; - const Node** lets = NULL; - while (true) { - assert(is_case(lam)); - switch (is_terminator(terminator)) { - case NotATerminator: assert(false); - case Terminator_Let_TAG: { - if (lets) - lets[depth] = terminator; - lam = get_let_tail(terminator); - terminator = get_abstraction_body(lam); - depth++; - continue; - } - case Terminator_Yield_TAG: { - if (dry_run) { - lets = calloc(sizeof(const Node*), depth); - dry_run = false; - depth = 0; - // Start over ! - lam = instruction->payload.block.inside; - terminator = get_abstraction_body(lam); - continue; - } else { - // wrap the original tail with the args of join() - assert(is_case(tail)); - const Node* acc = let(arena, quote_helper(arena, terminator->payload.yield.args), tail); - // rebuild the let chain that we traversed - for (size_t i = 0; i < depth; i++) { - const Node* olet = lets[depth - 1 - i]; - const Node* olam = get_let_tail(olet); - assert(olam->tag == Case_TAG); - Nodes params = get_abstraction_params(olam); - for (size_t j = 0; j < params.count; j++) { - // recycle the params by setting their abs value to NULL - *((Node**) &(params.nodes[j]->payload.var.abs)) = NULL; - } - const Node* nlam = case_(arena, params, acc); - acc = let(arena, get_let_instruction(olet), nlam); - } - free(lets); - return acc; - } - } - // if we see anything else, give up - default: { - assert(dry_run); - return node; - } - } - } - } - default: break; - } - - return node; -} +#define APPLY_FOLD(F) { const Node* applied_fold = F(node); if (applied_fold) return applied_fold; } -static const Node* fold_prim_op(IrArena* arena, const Node* node) { +static inline const Node* fold_constant_math(const Node* node) { + IrArena* arena = node->arena; PrimOp payload = node->payload.prim_op; LARRAY(const FloatLiteral*, float_literals, payload.operands.count); @@ -125,14 +42,14 @@ static const Node* fold_prim_op(IrArena* arena, const Node* node) { IntSizes int_width; bool is_signed; for (size_t i = 0; i < payload.operands.count; i++) { - int_literals[i] = resolve_to_int_literal(payload.operands.nodes[i]); + int_literals[i] = shd_resolve_to_int_literal(payload.operands.nodes[i]); all_int_literals &= int_literals[i] != NULL; if (int_literals[i]) { int_width = int_literals[i]->width; is_signed = int_literals[i]->is_signed; } - float_literals[i] = resolve_to_float_literal(payload.operands.nodes[i]); + float_literals[i] = shd_resolve_to_float_literal(payload.operands.nodes[i]); if (float_literals[i]) float_width = float_literals[i]->width; all_float_literals &= float_literals[i] != NULL; @@ -140,12 +57,12 @@ static const Node* fold_prim_op(IrArena* arena, const Node* node) { #define UN_OP(primop, op) case primop##_op: \ if (all_int_literals) return quote_single(arena, int_literal(arena, (IntLiteral) { .is_signed = is_signed, .width = int_width, .value = op int_literals[0]->value})); \ -else if (all_float_literals) return quote_single(arena, fp_literal_helper(arena, float_width, op get_float_literal_value(*float_literals[0]))); \ +else if (all_float_literals) return quote_single(arena, shd_fp_literal_helper(arena, float_width, op shd_get_float_literal_value(*float_literals[0]))); \ else break; #define BIN_OP(primop, op) case primop##_op: \ if (all_int_literals) return quote_single(arena, int_literal(arena, (IntLiteral) { .is_signed = is_signed, .width = int_width, .value = int_literals[0]->value op int_literals[1]->value })); \ -else if (all_float_literals) return quote_single(arena, fp_literal_helper(arena, float_width, get_float_literal_value(*float_literals[0]) op get_float_literal_value(*float_literals[1]))); \ +else if (all_float_literals) return quote_single(arena, shd_fp_literal_helper(arena, float_width, shd_get_float_literal_value(*float_literals[0]) op shd_get_float_literal_value(*float_literals[1]))); \ break; if (all_int_literals || all_float_literals) { @@ -159,9 +76,9 @@ break; if (all_int_literals) return quote_single(arena, int_literal(arena, (IntLiteral) { .is_signed = is_signed, .width = int_width, .value = int_literals[0]->value % int_literals[1]->value })); else - return quote_single(arena, fp_literal_helper(arena, float_width, fmod(get_float_literal_value(*float_literals[0]), get_float_literal_value(*float_literals[1])))); + return quote_single(arena, shd_fp_literal_helper(arena, float_width, fmod(shd_get_float_literal_value(*float_literals[0]), shd_get_float_literal_value(*float_literals[1])))); case reinterpret_op: { - const Type* dst_t = first(payload.type_arguments); + const Type* dst_t = shd_first(payload.type_arguments); uint64_t raw_value = int_literals[0] ? int_literals[0]->value : float_literals[0]->value; if (dst_t->tag == Int_TAG) { return quote_single(arena, int_literal(arena, (IntLiteral) { .is_signed = dst_t->payload.int_type.is_signed, .width = dst_t->payload.int_type.width, .value = raw_value })); @@ -171,29 +88,29 @@ break; break; } case convert_op: { - const Type* dst_t = first(payload.type_arguments); + const Type* dst_t = shd_first(payload.type_arguments); uint64_t bitmask = 0; - if (get_type_bitwidth(dst_t) == 64) + if (shd_get_type_bitwidth(dst_t) == 64) bitmask = UINT64_MAX; else - bitmask = ~(UINT64_MAX << get_type_bitwidth(dst_t)); + bitmask = ~(UINT64_MAX << shd_get_type_bitwidth(dst_t)); if (dst_t->tag == Int_TAG) { if (all_int_literals) { - uint64_t old_value = get_int_literal_value(*int_literals[0], int_literals[0]->is_signed); + uint64_t old_value = shd_get_int_literal_value(*int_literals[0], int_literals[0]->is_signed); uint64_t value = old_value & bitmask; - return quote_single(arena, int_literal(arena, (IntLiteral) {.is_signed = dst_t->payload.int_type.is_signed, .width = dst_t->payload.int_type.width, .value = value})); + return quote_single(arena, int_literal(arena, (IntLiteral) { .is_signed = dst_t->payload.int_type.is_signed, .width = dst_t->payload.int_type.width, .value = value })); } else if (all_float_literals) { - double old_value = get_float_literal_value(*float_literals[0]); + double old_value = shd_get_float_literal_value(*float_literals[0]); int64_t value = old_value; - return quote_single(arena, int_literal(arena, (IntLiteral) {.is_signed = dst_t->payload.int_type.is_signed, .width = dst_t->payload.int_type.width, .value = value})); + return quote_single(arena, int_literal(arena, (IntLiteral) { .is_signed = dst_t->payload.int_type.is_signed, .width = dst_t->payload.int_type.width, .value = value })); } } else if (dst_t->tag == Float_TAG) { if (all_int_literals) { - uint64_t old_value = get_int_literal_value(*int_literals[0], int_literals[0]->is_signed); + uint64_t old_value = shd_get_int_literal_value(*int_literals[0], int_literals[0]->is_signed); double value = old_value; - return quote_single(arena, fp_literal_helper(arena, dst_t->payload.float_type.width, value)); + return quote_single(arena, shd_fp_literal_helper(arena, dst_t->payload.float_type.width, value)); } else if (all_float_literals) { - double old_value = get_float_literal_value(*float_literals[0]); + double old_value = shd_get_float_literal_value(*float_literals[0]); return quote_single(arena, float_literal(arena, (FloatLiteral) { .width = dst_t->payload.float_type.width, .value = old_value })); } } @@ -203,6 +120,12 @@ break; } } + return NULL; +} + +static inline const Node* fold_simplify_math(const Node* node) { + IrArena* arena = node->arena; + PrimOp payload = node->payload.prim_op; switch (payload.op) { case add_op: { // If either operand is zero, destroy the add @@ -217,7 +140,7 @@ break; return quote_single(arena, payload.operands.nodes[0]); // if first operand is zero, invert the second one if (is_zero(payload.operands.nodes[0])) - return prim_op(arena, (PrimOp) { .op = neg_op, .operands = singleton(payload.operands.nodes[1]), .type_arguments = empty(arena) }); + return prim_op(arena, (PrimOp) { .op = neg_op, .operands = shd_singleton(payload.operands.nodes[1]), .type_arguments = shd_empty(arena) }); break; } case mul_op: { @@ -237,70 +160,263 @@ break; return quote_single(arena, payload.operands.nodes[0]); break; } - case subgroup_broadcast_first_op: { - const Node* value = first(payload.operands); - if (is_qualified_type_uniform(value->type)) + case eq_op: { + if (payload.operands.nodes[0] == payload.operands.nodes[1]) + return quote_single(arena, true_lit(arena)); + break; + } + case neq_op: { + if (payload.operands.nodes[0] == payload.operands.nodes[1]) + return quote_single(arena, false_lit(arena)); + break; + } + default: break; + } + + return NULL; +} + +static inline const Node* resolve_ptr_source(const Node* ptr) { + const Node* original_ptr = ptr; + IrArena* a = ptr->arena; + const Type* t = ptr->type; + bool u = shd_deconstruct_qualified_type(&t); + assert(t->tag == PtrType_TAG); + const Type* desired_pointee_type = t->payload.ptr_type.pointed_type; + // const Node* last_known_good = node; + + int distance = 0; + bool specialize_generic = false; + AddressSpace src_as = t->payload.ptr_type.address_space; + while (true) { + const Node* def = ptr; + switch (def->tag) { + case PrimOp_TAG: { + PrimOp instruction = def->payload.prim_op; + switch (instruction.op) { + case reinterpret_op: { + distance++; + ptr = shd_first(instruction.operands); + continue; + } + case convert_op: { + // only conversions to generic pointers are acceptable + if (shd_first(instruction.type_arguments)->tag != PtrType_TAG) + break; + assert(!specialize_generic && "something should not be converted to generic twice!"); + specialize_generic = true; + ptr = shd_first(instruction.operands); + src_as = shd_get_unqualified_type(ptr->type)->payload.ptr_type.address_space; + continue; + } + default: break; + } + break; + } + case PtrCompositeElement_TAG: { + PtrCompositeElement payload = def->payload.ptr_composite_element; + if (is_zero(payload.index)) { + distance++; + ptr = payload.ptr; + continue; + } + break; + } + default: break; + } + break; + } + + // if there was more than one of those pointless casts... + if (distance > 1 || specialize_generic) { + const Type* new_src_ptr_type = ptr->type; + shd_deconstruct_qualified_type(&new_src_ptr_type); + if (new_src_ptr_type->tag != PtrType_TAG || new_src_ptr_type->payload.ptr_type.pointed_type != desired_pointee_type) { + PtrType payload = t->payload.ptr_type; + payload.address_space = src_as; + ptr = prim_op_helper(a, reinterpret_op, shd_singleton(ptr_type(a, payload)), shd_singleton(ptr)); + } + return ptr; + } + return NULL; +} + +static inline const Node* simplify_ptr_operand(IrArena* a, const Node* old_op) { + const Type* ptr_t = old_op->type; + shd_deconstruct_qualified_type(&ptr_t); + if (ptr_t->payload.ptr_type.is_reference) + return NULL; + return resolve_ptr_source(old_op); +} + +static inline const Node* fold_simplify_ptr_operand(const Node* node) { + IrArena* arena = node->arena; + const Node* r = NULL; + switch (node->tag) { + case Load_TAG: { + Load payload = node->payload.load; + const Node* nptr = simplify_ptr_operand(arena, payload.ptr); + if (!nptr) break; + payload.ptr = nptr; + r = load(arena, payload); + break; + } + case Store_TAG: { + Store payload = node->payload.store; + const Node* nptr = simplify_ptr_operand(arena, payload.ptr); + if (!nptr) break; + payload.ptr = nptr; + r = store(arena, payload); + break; + } + case PtrCompositeElement_TAG: { + PtrCompositeElement payload = node->payload.ptr_composite_element; + const Node* nptr = simplify_ptr_operand(arena, payload.ptr); + if (!nptr) break; + payload.ptr = nptr; + r = ptr_composite_element(arena, payload); + break; + } + case PtrArrayElementOffset_TAG: { + PtrArrayElementOffset payload = node->payload.ptr_array_element_offset; + const Node* nptr = simplify_ptr_operand(arena, payload.ptr); + if (!nptr) break; + payload.ptr = nptr; + r = ptr_array_element_offset(arena, payload); + break; + } + case Call_TAG: { + Call payload = node->payload.call; + const Node* nptr = simplify_ptr_operand(arena, payload.callee); + if (!nptr) break; + payload.callee = nptr; + r = call(arena, payload); + break; + } + case TailCall_TAG: { + TailCall payload = node->payload.tail_call; + const Node* nptr = simplify_ptr_operand(arena, payload.callee); + if (!nptr) break; + payload.callee = nptr; + r = tail_call(arena, payload); + break; + } + default: return node; + } + + if (!r) + return node; + + if (!shd_is_subtype(node->type, r->type)) + r = prim_op_helper(arena, convert_op, shd_singleton(shd_get_unqualified_type(node->type)), shd_singleton(r)); + return r; +} + +static const Node* fold_prim_op(IrArena* arena, const Node* node) { + APPLY_FOLD(fold_constant_math) + APPLY_FOLD(fold_simplify_math) + + PrimOp payload = node->payload.prim_op; + switch (payload.op) { + // TODO: case subgroup_broadcast_first_op: + case subgroup_assume_uniform_op: { + const Node* value = shd_first(payload.operands); + if (shd_is_qualified_type_uniform(value->type)) return quote_single(arena, value); break; } - case reinterpret_op: case convert_op: + case reinterpret_op: { // get rid of identity casts - if (payload.type_arguments.nodes[0] == get_unqualified_type(payload.operands.nodes[0]->type)) + if (payload.type_arguments.nodes[0] == shd_get_unqualified_type(payload.operands.nodes[0]->type)) return quote_single(arena, payload.operands.nodes[0]); break; + } + default: break; + } + return node; +} + +static const Node* fold_memory_poison(IrArena* arena, const Node* node) { + switch (node->tag) { + case Load_TAG: { + if (node->payload.load.ptr->tag == Undef_TAG) + return mem_and_value(arena, (MemAndValue) { .value = undef(arena, (Undef) { .type = shd_get_unqualified_type(node->type) }), .mem = node->payload.load.mem }); + break; + } + case Store_TAG: { + if (node->payload.store.ptr->tag == Undef_TAG) + return mem_and_value(arena, (MemAndValue) { .value = shd_tuple_helper(arena, shd_empty(arena)), .mem = node->payload.store.mem }); + break; + } + case PtrArrayElementOffset_TAG: { + PtrArrayElementOffset payload = node->payload.ptr_array_element_offset; + if (payload.ptr->tag == Undef_TAG) + return quote_single(arena, undef(arena, (Undef) { .type = shd_get_unqualified_type(node->type) })); + break; + } + case PtrCompositeElement_TAG: { + PtrCompositeElement payload = node->payload.ptr_composite_element; + if (payload.ptr->tag == Undef_TAG) + return quote_single(arena, undef(arena, (Undef) { .type = shd_get_unqualified_type(node->type) })); + break; + } + case PrimOp_TAG: { + PrimOp payload = node->payload.prim_op; + switch (payload.op) { + case reinterpret_op: + case convert_op: { + if (shd_first(payload.operands)->tag == Undef_TAG) + return quote_single(arena, undef(arena, (Undef) { .type = shd_get_unqualified_type(node->type) })); + break; + } + default: break; + } + break; + } default: break; } return node; } static bool is_unreachable_case(const Node* c) { - assert(c && c->tag == Case_TAG); + assert(c && c->tag == BasicBlock_TAG); const Node* b = get_abstraction_body(c); return b->tag == Unreachable_TAG; } -const Node* fold_node(IrArena* arena, const Node* node) { - const Node* folded = node; +static bool is_unreachable_destination(const Node* j) { + assert(j && j->tag == Jump_TAG); + const Node* b = get_abstraction_body(j->payload.jump.target); + return b->tag == Unreachable_TAG; +} + +const Node* _shd_fold_node(IrArena* arena, const Node* node) { + const Node* const original_node = node; + node = fold_memory_poison(arena, node); + node = fold_simplify_ptr_operand(node); switch (node->tag) { - case Let_TAG: folded = fold_let(arena, node); break; - case PrimOp_TAG: folded = fold_prim_op(arena, node); break; - case Block_TAG: { - const Node* lam = node->payload.block.inside; - const Node* body = lam->payload.case_.body; - if (body->tag == Yield_TAG) { - return quote_helper(arena, body->payload.yield.args); - } else if (body->tag == Let_TAG) { - // fold block { let x, y, z = I; yield (x, y, z); } back to I - const Node* instr = get_let_instruction(body); - const Node* let_case = get_let_tail(body); - const Node* let_case_body = get_abstraction_body(let_case); - if (let_case_body->tag == Yield_TAG) { - bool only_forwards = true; - Nodes let_case_params = get_abstraction_params(let_case); - Nodes yield_args = let_case_body->payload.yield.args; - if (let_case_params.count == yield_args.count) { - for (size_t i = 0; i < yield_args.count; i++) { - only_forwards &= yield_args.nodes[i] == let_case_params.nodes[i]; - } - if (only_forwards) { - debugv_print("Fold: simplify "); - log_node(DEBUGV, node); - debugv_print(" into just "); - log_node(DEBUGV, instr); - debugv_print(".\n"); - return instr; - } - } - } - } + case PrimOp_TAG: node = fold_prim_op(arena, node); break; + case PtrArrayElementOffset_TAG: { + PtrArrayElementOffset payload = node->payload.ptr_array_element_offset; + if (is_zero(payload.offset)) + return payload.ptr; break; } - case If_TAG: { - If payload = node->payload.if_instr; - const Node* false_case = payload.if_false; - if (arena->config.optimisations.delete_unreachable_structured_cases && false_case && is_unreachable_case(false_case)) - return block(arena, (Block) { .inside = payload.if_true, .yield_types = add_qualifiers(arena, payload.yield_types, false) }); + case Branch_TAG: { + Branch payload = node->payload.branch; + if (arena->config.optimisations.fold_static_control_flow) { + if (payload.condition == true_lit(arena)) { + return payload.true_jump; + } else if (payload.condition == false_lit(arena)) { + return payload.false_jump; + } + } else if (arena->config.optimisations.delete_unreachable_structured_cases) { + if (is_unreachable_destination(payload.true_jump)) + return payload.false_jump; + else if (is_unreachable_destination(payload.false_jump)) + return payload.true_jump; + } break; } case Match_TAG: { @@ -322,30 +438,44 @@ const Node* fold_node(IrArena* arena, const Node* node) { if (new_cases_count == old_cases.count) break; - if (new_cases_count == 1 && is_unreachable_case(payload.default_case)) + /*if (new_cases_count == 1 && is_unreachable_case(payload.default_case)) return block(arena, (Block) { .inside = cases[0], .yield_types = add_qualifiers(arena, payload.yield_types, false) }); if (new_cases_count == 0) - return block(arena, (Block) { .inside = payload.default_case, .yield_types = add_qualifiers(arena, payload.yield_types, false) }); + return block(arena, (Block) { .inside = payload.default_case, .yield_types = add_qualifiers(arena, payload.yield_types, false) });*/ return match_instr(arena, (Match) { .inspect = payload.inspect, .yield_types = payload.yield_types, .default_case = payload.default_case, - .literals = nodes(arena, new_cases_count, literals), - .cases = nodes(arena, new_cases_count, cases), + .literals = shd_nodes(arena, new_cases_count, literals), + .cases = shd_nodes(arena, new_cases_count, cases), + .tail = payload.tail, + .mem = payload.mem, }); } default: break; } // catch bad folding rules that mess things up - if (is_value(node)) assert(is_value(folded)); - if (is_instruction(node)) assert(is_instruction(folded)); - if (is_terminator(node)) assert(is_terminator(folded)); + if (is_value(original_node)) assert(is_value(node)); + if (is_instruction(original_node)) assert(is_instruction(node) || is_value(node)); + if (is_terminator(original_node)) assert(is_terminator(node)); if (node->type) - assert(is_subtype(node->type, folded->type)); + assert(shd_is_subtype(original_node->type, node->type)); + + return node; +} - return folded; +const Node* _shd_fold_node_operand(NodeTag tag, NodeClass nc, String opname, const Node* op) { + if (!op) + return NULL; + if (op->tag == MemAndValue_TAG) { + MemAndValue payload = op->payload.mem_and_value; + if (nc == NcMem) + return payload.mem; + return payload.value; + } + return op; } diff --git a/src/shady/fold.h b/src/shady/fold.h index b8359dfc8..f21854e4b 100644 --- a/src/shady/fold.h +++ b/src/shady/fold.h @@ -3,6 +3,6 @@ #include "ir_private.h" -const Node* fold_node(IrArena* arena, const Node* instruction); +const Node* _shd_fold_node(IrArena* arena, const Node* node); #endif diff --git a/src/shady/generator/CMakeLists.txt b/src/shady/generator/CMakeLists.txt index d5837075d..d5f3f77c5 100644 --- a/src/shady/generator/CMakeLists.txt +++ b/src/shady/generator/CMakeLists.txt @@ -1,19 +1,40 @@ find_package(json-c REQUIRED) -add_library(generator_common generator.c generator_common.c json_apply.c) +add_library(generator_common STATIC generator.c generator_common.c json_apply.c) target_link_libraries(generator_common PUBLIC common json-c::json-c) target_include_directories(generator_common PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +add_executable(import_spv_defs import_spv_defs.c) +target_link_libraries(import_spv_defs PUBLIC common generator_common) + +# This hacky job is required for being able to run built targets in-place when generating the code +# This is also required for the various drivers but since they're built in the same directory it will work for now +if (WIN32) + message("copying DLLs for generator targets") + add_custom_command(TARGET import_spv_defs POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy -t $ $ + COMMAND_EXPAND_LISTS + ) +endif () + +get_target_property(SPIRV_HEADERS_INCLUDE_DIRS SPIRV-Headers::SPIRV-Headers INTERFACE_INCLUDE_DIRECTORIES) +add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/imported.json COMMAND import_spv_defs ${CMAKE_CURRENT_BINARY_DIR}/imported.json ${PROJECT_SOURCE_DIR}/include/shady/spv_imports.json ${SPIRV_HEADERS_INCLUDE_DIRS} DEPENDS import_spv_defs SPIRV-Headers::SPIRV-Headers ${PROJECT_SOURCE_DIR}/include/shady/spv_imports.json VERBATIM) +add_custom_target(do_import_spv_defs DEPENDS import_spv_defs ${CMAKE_CURRENT_BINARY_DIR}/imported.json) + +set(SHADY_IMPORTED_JSON_PATH ${CMAKE_CURRENT_BINARY_DIR}/imported.json CACHE INTERNAL "path to imported.json") + function(add_generated_file) cmake_parse_arguments(PARSE_ARGV 0 F "" "FILE_NAME;TARGET_NAME" "SOURCES" ) set(GENERATOR_NAME generator_${F_FILE_NAME}) - add_executable(${GENERATOR_NAME} ${F_SOURCES}) + add_executable(${GENERATOR_NAME} ${F_SOURCES} ${PROJECT_SOURCE_DIR}/src/shady/generator/generator_main.c) target_link_libraries(${GENERATOR_NAME} generator_common) get_target_property(SPIRV_HEADERS_INCLUDE_DIRS SPIRV-Headers::SPIRV-Headers INTERFACE_INCLUDE_DIRECTORIES) - add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${F_FILE_NAME} COMMAND ${GENERATOR_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${F_FILE_NAME} ${CMAKE_SOURCE_DIR}/include/shady/grammar.json ${CMAKE_SOURCE_DIR}/include/shady/primops.json ${SPIRV_HEADERS_INCLUDE_DIRS} DEPENDS ${GENERATOR_NAME} ${CMAKE_SOURCE_DIR}/include/shady/grammar.json ${CMAKE_SOURCE_DIR}/include/shady/primops.json VERBATIM) + if ("${F_TARGET_NAME}" STREQUAL "") set(F_TARGET_NAME generate_${F_FILE_NAME}) endif () + + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${F_FILE_NAME} COMMAND ${GENERATOR_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${F_FILE_NAME} "${SHADY_IMPORTED_JSON_PATH}" ${PROJECT_SOURCE_DIR}/include/shady/grammar.json ${PROJECT_SOURCE_DIR}/include/shady/primops.json DEPENDS do_import_spv_defs ${GENERATOR_NAME} ${PROJECT_SOURCE_DIR}/include/shady/grammar.json ${PROJECT_SOURCE_DIR}/include/shady/primops.json VERBATIM) add_custom_target(${F_TARGET_NAME} DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${F_FILE_NAME}) endfunction() diff --git a/src/shady/generator/generator.c b/src/shady/generator/generator.c index 3930b8db1..1f97d2722 100644 --- a/src/shady/generator/generator.c +++ b/src/shady/generator/generator.c @@ -1,6 +1,6 @@ #include "generator.h" -static bool should_include_instruction(json_object* instruction) { +inline static bool should_include_instruction(json_object* instruction) { String class = json_object_get_string(json_object_object_get(instruction, "class")); if (strcmp(class, "@exclude") == 0) return false; @@ -11,7 +11,7 @@ void add_comments(Growy* g, String indent, json_object* comments) { if (!indent) indent = ""; if (json_object_get_type(comments) == json_type_string) { - growy_append_formatted(g, "%s/// %s\n", indent, json_object_get_string(comments)); + shd_growy_append_formatted(g, "%s/// %s\n", indent, json_object_get_string(comments)); } else if (json_object_get_type(comments) == json_type_array) { size_t size = json_object_array_length(comments); for (size_t i = 0; i < size; i++) @@ -62,13 +62,14 @@ String capitalize(String str) { return dst; } -void generate_header(Growy* g, Data data) { - int32_t major = json_object_get_int(json_object_object_get(data.spv, "major_version")); - int32_t minor = json_object_get_int(json_object_object_get(data.spv, "minor_version")); - int32_t revision = json_object_get_int(json_object_object_get(data.spv, "revision")); - growy_append_formatted(g, "/* Generated from SPIR-V %d.%d revision %d */\n", major, minor, revision); - growy_append_formatted(g, "/* Do not edit this file manually ! */\n"); - growy_append_formatted(g, "/* It is generated by the 'generator' target using Json grammar files. */\n\n"); +void generate_header(Growy* g, json_object* root) { + json_object* spv = json_object_object_get(root, "spv"); + int32_t major = json_object_get_int(json_object_object_get(spv, "major_version")); + int32_t minor = json_object_get_int(json_object_object_get(spv, "minor_version")); + int32_t revision = json_object_get_int(json_object_object_get(spv, "revision")); + shd_growy_append_formatted(g, "/* Generated from SPIR-V %d.%d revision %d */\n", major, minor, revision); + shd_growy_append_formatted(g, "/* Do not edit this file manually ! */\n"); + shd_growy_append_formatted(g, "/* It is generated by the 'generator' target using Json grammar files. */\n\n"); } bool starts_with_vowel(String str) { @@ -80,88 +81,3 @@ bool starts_with_vowel(String str) { } return false; } - -enum { - ArgSelf = 0, - ArgDstFile, - ArgShadyGrammarJson, - ArgShadyPrimopsJson, - ArgSpirvGrammarSearchPathBegins -}; - -int main(int argc, char** argv) { - assert(argc > ArgSpirvGrammarSearchPathBegins); - - //char* mode = argv[ArgGeneratorFn]; - char* dst_file = argv[ArgDstFile]; - char* shd_grammar_json_path = argv[ArgShadyGrammarJson]; - char* shd_primops_json_path = argv[ArgShadyPrimopsJson]; - // search the include path for spirv.core.grammar.json - char* spv_core_json_path = NULL; - for (size_t i = ArgSpirvGrammarSearchPathBegins; i < argc; i++) { - char* path = format_string_new("%s/spirv/unified1/spirv.core.grammar.json", argv[i]); - info_print("trying path %s\n", path); - FILE* f = fopen(path, "rb"); - if (f) { - spv_core_json_path = path; - fclose(f); - break; - } - free(path); - } - - if (!spv_core_json_path) - abort(); - - json_tokener* tokener = json_tokener_new_ex(32); - enum json_tokener_error json_err; - - typedef struct { - size_t size; - char* contents; - json_object* root; - } JsonFile; - - String json_paths[3] = { shd_grammar_json_path, shd_primops_json_path, spv_core_json_path }; - JsonFile json_files[3]; - for (size_t i = 0; i < sizeof(json_files) / sizeof(json_files[0]); i++) { - String path = json_paths[i]; - read_file(path, &json_files[i].size, &json_files[i].contents); - json_files[i].root = json_tokener_parse_ex(tokener, json_files[i].contents, json_files[i].size); - json_err = json_tokener_get_error(tokener); - if (json_err != json_tokener_success) { - error("Json tokener error while parsing %s:\n %s\n", path, json_tokener_error_desc(json_err)); - } - - info_print("Correctly opened json file: %s\n", path); - } - Growy* g = new_growy(); - - Data data = { - .shd = json_object_new_object(), - .spv = json_files[2].root, - }; - - for (size_t i = 0; i < 2; i++) { - json_apply_object(data.shd, json_files[i].root); - } - - generate(g, data); - - size_t final_size = growy_size(g); - growy_append_bytes(g, 1, (char[]) { 0 }); - char* generated = growy_deconstruct(g); - debug_print("debug: %s\n", generated); - if (!write_file(dst_file, final_size, generated)) { - error_print("Failed to write file '%s'\n", dst_file); - error_die(); - } - free(generated); - for (size_t i = 0; i < sizeof(json_files) / sizeof(json_files[0]); i++) { - free(json_files[i].contents); - json_object_put(json_files[i].root); - } - json_object_put(data.shd); - json_tokener_free(tokener); - free(spv_core_json_path); -} diff --git a/src/shady/generator/generator.h b/src/shady/generator/generator.h index ad4b87931..7c9e72fbb 100644 --- a/src/shady/generator/generator.h +++ b/src/shady/generator/generator.h @@ -14,21 +14,19 @@ #include typedef const char* String; -typedef struct { - json_object* shd; - json_object* spv; -} Data; -void generate(Growy* g, Data); - -void generate_header(Growy* g, Data data); +void generate_header(Growy* g, json_object* root); void add_comments(Growy* g, String indent, json_object* comments); String to_snake_case(String camel); String capitalize(String str); bool starts_with_vowel(String str); bool has_custom_ctor(json_object* node); -void generate_node_ctor(Growy* g, json_object* nodes, bool definition); + +json_object* lookup_node_class(json_object* src, String name); +bool find_in_set(json_object* node, String class_name); +String class_to_type(json_object* src, String class, bool list); +String get_type_for_operand(json_object* src, json_object* op); void generate_bit_enum(Growy* g, String enum_type_name, String enum_case_prefix, json_object* cases); void generate_bit_enum_classifier(Growy* g, String fn_name, String enum_type_name, String enum_case_prefix, String src_type_name, String src_case_prefix, String src_case_suffix, json_object* cases); diff --git a/src/shady/generator/generator_common.c b/src/shady/generator/generator_common.c index 5d20b2551..3cad55662 100644 --- a/src/shady/generator/generator_common.c +++ b/src/shady/generator/generator_common.c @@ -5,121 +5,129 @@ bool has_custom_ctor(json_object* node) { return (constructor && strcmp(constructor, "custom") == 0); } -void generate_node_ctor(Growy* g, json_object* nodes, bool definition) { - for (size_t i = 0; i < json_object_array_length(nodes); i++) { - json_object* node = json_object_array_get_idx(nodes, i); - - String name = json_object_get_string(json_object_object_get(node, "name")); - assert(name); - - if (has_custom_ctor(node)) - continue; +json_object* lookup_node_class(json_object* src, String name) { + json_object* node_classes = json_object_object_get(src, "node-classes"); + for (size_t i = 0; i < json_object_array_length(node_classes); i++) { + json_object* class = json_object_array_get_idx(node_classes, i); + String class_name = json_object_get_string(json_object_object_get(class, "name")); + assert(class_name); + if (strcmp(name, class_name) == 0) + return class; + } + return NULL; +} - if (definition && i > 0) - growy_append_formatted(g, "\n"); +String class_to_type(json_object* src, String class, bool list) { + assert(class); + if (strcmp(class, "string") == 0) { + if (list) + return "Strings"; + else + return "String"; + } + // check the class is valid + if (!lookup_node_class(src, class)) { + shd_error_print("invalid node class '%s'\n", class); + shd_error_die(); + } + return list ? "Nodes" : "const Node*"; +} - String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); - void* alloc = NULL; - if (!snake_name) { - alloc = snake_name = to_snake_case(name); +bool find_in_set(json_object* node, String class_name) { + switch (json_object_get_type(node)) { + case json_type_array: { + for (size_t i = 0; i < json_object_array_length(node); i++) + if (find_in_set(json_object_array_get_idx(node, i), class_name)) + return true; + break; } + case json_type_string: return strcmp(json_object_get_string(node), class_name) == 0; + default: break; + } + return false; +} - String ap = definition ? " arena" : ""; - json_object* ops = json_object_object_get(node, "ops"); - if (ops) - growy_append_formatted(g, "const Node* %s(IrArena*%s, %s%s)", snake_name, ap, name, definition ? " payload" : ""); - else - growy_append_formatted(g, "const Node* %s(IrArena*%s)", snake_name, ap); +String get_type_for_operand(json_object* src, json_object* op) { + String op_type = json_object_get_string(json_object_object_get(op, "type")); + bool list = json_object_get_boolean(json_object_object_get(op, "list")); + String op_class = NULL; + if (!op_type) { + op_class = json_object_get_string(json_object_object_get(op, "class")); + op_type = class_to_type(src, op_class, list); + } + assert(op_type); + return op_type; +} - if (definition) { - growy_append_formatted(g, " {\n"); - growy_append_formatted(g, "\tNode node;\n"); - growy_append_formatted(g, "\tmemset((void*) &node, 0, sizeof(Node));\n"); - growy_append_formatted(g, "\tnode = (Node) {\n"); - growy_append_formatted(g, "\t\t.arena = arena,\n"); - growy_append_formatted(g, "\t\t.tag = %s_TAG,\n", name); - if (ops) - growy_append_formatted(g, "\t\t.payload.%s = payload,\n", snake_name); - json_object* t = json_object_object_get(node, "type"); - if (!t || json_object_get_boolean(t)) { - growy_append_formatted(g, "\t\t.type = arena->config.check_types ? "); - if (ops) - growy_append_formatted(g, "check_type_%s(arena, payload)", snake_name); - else - growy_append_formatted(g, "check_type_%s(arena)", snake_name); - growy_append_formatted(g, ": NULL,\n"); - } else - growy_append_formatted(g, "\t\t.type = NULL,\n"); - growy_append_formatted(g, "\t};\n"); - growy_append_formatted(g, "\treturn create_node_helper(arena, node, NULL);\n"); - growy_append_formatted(g, "}\n"); - } else { - growy_append_formatted(g, ";\n"); +void preprocess(json_object* src) { + json_object* nodes = json_object_object_get(src, "nodes"); + for (size_t i = 0; i < json_object_array_length(nodes); i++) { + json_object* node = json_object_array_get_idx(nodes, i); + String name = json_object_get_string(json_object_object_get(node, "name")); + json_object* snake_name = json_object_object_get(node, "snake_name"); + if (!snake_name) { + String tmp = to_snake_case(name); + json_object* generated_snake_name = json_object_new_string(tmp); + json_object_object_add(node, "snake_name", generated_snake_name); + free((void*) tmp); } - - if (alloc) - free(alloc); } - growy_append_formatted(g, "\n"); } void generate_bit_enum(Growy* g, String enum_type_name, String enum_case_prefix, json_object* cases) { assert(json_object_get_type(cases) == json_type_array); - growy_append_formatted(g, "typedef enum {\n"); + shd_growy_append_formatted(g, "typedef enum {\n"); for (size_t i = 0; i < json_object_array_length(cases); i++) { json_object* node_class = json_object_array_get_idx(cases, i); String name = json_object_get_string(json_object_object_get(node_class, "name")); String capitalized = capitalize(name); - growy_append_formatted(g, "\t%s%s = 0b1", enum_case_prefix, capitalized); - for (int c = 0; c < i; c++) - growy_append_string_literal(g, "0"); - growy_append_formatted(g, ",\n"); - free(capitalized); + shd_growy_append_formatted(g, "\t%s%s = 0x%x", enum_case_prefix, capitalized, (1 << i)); + shd_growy_append_formatted(g, ",\n"); + free((void*) capitalized); } - growy_append_formatted(g, "} %s;\n\n", enum_type_name); + shd_growy_append_formatted(g, "} %s;\n\n", enum_type_name); } void generate_bit_enum_classifier(Growy* g, String fn_name, String enum_type_name, String enum_case_prefix, String src_type_name, String src_case_prefix, String src_case_suffix, json_object* cases) { - growy_append_formatted(g, "%s %s(%s tag) {\n", enum_type_name, fn_name, src_type_name); - growy_append_formatted(g, "\tswitch (tag) { \n"); + shd_growy_append_formatted(g, "%s %s(%s tag) {\n", enum_type_name, fn_name, src_type_name); + shd_growy_append_formatted(g, "\tswitch (tag) { \n"); assert(json_object_get_type(cases) == json_type_array); for (size_t i = 0; i < json_object_array_length(cases); i++) { json_object* node = json_object_array_get_idx(cases, i); String name = json_object_get_string(json_object_object_get(node, "name")); - growy_append_formatted(g, "\t\tcase %s%s%s: \n", src_case_prefix, name, src_case_suffix); + shd_growy_append_formatted(g, "\t\tcase %s%s%s: \n", src_case_prefix, name, src_case_suffix); json_object* class = json_object_object_get(node, "class"); switch (json_object_get_type(class)) { - case json_type_null: - growy_append_formatted(g, "\t\t\treturn 0;\n"); + case json_type_null:shd_growy_append_formatted(g, "\t\t\treturn 0;\n"); break; case json_type_string: { String cap = capitalize(json_object_get_string(class)); - growy_append_formatted(g, "\t\t\treturn %s%s;\n", enum_case_prefix, cap); - free(cap); + shd_growy_append_formatted(g, "\t\t\treturn %s%s;\n", enum_case_prefix, cap); + free((void*) cap); break; } case json_type_array: { - growy_append_formatted(g, "\t\t\treturn "); + shd_growy_append_formatted(g, "\t\t\treturn "); for (size_t j = 0; j < json_object_array_length(class); j++) { if (j > 0) - growy_append_formatted(g, " | "); + shd_growy_append_formatted(g, " | "); String cap = capitalize(json_object_get_string(json_object_array_get_idx(class, j))); - growy_append_formatted(g, "%s%s", enum_case_prefix, cap); - free(cap); + shd_growy_append_formatted(g, "%s%s", enum_case_prefix, cap); + free((void*) cap); } - growy_append_formatted(g, ";\n"); + shd_growy_append_formatted(g, ";\n"); break; } case json_type_boolean: case json_type_double: case json_type_int: case json_type_object: - error_print("Invalid datatype for a node's 'class' attribute"); + shd_error_print("Invalid datatype for a node's 'class' attribute"); break; } } - growy_append_formatted(g, "\t\tdefault: assert(false);\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "\tSHADY_UNREACHABLE;\n"); - growy_append_formatted(g, "}\n"); + shd_growy_append_formatted(g, "\t\tdefault: assert(false);\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "\tSHADY_UNREACHABLE;\n"); + shd_growy_append_formatted(g, "}\n"); } \ No newline at end of file diff --git a/src/shady/generator/generator_main.c b/src/shady/generator/generator_main.c new file mode 100644 index 000000000..a24fe3448 --- /dev/null +++ b/src/shady/generator/generator_main.c @@ -0,0 +1,67 @@ +#include "generator.h" + +#include "util.h" +#include "portability.h" + +void preprocess(json_object* root); +void generate(Growy* g, json_object* root); + +enum { + ArgSelf = 0, + ArgDstFile, + ArgFirstInput, +}; + +int main(int argc, char** argv) { + assert(argc > ArgFirstInput); + int inputs_count = argc - ArgFirstInput; + char* dst_file = argv[ArgDstFile]; + + json_tokener* tokener = json_tokener_new_ex(32); + enum json_tokener_error json_err; + + typedef struct { + size_t size; + char* contents; + json_object* root; + } JsonFile; + + LARRAY(JsonFile, json_files, inputs_count); + for (size_t i = 0; i < inputs_count; i++) { + String path = argv[ArgFirstInput + i]; + shd_read_file(path, &json_files[i].size, &json_files[i].contents); + json_files[i].root = json_tokener_parse_ex(tokener, json_files[i].contents, json_files[i].size); + json_err = json_tokener_get_error(tokener); + if (json_err != json_tokener_success) { + shd_error("Json tokener error while parsing %s:\n %s\n", path, json_tokener_error_desc(json_err)); + } + + shd_info_print("Correctly opened json file: %s\n", path); + } + Growy* g = shd_new_growy(); + + json_object* src = json_object_new_object(); + + for (size_t i = 0; i < inputs_count; i++) { + json_apply_object(src, json_files[i].root); + } + + preprocess(src); + generate(g, src); + + size_t final_size = shd_growy_size(g); + shd_growy_append_bytes(g, 1, (char[]) { 0 }); + char* generated = shd_growy_deconstruct(g); + shd_debug_print("debug: %s\n", generated); + if (!shd_write_file(dst_file, final_size, generated)) { + shd_error_print("Failed to write file '%s'\n", dst_file); + shd_error_die(); + } + free(generated); + for (size_t i = 0; i < inputs_count; i++) { + free(json_files[i].contents); + json_object_put(json_files[i].root); + } + json_object_put(src); + json_tokener_free(tokener); +} diff --git a/src/shady/generator/import_spv_defs.c b/src/shady/generator/import_spv_defs.c new file mode 100644 index 000000000..57409307a --- /dev/null +++ b/src/shady/generator/import_spv_defs.c @@ -0,0 +1,340 @@ +#include "generator.h" + +#include "util.h" +#include "list.h" + +enum { + ArgSelf = 0, + ArgDstFile, + ArgImportsFile, + ArgSpirvGrammarSearchPathBegins +}; + +static String sanitize_node_name(String name) { + char* tmpname = NULL; + tmpname = calloc(strlen(name) + 1, 1); + bool is_type = false; + if (shd_string_starts_with(name, "OpType")) { + memcpy(tmpname, name + 6, strlen(name) - 6); + is_type = true; + } else if (shd_string_starts_with(name, "Op")) + memcpy(tmpname, name + 2, strlen(name) - 2); + else + memcpy(tmpname, name, strlen(name)); + + if (is_type) + memcpy(tmpname + strlen(tmpname), "Type", 4); + + return tmpname; +} + +static String sanitize_field_name(String name) { + char* tmpname = NULL; + tmpname = calloc(strlen(name) + 1, 1); + if (name[0] == '\'') { + memcpy(tmpname, name + 1, strlen(name) - 2); + name = tmpname; + } else { + memcpy(tmpname, name, strlen(name)); + } + for (size_t i = 0; i < strlen(tmpname); i++) { + if (tmpname[i] == ' ') + tmpname[i] = '_'; + else + tmpname[i] = tolower(tmpname[i]); + } + return tmpname; +} + +static void copy_object(json_object* dst, json_object* src, String name, String copied_name) { + json_object* o = json_object_object_get(src, name); + json_object_get(o); + json_object_object_add(dst, copied_name ? copied_name : name, o); +} + +void apply_instruction_filter(json_object* filter, json_object* instruction, json_object* instantiated_filter, struct List* pending) { + switch (json_object_get_type(filter)) { + case json_type_array: { + for (size_t i = 0; i < json_object_array_length(filter); i++) { + apply_instruction_filter(json_object_array_get_idx(filter, i), instruction, instantiated_filter, pending); + } + break; + } + case json_type_object: { + json_object* filter_name = json_object_object_get(filter, "filter-name"); + if (filter_name) { + assert(json_object_get_type(filter_name) == json_type_object); + String name = json_object_get_string(json_object_object_get(instruction, "opname")); + bool found = false; + json_object_object_foreach(filter_name, match_name, subfilter) { + if (strcmp(name, match_name) == 0) { + found = true; + shd_list_append(json_object*, pending, subfilter); + } + } + if (!found) + return; + } + + json_apply_object(instantiated_filter, filter); + /*json_object_object_foreach(filter, proprerty, value) { + json_object_get(value); + json_object_object_add(instantiated_filter, proprerty, value); + }*/ + break; + } + default: shd_error("Filters need to be arrays or objects"); + } +} + +json_object* apply_instruction_filters(json_object* filter, json_object* instruction) { + json_object* instantiated_filter = json_object_new_object(); + struct List* pending = shd_new_list(json_object*); + apply_instruction_filter(filter, instruction, instantiated_filter, pending); + while(shd_list_count(pending) > 0) { + json_object* pending_filter = shd_read_list(json_object*, pending)[0]; + shd_list_remove(json_object*, pending, 0); + apply_instruction_filter(pending_filter, instruction, instantiated_filter, pending); + continue; + } + shd_destroy_list(pending); + return instantiated_filter; +} + +void apply_operand_filter(json_object* filter, json_object* operand, json_object* instantiated_filter, struct List* pending) { + //fprintf(stderr, "applying %s\n", json_object_to_json_string(filter)); + switch (json_object_get_type(filter)) { + case json_type_array: { + for (size_t i = 0; i < json_object_array_length(filter); i++) { + apply_operand_filter(json_object_array_get_idx(filter, i), operand, instantiated_filter, pending); + } + break; + } + case json_type_object: { + json_object* filter_name = json_object_object_get(filter, "filter-name"); + if (filter_name) { + assert(json_object_get_type(filter_name) == json_type_object); + String name = json_object_get_string(json_object_object_get(operand, "name")); + if (!name) + name = ""; + bool found = false; + json_object_object_foreach(filter_name, match_name, subfilter) { + if (strcmp(name, match_name) == 0) { + found = true; + shd_list_append(json_object*, pending, subfilter); + } + } + if (!found) + return; + } + json_object* filter_kind = json_object_object_get(filter, "filter-kind"); + if (filter_kind) { + assert(json_object_get_type(filter_kind) == json_type_object); + String kind = json_object_get_string(json_object_object_get(operand, "kind")); + if (!kind) + kind = ""; + bool found = false; + json_object_object_foreach(filter_kind, match_name, subfilter) { + if (strcmp(kind, match_name) == 0) { + found = true; + shd_list_append(json_object*, pending, subfilter); + } + } + if (!found) + return; + } + + json_apply_object(instantiated_filter, filter); + break; + } + default: shd_error("Filters need to be arrays or objects"); + } +} + +json_object* apply_operand_filters(json_object* filter, json_object* operand) { + //fprintf(stderr, "building filter for %s\n", json_object_to_json_string(operand)); + json_object* instantiated_filter = json_object_new_object(); + struct List* pending = shd_new_list(json_object*); + apply_operand_filter(filter, operand, instantiated_filter, pending); + while(shd_list_count(pending) > 0) { + json_object* pending_filter = shd_read_list(json_object*, pending)[0]; + shd_list_remove(json_object*, pending, 0); + apply_operand_filter(pending_filter, operand, instantiated_filter, pending); + continue; + } + shd_destroy_list(pending); + //fprintf(stderr, "done: %s\n", json_object_to_json_string(instantiated_filter)); + return instantiated_filter; +} + +json_object* import_operand(json_object* operand, json_object* instruction_filter) { + String kind = json_object_get_string(json_object_object_get(operand, "kind")); + assert(kind); + String name = json_object_get_string(json_object_object_get(operand, "name")); + if (!name) + name = kind; + + json_object* operand_filters = json_object_object_get(instruction_filter, "operand-filters"); + assert(operand_filters); + json_object* filter = apply_operand_filters(operand_filters, operand); + + String import_property = json_object_get_string(json_object_object_get(filter, "import")); + if (!import_property || (strcmp(import_property, "no") == 0)) { + json_object_put(filter); + return NULL; + } else if (strcmp(import_property, "yes") != 0) { + shd_error("a filter's 'import' property needs to be 'yes' or 'no'") + } + + json_object* field = json_object_new_object(); + + const char* field_name = sanitize_field_name(name); + json_object_object_add(field, "name", json_object_new_string(field_name)); + free((void*) field_name); + + json_object* insert = json_object_object_get(filter, "overlay"); + if (insert) { + json_apply_object(field, insert); + } + json_object_put(filter); + + return field; +} + +json_object* import_filtered_instruction(json_object* instruction, json_object* filter) { + String name = json_object_get_string(json_object_object_get(instruction, "opname")); + assert(name && strlen(name) > 2); + + String import_property = json_object_get_string(json_object_object_get(filter, "import")); + if (!import_property || (strcmp(import_property, "no") == 0)) { + return NULL; + } else if (strcmp(import_property, "yes") != 0) { + shd_error("a filter's 'import' property needs to be 'yes' or 'no'") + } + String node_name = sanitize_node_name(name); + + json_object* node = json_object_new_object(); + json_object_object_add(node, "name", json_object_new_string(node_name)); + copy_object(node, instruction, "opcode", "spirv-opcode"); + + json_object* insert = json_object_object_get(filter, "overlay"); + if (insert) { + json_apply_object(node, insert); + } + + json_object* operands = json_object_object_get(instruction, "operands"); + assert(operands); + json_object* ops = json_object_new_array(); + for (size_t i = 0; i < json_object_array_length(operands); i++) { + json_object* operand = json_object_array_get_idx(operands, i); + json_object* field = import_operand(operand, filter); + if (field) + json_object_array_add(ops, field); + } + + if (json_object_array_length(ops) > 0) + json_object_object_add(node, "ops", ops); + else + json_object_put(ops); + + free((void*) node_name); + return node; +} + +void import_spirv_defs(json_object* imports, json_object* src, json_object* dst) { + json_object* spv = json_object_new_object(); + json_object_object_add(dst, "spv", spv); + copy_object(spv, src, "major_version", NULL); + copy_object(spv, src, "minor_version", NULL); + copy_object(spv, src, "revision", NULL); + + // import instructions + json_object* filters = json_object_object_get(imports, "instruction-filters"); + json_object* nodes = json_object_new_array(); + json_object_object_add(dst, "nodes", nodes); + json_object* instructions = json_object_object_get(src, "instructions"); + //assert(false); + for (size_t i = 0; i < json_object_array_length(instructions); i++) { + json_object* instruction = json_object_array_get_idx(instructions, i); + + json_object* filter = apply_instruction_filters(filters, instruction); + json_object* result = import_filtered_instruction(instruction, filter); + if (result) { + json_object_array_add(nodes, result); + } + json_object_put(filter); + } +} + +int main(int argc, char** argv) { + assert(argc > ArgSpirvGrammarSearchPathBegins); + + char* dst_file = argv[ArgDstFile]; + // search the include path for spirv.core.grammar.json + char* spv_core_json_path = NULL; + for (size_t i = ArgSpirvGrammarSearchPathBegins; i < argc; i++) { + char* path = shd_format_string_new("%s/spirv/unified1/spirv.core.grammar.json", argv[i]); + shd_info_print("trying path %s\n", path); + FILE* f = fopen(path, "rb"); + if (f) { + spv_core_json_path = path; + fclose(f); + break; + } + free(path); + } + + if (!spv_core_json_path) + abort(); + + json_tokener* tokener = json_tokener_new_ex(32); + enum json_tokener_error json_err; + + typedef struct { + size_t size; + char* contents; + json_object* root; + } JsonFile; + + JsonFile imports; + shd_read_file(argv[ArgImportsFile], &imports.size, &imports.contents); + imports.root = json_tokener_parse_ex(tokener, imports.contents, imports.size); + json_err = json_tokener_get_error(tokener); + if (json_err != json_tokener_success) { + shd_error("Json tokener error while parsing %s:\n %s\n", argv[ArgImportsFile], json_tokener_error_desc(json_err)); + } + + JsonFile spirv; + shd_read_file(spv_core_json_path, &spirv.size, &spirv.contents); + spirv.root = json_tokener_parse_ex(tokener, spirv.contents, spirv.size); + json_err = json_tokener_get_error(tokener); + if (json_err != json_tokener_success) { + shd_error("Json tokener error while parsing %s:\n %s\n", spv_core_json_path, json_tokener_error_desc(json_err)); + } + + shd_info_print("Correctly opened json file: %s\n", spv_core_json_path); + + json_object* output = json_object_new_object(); + + import_spirv_defs(imports.root, spirv.root, output); + + Growy* g = shd_new_growy(); + shd_growy_append_string(g, json_object_to_json_string_ext(output, JSON_C_TO_STRING_PRETTY)); + json_object_put(output); + + size_t final_size = shd_growy_size(g); + shd_growy_append_bytes(g, 1, (char[]) { 0 }); + char* generated = shd_growy_deconstruct(g); + shd_debug_print("debug: %s\n", generated); + if (!shd_write_file(dst_file, final_size, generated)) { + shd_error_print("Failed to write file '%s'\n", dst_file); + shd_error_die(); + } + free(generated); + free(spirv.contents); + json_object_put(spirv.root); + free(imports.contents); + json_object_put(imports.root); + json_tokener_free(tokener); + free(spv_core_json_path); +} diff --git a/src/shady/generator/json_apply.c b/src/shady/generator/json_apply.c index 88077d4ed..5eac0168f 100644 --- a/src/shady/generator/json_apply.c +++ b/src/shady/generator/json_apply.c @@ -17,12 +17,22 @@ void json_apply_object(json_object* target, json_object* src) { json_object* existing = json_object_object_get(target, name); if (existing && json_object_get_type(existing) == json_type_object) { json_apply_object(existing, value); - } else if (existing && json_object_get_type(existing) == json_type_array && json_object_get_type(value) == json_type_array && json_object_array_length(value) <= json_object_array_length(existing)) { - for (size_t j = 0; j < json_object_array_length(value); j++) - json_object_array_put_idx(existing, j, json_object_array_get_idx(value, j)); + } else if (existing && json_object_get_type(existing) == json_type_array && json_object_get_type(value) == json_type_array/* && json_object_array_length(value) <= json_object_array_length(existing)*/) { + for (size_t j = 0; j < json_object_array_length(value); j++) { + json_object* elem = json_object_array_get_idx(value, j); + // json_object* copy = NULL; + // json_object_deep_copy(elem, ©, NULL); + // json_object_array_put_idx(existing, j, copy); + // json_object_array_put_idx(existing, j, elem); + json_object_array_add(existing, elem); + json_object_get(elem); + } } else { if (existing) - warn_print("json-apply: overwriting key '%s'\n", name); + shd_warn_print("json-apply: overwriting key '%s'\n", name); + // json_object* copy = NULL; + // json_object_deep_copy(value, ©, NULL); + // json_object_object_add(target, name, copy); json_object_object_add(target, name, value); json_object_get(value); } diff --git a/src/shady/generator_constructors.c b/src/shady/generator_constructors.c index 66d08a0e2..467ed3bd0 100644 --- a/src/shady/generator_constructors.c +++ b/src/shady/generator_constructors.c @@ -1,9 +1,9 @@ #include "generator.h" -static void generate_pre_construction_validation(Growy* g, Data data) { - json_object* nodes = json_object_object_get(data.shd, "nodes"); - growy_append_formatted(g, "void pre_construction_validation(IrArena* arena, Node* node) {\n"); - growy_append_formatted(g, "\tswitch (node->tag) { \n"); +static void generate_pre_construction_validation(Growy* g, json_object* src) { + json_object* nodes = json_object_object_get(src, "nodes"); + shd_growy_append_formatted(g, "void pre_construction_validation(IrArena* arena, Node* node) {\n"); + shd_growy_append_formatted(g, "\tswitch (node->tag) { \n"); assert(json_object_get_type(nodes) == json_type_array); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -11,9 +11,10 @@ static void generate_pre_construction_validation(Growy* g, Data data) { String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); void* alloc = NULL; if (!snake_name) { - alloc = snake_name = to_snake_case(name); + snake_name = to_snake_case(name); + alloc = (void*) snake_name; } - growy_append_formatted(g, "\tcase %s_TAG: {\n", name); + shd_growy_append_formatted(g, "\tcase %s_TAG: {\n", name); json_object* ops = json_object_object_get(node, "ops"); if (ops) { assert(json_object_get_type(ops) == json_type_array); @@ -26,46 +27,58 @@ static void generate_pre_construction_validation(Growy* g, Data data) { bool list = json_object_get_boolean(json_object_object_get(op, "list")); if (strcmp(class, "string") == 0) { if (!list) - growy_append_formatted(g, "\t\tnode->payload.%s.%s = string(arena, node->payload.%s.%s);\n", snake_name, op_name, snake_name, op_name); + shd_growy_append_formatted(g, "\t\tnode->payload.%s.%s = shd_string(arena, node->payload.%s.%s);\n", snake_name, op_name, snake_name, op_name); else - growy_append_formatted(g, "\t\tnode->payload.%s.%s = import_strings(arena, node->payload.%s.%s);\n", snake_name, op_name, snake_name, op_name); + shd_growy_append_formatted(g, "\t\tnode->payload.%s.%s = _shd_import_strings(arena, node->payload.%s.%s);\n", snake_name, op_name, snake_name, op_name); } else { String cap = capitalize(class); - growy_append_formatted(g, "\t\t{\n"); + shd_growy_append_formatted(g, "\t\t{\n"); String extra = ""; if (list) { - growy_append_formatted(g, "\t\t\tNodes ops = node->payload.%s.%s;\n", snake_name, op_name); - growy_append_formatted(g, "\t\t\tfor (size_t i = 0; i < ops.count; i++) {\n"); - growy_append_formatted(g, "\t\t\tconst Node* op = ops.nodes[i];\n"); + shd_growy_append_formatted(g, "\t\t\tsize_t ops_count = node->payload.%s.%s.count;\n", snake_name, op_name); + shd_growy_append_formatted(g, "\t\t\tLARRAY(const Node*, ops, ops_count);\n"); + shd_growy_append_formatted(g, "\t\t\tif (ops_count > 0) memcpy(ops, node->payload.%s.%s.nodes, sizeof(const Node*) * ops_count);\n", snake_name, op_name); + shd_growy_append_formatted(g, "\t\t\tfor (size_t i = 0; i < ops_count; i++) {\n"); + shd_growy_append_formatted(g, "\t\t\tconst Node** pop = &ops[i];\n"); extra = "\t"; } if (!list) - growy_append_formatted(g, "\t\t\tconst Node* op = node->payload.%s.%s;\n", snake_name, op_name); - growy_append_formatted(g, "%s\t\t\tif (arena->config.check_op_classes && op != NULL && !is_%s(op)) {\n", extra, class); - growy_append_formatted(g, "%s\t\t\t\terror_print(\"Invalid '%s' operand for node '%s', expected a %s\");\n", extra, op_name, name, class); - growy_append_formatted(g, "%s\t\t\t\terror_die();\n", extra); - growy_append_formatted(g, "%s\t\t\t}\n", extra); - if (list) - growy_append_formatted(g, "\t\t\t}\n"); - free(cap); - growy_append_formatted(g, "\t\t}\n"); + shd_growy_append_formatted(g, "\t\t\tconst Node** pop = &node->payload.%s.%s;\n", snake_name, op_name); + + shd_growy_append_formatted(g, "\t\t\t*pop = _shd_fold_node_operand(%s_TAG, Nc%s, \"%s\", *pop);\n", name, cap, op_name); + + if (!(json_object_get_boolean(json_object_object_get(op, "nullable")) || json_object_get_boolean(json_object_object_get(op, "ignore")))) { + shd_growy_append_formatted(g, "%s\t\t\tif (!*pop) {\n", extra); + shd_growy_append_formatted(g, "%s\t\t\t\tshd_error(\"operand '%s' of node '%s' cannot be null\");\n", extra, op_name, name); + shd_growy_append_formatted(g, "%s\t\t\t}\n", extra); + } + + shd_growy_append_formatted(g, "%s\t\t\tif (arena->config.check_op_classes && *pop != NULL && !is_%s(*pop)) {\n", extra, class); + shd_growy_append_formatted(g, "%s\t\t\t\tshd_error_print(\"Invalid '%s' operand for node '%s', expected a %s\");\n", extra, op_name, name, class); + shd_growy_append_formatted(g, "%s\t\t\t\tshd_error_die();\n", extra); + shd_growy_append_formatted(g, "%s\t\t\t}\n", extra); + if (list) { + shd_growy_append_formatted(g, "\t\t\t}\n"); + shd_growy_append_formatted(g, "\t\t\tnode->payload.%s.%s = shd_nodes(arena, ops_count, ops);\n", snake_name, op_name); + } + free((void*) cap); + shd_growy_append_formatted(g, "\t\t}\n"); } } } - growy_append_formatted(g, "\t\tbreak;\n"); - growy_append_formatted(g, "\t}\n", name); + shd_growy_append_formatted(g, "\t\tbreak;\n"); + shd_growy_append_formatted(g, "\t}\n", name); if (alloc) free(alloc); } - growy_append_formatted(g, "\t\tdefault: break;\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "}\n\n"); + shd_growy_append_formatted(g, "\t\tdefault: break;\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n\n"); } -void generate(Growy* g, Data data) { - generate_header(g, data); +void generate(Growy* g, json_object* src) { + generate_header(g, src); - json_object* nodes = json_object_object_get(data.shd, "nodes"); - generate_node_ctor(g, nodes, true); - generate_pre_construction_validation(g, data); + json_object* nodes = json_object_object_get(src, "nodes"); + generate_pre_construction_validation(g, src); } diff --git a/src/shady/generator_node.c b/src/shady/generator_node.c index f6aa94655..59e2b2c11 100644 --- a/src/shady/generator_node.c +++ b/src/shady/generator_node.c @@ -3,39 +3,40 @@ #include "generator.h" static void generate_node_names_string_array(Growy* g, json_object* nodes) { - growy_append_formatted(g, "const char* node_tags[] = {\n"); - growy_append_formatted(g, "\t\"invalid\",\n"); + shd_growy_append_formatted(g, "const char* node_tags[] = {\n"); + shd_growy_append_formatted(g, "\t\"invalid\",\n"); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); String name = json_object_get_string(json_object_object_get(node, "name")); String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); void* alloc = NULL; if (!snake_name) { - alloc = snake_name = to_snake_case(name); + snake_name = to_snake_case(name); + alloc = (void*) snake_name; } assert(name); - growy_append_formatted(g, "\t\"%s\",\n", snake_name); + shd_growy_append_formatted(g, "\t\"%s\",\n", snake_name); if (alloc) free(alloc); } - growy_append_formatted(g, "};\n\n"); + shd_growy_append_formatted(g, "};\n\n"); } static void generate_node_has_payload_array(Growy* g, json_object* nodes) { - growy_append_formatted(g, "const bool node_type_has_payload[] = {\n"); - growy_append_formatted(g, "\tfalse,\n"); + shd_growy_append_formatted(g, "const bool node_type_has_payload[] = {\n"); + shd_growy_append_formatted(g, "\tfalse,\n"); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); json_object* ops = json_object_object_get(node, "ops"); - growy_append_formatted(g, "\t%s,\n", ops ? "true" : "false"); + shd_growy_append_formatted(g, "\t%s,\n", ops ? "true" : "false"); } - growy_append_formatted(g, "};\n\n"); + shd_growy_append_formatted(g, "};\n\n"); } -static void generate_node_payload_hash_fn(Growy* g, Data data, json_object* nodes) { - growy_append_formatted(g, "KeyHash hash_node_payload(const Node* node) {\n"); - growy_append_formatted(g, "\tKeyHash hash = 0;\n"); - growy_append_formatted(g, "\tswitch (node->tag) { \n"); +static void generate_node_payload_hash_fn(Growy* g, json_object* src, json_object* nodes) { + shd_growy_append_formatted(g, "KeyHash _shd_hash_node_payload(const Node* node) {\n"); + shd_growy_append_formatted(g, "\tKeyHash hash = 0;\n"); + shd_growy_append_formatted(g, "\tswitch (node->tag) { \n"); assert(json_object_get_type(nodes) == json_type_array); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -43,37 +44,38 @@ static void generate_node_payload_hash_fn(Growy* g, Data data, json_object* node String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); void* alloc = NULL; if (!snake_name) { - alloc = snake_name = to_snake_case(name); + snake_name = to_snake_case(name); + alloc = (void*) snake_name; } json_object* ops = json_object_object_get(node, "ops"); if (ops) { assert(json_object_get_type(ops) == json_type_array); - growy_append_formatted(g, "\tcase %s_TAG: {\n", name); - growy_append_formatted(g, "\t\t%s payload = node->payload.%s;\n", name, snake_name); + shd_growy_append_formatted(g, "\tcase %s_TAG: {\n", name); + shd_growy_append_formatted(g, "\t\t%s payload = node->payload.%s;\n", name, snake_name); for (size_t j = 0; j < json_object_array_length(ops); j++) { json_object* op = json_object_array_get_idx(ops, j); String op_name = json_object_get_string(json_object_object_get(op, "name")); bool ignore = json_object_get_boolean(json_object_object_get(op, "ignore")); if (!ignore) { - growy_append_formatted(g, "\t\thash = hash ^ hash_murmur(&payload.%s, sizeof(payload.%s));\n", op_name, op_name); + shd_growy_append_formatted(g, "\t\thash = hash ^ shd_hash(&payload.%s, sizeof(payload.%s));\n", op_name, op_name); } } - growy_append_formatted(g, "\t\tbreak;\n"); - growy_append_formatted(g, "\t}\n", name); + shd_growy_append_formatted(g, "\t\tbreak;\n"); + shd_growy_append_formatted(g, "\t}\n", name); } if (alloc) free(alloc); } - growy_append_formatted(g, "\t\tdefault: assert(false);\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "\treturn hash;\n"); - growy_append_formatted(g, "}\n"); + shd_growy_append_formatted(g, "\t\tdefault: assert(false);\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "\treturn hash;\n"); + shd_growy_append_formatted(g, "}\n"); } -static void generate_node_payload_cmp_fn(Growy* g, Data data, json_object* nodes) { - growy_append_formatted(g, "bool compare_node_payload(const Node* a, const Node* b) {\n"); - growy_append_formatted(g, "\tbool eq = true;\n"); - growy_append_formatted(g, "\tswitch (a->tag) { \n"); +static void generate_node_payload_cmp_fn(Growy* g, json_object* src, json_object* nodes) { + shd_growy_append_formatted(g, "bool _shd_compare_node_payload(const Node* a, const Node* b) {\n"); + shd_growy_append_formatted(g, "\tbool eq = true;\n"); + shd_growy_append_formatted(g, "\tswitch (a->tag) { \n"); assert(json_object_get_type(nodes) == json_type_array); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -81,83 +83,85 @@ static void generate_node_payload_cmp_fn(Growy* g, Data data, json_object* nodes String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); void* alloc = NULL; if (!snake_name) { - alloc = snake_name = to_snake_case(name); + snake_name = to_snake_case(name); + alloc = (void*) snake_name; } json_object* ops = json_object_object_get(node, "ops"); if (ops) { assert(json_object_get_type(ops) == json_type_array); - growy_append_formatted(g, "\tcase %s_TAG: {\n", name); - growy_append_formatted(g, "\t\t%s payload_a = a->payload.%s;\n", name, snake_name); - growy_append_formatted(g, "\t\t%s payload_b = b->payload.%s;\n", name, snake_name); + shd_growy_append_formatted(g, "\tcase %s_TAG: {\n", name); + shd_growy_append_formatted(g, "\t\t%s payload_a = a->payload.%s;\n", name, snake_name); + shd_growy_append_formatted(g, "\t\t%s payload_b = b->payload.%s;\n", name, snake_name); for (size_t j = 0; j < json_object_array_length(ops); j++) { json_object* op = json_object_array_get_idx(ops, j); String op_name = json_object_get_string(json_object_object_get(op, "name")); bool ignore = json_object_get_boolean(json_object_object_get(op, "ignore")); if (!ignore) { - growy_append_formatted(g, "\t\teq &= memcmp(&payload_a.%s, &payload_b.%s, sizeof(payload_a.%s)) == 0;\n", op_name, op_name, op_name); + shd_growy_append_formatted(g, "\t\teq &= memcmp(&payload_a.%s, &payload_b.%s, sizeof(payload_a.%s)) == 0;\n", op_name, op_name, op_name); } } - growy_append_formatted(g, "\t\tbreak;\n"); - growy_append_formatted(g, "\t}\n", name); + shd_growy_append_formatted(g, "\t\tbreak;\n"); + shd_growy_append_formatted(g, "\t}\n", name); } if (alloc) free(alloc); } - growy_append_formatted(g, "\t\tdefault: assert(false);\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "\treturn eq;\n"); - growy_append_formatted(g, "}\n"); + shd_growy_append_formatted(g, "\t\tdefault: assert(false);\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "\treturn eq;\n"); + shd_growy_append_formatted(g, "}\n"); } -static void generate_isa_for_class(Growy* g, json_object* nodes, String class, String capitalized_class, bool use_enum) { +static void generate_node_is_nominal(Growy* g, json_object* nodes) { + shd_growy_append_formatted(g, "bool shd_is_node_nominal(const Node* node) {\n"); + shd_growy_append_formatted(g, "\tswitch (node->tag) { \n"); assert(json_object_get_type(nodes) == json_type_array); - if (use_enum) - growy_append_formatted(g, "%sTag is_%s(const Node* node) {\n", capitalized_class, class); - else - growy_append_formatted(g, "bool is_%s(const Node* node) {\n", class); - growy_append_formatted(g, "\tif (get_node_class_from_tag(node->tag) & Nc%s)\n", capitalized_class); - if (use_enum) { - growy_append_formatted(g, "\t\treturn (%sTag) node->tag;\n", capitalized_class); - growy_append_formatted(g, "\treturn (%sTag) 0;\n", capitalized_class); - } else { - growy_append_formatted(g, "\t\treturn true;\n", capitalized_class); - growy_append_formatted(g, "\treturn false;\n", capitalized_class); + for (size_t i = 0; i < json_object_array_length(nodes); i++) { + json_object* node = json_object_array_get_idx(nodes, i); + String name = json_object_get_string(json_object_object_get(node, "name")); + if (json_object_get_boolean(json_object_object_get(node, "nominal"))) { + shd_growy_append_formatted(g, "\t\tcase %s_TAG: return true;\n", name); + } } - growy_append_formatted(g, "}\n\n"); + shd_growy_append_formatted(g, "\t\tdefault: return false;\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n"); } void generate_address_space_name_fn(Growy* g, json_object* address_spaces) { - growy_append_formatted(g, "String get_address_space_name(AddressSpace as) {\n"); - growy_append_formatted(g, "\tswitch (as) {\n"); + shd_growy_append_formatted(g, "String shd_get_address_space_name(AddressSpace as) {\n"); + shd_growy_append_formatted(g, "\tswitch (as) {\n"); for (size_t i = 0; i < json_object_array_length(address_spaces); i++) { json_object* as = json_object_array_get_idx(address_spaces, i); String name = json_object_get_string(json_object_object_get(as, "name")); - growy_append_formatted(g, "\t\t case As%s: return \"%s\";\n", name, name); + shd_growy_append_formatted(g, "\t\t case As%s: return \"%s\";\n", name, name); } - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "}\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n"); } -void generate(Growy* g, Data data) { - generate_header(g, data); +void generate(Growy* g, json_object* src) { + generate_header(g, src); - json_object* nodes = json_object_object_get(data.shd, "nodes"); - generate_address_space_name_fn(g, json_object_object_get(data.shd, "address-spaces")); + json_object* nodes = json_object_object_get(src, "nodes"); + generate_address_space_name_fn(g, json_object_object_get(src, "address-spaces")); generate_node_names_string_array(g, nodes); + generate_node_is_nominal(g, nodes); generate_node_has_payload_array(g, nodes); - generate_node_payload_hash_fn(g, data, nodes); - generate_node_payload_cmp_fn(g, data, nodes); - generate_bit_enum_classifier(g, "get_node_class_from_tag", "NodeClass", "Nc", "NodeTag", "", "_TAG", nodes); + generate_node_payload_hash_fn(g, src, nodes); + generate_node_payload_cmp_fn(g, src, nodes); + generate_bit_enum_classifier(g, "shd_get_node_class_from_tag", "NodeClass", "Nc", "NodeTag", "", "_TAG", nodes); - json_object* node_classes = json_object_object_get(data.shd, "node-classes"); + json_object* node_classes = json_object_object_get(src, "node-classes"); for (size_t i = 0; i < json_object_array_length(node_classes); i++) { json_object* node_class = json_object_array_get_idx(node_classes, i); String name = json_object_get_string(json_object_object_get(node_class, "name")); assert(name); + //generate_getters_for_class(g, src, nodes, node_class); + json_object* generate_enum = json_object_object_get(node_class, "generate-enum"); String capitalized = capitalize(name); - generate_isa_for_class(g, nodes, name, capitalized, !generate_enum || json_object_get_boolean(generate_enum)); - free(capitalized); + free((void*) capitalized); } } diff --git a/src/shady/generator_primops.c b/src/shady/generator_primops.c index 551065a50..c20793f1e 100644 --- a/src/shady/generator_primops.c +++ b/src/shady/generator_primops.c @@ -1,7 +1,7 @@ #include "generator.h" static void generate_primops_names_array(Growy* g, json_object* primops) { - growy_append_string(g, "const char* primop_names[] = {\n"); + shd_growy_append_string(g, "const char* primop_names[] = {\n"); for (size_t i = 0; i < json_object_array_length(primops); i++) { json_object* node = json_object_array_get_idx(primops, i); @@ -9,14 +9,14 @@ static void generate_primops_names_array(Growy* g, json_object* primops) { String name = json_object_get_string(json_object_object_get(node, "name")); assert(name); - growy_append_formatted(g, "\"%s\",", name); + shd_growy_append_formatted(g, "\"%s\",", name); } - growy_append_string(g, "\n};\n"); + shd_growy_append_string(g, "\n};\n"); } static void generate_primops_side_effects_array(Growy* g, json_object* primops) { - growy_append_string(g, "const bool primop_side_effects[] = {\n"); + shd_growy_append_string(g, "const bool primop_side_effects[] = {\n"); for (size_t i = 0; i < json_object_array_length(primops); i++) { json_object* node = json_object_array_get_idx(primops, i); @@ -26,20 +26,20 @@ static void generate_primops_side_effects_array(Growy* g, json_object* primops) bool side_effects = json_object_get_boolean(json_object_object_get(node, "side-effects")); if (side_effects) - growy_append_string(g, "true, "); + shd_growy_append_string(g, "true, "); else - growy_append_string(g, "false, "); + shd_growy_append_string(g, "false, "); } - growy_append_string(g, "\n};\n"); + shd_growy_append_string(g, "\n};\n"); } -void generate(Growy* g, Data data) { - generate_header(g, data); +void generate(Growy* g, json_object* shd) { + generate_header(g, shd); - json_object* primops = json_object_object_get(data.shd, "prim-ops"); + json_object* primops = json_object_object_get(shd, "prim-ops"); generate_primops_names_array(g, primops); generate_primops_side_effects_array(g, primops); - generate_bit_enum_classifier(g, "get_primop_class", "OpClass", "Oc", "Op", "", "_op", primops); + generate_bit_enum_classifier(g, "shd_get_primop_class", "OpClass", "Oc", "Op", "", "_op", primops); } diff --git a/src/shady/generator_print.c b/src/shady/generator_print.c new file mode 100644 index 000000000..14ddfb98d --- /dev/null +++ b/src/shady/generator_print.c @@ -0,0 +1,87 @@ +#include "generator.h" + +void generate_node_print_fns(Growy* g, json_object* src) { + json_object* nodes = json_object_object_get(src, "nodes"); + shd_growy_append_formatted(g, "void _shd_print_node_generated(PrinterCtx* ctx, const Node* node) {\n"); + shd_growy_append_formatted(g, "\tswitch (node->tag) { \n"); + assert(json_object_get_type(nodes) == json_type_array); + for (size_t i = 0; i < json_object_array_length(nodes); i++) { + json_object* node = json_object_array_get_idx(nodes, i); + String name = json_object_get_string(json_object_object_get(node, "name")); + String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); + void* alloc = NULL; + if (!snake_name) { + snake_name = to_snake_case(name); + alloc = (void*) snake_name; + } + shd_growy_append_formatted(g, "\tcase %s_TAG: {\n", name); + shd_growy_append_formatted(g, "\t\tshd_print(ctx->printer, GREEN);\n"); + shd_growy_append_formatted(g, "\t\tshd_print(ctx->printer, \"%s\");\n", name); + shd_growy_append_formatted(g, "\t\tshd_print(ctx->printer, RESET);\n"); + shd_growy_append_formatted(g, "\t\tshd_print(ctx->printer, \"(\");\n"); + json_object* ops = json_object_object_get(node, "ops"); + if (ops) { + assert(json_object_get_type(ops) == json_type_array); + for (size_t j = 0; j < json_object_array_length(ops); j++) { + json_object* op = json_object_array_get_idx(ops, j); + bool ignore = json_object_get_boolean(json_object_object_get(op, "ignore")); + if (ignore) + continue; + + String op_name = json_object_get_string(json_object_object_get(op, "name")); + String op_class = json_object_get_string(json_object_object_get(op, "class")); + if (op_class && strcmp(op_class, "string") != 0) { + bool is_list = json_object_get_boolean(json_object_object_get(op, "list")); + String cap_class = capitalize(op_class); + if (is_list) { + shd_growy_append_formatted(g, "\t\t{\n"); + shd_growy_append_formatted(g, "\t\t\t_shd_print_node_operand_list(ctx, node, \"%s\", Nc%s, node->payload.%s.%s);\n", op_name, cap_class, snake_name, op_name); + // growy_append_formatted(g, "\t\t\tsize_t count = node->payload.%s.%s.count;\n", snake_name, op_name); + // growy_append_formatted(g, "\t\t\tfor (size_t i = 0; i < count; i++) {\n"); + // growy_append_formatted(g, "\t\t\t\tprint_node_operand(printer, node, \"%s\", Nc%s, i, node->payload.%s.%s.nodes[i], config);\n", op_name, cap_class, snake_name, op_name); + // growy_append_formatted(g, "\t\t\t}\n"); + shd_growy_append_formatted(g, "\t\t}\n"); + } else { + shd_growy_append_formatted(g, "\t\t{\n"); + shd_growy_append_formatted(g, "\t\t\t_shd_print_node_operand(ctx, node, \"%s\", Nc%s, node->payload.%s.%s);\n", op_name, cap_class, snake_name, op_name); + shd_growy_append_formatted(g, "\t\t}\n"); + } + free((void*) cap_class); + } else { + String op_type = json_object_get_string(json_object_object_get(op, "type")); + if (!op_type) { + assert(op_class && strcmp(op_class, "string") == 0); + if (json_object_get_boolean(json_object_object_get(op, "list"))) + op_type = "Strings"; + else + op_type = "String"; + } + char* s = strdup(op_type); + for (size_t k = 0; k < strlen(op_type); k++) { + if (!isalnum(s[k])) + s[k] = '_'; + } + shd_growy_append_formatted(g, "\t\t_shd_print_node_operand_%s(ctx, node, \"%s\", node->payload.%s.%s);\n", s, op_name, snake_name, op_name); + free(s); + } + + if (j + 1 < json_object_array_length(ops)) + shd_growy_append_formatted(g, "\t\tshd_print(ctx->printer, \", \");\n"); + } + } + shd_growy_append_formatted(g, "\t\tshd_print(ctx->printer, \")\");\n"); + shd_growy_append_formatted(g, "\t\tbreak;\n"); + shd_growy_append_formatted(g, "\t}\n", name); + if (alloc) + free(alloc); + } + shd_growy_append_formatted(g, "\t\tdefault: assert(false);\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n"); +} + +void generate(Growy* g, json_object* src) { + generate_header(g, src); + //growy_append_formatted(g, "#include \"print.h\"\n\n"); + generate_node_print_fns(g, src); +} diff --git a/src/shady/generator_rewrite.c b/src/shady/generator_rewrite.c index 65d948393..1a3172c34 100644 --- a/src/shady/generator_rewrite.c +++ b/src/shady/generator_rewrite.c @@ -1,24 +1,24 @@ #include "generator.h" static void generate_can_be_default_rewritten_fn(Growy* g, json_object* nodes) { - growy_append_formatted(g, "static bool can_be_default_rewritten(NodeTag tag) {\n"); - growy_append_formatted(g, "\tswitch (tag) { \n"); + shd_growy_append_formatted(g, "static bool can_be_default_rewritten(NodeTag tag) {\n"); + shd_growy_append_formatted(g, "\tswitch (tag) { \n"); assert(json_object_get_type(nodes) == json_type_array); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); if (has_custom_ctor(node)) continue; String name = json_object_get_string(json_object_object_get(node, "name")); - growy_append_formatted(g, "\t\tcase %s_TAG: return true;\n", name); + shd_growy_append_formatted(g, "\t\tcase %s_TAG: return true;\n", name); } - growy_append_formatted(g, "\t\tdefault: return false;\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "}\n\n"); + shd_growy_append_formatted(g, "\t\tdefault: return false;\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n\n"); } static void generate_rewriter_default_fns(Growy* g, json_object* nodes) { - growy_append_formatted(g, "static const Node* recreate_node_identity_generated(Rewriter* rewriter, const Node* node) {\n"); - growy_append_formatted(g, "\tswitch (node->tag) { \n"); + shd_growy_append_formatted(g, "static const Node* recreate_node_identity_generated(Rewriter* rewriter, const Node* node) {\n"); + shd_growy_append_formatted(g, "\tswitch (node->tag) { \n"); assert(json_object_get_type(nodes) == json_type_array); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -30,14 +30,16 @@ static void generate_rewriter_default_fns(Growy* g, json_object* nodes) { String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); void* alloc = NULL; if (!snake_name) { - alloc = snake_name = to_snake_case(name); + snake_name = to_snake_case(name); + alloc = (void*) snake_name; } - growy_append_formatted(g, "\t\tcase %s_TAG: {\n", name); + shd_growy_append_formatted(g, "\t\tcase %s_TAG: {\n", name); json_object* ops = json_object_object_get(node, "ops"); if (ops) { assert(json_object_get_type(ops) == json_type_array); - growy_append_formatted(g, "\t\t\t%s old_payload = node->payload.%s;\n", name, snake_name); - growy_append_formatted(g, "\t\t\t%s payload = { 0 };\n", name, snake_name); + shd_growy_append_formatted(g, "\t\t\t%s old_payload = node->payload.%s;\n", name, snake_name); + shd_growy_append_formatted(g, "\t\t\t%s payload;\n", name); + shd_growy_append_formatted(g, "\t\t\tmemset(&payload, 0, sizeof(payload));\n"); for (size_t j = 0; j < json_object_array_length(ops); j++) { json_object* op = json_object_array_get_idx(ops, j); String op_name = json_object_get_string(json_object_object_get(op, "name")); @@ -48,40 +50,40 @@ static void generate_rewriter_default_fns(Growy* g, json_object* nodes) { String class = json_object_get_string(json_object_object_get(op, "class")); if (!class) { assert(!list); - growy_append_formatted(g, "\t\t\tpayload.%s = old_payload.%s;\n", op_name, op_name); + shd_growy_append_formatted(g, "\t\t\tpayload.%s = old_payload.%s;\n", op_name, op_name); continue; } if (strcmp(class, "string") == 0) { if (list) - growy_append_formatted(g, "\t\t\tpayload.%s = strings(rewriter->dst_arena, old_payload.%s.count, old_payload.%s.strings);\n", op_name, op_name, op_name); + shd_growy_append_formatted(g, "\t\t\tpayload.%s = shd_strings(rewriter->dst_arena, old_payload.%s.count, old_payload.%s.strings);\n", op_name, op_name, op_name); else - growy_append_formatted(g, "\t\t\tpayload.%s = string(rewriter->dst_arena, old_payload.%s);\n", op_name, op_name); + shd_growy_append_formatted(g, "\t\t\tpayload.%s = shd_string(rewriter->dst_arena, old_payload.%s);\n", op_name, op_name); continue; } String class_cap = capitalize(class); if (list) - growy_append_formatted(g, "\t\t\tpayload.%s = rewrite_ops_helper(rewriter, Nc%s, \"%s\", old_payload.%s);\n", op_name, class_cap, op_name, op_name); + shd_growy_append_formatted(g, "\t\t\tpayload.%s = rewrite_ops_helper(rewriter, Nc%s, \"%s\", old_payload.%s);\n", op_name, class_cap, op_name, op_name); else - growy_append_formatted(g, "\t\t\tpayload.%s = rewrite_op_helper(rewriter, Nc%s, \"%s\", old_payload.%s);\n", op_name, class_cap, op_name, op_name); - free(class_cap); + shd_growy_append_formatted(g, "\t\t\tpayload.%s = rewrite_op_helper(rewriter, Nc%s, \"%s\", old_payload.%s);\n", op_name, class_cap, op_name, op_name); + free((void*) class_cap); } - growy_append_formatted(g, "\t\t\treturn %s(rewriter->dst_arena, payload);\n", snake_name); + shd_growy_append_formatted(g, "\t\t\treturn %s(rewriter->dst_arena, payload);\n", snake_name); } else - growy_append_formatted(g, "\t\t\treturn %s(rewriter->dst_arena);\n", snake_name); - growy_append_formatted(g, "\t\t}\n", name); + shd_growy_append_formatted(g, "\t\t\treturn %s(rewriter->dst_arena);\n", snake_name); + shd_growy_append_formatted(g, "\t\t}\n", name); if (alloc) free(alloc); } - growy_append_formatted(g, "\t\tdefault: assert(false);\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "}\n\n"); + shd_growy_append_formatted(g, "\t\tdefault: assert(false);\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n\n"); } -void generate(Growy* g, Data data) { - generate_header(g, data); +void generate(Growy* g, json_object* src) { + generate_header(g, src); - json_object* nodes = json_object_object_get(data.shd, "nodes"); + json_object* nodes = json_object_object_get(src, "nodes"); generate_can_be_default_rewritten_fn(g, nodes); generate_rewriter_default_fns(g, nodes); } diff --git a/src/shady/generator_type.c b/src/shady/generator_type.c index 0f279a3bd..11043b6b3 100644 --- a/src/shady/generator_type.c +++ b/src/shady/generator_type.c @@ -1,9 +1,12 @@ #include "generator.h" -void generate(Growy* g, Data data) { - generate_header(g, data); +void generate(Growy* g, json_object* src) { + generate_header(g, src); - json_object* nodes = json_object_object_get(data.shd, "nodes"); + json_object* nodes = json_object_object_get(src, "nodes"); + + shd_growy_append_formatted(g, "const Type* _shd_check_type_generated(IrArena* a, const Node* node) {\n"); + shd_growy_append_formatted(g, "\tswitch(node->tag) {\n"); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -13,19 +16,24 @@ void generate(Growy* g, Data data) { String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); void* alloc = NULL; if (!snake_name) { - alloc = snake_name = to_snake_case(name); + snake_name = to_snake_case(name); + alloc = (void*) snake_name; } json_object* t = json_object_object_get(node, "type"); if (!t || json_object_get_boolean(t)) { + shd_growy_append_formatted(g, "\t\tcase %s_TAG: ", name); json_object* ops = json_object_object_get(node, "ops"); if (ops) - growy_append_formatted(g, "const Type* check_type_%s(IrArena*, %s);\n", snake_name, name); + shd_growy_append_formatted(g, "return _shd_check_type_%s(a, node->payload.%s);\n", snake_name, snake_name); else - growy_append_formatted(g, "const Type* check_type_%s(IrArena*);\n", snake_name); + shd_growy_append_formatted(g, "return _shd_check_type_%s(a);\n", snake_name); } if (alloc) free(alloc); } + shd_growy_append_formatted(g, "\t\tdefault: return NULL;\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n"); } diff --git a/src/shady/generator_visit.c b/src/shady/generator_visit.c index ff43113fe..5be736aba 100644 --- a/src/shady/generator_visit.c +++ b/src/shady/generator_visit.c @@ -1,11 +1,11 @@ #include "generator.h" -void generate(Growy* g, Data data) { - generate_header(g, data); +void generate(Growy* g, json_object* src) { + generate_header(g, src); - json_object* nodes = json_object_object_get(data.shd, "nodes"); - growy_append_formatted(g, "void visit_node_operands(Visitor* visitor, NodeClass exclude, const Node* node) {\n"); - growy_append_formatted(g, "\tswitch (node->tag) { \n"); + json_object* nodes = json_object_object_get(src, "nodes"); + shd_growy_append_formatted(g, "void shd_visit_node_operands(Visitor* visitor, NodeClass exclude, const Node* node) {\n"); + shd_growy_append_formatted(g, "\tswitch (node->tag) { \n"); assert(json_object_get_type(nodes) == json_type_array); for (size_t i = 0; i < json_object_array_length(nodes); i++) { json_object* node = json_object_array_get_idx(nodes, i); @@ -13,13 +13,14 @@ void generate(Growy* g, Data data) { String snake_name = json_object_get_string(json_object_object_get(node, "snake_name")); void* alloc = NULL; if (!snake_name) { - alloc = snake_name = to_snake_case(name); + snake_name = to_snake_case(name); + alloc = (void*) snake_name; } - growy_append_formatted(g, "\tcase %s_TAG: {\n", name); + shd_growy_append_formatted(g, "\tcase %s_TAG: {\n", name); json_object* ops = json_object_object_get(node, "ops"); if (ops) { assert(json_object_get_type(ops) == json_type_array); - growy_append_formatted(g, "\t\t%s payload = node->payload.%s;\n", name, snake_name); + shd_growy_append_formatted(g, "\t\t%s payload = node->payload.%s;\n", name, snake_name); for (size_t j = 0; j < json_object_array_length(ops); j++) { json_object* op = json_object_array_get_idx(ops, j); String op_name = json_object_get_string(json_object_object_get(op, "name")); @@ -30,21 +31,21 @@ void generate(Growy* g, Data data) { bool list = json_object_get_boolean(json_object_object_get(op, "list")); bool ignore = json_object_get_boolean(json_object_object_get(op, "ignore")); if (!ignore) { - growy_append_formatted(g, "\t\tif ((exclude & Nc%s) == 0)\n", class_cap); + shd_growy_append_formatted(g, "\t\tif ((exclude & Nc%s) == 0)\n", class_cap); if (list) - growy_append_formatted(g, "\t\t\tvisit_ops(visitor, Nc%s, \"%s\", payload.%s);\n", class_cap, op_name, op_name); + shd_growy_append_formatted(g, "\t\t\tshd_visit_ops(visitor, Nc%s, \"%s\", payload.%s);\n", class_cap, op_name, op_name); else - growy_append_formatted(g, "\t\t\tvisit_op(visitor, Nc%s, \"%s\", payload.%s);\n", class_cap, op_name, op_name); + shd_growy_append_formatted(g, "\t\t\tshd_visit_op(visitor, Nc%s, \"%s\", payload.%s, 0);\n", class_cap, op_name, op_name); } - free(class_cap); + free((void*) class_cap); } } - growy_append_formatted(g, "\t\tbreak;\n"); - growy_append_formatted(g, "\t}\n", name); + shd_growy_append_formatted(g, "\t\tbreak;\n"); + shd_growy_append_formatted(g, "\t}\n", name); if (alloc) free(alloc); } - growy_append_formatted(g, "\t\tdefault: assert(false);\n"); - growy_append_formatted(g, "\t}\n"); - growy_append_formatted(g, "}\n\n"); + shd_growy_append_formatted(g, "\t\tdefault: assert(false);\n"); + shd_growy_append_formatted(g, "\t}\n"); + shd_growy_append_formatted(g, "}\n\n"); } diff --git a/src/shady/internal/scheduler.slim b/src/shady/internal/scheduler.slim index afc9112cb..e0a521212 100644 --- a/src/shady/internal/scheduler.slim +++ b/src/shady/internal/scheduler.slim @@ -9,19 +9,39 @@ u32 payload; }; -@Internal subgroup u32 actual_subgroup_size; +// const u32 SpvScopeSubgroup = 3; +// const u32 SpvGroupOperationReduce = 0; +// const u32 OpGroupNonUniformElect = 333; +// const u32 OpGroupNonUniformBroadcastFirst = 338; +// const u32 OpGroupNonUniformBallot = 339; +// const u32 OpGroupNonUniformIAdd = 349; -@Internal subgroup u32 scheduler_cursor = 0; -@Internal subgroup [TreeNode; SUBGROUP_SIZE] scheduler_vector; -@Internal subgroup [u32; SUBGROUP_SIZE] resume_at; +@Internal var logical private u32 actual_subgroup_size; -@Internal subgroup u32 next_fn; -@Internal subgroup TreeNode active_branch; +@Internal var logical subgroup u32 scheduler_cursor = 0; +@Internal var logical subgroup [TreeNode; SUBGROUP_SIZE] scheduler_vector; +@Internal var logical subgroup [u32; SUBGROUP_SIZE] resume_at; + +@Internal var logical subgroup u32 next_fn; +@Internal var logical subgroup TreeNode active_branch; @Internal @Builtin("SubgroupLocalInvocationId") -input u32 subgroup_local_id; +var input u32 subgroup_local_id; + +@Internal @Builtin("SubgroupId") +var uniform input u32 subgroup_id; -@Internal @Structured @DisablePass("setup_stack_frames") @Leaf +@Internal +fn subgroup_active_mask uniform mask_t() { + return (ext_instr["spirv.core", 339, uniform mask_t](3, true)); +} + +@Internal +fn subgroup_ballot uniform mask_t(varying bool b) { + return (ext_instr["spirv.core", 339, uniform mask_t](3, b)); +} + +@Internal @Exported fn builtin_init_scheduler() { val init_mask = subgroup_active_mask(); @@ -29,10 +49,10 @@ fn builtin_init_scheduler() { scheduler_vector#(subgroup_local_id) = tree_node1; active_branch = tree_node1; - actual_subgroup_size = subgroup_reduce_sum(u32 1); + actual_subgroup_size = (ext_instr["spirv.core", 349, varying u32](3, u32 1, u32 0)); } -@Internal @Structured @DisablePass("setup_stack_frames") @Leaf +@Internal @Exported fn builtin_entry_join_point uniform JoinPoint() { val init_mask = subgroup_active_mask(); @@ -41,10 +61,10 @@ fn builtin_entry_join_point uniform JoinPoint() { return (jp); } -@Internal @Structured @Leaf +@Internal @Exported fn builtin_create_control_point varying JoinPoint(uniform u32 join_destination, varying u32 payload) { val curr_mask = subgroup_active_mask(); - val depth = subgroup_broadcast_first(scheduler_vector#(subgroup_local_id)#1); + val depth = ext_instr["spirv.core", 338, uniform u32](3, scheduler_vector#(subgroup_local_id)#1); val tree_node = composite TreeNode(curr_mask, depth); val jp = composite JoinPoint(tree_node, join_destination, payload); @@ -54,9 +74,9 @@ fn builtin_create_control_point varying JoinPoint(uniform u32 join_destination, return (jp); } -@Internal @Structured @Leaf +@Internal @Exported fn builtin_fork(varying u32 branch_destination) { - val first_branch = subgroup_broadcast_first(branch_destination); + val first_branch = ext_instr["spirv.core", 338, uniform u32](3, branch_destination); // if there is disagreement on the destination, then increase the depth of every branch val uniform_branch = subgroup_active_mask() == subgroup_ballot(first_branch == branch_destination); @@ -68,7 +88,7 @@ fn builtin_fork(varying u32 branch_destination) { // Partition the set of branch destinations and adapt the masks in turn loop() { - val elected = subgroup_broadcast_first(branch_destination); + val elected = ext_instr["spirv.core", 338, uniform u32](3, branch_destination); resume_at#(subgroup_local_id) = elected; scheduler_vector#(subgroup_local_id)#0 = subgroup_ballot(elected == branch_destination); if (elected == branch_destination) { @@ -76,11 +96,11 @@ fn builtin_fork(varying u32 branch_destination) { } } - // We must pick one branch as our 'favourite child' to schedule for immediate execution# + // We must pick one branch as our 'favourite child' to schedule for immediate execution // we could do fancy intrinsics, but for now we'll just pick the first one - if (subgroup_elect_first()) { - next_fn = subgroup_broadcast_first(branch_destination); - active_branch = subgroup_broadcast_first(scheduler_vector#(subgroup_local_id)); + if (ext_instr["spirv.core", 333, varying bool](3)) { + next_fn = ext_instr["spirv.core", 338, uniform u32](3, branch_destination); + active_branch = ext_instr["spirv.core", 338, uniform TreeNode](3, scheduler_vector#(subgroup_local_id)); // tag those variables as not in use# // resume_at#(subgroup_local_id) = -1; @@ -89,13 +109,13 @@ fn builtin_fork(varying u32 branch_destination) { } } -@Internal @Structured @Leaf +@Internal @Exported fn builtin_yield(uniform u32 resume_target) { resume_at#(subgroup_local_id) = resume_target; // resume_with#(subgroup_local_id) = subgroup_active_mask(); // only one thread runs that part - if (subgroup_elect_first()) { + if (ext_instr["spirv.core", 333, varying bool](3)) { // bump the cursor // TODO bump it in a smarter way scheduler_cursor = (scheduler_cursor + u32 1) % actual_subgroup_size; @@ -103,18 +123,18 @@ fn builtin_yield(uniform u32 resume_target) { } } -@Internal @Structured @Leaf +@Internal @Exported fn builtin_join(varying u32 join_at, varying TreeNode token) { resume_at#(subgroup_local_id) = join_at; scheduler_vector#(subgroup_local_id) = token; // only one thread runs that part - if (subgroup_elect_first()) { + if (ext_instr["spirv.core", 333, varying bool](3)) { builtin_find_schedulable_leaf(); } } -@Internal @Structured @Leaf +@Internal fn is_parent bool(varying TreeNode child, varying TreeNode maybe_parent) { val child_mask = child#0; val parent_mask = maybe_parent#0; @@ -124,43 +144,43 @@ fn is_parent bool(varying TreeNode child, varying TreeNode maybe_parent) { return (child_depth >= parent_depth); } -@Internal @Structured @Leaf +@Internal fn forward_distance u32(varying u32 x, varying u32 dst, varying u32 max_mod) { var u32 t = dst - x; t = t % max_mod; return (t); } -@Internal @Structured @Leaf +@Internal fn reduce2 u32(varying u32 a_index, varying u32 b_index) { val a = scheduler_vector#a_index; val b = scheduler_vector#b_index; - + if (is_parent(a, b)) { return (a_index); } if (is_parent(b, a)) { return (b_index); } - + val a_dist = forward_distance(a_index, scheduler_cursor, actual_subgroup_size); val b_dist = forward_distance(b_index, scheduler_cursor, actual_subgroup_size); - + if (a_dist < b_dist) { return (a_index); } return (b_index); } -@Internal @Structured @Leaf +@Internal fn builtin_find_schedulable_leaf() { var u32 reduced = u32 0; - loop (uniform u32 i = u32 1) { + loop (varying u32 i = u32 1) { if (i >= actual_subgroup_size) { break; } reduced = reduce2(reduced, i); continue(i + u32 1); } - next_fn = subgroup_broadcast_first(resume_at#reduced); - active_branch = subgroup_broadcast_first(scheduler_vector#reduced); + next_fn = ext_instr["spirv.core", 338, uniform u32](3, resume_at#reduced); + active_branch = ext_instr["spirv.core", 338, uniform TreeNode](3, scheduler_vector#reduced); return (); } -@Internal @Structured @Leaf +@Internal @Exported fn builtin_get_active_branch mask_t() { val this_thread_branch = scheduler_vector#(subgroup_local_id); val same_dest = resume_at#(subgroup_local_id) == next_fn; diff --git a/src/shady/ir.c b/src/shady/ir.c index f49260c40..b32f5abe9 100644 --- a/src/shady/ir.c +++ b/src/shady/ir.c @@ -10,129 +10,135 @@ #include #include -static KeyHash hash_nodes(Nodes* nodes); -bool compare_nodes(Nodes* a, Nodes* b); +static KeyHash shd_hash_nodes(Nodes* nodes); +bool shd_compare_nodes(Nodes* a, Nodes* b); -static KeyHash hash_strings(Strings* strings); -static bool compare_strings(Strings* a, Strings* b); +static KeyHash shd_hash_strings(Strings* strings); +static bool shd_compare_strings(Strings* a, Strings* b); -KeyHash hash_string(const char** string); -bool compare_string(const char** a, const char** b); +KeyHash shd_hash_string(const char** string); +bool shd_compare_string(const char** a, const char** b); -KeyHash hash_node(const Node**); -bool compare_node(const Node** a, const Node** b); +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node** a, const Node** b); -IrArena* new_ir_arena(ArenaConfig config) { +IrArena* shd_new_ir_arena(const ArenaConfig* config) { IrArena* arena = malloc(sizeof(IrArena)); *arena = (IrArena) { - .arena = new_arena(), - .config = config, + .arena = shd_new_arena(), + .config = *config, - .next_free_id = 0, + .modules = shd_new_list(Module*), - .modules = new_list(Module*), + .node_set = shd_new_set(const Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .string_set = shd_new_set(const char*, (HashFn) shd_hash_string, (CmpFn) shd_compare_string), - .node_set = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node), - .string_set = new_set(const char*, (HashFn) hash_string, (CmpFn) compare_string), + .nodes_set = shd_new_set(Nodes, (HashFn) shd_hash_nodes, (CmpFn) shd_compare_nodes), + .strings_set = shd_new_set(Strings, (HashFn) shd_hash_strings, (CmpFn) shd_compare_strings), - .nodes_set = new_set(Nodes, (HashFn) hash_nodes, (CmpFn) compare_nodes), - .strings_set = new_set(Strings, (HashFn) hash_strings, (CmpFn) compare_strings), + .ids = shd_new_growy(), }; return arena; } -void destroy_ir_arena(IrArena* arena) { - for (size_t i = 0; i < entries_count_list(arena->modules); i++) { - destroy_module(read_list(Module*, arena->modules)[i]); +const Node* shd_get_node_by_id(const IrArena* a, NodeId id) { + return ((const Node**) shd_growy_data(a->ids))[id]; +} + +void shd_destroy_ir_arena(IrArena* arena) { + for (size_t i = 0; i < shd_list_count(arena->modules); i++) { + shd_destroy_module(shd_read_list(Module*, arena->modules)[i]); } - destroy_list(arena->modules); - destroy_dict(arena->strings_set); - destroy_dict(arena->string_set); - destroy_dict(arena->nodes_set); - destroy_dict(arena->node_set); - destroy_arena(arena->arena); + shd_destroy_list(arena->modules); + shd_destroy_dict(arena->strings_set); + shd_destroy_dict(arena->string_set); + shd_destroy_dict(arena->nodes_set); + shd_destroy_dict(arena->node_set); + shd_destroy_arena(arena->arena); + shd_destroy_growy(arena->ids); free(arena); } -ArenaConfig get_arena_config(const IrArena* a) { - return a->config; +const ArenaConfig* shd_get_arena_config(const IrArena* a) { + return &a->config; } -VarId fresh_id(IrArena* arena) { - return arena->next_free_id++; +NodeId _shd_allocate_node_id(IrArena* arena, const Node* n) { + shd_growy_append_object(arena->ids, n); + return shd_growy_size(arena->ids) / sizeof(const Node*); } -Nodes nodes(IrArena* arena, size_t count, const Node* in_nodes[]) { +Nodes shd_nodes(IrArena* arena, size_t count, const Node* in_nodes[]) { Nodes tmp = { .count = count, .nodes = in_nodes }; - const Nodes* found = find_key_dict(Nodes, arena->nodes_set, tmp); + const Nodes* found = shd_dict_find_key(Nodes, arena->nodes_set, tmp); if (found) return *found; Nodes nodes; nodes.count = count; - nodes.nodes = arena_alloc(arena->arena, sizeof(Node*) * count); + nodes.nodes = shd_arena_alloc(arena->arena, sizeof(Node*) * count); for (size_t i = 0; i < count; i++) nodes.nodes[i] = in_nodes[i]; - insert_set_get_result(Nodes, arena->nodes_set, nodes); + shd_set_insert_get_result(Nodes, arena->nodes_set, nodes); return nodes; } -Strings strings(IrArena* arena, size_t count, const char* in_strs[]) { +Strings shd_strings(IrArena* arena, size_t count, const char* in_strs[]) { Strings tmp = { .count = count, .strings = in_strs, }; - const Strings* found = find_key_dict(Strings, arena->strings_set, tmp); + const Strings* found = shd_dict_find_key(Strings, arena->strings_set, tmp); if (found) return *found; Strings strings; strings.count = count; - strings.strings = arena_alloc(arena->arena, sizeof(const char*) * count); + strings.strings = shd_arena_alloc(arena->arena, sizeof(const char*) * count); for (size_t i = 0; i < count; i++) strings.strings[i] = in_strs[i]; - insert_set_get_result(Strings, arena->strings_set, strings); + shd_set_insert_get_result(Strings, arena->strings_set, strings); return strings; } -Nodes empty(IrArena* a) { - return nodes(a, 0, NULL); +Nodes shd_empty(IrArena* a) { + return shd_nodes(a, 0, NULL); } -Nodes singleton(const Type* type) { - IrArena* arena = type->arena; - const Type* arr[] = { type }; - return nodes(arena, 1, arr); +Nodes shd_singleton(const Node* n) { + IrArena* arena = n->arena; + const Type* arr[] = { n }; + return shd_nodes(arena, 1, arr); } -const Node* first(Nodes nodes) { +const Node* shd_first(Nodes nodes) { assert(nodes.count > 0); return nodes.nodes[0]; } -Nodes append_nodes(IrArena* arena, Nodes old, const Node* new) { +Nodes shd_nodes_append(IrArena* arena, Nodes old, const Node* new) { LARRAY(const Node*, tmp, old.count + 1); for (size_t i = 0; i < old.count; i++) tmp[i] = old.nodes[i]; tmp[old.count] = new; - return nodes(arena, old.count + 1, tmp); + return shd_nodes(arena, old.count + 1, tmp); } -Nodes prepend_nodes(IrArena* arena, Nodes old, const Node* new) { +Nodes shd_nodes_prepend(IrArena* arena, Nodes old, const Node* new) { LARRAY(const Node*, tmp, old.count + 1); for (size_t i = 0; i < old.count; i++) tmp[i + 1] = old.nodes[i]; tmp[0] = new; - return nodes(arena, old.count + 1, tmp); + return shd_nodes(arena, old.count + 1, tmp); } -Nodes concat_nodes(IrArena* arena, Nodes a, Nodes b) { +Nodes shd_concat_nodes(IrArena* arena, Nodes a, Nodes b) { LARRAY(const Node*, tmp, a.count + b.count); size_t j = 0; for (size_t i = 0; i < a.count; i++) @@ -140,15 +146,22 @@ Nodes concat_nodes(IrArena* arena, Nodes a, Nodes b) { for (size_t i = 0; i < b.count; i++) tmp[j++] = b.nodes[i]; assert(j == a.count + b.count); - return nodes(arena, j, tmp); + return shd_nodes(arena, j, tmp); } -Nodes change_node_at_index(IrArena* arena, Nodes old, size_t i, const Node* n) { +Nodes shd_change_node_at_index(IrArena* arena, Nodes old, size_t i, const Node* n) { LARRAY(const Node*, tmp, old.count); for (size_t j = 0; j < old.count; j++) - tmp[j] = old.nodes[i]; + tmp[j] = old.nodes[j]; tmp[i] = n; - return nodes(arena, old.count, tmp); + return shd_nodes(arena, old.count, tmp); +} + +bool shd_find_in_nodes(Nodes nodes, const Node* n) { + for (size_t i = 0; i < nodes.count; i++) + if (nodes.nodes[i] == n) + return true; + return false; } /// takes care of structural sharing @@ -156,19 +169,19 @@ static const char* string_impl(IrArena* arena, size_t size, const char* zero_ter if (!zero_terminated) return NULL; const char* ptr = zero_terminated; - const char** found = find_key_dict(const char*, arena->string_set, ptr); + const char** found = shd_dict_find_key(const char*, arena->string_set, ptr); if (found) return *found; - char* new_str = (char*) arena_alloc(arena->arena, strlen(zero_terminated) + 1); + char* new_str = (char*) shd_arena_alloc(arena->arena, strlen(zero_terminated) + 1); strncpy(new_str, zero_terminated, size); new_str[size] = '\0'; - insert_set_get_result(const char*, arena->string_set, new_str); + shd_set_insert_get_result(const char*, arena->string_set, new_str); return new_str; } -const char* string_sized(IrArena* arena, size_t size, const char* str) { +const char* shd_string_sized(IrArena* arena, size_t size, const char* str) { LARRAY(char, new_str, size + 1); strncpy(new_str, str, size); @@ -177,78 +190,81 @@ const char* string_sized(IrArena* arena, size_t size, const char* str) { return string_impl(arena, size, str); } -const char* string(IrArena* arena, const char* str) { +const char* shd_string(IrArena* arena, const char* str) { if (!str) return NULL; return string_impl(arena, strlen(str), str); } // TODO merge with strings() -Strings import_strings(IrArena* dst_arena, Strings old_strings) { +Strings _shd_import_strings(IrArena* dst_arena, Strings old_strings) { size_t count = old_strings.count; LARRAY(String, arr, count); for (size_t i = 0; i < count; i++) - arr[i] = string(dst_arena, old_strings.strings[i]); - return strings(dst_arena, count, arr); + arr[i] = shd_string(dst_arena, old_strings.strings[i]); + return shd_strings(dst_arena, count, arr); } -void format_string_internal(const char* str, va_list args, void* uptr, void callback(void*, size_t, char*)); +void shd_format_string_internal(const char* str, va_list args, void* uptr, void callback(void*, size_t, char*)); -typedef struct { IrArena* a; const char** result; } InternInArenaPayload; +typedef struct { + IrArena* a; + const char** result; +} InternInArenaPayload; static void intern_in_arena(InternInArenaPayload* uptr, size_t len, char* tmp) { const char* interned = string_impl(uptr->a, len, tmp); *uptr->result = interned; } -String format_string_interned(IrArena* arena, const char* str, ...) { +String shd_fmt_string_irarena(IrArena* arena, const char* str, ...) { String result = NULL; InternInArenaPayload p = { .a = arena, .result = &result }; va_list args; va_start(args, str); - format_string_internal(str, args, &p, (void(*)(void*, size_t, char*)) intern_in_arena); + shd_format_string_internal(str, args, &p, (void (*)(void*, size_t, char*)) intern_in_arena); va_end(args); return result; } -const char* unique_name(IrArena* arena, const char* str) { - return format_string_interned(arena, "%s_%d", str, fresh_id(arena)); +const char* shd_make_unique_name(IrArena* arena, const char* str) { + return shd_fmt_string_irarena(arena, "%s_%d", str, _shd_allocate_node_id(arena, NULL)); } -KeyHash hash_nodes(Nodes* nodes) { - return hash_murmur(nodes->nodes, sizeof(const Node*) * nodes->count); +KeyHash shd_hash_nodes(Nodes* nodes) { + return shd_hash(nodes->nodes, sizeof(const Node*) * nodes->count); } -bool compare_nodes(Nodes* a, Nodes* b) { +bool shd_compare_nodes(Nodes* a, Nodes* b) { if (a->count != b->count) return false; if (a->count == 0) return true; assert(a->nodes != NULL && b->nodes != NULL); return memcmp(a->nodes, b->nodes, sizeof(Node*) * (a->count)) == 0; // actually compare the data } -KeyHash hash_strings(Strings* strings) { - return hash_murmur(strings->strings, sizeof(char*) * strings->count); +KeyHash shd_hash_strings(Strings* strings) { + return shd_hash(strings->strings, sizeof(char*) * strings->count); } -bool compare_strings(Strings* a, Strings* b) { +bool shd_compare_strings(Strings* a, Strings* b) { if (a->count != b->count) return false; if (a->count == 0) return true; assert(a->strings != NULL && b->strings != NULL); return memcmp(a->strings, b->strings, sizeof(const char*) * a->count) == 0; } -KeyHash hash_string(const char** string) { +KeyHash shd_hash_string(const char** string) { if (!*string) return 0; - return hash_murmur(*string, strlen(*string)); + return shd_hash(*string, strlen(*string)); } -bool compare_string(const char** a, const char** b) { +bool shd_compare_string(const char** a, const char** b) { if (*a == NULL || *b == NULL) return (!*a) == (!*b); return strlen(*a) == strlen(*b) && strcmp(*a, *b) == 0; } -Nodes list_to_nodes(IrArena* arena, struct List* list) { - return nodes(arena, entries_count_list(list), read_list(const Node*, list)); +Nodes shd_list_to_nodes(IrArena* arena, struct List* list) { + return shd_nodes(arena, shd_list_count(list), shd_read_list(const Node*, list)); } diff --git a/src/shady/ir/CMakeLists.txt b/src/shady/ir/CMakeLists.txt new file mode 100644 index 000000000..875f9e8b8 --- /dev/null +++ b/src/shady/ir/CMakeLists.txt @@ -0,0 +1,18 @@ +target_sources(shady PRIVATE + annotation.c + mem.c + module.c + int.c + float.c + composite.c + function.c + decl.c + builtin.c + type.c + grammar.c + debug.c + memory_layout.c + stack.c + cast.c + ext.c +) \ No newline at end of file diff --git a/src/shady/ir/annotation.c b/src/shady/ir/annotation.c new file mode 100644 index 000000000..fb56b03bf --- /dev/null +++ b/src/shady/ir/annotation.c @@ -0,0 +1,86 @@ +#include "ir_private.h" + +#include "log.h" +#include "portability.h" + +#include +#include + +static const Node* search_annotations(const Node* decl, const char* name, size_t* i) { + assert(decl); + const Nodes annotations = get_declaration_annotations(decl); + while (*i < annotations.count) { + const Node* annotation = annotations.nodes[*i]; + (*i)++; + if (strcmp(get_annotation_name(annotation), name) == 0) { + return annotation; + } + } + + return NULL; +} + +const Node* shd_lookup_annotation(const Node* decl, const char* name) { + size_t i = 0; + return search_annotations(decl, name, &i); +} + +const Node* shd_lookup_annotation_list(Nodes annotations, const char* name) { + for (size_t i = 0; i < annotations.count; i++) { + if (strcmp(get_annotation_name(annotations.nodes[i]), name) == 0) { + return annotations.nodes[i]; + } + } + return NULL; +} + +const Node* shd_get_annotation_value(const Node* annotation) { + assert(annotation); + if (annotation->tag != AnnotationValue_TAG) + shd_error("This annotation does not have a single payload"); + return annotation->payload.annotation_value.value; +} + +Nodes shd_get_annotation_values(const Node* annotation) { + assert(annotation); + if (annotation->tag != AnnotationValues_TAG) + shd_error("This annotation does not have multiple payloads"); + return annotation->payload.annotation_values.values; +} + +/// Gets the string literal attached to an annotation, if present. +const char* shd_get_annotation_string_payload(const Node* annotation) { + const Node* payload = shd_get_annotation_value(annotation); + if (!payload) return NULL; + if (payload->tag != StringLiteral_TAG) + shd_error("Wrong annotation payload tag, expected a string literal") + return payload->payload.string_lit.string; +} + +bool shd_lookup_annotation_with_string_payload(const Node* decl, const char* annotation_name, const char* expected_payload) { + size_t i = 0; + while (true) { + const Node* next = search_annotations(decl, annotation_name, &i); + if (!next) return false; + if (strcmp(shd_get_annotation_string_payload(next), expected_payload) == 0) + return true; + } +} + +Nodes shd_filter_out_annotation(IrArena* arena, Nodes annotations, const char* name) { + LARRAY(const Node*, new_annotations, annotations.count); + size_t new_count = 0; + for (size_t i = 0; i < annotations.count; i++) { + if (strcmp(get_annotation_name(annotations.nodes[i]), name) != 0) { + new_annotations[new_count++] = annotations.nodes[i]; + } + } + return shd_nodes(arena, new_count, new_annotations); +} + +ExecutionModel shd_execution_model_from_string(const char* string) { +#define EM(n, _) if (strcmp(string, #n) == 0) return Em##n; + EXECUTION_MODELS(EM) +#undef EM + return EmNone; +} diff --git a/src/shady/ir/builtin.c b/src/shady/ir/builtin.c new file mode 100644 index 000000000..4d7ab6b3c --- /dev/null +++ b/src/shady/ir/builtin.c @@ -0,0 +1,148 @@ +#include "shady/ir/builtin.h" +#include "shady/ir/builder.h" +#include "shady/ir/annotation.h" +#include "shady/ir/module.h" +#include "shady/ir/mem.h" +#include "shady/ir/decl.h" + +#include "log.h" +#include "portability.h" + +#include + +#include + +static AddressSpace builtin_as[] = { +#define BUILTIN(_, as, _2) as, +SHADY_BUILTINS() +#undef BUILTIN +}; + +AddressSpace shd_get_builtin_address_space(Builtin builtin) { + if (builtin >= BuiltinsCount) + return AsGeneric; + return builtin_as[builtin]; +} + +static String builtin_names[] = { +#define BUILTIN(name, _, _2) #name, +SHADY_BUILTINS() +#undef BUILTIN +}; + +String shd_get_builtin_name(Builtin builtin) { + if (builtin >= BuiltinsCount) + return ""; + return builtin_names[builtin]; +} + +const Type* shd_get_builtin_type(IrArena* arena, Builtin builtin) { + switch (builtin) { +#define BUILTIN(name, _, datatype) case Builtin##name: return datatype; +SHADY_BUILTINS() +#undef BUILTIN + default: shd_error("Unhandled builtin") + } +} + +// What's the decoration for the builtin +static SpvBuiltIn spv_builtins[] = { +#define BUILTIN(name, _, _2) SpvBuiltIn##name, +SHADY_BUILTINS() +#undef BUILTIN +}; + +Builtin shd_get_builtin_by_name(String s) { + for (size_t i = 0; i < BuiltinsCount; i++) { + if (strcmp(s, builtin_names[i]) == 0) { + return i; + } + } + return BuiltinsCount; +} + +Builtin shd_get_builtin_by_spv_id(SpvBuiltIn id) { + Builtin b = BuiltinsCount; + for (size_t i = 0; i < BuiltinsCount; i++) { + if (id == spv_builtins[i]) { + b = i; + break; + } + } + return b; +} + +Builtin shd_get_decl_builtin(const Node* decl) { + const Node* a = shd_lookup_annotation(decl, "Builtin"); + if (!a) + return BuiltinsCount; + String payload = shd_get_annotation_string_payload(a); + return shd_get_builtin_by_name(payload); +} + + +bool shd_is_decl_builtin(const Node* decl) { + return shd_get_decl_builtin(decl) != BuiltinsCount; +} + +int32_t shd_get_builtin_spv_id(Builtin builtin) { + if (builtin >= BuiltinsCount) + return 0; + return spv_builtins[builtin]; +} + +bool shd_is_builtin_load_op(const Node* n, Builtin* out) { + assert(is_instruction(n)); + if (n->tag == Load_TAG) { + const Node* src = n->payload.load.ptr; + if (src->tag == RefDecl_TAG) + src = src->payload.ref_decl.decl; + if (src->tag == GlobalVariable_TAG) { + const Node* a = shd_lookup_annotation(src, "Builtin"); + if (a) { + String bn = shd_get_annotation_string_payload(a); + assert(bn); + Builtin b = shd_get_builtin_by_name(bn); + if (b != BuiltinsCount) { + *out = b; + return true; + } + } + } + } + return false; +} + +const Node* shd_get_builtin(Module* m, Builtin b) { + Nodes decls = shd_module_get_declarations(m); + for (size_t i = 0; i < decls.count; i++) { + const Node* decl = decls.nodes[i]; + if (decl->tag != GlobalVariable_TAG) + continue; + const Node* a = shd_lookup_annotation(decl, "Builtin"); + if (!a) + continue; + String builtin_name = shd_get_annotation_string_payload(a); + assert(builtin_name); + if (strcmp(builtin_name, shd_get_builtin_name(b)) == 0) + return decl; + } + + return NULL; +} + +const Node* shd_get_or_create_builtin(Module* m, Builtin b, String n) { + const Node* decl = shd_get_builtin(m, b); + if (decl) + return decl; + + AddressSpace as = shd_get_builtin_address_space(b); + IrArena* a = shd_module_get_arena(m); + decl = global_var(m, shd_singleton(annotation_value_helper(a, "Builtin", string_lit_helper(a,shd_get_builtin_name(b)))), + shd_get_builtin_type(a, b), n ? n : shd_fmt_string_irarena(a, "builtin_%s", shd_get_builtin_name(b)), as); + return decl; +} + +const Node* shd_bld_builtin_load(Module* m, BodyBuilder* bb, Builtin b) { + return shd_bld_load(bb, ref_decl_helper(shd_module_get_arena(m), shd_get_or_create_builtin(m, b, NULL))); +} diff --git a/src/shady/ir/cast.c b/src/shady/ir/cast.c new file mode 100644 index 000000000..a4d7f845d --- /dev/null +++ b/src/shady/ir/cast.c @@ -0,0 +1,59 @@ +#include "shady/ir/cast.h" +#include "shady/ir/grammar.h" +#include "shady/ir/type.h" +#include "shady/ir/memory_layout.h" + +#include + +const Node* shd_bld_reinterpret_cast(BodyBuilder* bb, const Type* dst, const Node* src) { + assert(is_type(dst)); + return prim_op(shd_get_bb_arena(bb), (PrimOp) { .op = reinterpret_op, .operands = shd_singleton(src), .type_arguments = shd_singleton(dst)}); +} + +const Node* shd_bld_conversion(BodyBuilder* bb, const Type* dst, const Node* src) { + assert(is_type(dst)); + return prim_op(shd_get_bb_arena(bb), (PrimOp) { .op = convert_op, .operands = shd_singleton(src), .type_arguments = shd_singleton(dst)}); +} + +bool shd_is_reinterpret_cast_legal(const Type* src_type, const Type* dst_type) { + assert(shd_is_data_type(src_type) && shd_is_data_type(dst_type)); + if (src_type == dst_type) + return true; // folding will eliminate those, but we need to pass type-checking first :) + if (!(shd_is_arithm_type(src_type) || src_type->tag == MaskType_TAG || shd_is_physical_ptr_type(src_type))) + return false; + if (!(shd_is_arithm_type(dst_type) || dst_type->tag == MaskType_TAG || shd_is_physical_ptr_type(dst_type))) + return false; + assert(shd_get_type_bitwidth(src_type) == shd_get_type_bitwidth(dst_type)); + // either both pointers need to be in the generic address space, and we're only casting the element type, OR neither can be + if ((shd_is_physical_ptr_type(src_type) && shd_is_physical_ptr_type(dst_type)) && (shd_is_generic_ptr_type(src_type) != shd_is_generic_ptr_type(dst_type))) + return false; + return true; +} + +bool shd_is_conversion_legal(const Type* src_type, const Type* dst_type) { + assert(shd_is_data_type(src_type) && shd_is_data_type(dst_type)); + if (!(shd_is_arithm_type(src_type) || (shd_is_physical_ptr_type(src_type) && shd_get_type_bitwidth(src_type) == shd_get_type_bitwidth(dst_type)))) + return false; + if (!(shd_is_arithm_type(dst_type) || (shd_is_physical_ptr_type(dst_type) && shd_get_type_bitwidth(src_type) == shd_get_type_bitwidth(dst_type)))) + return false; + // we only allow ptr-ptr conversions, use reinterpret otherwise + if (shd_is_physical_ptr_type(src_type) != shd_is_physical_ptr_type(dst_type)) + return false; + // exactly one of the pointers needs to be in the generic address space + if (shd_is_generic_ptr_type(src_type) && shd_is_generic_ptr_type(dst_type)) + return false; + if (src_type->tag == Int_TAG && dst_type->tag == Int_TAG) { + bool changes_sign = src_type->payload.int_type.is_signed != dst_type->payload.int_type.is_signed; + bool changes_width = src_type->payload.int_type.width != dst_type->payload.int_type.width; + if (changes_sign && changes_width) + return false; + } + // element types have to match (use reinterpret_cast for changing it) + if (shd_is_physical_ptr_type(src_type) && shd_is_physical_ptr_type(dst_type)) { + AddressSpace src_as = src_type->payload.ptr_type.address_space; + AddressSpace dst_as = dst_type->payload.ptr_type.address_space; + if (src_type->payload.ptr_type.pointed_type != dst_type->payload.ptr_type.pointed_type) + return false; + } + return true; +} diff --git a/src/shady/ir/composite.c b/src/shady/ir/composite.c new file mode 100644 index 000000000..281e001c7 --- /dev/null +++ b/src/shady/ir/composite.c @@ -0,0 +1,105 @@ +#include "shady/ir/composite.h" + +#include "ir_private.h" + +#include "log.h" +#include "portability.h" + +#include + +const Node* shd_extract_helper(IrArena* a, const Node* base, Nodes selectors) { + LARRAY(const Node*, ops, 1 + selectors.count); + ops[0] = base; + for (size_t i = 0; i < selectors.count; i++) + ops[1 + i] = selectors.nodes[i]; + return prim_op_helper(a, extract_op, shd_empty(a), shd_nodes(a, 1 + selectors.count, ops)); +} + +const Node* shd_extract_single_helper(IrArena* a, const Node* composite, const Node* index) { + return prim_op_helper(a, extract_op, shd_empty(a), mk_nodes(a, composite, index)); +} + +const Node* shd_maybe_tuple_helper(IrArena* a, Nodes values) { + if (values.count == 1) + return shd_first(values); + return shd_tuple_helper(a, values); +} + +const Node* shd_tuple_helper(IrArena* a, Nodes contents) { + const Type* t = NULL; + if (a->config.check_types) { + // infer the type of the tuple + Nodes member_types = shd_get_values_types(a, contents); + t = record_type(a, (RecordType) {.members = shd_strip_qualifiers(a, member_types)}); + } + + return composite_helper(a, t, contents); +} + +void shd_enter_composite_type(const Type** datatype, bool* uniform, const Node* selector, bool allow_entering_pack) { + const Type* current_type = *datatype; + + if (selector->arena->config.check_types) { + const Type* selector_type = selector->type; + bool selector_uniform = shd_deconstruct_qualified_type(&selector_type); + assert(selector_type->tag == Int_TAG && "selectors must be integers"); + *uniform &= selector_uniform; + } + + try_again: + switch (current_type->tag) { + case RecordType_TAG: { + size_t selector_value = shd_get_int_literal_value(*shd_resolve_to_int_literal(selector), false); + assert(selector_value < current_type->payload.record_type.members.count); + current_type = current_type->payload.record_type.members.nodes[selector_value]; + break; + } + case ArrType_TAG: { + current_type = current_type->payload.arr_type.element_type; + break; + } + case TypeDeclRef_TAG: { + const Node* nom_decl = current_type->payload.type_decl_ref.decl; + assert(nom_decl->tag == NominalType_TAG); + current_type = nom_decl->payload.nom_type.body; + goto try_again; + } + case PackType_TAG: { + assert(allow_entering_pack); + assert(selector->tag == IntLiteral_TAG && "selectors when indexing into a pack type need to be constant"); + size_t selector_value = shd_get_int_literal_value(*shd_resolve_to_int_literal(selector), false); + assert(selector_value < current_type->payload.pack_type.width); + current_type = current_type->payload.pack_type.element_type; + break; + } + // also remember to assert literals for the selectors ! + default: { + shd_log_fmt(ERROR, "Trying to enter non-composite type '"); + shd_log_node(ERROR, current_type); + shd_log_fmt(ERROR, "' with selector '"); + shd_log_node(ERROR, selector); + shd_log_fmt(ERROR, "'."); + shd_error(""); + } + } + *datatype = current_type; +} + +void shd_enter_composite_type_indices(const Type** datatype, bool* uniform, Nodes indices, bool allow_entering_pack) { + for(size_t i = 0; i < indices.count; i++) { + const Node* selector = indices.nodes[i]; + shd_enter_composite_type(datatype, uniform, selector, allow_entering_pack); + } +} + +Nodes shd_deconstruct_composite(IrArena* a, const Node* value, size_t outputs_count) { + if (outputs_count > 1) { + LARRAY(const Node*, extracted, outputs_count); + for (size_t i = 0; i < outputs_count; i++) + extracted[i] = shd_extract_single_helper(a, value, shd_int32_literal(a, i)); + return shd_nodes(a, outputs_count, extracted); + } else if (outputs_count == 1) + return shd_singleton(value); + else + return shd_empty(a); +} diff --git a/src/shady/ir/debug.c b/src/shady/ir/debug.c new file mode 100644 index 000000000..92584c192 --- /dev/null +++ b/src/shady/ir/debug.c @@ -0,0 +1,35 @@ +#include "shady/ir/debug.h" +#include "shady/ir/grammar.h" + +#include +#include + +String shd_get_value_name_unsafe(const Node* v) { + assert(v && is_value(v)); + if (v->tag == Param_TAG) + return v->payload.param.name; + return NULL; +} + +String shd_get_value_name_safe(const Node* v) { + String name = shd_get_value_name_unsafe(v); + if (name && strlen(name) > 0) + return name; + //if (v->tag == Variable_TAG) + return shd_fmt_string_irarena(v->arena, "%%%d", v->id); + //return node_tags[v->tag]; +} + +void shd_set_value_name(const Node* var, String name) { + // TODO: annotations + // if (var->tag == Variablez_TAG) + // var->payload.varz.name = shd_string(var->arena, name); +} + +void shd_bld_comment(BodyBuilder* bb, String str) { + shd_bld_add_instruction_extract(bb, comment(shd_get_bb_arena(bb), (Comment) { .string = str, .mem = shd_bb_mem(bb) })); +} + +void shd_bld_debug_printf(BodyBuilder* bb, String pattern, Nodes args) { + shd_bld_add_instruction(bb, debug_printf(shd_get_bb_arena(bb), (DebugPrintf) { .string = pattern, .args = args, .mem = shd_bb_mem(bb) })); +} diff --git a/src/shady/ir/decl.c b/src/shady/ir/decl.c new file mode 100644 index 000000000..50add0afc --- /dev/null +++ b/src/shady/ir/decl.c @@ -0,0 +1,70 @@ +#include "shady/ir/decl.h" +#include "shady/rewrite.h" + +#include "../ir_private.h" + +#include + +bool shd_compare_nodes(Nodes* a, Nodes* b); + +Node* _shd_constant(Module* mod, Nodes annotations, const Type* hint, String name) { + IrArena* arena = mod->arena; + Constant cnst = { + .annotations = annotations, + .name = shd_string(arena, name), + .type_hint = hint, + .value = NULL, + }; + Node node; + memset((void*) &node, 0, sizeof(Node)); + node = (Node) { + .arena = arena, + .tag = Constant_TAG, + .payload.constant = cnst + }; + Node* decl = _shd_create_node_helper(arena, node, NULL); + _shd_module_add_decl(mod, decl); + return decl; +} + +Node* _shd_global_var(Module* mod, Nodes annotations, const Type* type, const char* name, AddressSpace as) { + const Node* existing = shd_module_get_declaration(mod, name); + if (existing) { + assert(existing->tag == GlobalVariable_TAG); + assert(existing->payload.global_variable.type == type); + assert(existing->payload.global_variable.address_space == as); + assert(!mod->arena->config.check_types || shd_compare_nodes((Nodes*) &existing->payload.global_variable.annotations, &annotations)); + return (Node*) existing; + } + + IrArena* arena = mod->arena; + GlobalVariable gvar = { + .annotations = annotations, + .name = shd_string(arena, name), + .type = type, + .address_space = as, + .init = NULL, + }; + + Node node; + memset((void*) &node, 0, sizeof(Node)); + node = (Node) { + .arena = arena, + .tag = GlobalVariable_TAG, + .payload.global_variable = gvar + }; + Node* decl = _shd_create_node_helper(arena, node, NULL); + _shd_module_add_decl(mod, decl); + return decl; +} + +const Node* shd_find_or_process_decl(Rewriter* rewriter, const char* name) { + Nodes old_decls = shd_module_get_declarations(rewriter->src_module); + for (size_t i = 0; i < old_decls.count; i++) { + const Node* decl = old_decls.nodes[i]; + if (strcmp(get_declaration_name(decl), name) == 0) { + return shd_rewrite_node(rewriter, decl); + } + } + assert(false); +} diff --git a/src/shady/ir/ext.c b/src/shady/ir/ext.c new file mode 100644 index 000000000..b10c8fcab --- /dev/null +++ b/src/shady/ir/ext.c @@ -0,0 +1,12 @@ +#include "shady/ir/ext.h" +#include "shady/ir/grammar.h" + +const Node* shd_bld_ext_instruction(BodyBuilder* bb, String set, int opcode, const Type* return_t, Nodes operands) { + return shd_bld_add_instruction(bb, ext_instr(shd_get_bb_arena(bb), (ExtInstr) { + .mem = shd_bb_mem(bb), + .set = set, + .opcode = opcode, + .result_t = return_t, + .operands = operands, + })); +} diff --git a/src/shady/ir/float.c b/src/shady/ir/float.c new file mode 100644 index 000000000..87361d3ba --- /dev/null +++ b/src/shady/ir/float.c @@ -0,0 +1,59 @@ +#include "shady/ir/float.h" + +#include "shady/analysis/literal.h" + +#include "log.h" + +#include + +const Type* shd_fp16_type(IrArena* arena) { return float_type(arena, (Float) { .width = FloatTy16 }); } +const Type* shd_fp32_type(IrArena* arena) { return float_type(arena, (Float) { .width = FloatTy32 }); } +const Type* shd_fp64_type(IrArena* arena) { return float_type(arena, (Float) { .width = FloatTy64 }); } + +const Node* shd_fp_literal_helper(IrArena* a, FloatSizes size, double value) { + switch (size) { + case FloatTy16: assert(false); break; + case FloatTy32: { + float f = value; + uint64_t bits = 0; + memcpy(&bits, &f, sizeof(f)); + return float_literal(a, (FloatLiteral) { .width = size, .value = bits }); + } + case FloatTy64: { + uint64_t bits = 0; + memcpy(&bits, &value, sizeof(value)); + return float_literal(a, (FloatLiteral) { .width = size, .value = bits }); + } + } +} + +const FloatLiteral* shd_resolve_to_float_literal(const Node* node) { + node = shd_resolve_node_to_definition(node, shd_default_node_resolve_config()); + if (!node) + return NULL; + if (node->tag == FloatLiteral_TAG) + return &node->payload.float_literal; + return NULL; +} + +static_assert(sizeof(float) == sizeof(uint64_t) / 2, "floats aren't the size we expect"); +double shd_get_float_literal_value(FloatLiteral literal) { + double r; + switch (literal.width) { + case FloatTy16: + shd_error_print("TODO: fp16 literals"); + shd_error_die(); + SHADY_UNREACHABLE; + break; + case FloatTy32: { + float f; + memcpy(&f, &literal.value, sizeof(float)); + r = (double) f; + break; + } + case FloatTy64: + memcpy(&r, &literal.value, sizeof(double)); + break; + } + return r; +} diff --git a/src/shady/ir/function.c b/src/shady/ir/function.c new file mode 100644 index 000000000..b5dee81cb --- /dev/null +++ b/src/shady/ir/function.c @@ -0,0 +1,150 @@ +#include "shady/ir/function.h" + +#include "../ir_private.h" + +#include + +Node* _shd_param(IrArena* arena, const Type* type, const char* name) { + Param param = { + .type = type, + .name = shd_string(arena, name), + }; + + Node node; + memset((void*) &node, 0, sizeof(Node)); + node = (Node) { + .arena = arena, + .tag = Param_TAG, + .payload.param = param + }; + return _shd_create_node_helper(arena, node, NULL); +} + +Node* _shd_function(Module* mod, Nodes params, const char* name, Nodes annotations, Nodes return_types) { + assert(!mod->sealed); + IrArena* arena = mod->arena; + Function payload = { + .module = mod, + .params = params, + .body = NULL, + .name = name, + .annotations = annotations, + .return_types = return_types, + }; + + Node node; + memset((void*) &node, 0, sizeof(Node)); + node = (Node) { + .arena = arena, + .tag = Function_TAG, + .payload.fun = payload + }; + Node* fn = _shd_create_node_helper(arena, node, NULL); + _shd_module_add_decl(mod, fn); + + for (size_t i = 0; i < params.count; i++) { + Node* param = (Node*) params.nodes[i]; + assert(param->tag == Param_TAG); + assert(!param->payload.param.abs); + param->payload.param.abs = fn; + param->payload.param.pindex = i; + } + + return fn; +} + +Node* _shd_basic_block(IrArena* arena, Nodes params, const char* name) { + BasicBlock payload = { + .params = params, + .body = NULL, + .name = name, + }; + + Node node; + memset((void*) &node, 0, sizeof(Node)); + node = (Node) { + .arena = arena, + .tag = BasicBlock_TAG, + .payload.basic_block = payload + }; + + Node* bb = _shd_create_node_helper(arena, node, NULL); + + for (size_t i = 0; i < params.count; i++) { + Node* param = (Node*) params.nodes[i]; + assert(param->tag == Param_TAG); + assert(!param->payload.param.abs); + param->payload.param.abs = bb; + param->payload.param.pindex = i; + } + + return bb; +} + +const Node* shd_get_abstraction_mem(const Node* abs) { + return abs_mem(abs->arena, (AbsMem) { .abs = abs }); +} + +String shd_get_abstraction_name(const Node* abs) { + assert(is_abstraction(abs)); + switch (abs->tag) { + case Function_TAG: return abs->payload.fun.name; + case BasicBlock_TAG: return abs->payload.basic_block.name; + default: assert(false); + } +} + +String shd_get_abstraction_name_unsafe(const Node* abs) { + assert(is_abstraction(abs)); + switch (abs->tag) { + case Function_TAG: return abs->payload.fun.name; + case BasicBlock_TAG: return abs->payload.basic_block.name; + default: assert(false); + } +} + +String shd_get_abstraction_name_safe(const Node* abs) { + String name = shd_get_abstraction_name_unsafe(abs); + if (name) + return name; + return shd_fmt_string_irarena(abs->arena, "%%%d", abs->id); +} + +void shd_set_abstraction_body(Node* abs, const Node* body) { + assert(is_abstraction(abs)); + assert(!body || is_terminator(body)); + IrArena* a = abs->arena; + + if (body) { + while (true) { + const Node* mem0 = shd_get_original_mem(get_terminator_mem(body)); + assert(mem0->tag == AbsMem_TAG); + Node* mem_abs = mem0->payload.abs_mem.abs; + if (is_basic_block(mem_abs)) { + BodyBuilder* insert = mem_abs->payload.basic_block.insert; + if (insert && mem_abs != abs) { + const Node* mem = _shd_bb_insert_mem(insert); + const Node* block = _shd_bb_insert_block(insert); + shd_set_abstraction_body((Node*) block, _shd_bld_finish_pseudo_instr(insert, body)); + body = jump_helper(a, mem, block, shd_empty(a)); + // mem_abs->payload.basic_block.insert = NULL; + continue; + } + } + assert(mem_abs == abs); + break; + } + } + + switch (abs->tag) { + case Function_TAG: abs->payload.fun.body = body; break; + case BasicBlock_TAG: abs->payload.basic_block.body = body; break; + default: assert(false); + } +} + +Nodes shd_bld_call(BodyBuilder* bb, const Node* callee, Nodes args) { + assert(shd_get_arena_config(shd_get_bb_arena(bb))->check_types); + const Node* instruction = call(shd_get_bb_arena(bb), (Call) { .callee = callee, .args = args, .mem = shd_bb_mem(bb) }); + return shd_bld_add_instruction_extract(bb, instruction); +} diff --git a/src/shady/ir/grammar.c b/src/shady/ir/grammar.c new file mode 100644 index 000000000..0cc2830c6 --- /dev/null +++ b/src/shady/ir/grammar.c @@ -0,0 +1,60 @@ +#include "shady/ir/grammar.h" + +#include "../fold.h" + +#include "log.h" +#include "portability.h" +#include "dict.h" + +#include +#include + +Strings _shd_import_strings(IrArena* dst_arena, Strings old_strings); + +static void pre_construction_validation(IrArena* arena, Node* node); + +const Node* _shd_fold_node_operand(NodeTag tag, NodeClass nc, String opname, const Node* op); + +const Type* _shd_check_type_generated(IrArena* a, const Node* node); + +Node* _shd_create_node_helper(IrArena* arena, Node node, bool* pfresh) { + pre_construction_validation(arena, &node); + if (arena->config.check_types) + node.type = _shd_check_type_generated(arena, &node); + + if (pfresh) + *pfresh = false; + + Node* ptr = &node; + Node** found = shd_dict_find_key(Node*, arena->node_set, ptr); + // sanity check nominal nodes to be unique, check for duplicates in structural nodes + if (shd_is_node_nominal(&node)) + assert(!found); + else if (found) + return *found; + + if (pfresh) + *pfresh = true; + + if (arena->config.allow_fold) { + Node* folded = (Node*) _shd_fold_node(arena, ptr); + if (folded != ptr) { + // The folding process simplified the node, we store a mapping to that simplified node and bail out ! + shd_set_insert_get_result(Node*, arena->node_set, folded); + return folded; + } + } + + if (arena->config.check_types && node.type) + assert(is_type(node.type)); + + // place the node in the arena and return it + Node* alloc = (Node*) shd_arena_alloc(arena->arena, sizeof(Node)); + *alloc = node; + alloc->id = _shd_allocate_node_id(arena, alloc); + shd_set_insert_get_result(const Node*, arena->node_set, alloc); + + return alloc; +} + +#include "../constructors_generated.c" diff --git a/src/shady/ir/int.c b/src/shady/ir/int.c new file mode 100644 index 000000000..b055f5648 --- /dev/null +++ b/src/shady/ir/int.c @@ -0,0 +1,123 @@ +#include "shady/ir/int.h" +#include "shady/ir/type.h" + +#include "shady/analysis/literal.h" + +#include "log.h" + +#include + +const Type* shd_int_type_helper(IrArena* a, bool s, IntSizes w) { return int_type(a, (Int) { .width = w, .is_signed = s }); } + +const Type* shd_int8_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy8 , .is_signed = true }); } +const Type* shd_int16_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy16, .is_signed = true }); } +const Type* shd_int32_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy32, .is_signed = true }); } +const Type* shd_int64_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy64, .is_signed = true }); } + +const Type* shd_uint8_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy8 , .is_signed = false }); } +const Type* shd_uint16_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy16, .is_signed = false }); } +const Type* shd_uint32_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy32, .is_signed = false }); } +const Type* shd_uint64_type(IrArena* arena) { return int_type(arena, (Int) { .width = IntTy64, .is_signed = false }); } + +const Node* shd_int8_literal (IrArena* arena, int8_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy8, .value = (uint64_t) (uint8_t) i, .is_signed = true }); } +const Node* shd_int16_literal(IrArena* arena, int16_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy16, .value = (uint64_t) (uint16_t) i, .is_signed = true }); } +const Node* shd_int32_literal(IrArena* arena, int32_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy32, .value = (uint64_t) (uint32_t) i, .is_signed = true }); } +const Node* shd_int64_literal(IrArena* arena, int64_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy64, .value = (uint64_t) i, .is_signed = true }); } + +const Node* shd_uint8_literal (IrArena* arena, uint8_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy8, .value = (int64_t) i }); } +const Node* shd_uint16_literal(IrArena* arena, uint16_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy16, .value = (int64_t) i }); } +const Node* shd_uint32_literal(IrArena* arena, uint32_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy32, .value = (int64_t) i }); } +const Node* shd_uint64_literal(IrArena* arena, uint64_t i) { return int_literal(arena, (IntLiteral) { .width = IntTy64, .value = i }); } + +const IntLiteral* shd_resolve_to_int_literal(const Node* node) { + node = shd_resolve_node_to_definition(node, shd_default_node_resolve_config()); + if (!node) + return NULL; + if (node->tag == IntLiteral_TAG) + return &node->payload.int_literal; + return NULL; +} + +int64_t shd_get_int_literal_value(IntLiteral literal, bool sign_extend) { + if (sign_extend) { + switch (literal.width) { + case IntTy8: return (int64_t) (int8_t) (literal.value & 0xFF); + case IntTy16: return (int64_t) (int16_t) (literal.value & 0xFFFF); + case IntTy32: return (int64_t) (int32_t) (literal.value & 0xFFFFFFFF); + case IntTy64: return (int64_t) literal.value; + default: assert(false); + } + } else { + switch (literal.width) { + case IntTy8: return literal.value & 0xFF; + case IntTy16: return literal.value & 0xFFFF; + case IntTy32: return literal.value & 0xFFFFFFFF; + case IntTy64: return literal.value; + default: assert(false); + } + } +} + +int64_t shd_get_int_value(const Node* node, bool sign_extend) { + const IntLiteral* lit = shd_resolve_to_int_literal(node); + if (!lit) shd_error("Not a literal"); + return shd_get_int_literal_value(*lit, sign_extend); +} + +const Node* shd_bld_convert_int_extend_according_to_src_t(BodyBuilder* bb, const Type* dst_type, const Node* src) { + IrArena* a = shd_get_bb_arena(bb); + const Type* src_type = shd_get_unqualified_type(src->type); + assert(src_type->tag == Int_TAG); + assert(dst_type->tag == Int_TAG); + + // first convert to final bitsize then bitcast + const Type* extended_src_t = int_type(shd_get_bb_arena(bb), (Int) { .width = dst_type->payload.int_type.width, .is_signed = src_type->payload.int_type.is_signed }); + const Node* val = src; + val = prim_op_helper(a, convert_op, shd_singleton(extended_src_t), shd_singleton(val)); + val = prim_op_helper(a, reinterpret_op, shd_singleton(dst_type), shd_singleton(val)); + return val; +} + +const Node* shd_bld_convert_int_extend_according_to_dst_t(BodyBuilder* bb, const Type* dst_type, const Node* src) { + IrArena* a = shd_get_bb_arena(bb); + const Type* src_type = shd_get_unqualified_type(src->type); + assert(src_type->tag == Int_TAG); + assert(dst_type->tag == Int_TAG); + + // first bitcast then convert to final bitsize + const Type* reinterpreted_src_t = int_type(shd_get_bb_arena(bb), (Int) { .width = src_type->payload.int_type.width, .is_signed = dst_type->payload.int_type.is_signed }); + const Node* val = src; + val = prim_op_helper(a, reinterpret_op, shd_singleton(reinterpreted_src_t), shd_singleton(val)); + val = prim_op_helper(a, convert_op, shd_singleton(dst_type), shd_singleton(val)); + return val; +} + +const Node* shd_bld_convert_int_zero_extend(BodyBuilder* bb, const Type* dst_type, const Node* src) { + IrArena* a = shd_get_bb_arena(bb); + const Type* src_type = shd_get_unqualified_type(src->type); + assert(src_type->tag == Int_TAG); + assert(dst_type->tag == Int_TAG); + + const Node* val = src; + val = prim_op_helper(a, reinterpret_op, shd_singleton( + int_type(shd_get_bb_arena(bb), (Int) { .width = src_type->payload.int_type.width, .is_signed = false })), shd_singleton(val)); + val = prim_op_helper(a, convert_op, shd_singleton( + int_type(shd_get_bb_arena(bb), (Int) { .width = dst_type->payload.int_type.width, .is_signed = false })), shd_singleton(val)); + val = prim_op_helper(a, reinterpret_op, shd_singleton(dst_type), shd_singleton(val)); + return val; +} + +const Node* shd_bld_convert_int_sign_extend(BodyBuilder* bb, const Type* dst_type, const Node* src) { + IrArena* a = shd_get_bb_arena(bb); + const Type* src_type = shd_get_unqualified_type(src->type); + assert(src_type->tag == Int_TAG); + assert(dst_type->tag == Int_TAG); + + const Node* val = src; + val = prim_op_helper(a, reinterpret_op, shd_singleton( + int_type(shd_get_bb_arena(bb), (Int) { .width = src_type->payload.int_type.width, .is_signed = true })), shd_singleton(val)); + val = prim_op_helper(a, convert_op, shd_singleton( + int_type(shd_get_bb_arena(bb), (Int) { .width = dst_type->payload.int_type.width, .is_signed = true })), shd_singleton(val)); + val = prim_op_helper(a, reinterpret_op, shd_singleton(dst_type), shd_singleton(val)); + return val; +} diff --git a/src/shady/ir/mem.c b/src/shady/ir/mem.c new file mode 100644 index 000000000..a3ff90620 --- /dev/null +++ b/src/shady/ir/mem.c @@ -0,0 +1,87 @@ +#include "shady/ir/grammar.h" +#include "shady/ir/builder.h" + +#include + +#pragma GCC diagnostic error "-Wswitch" + +const Node* shd_get_parent_mem(const Node* mem) { + assert(is_mem(mem)); + switch (is_mem(mem)) { + case NotAMem: return NULL; + case Mem_AbsMem_TAG: + return NULL; + case Mem_Call_TAG: + mem = mem->payload.call.mem; + return mem; + case Mem_MemAndValue_TAG: + mem = mem->payload.mem_and_value.mem; + return mem; + case Mem_Comment_TAG: + mem = mem->payload.comment.mem; + return mem; + case Mem_StackAlloc_TAG: + mem = mem->payload.stack_alloc.mem; + return mem; + case Mem_LocalAlloc_TAG: + mem = mem->payload.local_alloc.mem; + return mem; + case Mem_Load_TAG: + mem = mem->payload.load.mem; + return mem; + case Mem_Store_TAG: + mem = mem->payload.store.mem; + return mem; + case Mem_CopyBytes_TAG: + mem = mem->payload.copy_bytes.mem; + return mem; + case Mem_FillBytes_TAG: + mem = mem->payload.fill_bytes.mem; + return mem; + case Mem_PushStack_TAG: + mem = mem->payload.push_stack.mem; + return mem; + case Mem_PopStack_TAG: + mem = mem->payload.pop_stack.mem; + return mem; + case Mem_GetStackSize_TAG: + mem = mem->payload.get_stack_size.mem; + return mem; + case Mem_SetStackSize_TAG: + mem = mem->payload.set_stack_size.mem; + return mem; + case Mem_DebugPrintf_TAG: + mem = mem->payload.debug_printf.mem; + return mem; + case Mem_ExtInstr_TAG: + mem = mem->payload.ext_instr.mem; + return mem; + } +} + +const Node* shd_get_original_mem(const Node* mem) { + while (true) { + const Node* nmem = shd_get_parent_mem(mem); + if (nmem) { + mem = nmem; + continue; + } + return mem; + } +} + +const Node* shd_bld_stack_alloc(BodyBuilder* bb, const Type* type) { + return shd_first(shd_bld_add_instruction_extract(bb, stack_alloc(shd_get_bb_arena(bb), (StackAlloc) { .type = type, .mem = shd_bb_mem(bb) }))); +} + +const Node* shd_bld_local_alloc(BodyBuilder* bb, const Type* type) { + return shd_first(shd_bld_add_instruction_extract(bb, local_alloc(shd_get_bb_arena(bb), (LocalAlloc) { .type = type, .mem = shd_bb_mem(bb) }))); +} + +const Node* shd_bld_load(BodyBuilder* bb, const Node* ptr) { + return shd_first(shd_bld_add_instruction_extract(bb, load(shd_get_bb_arena(bb), (Load) { .ptr = ptr, .mem = shd_bb_mem(bb) }))); +} + +void shd_bld_store(BodyBuilder* bb, const Node* ptr, const Node* value) { + shd_bld_add_instruction_extract(bb, store(shd_get_bb_arena(bb), (Store) { .ptr = ptr, .value = value, .mem = shd_bb_mem(bb) })); +} diff --git a/src/shady/ir/memory_layout.c b/src/shady/ir/memory_layout.c new file mode 100644 index 000000000..e77257bac --- /dev/null +++ b/src/shady/ir/memory_layout.c @@ -0,0 +1,161 @@ +#include "shady/ir/memory_layout.h" +#include "shady/ir/float.h" +#include "shady/ir/type.h" + +#include "log.h" +#include "portability.h" + +#include + +inline static size_t round_up(size_t a, size_t b) { + if (b == 0) + return a; + size_t divided = (a + b - 1) / b; + return divided * b; +} + +static int maxof(int a, int b) { + if (a > b) + return a; + return b; +} + +TypeMemLayout shd_get_record_layout(IrArena* a, const Node* record_type, FieldLayout* fields) { + assert(record_type->tag == RecordType_TAG); + + size_t offset = 0; + size_t max_align = 0; + + Nodes member_types = record_type->payload.record_type.members; + for (size_t i = 0; i < member_types.count; i++) { + TypeMemLayout member_layout = shd_get_mem_layout(a, member_types.nodes[i]); + offset = round_up(offset, member_layout.alignment_in_bytes); + if (fields) { + fields[i].mem_layout = member_layout; + fields[i].offset_in_bytes = offset; + } + offset += member_layout.size_in_bytes; + if (member_layout.alignment_in_bytes > max_align) + max_align = member_layout.alignment_in_bytes; + } + + return (TypeMemLayout) { + .type = record_type, + .size_in_bytes = round_up(offset, max_align), + .alignment_in_bytes = max_align, + }; +} + +size_t shd_get_record_field_offset_in_bytes(IrArena* a, const Type* t, size_t i) { + assert(t->tag == RecordType_TAG); + Nodes member_types = t->payload.record_type.members; + assert(i < member_types.count); + LARRAY(FieldLayout, fields, member_types.count); + shd_get_record_layout(a, t, fields); + return fields[i].offset_in_bytes; +} + +TypeMemLayout shd_get_mem_layout(IrArena* a, const Type* type) { + size_t base_word_size = int_size_in_bytes(shd_get_arena_config(a)->memory.word_size); + assert(is_type(type)); + switch (type->tag) { + case FnType_TAG: shd_error("Functions have an opaque memory representation"); + case PtrType_TAG: switch (type->payload.ptr_type.address_space) { + case AsPrivate: + case AsSubgroup: + case AsShared: + case AsGlobal: + case AsGeneric: return shd_get_mem_layout(a, int_type(a, (Int) { .width = shd_get_arena_config(a)->memory.ptr_size, .is_signed = false })); // TODO: use per-as layout + default: shd_error("Pointers in address space '%s' does not have a defined memory layout", shd_get_address_space_name(type->payload.ptr_type.address_space)); + } + case Int_TAG: return (TypeMemLayout) { + .type = type, + .size_in_bytes = int_size_in_bytes(type->payload.int_type.width), + .alignment_in_bytes = maxof(int_size_in_bytes(type->payload.int_type.width), base_word_size), + }; + case Float_TAG: return (TypeMemLayout) { + .type = type, + .size_in_bytes = float_size_in_bytes(type->payload.float_type.width), + .alignment_in_bytes = maxof(float_size_in_bytes(type->payload.float_type.width), base_word_size), + }; + case Bool_TAG: return (TypeMemLayout) { + .type = type, + .size_in_bytes = base_word_size, + .alignment_in_bytes = base_word_size, + }; + case ArrType_TAG: { + const Node* size = type->payload.arr_type.size; + assert(size && "We can't know the full layout of arrays of unknown size !"); + size_t actual_size = shd_get_int_literal_value(*shd_resolve_to_int_literal(size), false); + TypeMemLayout element_layout = shd_get_mem_layout(a, type->payload.arr_type.element_type); + return (TypeMemLayout) { + .type = type, + .size_in_bytes = actual_size * element_layout.size_in_bytes, + .alignment_in_bytes = element_layout.alignment_in_bytes + }; + } + case PackType_TAG: { + size_t width = type->payload.pack_type.width; + TypeMemLayout element_layout = shd_get_mem_layout(a, type->payload.pack_type.element_type); + return (TypeMemLayout) { + .type = type, + .size_in_bytes = width * element_layout.size_in_bytes /* TODO Vulkan vec3 -> vec4 alignment rules ? */, + .alignment_in_bytes = element_layout.alignment_in_bytes + }; + } + case QualifiedType_TAG: return shd_get_mem_layout(a, type->payload.qualified_type.type); + case TypeDeclRef_TAG: return shd_get_mem_layout(a, type->payload.type_decl_ref.decl->payload.nom_type.body); + case RecordType_TAG: return shd_get_record_layout(a, type, NULL); + default: shd_error("not a known type"); + } +} + +const Node* shd_bytes_to_words(BodyBuilder* bb, const Node* bytes) { + IrArena* a = bytes->arena; + const Type* word_type = int_type(a, (Int) { .width = shd_get_arena_config(a)->memory.word_size, .is_signed = false }); + size_t word_width = shd_get_type_bitwidth(word_type); + const Node* bytes_per_word = size_t_literal(a, word_width / 8); + return prim_op_helper(a, div_op, shd_empty(a), mk_nodes(a, bytes, bytes_per_word)); +} + +uint64_t shd_bytes_to_words_static(const IrArena* a, uint64_t bytes) { + uint64_t word_width = int_size_in_bytes(shd_get_arena_config(a)->memory.word_size); + return bytes / word_width; +} + +IntSizes shd_float_to_int_width(FloatSizes width) { + switch (width) { + case FloatTy16: return IntTy16; + case FloatTy32: return IntTy32; + case FloatTy64: return IntTy64; + } +} + +size_t shd_get_type_bitwidth(const Type* t) { + const ArenaConfig* aconfig = shd_get_arena_config(t->arena); + switch (t->tag) { + case Int_TAG: return int_size_in_bytes(t->payload.int_type.width) * 8; + case Float_TAG: return float_size_in_bytes(t->payload.float_type.width) * 8; + case PtrType_TAG: { + if (aconfig->address_spaces[t->payload.ptr_type.address_space].physical) + return int_size_in_bytes(aconfig->memory.ptr_size) * 8; + break; + } + default: break; + } + return SIZE_MAX; +} + +const Node* _shd_lea_helper(IrArena* a, const Node* ptr, const Node* offset, Nodes indices) { + const Node* lea = ptr_array_element_offset(a, (PtrArrayElementOffset) { + .ptr = ptr, + .offset = offset, + }); + for (size_t i = 0; i < indices.count; i++) { + lea = ptr_composite_element(a, (PtrCompositeElement) { + .ptr = lea, + .index = indices.nodes[i], + }); + } + return lea; +} diff --git a/src/shady/ir/module.c b/src/shady/ir/module.c new file mode 100644 index 000000000..1ec2586db --- /dev/null +++ b/src/shady/ir/module.c @@ -0,0 +1,50 @@ +#include "../ir_private.h" + +#include "list.h" +#include "portability.h" + +#include + +Module* shd_new_module(IrArena* arena, String name) { + Module* m = shd_arena_alloc(arena->arena, sizeof(Module)); + *m = (Module) { + .arena = arena, + .name = shd_string(arena, name), + .decls = shd_new_list(Node*), + }; + shd_list_append(Module*, arena->modules, m); + return m; +} + +IrArena* shd_module_get_arena(const Module* m) { + return m->arena; +} + +String shd_module_get_name(const Module* m) { + return m->name; +} + +Nodes shd_module_get_declarations(const Module* m) { + size_t count = shd_list_count(m->decls); + const Node** start = shd_read_list(const Node*, m->decls); + return shd_nodes(shd_module_get_arena(m), count, start); +} + +void _shd_module_add_decl(Module* m, Node* node) { + assert(is_declaration(node)); + assert(!shd_module_get_declaration(m, get_declaration_name(node)) && "duplicate declaration"); + shd_list_append(Node*, m->decls, node); +} + +Node* shd_module_get_declaration(const Module* m, String name) { + Nodes existing_decls = shd_module_get_declarations(m); + for (size_t i = 0; i < existing_decls.count; i++) { + if (strcmp(get_declaration_name(existing_decls.nodes[i]), name) == 0) + return (Node*) existing_decls.nodes[i]; + } + return NULL; +} + +void shd_destroy_module(Module* m) { + shd_destroy_list(m->decls); +} diff --git a/src/shady/ir/stack.c b/src/shady/ir/stack.c new file mode 100644 index 000000000..1a9934f4a --- /dev/null +++ b/src/shady/ir/stack.c @@ -0,0 +1,31 @@ +#include "shady/ir/stack.h" + +#include "shady/ir/grammar.h" + +void shd_bld_stack_push_value(BodyBuilder* bb, const Node* value) { + shd_bld_add_instruction_extract(bb, push_stack(shd_get_bb_arena(bb), (PushStack) { .value = value, .mem = shd_bb_mem(bb) })); +} + +void shd_bld_stack_push_values(BodyBuilder* bb, Nodes values) { + for (size_t i = values.count - 1; i < values.count; i--) { + const Node* value = values.nodes[i]; + shd_bld_stack_push_value(bb, value); + } +} + +const Node* shd_bld_stack_pop_value(BodyBuilder* bb, const Type* type) { + const Node* instruction = pop_stack(shd_get_bb_arena(bb), (PopStack) { .type = type, .mem = shd_bb_mem(bb) }); + return shd_first(shd_bld_add_instruction_extract(bb, instruction)); +} + +const Node* shd_bld_get_stack_base_addr(BodyBuilder* bb) { + return get_stack_base_addr(shd_get_bb_arena(bb), (GetStackBaseAddr) { .mem = shd_bb_mem(bb) }); +} + +const Node* shd_bld_get_stack_size(BodyBuilder* bb) { + return shd_first(shd_bld_add_instruction_extract(bb, get_stack_size(shd_get_bb_arena(bb), (GetStackSize) { .mem = shd_bb_mem(bb) }))); +} + +void shd_bld_set_stack_size(BodyBuilder* bb, const Node* new_size) { + shd_bld_add_instruction_extract(bb, set_stack_size(shd_get_bb_arena(bb), (SetStackSize) { .value = new_size, .mem = shd_bb_mem(bb) })); +} diff --git a/src/shady/ir/type.c b/src/shady/ir/type.c new file mode 100644 index 000000000..1d5d4c3cd --- /dev/null +++ b/src/shady/ir/type.c @@ -0,0 +1,559 @@ +#include "shady/ir/type.h" +#include "shady/ir/function.h" +#include "shady/ir/memory_layout.h" +#include "shady/ir/module.h" + +#include "../ir_private.h" + +#include "log.h" +#include "portability.h" +#include "util.h" + +#include + +#pragma GCC diagnostic error "-Wswitch" + +static bool are_types_identical(size_t num_types, const Type* types[]) { + for (size_t i = 0; i < num_types; i++) { + assert(types[i]); + if (types[0] != types[i]) + return false; + } + return true; +} + +bool shd_is_subtype(const Type* supertype, const Type* type) { + assert(supertype && type); + if (supertype->tag != type->tag) + return false; + if (type == supertype) + return true; + switch (is_type(supertype)) { + case NotAType: shd_error("supplied not a type to is_subtype"); + case QualifiedType_TAG: { + // uniform T <: varying T + if (supertype->payload.qualified_type.is_uniform && !type->payload.qualified_type.is_uniform) + return false; + return shd_is_subtype(supertype->payload.qualified_type.type, type->payload.qualified_type.type); + } + case RecordType_TAG: { + const Nodes* supermembers = &supertype->payload.record_type.members; + const Nodes* members = &type->payload.record_type.members; + if (supermembers->count != members->count) + return false; + for (size_t i = 0; i < members->count; i++) { + if (!shd_is_subtype(supermembers->nodes[i], members->nodes[i])) + return false; + } + return supertype->payload.record_type.special == type->payload.record_type.special; + } + case JoinPointType_TAG: { + const Nodes* superparams = &supertype->payload.join_point_type.yield_types; + const Nodes* params = &type->payload.join_point_type.yield_types; + if (params->count != superparams->count) return false; + for (size_t i = 0; i < params->count; i++) { + if (!shd_is_subtype(params->nodes[i], superparams->nodes[i])) + return false; + } + return true; + } + case FnType_TAG: { + // check returns + if (supertype->payload.fn_type.return_types.count != type->payload.fn_type.return_types.count) + return false; + for (size_t i = 0; i < type->payload.fn_type.return_types.count; i++) + if (!shd_is_subtype(supertype->payload.fn_type.return_types.nodes[i], type->payload.fn_type.return_types.nodes[i])) + return false; + // check params + const Nodes* superparams = &supertype->payload.fn_type.param_types; + const Nodes* params = &type->payload.fn_type.param_types; + if (params->count != superparams->count) return false; + for (size_t i = 0; i < params->count; i++) { + if (!shd_is_subtype(params->nodes[i], superparams->nodes[i])) + return false; + } + return true; + } case BBType_TAG: { + // check params + const Nodes* superparams = &supertype->payload.bb_type.param_types; + const Nodes* params = &type->payload.bb_type.param_types; + if (params->count != superparams->count) return false; + for (size_t i = 0; i < params->count; i++) { + if (!shd_is_subtype(params->nodes[i], superparams->nodes[i])) + return false; + } + return true; + } case LamType_TAG: { + // check params + const Nodes* superparams = &supertype->payload.lam_type.param_types; + const Nodes* params = &type->payload.lam_type.param_types; + if (params->count != superparams->count) return false; + for (size_t i = 0; i < params->count; i++) { + if (!shd_is_subtype(params->nodes[i], superparams->nodes[i])) + return false; + } + return true; + } case PtrType_TAG: { + if (supertype->payload.ptr_type.address_space != type->payload.ptr_type.address_space) + return false; + if (!supertype->payload.ptr_type.is_reference && type->payload.ptr_type.is_reference) + return false; + return shd_is_subtype(supertype->payload.ptr_type.pointed_type, type->payload.ptr_type.pointed_type); + } + case Int_TAG: return supertype->payload.int_type.width == type->payload.int_type.width && supertype->payload.int_type.is_signed == type->payload.int_type.is_signed; + case ArrType_TAG: { + if (!shd_is_subtype(supertype->payload.arr_type.element_type, type->payload.arr_type.element_type)) + return false; + // unsized arrays are supertypes of sized arrays (even though they're not datatypes...) + // TODO: maybe change this so it's only valid when talking about to pointer-to-arrays + const IntLiteral* size_literal = shd_resolve_to_int_literal(supertype->payload.arr_type.size); + if (size_literal && size_literal->value == 0) + return true; + return supertype->payload.arr_type.size == type->payload.arr_type.size || !supertype->payload.arr_type.size; + } + case PackType_TAG: { + if (!shd_is_subtype(supertype->payload.pack_type.element_type, type->payload.pack_type.element_type)) + return false; + return supertype->payload.pack_type.width == type->payload.pack_type.width; + } + case Type_TypeDeclRef_TAG: { + return supertype->payload.type_decl_ref.decl == type->payload.type_decl_ref.decl; + } + case Type_ImageType_TAG: { + if (!shd_is_subtype(supertype->payload.image_type.sampled_type, type->payload.image_type.sampled_type)) + return false; + if (supertype->payload.image_type.depth != type->payload.image_type.depth) + return false; + if (supertype->payload.image_type.dim != type->payload.image_type.dim) + return false; + if (supertype->payload.image_type.arrayed != type->payload.image_type.arrayed) + return false; + if (supertype->payload.image_type.ms != type->payload.image_type.ms) + return false; + if (supertype->payload.image_type.sampled != type->payload.image_type.sampled) + return false; + if (supertype->payload.image_type.imageformat != type->payload.image_type.imageformat) + return false; + return true; + } + case Type_SampledImageType_TAG: + return shd_is_subtype(supertype->payload.sampled_image_type.image_type, type->payload.sampled_image_type.image_type); + default: break; + } + // Two types are always equal (and therefore subtypes of each other) if their payload matches + return memcmp(&supertype->payload, &type->payload, sizeof(type->payload)) == 0; +} + +void shd_check_subtype(const Type* supertype, const Type* type) { + if (!shd_is_subtype(supertype, type)) { + shd_log_node(ERROR, type); + shd_error_print(" isn't a subtype of "); + shd_log_node(ERROR, supertype); + shd_error_print("\n"); + shd_error("failed check_subtype") + } +} + +/// Is this a type that a value in the language can have ? +bool shd_is_value_type(const Type* type) { + if (type->tag != QualifiedType_TAG) + return false; + return shd_is_data_type(shd_get_unqualified_type(type)); +} + +/// Is this a valid data type (for usage in other types and as type arguments) ? +bool shd_is_data_type(const Type* type) { + switch (is_type(type)) { + case Type_MaskType_TAG: + case Type_JoinPointType_TAG: + case Type_Int_TAG: + case Type_Float_TAG: + case Type_Bool_TAG: + return true; + case Type_PtrType_TAG: + return true; + case Type_ArrType_TAG: + // array types _must_ be sized to be real data types + return type->payload.arr_type.size != NULL; + case Type_PackType_TAG: + return shd_is_data_type(type->payload.pack_type.element_type); + case Type_RecordType_TAG: { + for (size_t i = 0; i < type->payload.record_type.members.count; i++) + if (!shd_is_data_type(type->payload.record_type.members.nodes[i])) + return false; + // multi-return record types are the results of instructions, but are not values themselves + return type->payload.record_type.special == NotSpecial; + } + case Type_TypeDeclRef_TAG: + return !shd_get_nominal_type_body(type) || shd_is_data_type(shd_get_nominal_type_body(type)); + // qualified types are not data types because that information is only meant for values + case Type_QualifiedType_TAG: return false; + // values cannot contain abstractions + case Type_FnType_TAG: + case Type_BBType_TAG: + case Type_LamType_TAG: + return false; + // this type has no values to begin with + case Type_NoRet_TAG: + return false; + case NotAType: + return false; + // Image stuff is data (albeit opaque) + case Type_SampledImageType_TAG: + case Type_SamplerType_TAG: + case Type_ImageType_TAG: + return true; + } +} + +bool shd_is_arithm_type(const Type* t) { + return t->tag == Int_TAG || t->tag == Float_TAG; +} + +bool shd_is_shiftable_type(const Type* t) { + return t->tag == Int_TAG || t->tag == MaskType_TAG; +} + +bool shd_has_boolean_ops(const Type* t) { + return t->tag == Int_TAG || t->tag == Bool_TAG || t->tag == MaskType_TAG; +} + +bool shd_is_comparable_type(const Type* t) { + return true; // TODO this is fine to allow, but we'll need to lower it for composite and native ptr types ! +} + +bool shd_is_ordered_type(const Type* t) { + return shd_is_arithm_type(t); +} + +bool shd_is_physical_ptr_type(const Type* t) { + if (t->tag != PtrType_TAG) + return false; + return !t->payload.ptr_type.is_reference; + // AddressSpace as = t->payload.ptr_type.address_space; + // return t->shd_get_arena_config(arena)->address_spaces[as].physical; +} + +bool shd_is_generic_ptr_type(const Type* t) { + if (t->tag != PtrType_TAG) + return false; + AddressSpace as = t->payload.ptr_type.address_space; + return as == AsGeneric; +} + +bool shd_is_addr_space_uniform(IrArena* arena, AddressSpace as) { + switch (as) { + case AsGeneric: + case AsInput: + case AsOutput: + case AsFunction: + case AsPrivate: return !shd_get_arena_config(arena)->is_simt; + default: return true; + } +} + +const Type* shd_get_actual_mask_type(IrArena* arena) { + switch (shd_get_arena_config(arena)->specializations.subgroup_mask_representation) { + case SubgroupMaskAbstract: return mask_type(arena); + case SubgroupMaskInt64: return shd_uint64_type(arena); + default: assert(false); + } +} + +String shd_get_type_name(IrArena* arena, const Type* t) { + switch (is_type(t)) { + case NotAType: assert(false); + case Type_MaskType_TAG: return "mask_t"; + case Type_JoinPointType_TAG: return "join_type_t"; + case Type_NoRet_TAG: return "no_ret"; + case Type_Int_TAG: { + if (t->payload.int_type.is_signed) + return shd_fmt_string_irarena(arena, "i%s", ((String[]) { "8", "16", "32", "64" })[t->payload.int_type.width]); + else + return shd_fmt_string_irarena(arena, "u%s", ((String[]) { "8", "16", "32", "64" })[t->payload.int_type.width]); + } + case Type_Float_TAG: return shd_fmt_string_irarena(arena, "f%s", ((String[]) { "16", "32", "64" })[t->payload.float_type.width]); + case Type_Bool_TAG: return "bool"; + case Type_TypeDeclRef_TAG: return t->payload.type_decl_ref.decl->payload.nom_type.name; + default: break; + } + return shd_make_unique_name(arena, shd_get_node_tag_string(t->tag)); +} + +const Type* shd_maybe_multiple_return(IrArena* arena, Nodes types) { + switch (types.count) { + case 0: return empty_multiple_return_type(arena); + case 1: return types.nodes[0]; + default: return record_type(arena, (RecordType) { + .members = types, + .names = shd_strings(arena, 0, NULL), + .special = MultipleReturn, + }); + } + SHADY_UNREACHABLE; +} + +Nodes shd_unwrap_multiple_yield_types(IrArena* arena, const Type* type) { + switch (type->tag) { + case RecordType_TAG: + if (type->payload.record_type.special == MultipleReturn) + return type->payload.record_type.members; + // fallthrough + default: + assert(shd_is_value_type(type)); + return shd_singleton(type); + } +} + +const Type* shd_get_pointee_type(IrArena* arena, const Type* type) { + bool qualified = false, uniform = false; + if (shd_is_value_type(type)) { + qualified = true; + uniform = shd_is_qualified_type_uniform(type); + type = shd_get_unqualified_type(type); + } + assert(type->tag == PtrType_TAG); + uniform &= shd_is_addr_space_uniform(arena, type->payload.ptr_type.address_space); + type = type->payload.ptr_type.pointed_type; + + if (qualified) + type = qualified_type(arena, (QualifiedType) { + .type = type, + .is_uniform = uniform + }); + return type; +} + +Nodes shd_get_param_types(IrArena* arena, Nodes variables) { + LARRAY(const Type*, arr, variables.count); + for (size_t i = 0; i < variables.count; i++) { + assert(variables.nodes[i]->tag == Param_TAG); + arr[i] = variables.nodes[i]->payload.param.type; + } + return shd_nodes(arena, variables.count, arr); +} + +Nodes shd_get_values_types(IrArena* arena, Nodes values) { + assert(shd_get_arena_config(arena)->check_types); + LARRAY(const Type*, arr, values.count); + for (size_t i = 0; i < values.count; i++) + arr[i] = values.nodes[i]->type; + return shd_nodes(arena, values.count, arr); +} + +bool shd_is_qualified_type_uniform(const Type* type) { + const Type* result_type = type; + bool is_uniform = shd_deconstruct_qualified_type(&result_type); + return is_uniform; +} + +const Type* shd_get_unqualified_type(const Type* type) { + assert(is_type(type)); + const Type* result_type = type; + shd_deconstruct_qualified_type(&result_type); + return result_type; +} + +bool shd_deconstruct_qualified_type(const Type** type_out) { + const Type* type = *type_out; + if (type->tag == QualifiedType_TAG) { + *type_out = type->payload.qualified_type.type; + return type->payload.qualified_type.is_uniform; + } else shd_error("Expected a value type (annotated with qual_type)") +} + +const Type* shd_as_qualified_type(const Type* type, bool uniform) { + return qualified_type(type->arena, (QualifiedType) { .type = type, .is_uniform = uniform }); +} + +Nodes shd_strip_qualifiers(IrArena* arena, Nodes tys) { + LARRAY(const Type*, arr, tys.count); + for (size_t i = 0; i < tys.count; i++) + arr[i] = shd_get_unqualified_type(tys.nodes[i]); + return shd_nodes(arena, tys.count, arr); +} + +Nodes shd_add_qualifiers(IrArena* arena, Nodes tys, bool uniform) { + LARRAY(const Type*, arr, tys.count); + for (size_t i = 0; i < tys.count; i++) + arr[i] = shd_as_qualified_type(tys.nodes[i], + uniform || !shd_get_arena_config(arena)->is_simt /* SIMD arenas ban varying value types */); + return shd_nodes(arena, tys.count, arr); +} + +const Type* shd_get_packed_type_element(const Type* type) { + const Type* t = type; + shd_deconstruct_packed_type(&t); + return t; +} + +size_t shd_get_packed_type_width(const Type* type) { + const Type* t = type; + return shd_deconstruct_packed_type(&t); +} + +size_t shd_deconstruct_packed_type(const Type** type) { + assert((*type)->tag == PackType_TAG); + return shd_deconstruct_maybe_packed_type(type); +} + +const Type* shd_get_maybe_packed_type_element(const Type* type) { + const Type* t = type; + shd_deconstruct_maybe_packed_type(&t); + return t; +} + +size_t shd_get_maybe_packed_type_width(const Type* type) { + const Type* t = type; + return shd_deconstruct_maybe_packed_type(&t); +} + +size_t shd_deconstruct_maybe_packed_type(const Type** type) { + const Type* t = *type; + assert(shd_is_data_type(t)); + if (t->tag == PackType_TAG) { + *type = t->payload.pack_type.element_type; + return t->payload.pack_type.width; + } + return 1; +} + +const Type* shd_maybe_packed_type_helper(const Type* type, size_t width) { + assert(width > 0); + if (width == 1) + return type; + return pack_type(type->arena, (PackType) { + .width = width, + .element_type = type, + }); +} + +const Type* shd_get_pointer_type_element(const Type* type) { + const Type* t = type; + shd_deconstruct_pointer_type(&t); + return t; +} + +AddressSpace shd_deconstruct_pointer_type(const Type** type) { + const Type* t = *type; + assert(t->tag == PtrType_TAG); + *type = t->payload.ptr_type.pointed_type; + return t->payload.ptr_type.address_space; +} + +const Node* shd_get_nominal_type_decl(const Type* type) { + assert(type->tag == TypeDeclRef_TAG); + return shd_get_maybe_nominal_type_decl(type); +} + +const Type* shd_get_nominal_type_body(const Type* type) { + assert(type->tag == TypeDeclRef_TAG); + return shd_get_maybe_nominal_type_body(type); +} + +const Node* shd_get_maybe_nominal_type_decl(const Type* type) { + if (type->tag == TypeDeclRef_TAG) { + const Node* decl = type->payload.type_decl_ref.decl; + assert(decl->tag == NominalType_TAG); + return decl; + } + return NULL; +} + +const Type* shd_get_maybe_nominal_type_body(const Type* type) { + const Node* decl = shd_get_maybe_nominal_type_decl(type); + if (decl) + return decl->payload.nom_type.body; + return type; +} + +Nodes shd_get_composite_type_element_types(const Type* type) { + switch (is_type(type)) { + case Type_TypeDeclRef_TAG: { + type = shd_get_nominal_type_body(type); + assert(type->tag == RecordType_TAG); + SHADY_FALLTHROUGH + } + case RecordType_TAG: { + return type->payload.record_type.members; + } + case Type_ArrType_TAG: + case Type_PackType_TAG: { + size_t size = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_fill_type_size(type)), false); + if (size >= 1024) { + shd_warn_print("Potential performance issue: creating a really big array of composites of types (size=%d)!\n", size); + } + const Type* element_type = shd_get_fill_type_element_type(type); + LARRAY(const Type*, types, size); + for (size_t i = 0; i < size; i++) { + types[i] = element_type; + } + return shd_nodes(type->arena, size, types); + } + default: shd_error("Not a composite type !") + } +} + +const Node* shd_get_fill_type_element_type(const Type* composite_t) { + switch (composite_t->tag) { + case ArrType_TAG: return composite_t->payload.arr_type.element_type; + case PackType_TAG: return composite_t->payload.pack_type.element_type; + default: shd_error("fill values need to be either array or pack types") + } +} + +const Node* shd_get_fill_type_size(const Type* composite_t) { + switch (composite_t->tag) { + case ArrType_TAG: return composite_t->payload.arr_type.size; + case PackType_TAG: return shd_int32_literal(composite_t->arena, composite_t->payload.pack_type.width); + default: shd_error("fill values need to be either array or pack types") + } +} + +Type* _shd_nominal_type(Module* mod, Nodes annotations, String name) { + IrArena* arena = shd_module_get_arena(mod); + NominalType payload = { + .name = shd_string(arena, name), + .module = mod, + .annotations = annotations, + .body = NULL, + }; + + Node node; + memset((void*) &node, 0, sizeof(Node)); + node = (Node) { + .arena = arena, + .type = NULL, + .tag = NominalType_TAG, + .payload.nom_type = payload + }; + Node* decl = _shd_create_node_helper(arena, node, NULL); + _shd_module_add_decl(mod, decl); + return decl; +} + +const Node* shd_get_default_value(IrArena* a, const Type* t) { + switch (is_type(t)) { + case NotAType: shd_error("") + case Type_Int_TAG: return int_literal(a, (IntLiteral) { .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed, .value = 0 }); + case Type_Float_TAG: return float_literal(a, (FloatLiteral) { .width = t->payload.float_type.width, .value = 0 }); + case Type_Bool_TAG: return false_lit(a); + case Type_PtrType_TAG: return null_ptr(a, (NullPtr) { .ptr_type = t }); + case Type_QualifiedType_TAG: return shd_get_default_value(a, t->payload.qualified_type.type); + case Type_RecordType_TAG: + case Type_ArrType_TAG: + case Type_PackType_TAG: + case Type_TypeDeclRef_TAG: { + Nodes elem_tys = shd_get_composite_type_element_types(t); + if (elem_tys.count >= 1024) { + shd_warn_print("Potential performance issue: creating a really composite full of zero/default values (size=%d)!\n", elem_tys.count); + } + LARRAY(const Node*, elems, elem_tys.count); + for (size_t i = 0; i < elem_tys.count; i++) + elems[i] = shd_get_default_value(a, elem_tys.nodes[i]); + return composite_helper(a, t, shd_nodes(a, elem_tys.count, elems)); + } + default: break; + } + return NULL; +} diff --git a/src/shady/ir_private.h b/src/shady/ir_private.h index f2163f017..adb1ee238 100644 --- a/src/shady/ir_private.h +++ b/src/shady/ir_private.h @@ -2,17 +2,19 @@ #define SHADY_IR_PRIVATE_H #include "shady/ir.h" +#include "shady/config.h" #include "arena.h" +#include "growy.h" #include "stdlib.h" #include "stdio.h" -typedef struct IrArena_ { +struct IrArena_ { Arena* arena; ArenaConfig config; - VarId next_free_id; + Growy* ids; struct List* modules; struct Dict* node_set; @@ -20,7 +22,7 @@ typedef struct IrArena_ { struct Dict* nodes_set; struct Dict* strings_set; -} IrArena_; +}; struct Module_ { IrArena* arena; @@ -29,17 +31,16 @@ struct Module_ { bool sealed; }; -void register_decl_module(Module*, Node*); -void destroy_module(Module* m); +void _shd_module_add_decl(Module* m, Node* node); +void shd_destroy_module(Module* m); -struct BodyBuilder_ { - IrArena* arena; - struct List* stack; -}; +NodeId _shd_allocate_node_id(IrArena* arena, const Node* n); -VarId fresh_id(IrArena*); +const Node* _shd_bb_insert_mem(BodyBuilder* bb); +const Node* _shd_bb_insert_block(BodyBuilder* bb); +const Node* _shd_bld_finish_pseudo_instr(BodyBuilder* bb, const Node* terminator); struct List; -Nodes list_to_nodes(IrArena*, struct List*); +Nodes shd_list_to_nodes(IrArena* arena, struct List* list); #endif diff --git a/src/shady/module.c b/src/shady/module.c deleted file mode 100644 index f5f9d74ce..000000000 --- a/src/shady/module.c +++ /dev/null @@ -1,50 +0,0 @@ -#include "ir_private.h" - -#include "list.h" -#include "portability.h" - -#include - -Module* new_module(IrArena* arena, String name) { - Module* m = arena_alloc(arena->arena, sizeof(Module)); - *m = (Module) { - .arena = arena, - .name = string(arena, name), - .decls = new_list(Node*), - }; - append_list(Module*, arena->modules, m); - return m; -} - -IrArena* get_module_arena(const Module* m) { - return m->arena; -} - -String get_module_name(const Module* m) { - return m->name; -} - -Nodes get_module_declarations(const Module* m) { - size_t count = entries_count_list(m->decls); - const Node** start = read_list(const Node*, m->decls); - return nodes(get_module_arena(m), count, start); -} - -void register_decl_module(Module* m, Node* node) { - assert(is_declaration(node)); - assert(!get_declaration(m, get_decl_name(node)) && "duplicate declaration"); - append_list(Node*, m->decls, node); -} - -const Node* get_declaration(const Module* m, String name) { - Nodes existing_decls = get_module_declarations(m); - for (size_t i = 0; i < existing_decls.count; i++) { - if (strcmp(get_decl_name(existing_decls.nodes[i]), name) == 0) - return existing_decls.nodes[i]; - } - return NULL; -} - -void destroy_module(Module* m) { - destroy_list(m->decls); -} diff --git a/src/shady/node.c b/src/shady/node.c index 2ac77cf77..87a2ff768 100644 --- a/src/shady/node.c +++ b/src/shady/node.c @@ -1,311 +1,24 @@ -#include "type.h" -#include "log.h" #include "ir_private.h" -#include "portability.h" +#include "log.h" +#include "portability.h" #include "dict.h" -#include -#include - -String get_decl_name(const Node* node) { - switch (node->tag) { - case Constant_TAG: return node->payload.constant.name; - case Function_TAG: return node->payload.fun.name; - case GlobalVariable_TAG: return node->payload.global_variable.name; - case NominalType_TAG: return node->payload.nom_type.name; - default: error("Not a decl !"); - } -} - -String get_value_name(const Node* v) { - assert(v && is_value(v)); - if (v->tag == Variable_TAG) - return v->payload.var.name; - return NULL; -} - -String get_value_name_safe(const Node* v) { - String name = get_value_name(v); - if (name) - return name; - if (v->tag == Variable_TAG) - return format_string_interned(v->arena, "v%d", v->payload.var.id); - return node_tags[v->tag]; -} - -void set_variable_name(Node* var, String name) { - assert(var->tag == Variable_TAG); - var->payload.var.name = string(var->arena, name); -} - -int64_t get_int_literal_value(IntLiteral literal, bool sign_extend) { - if (sign_extend) { - switch (literal.width) { - case IntTy8: return (int64_t) (int8_t) (literal.value & 0xFF); - case IntTy16: return (int64_t) (int16_t) (literal.value & 0xFFFF); - case IntTy32: return (int64_t) (int32_t) (literal.value & 0xFFFFFFFF); - case IntTy64: return (int64_t) literal.value; - default: assert(false); - } - } else { - switch (literal.width) { - case IntTy8: return literal.value & 0xFF; - case IntTy16: return literal.value & 0xFFFF; - case IntTy32: return literal.value & 0xFFFFFFFF; - case IntTy64: return literal.value; - default: assert(false); - } - } -} - -static_assert(sizeof(float) == sizeof(uint64_t) / 2, "floats aren't the size we expect"); -double get_float_literal_value(FloatLiteral literal) { - double r; - switch (literal.width) { - case FloatTy16: - error_print("TODO: fp16 literals"); - error_die(); - SHADY_UNREACHABLE; - break; - case FloatTy32: { - float f; - memcpy(&f, &literal.value, sizeof(float)); - r = (double) f; - break; - } - case FloatTy64: - memcpy(&r, &literal.value, sizeof(double)); - break; - } - return r; -} - -const Node* get_quoted_value(const Node* instruction) { - if (instruction->payload.prim_op.op == quote_op) - return first(instruction->payload.prim_op.operands); - return NULL; -} +const char* node_tags[]; -const Node* resolve_ptr_to_value(const Node* ptr, NodeResolveConfig config) { - while (ptr) { - ptr = resolve_node_to_definition(ptr, config); - switch (ptr->tag) { - case PrimOp_TAG: { - switch (ptr->payload.prim_op.op) { - case convert_op: { // allow address space conversions - ptr = first(ptr->payload.prim_op.operands); - continue; - } - default: break; - } - } - case GlobalVariable_TAG: - if (config.assume_globals_immutability) - return ptr->payload.global_variable.init; - break; - default: break; - } - ptr = NULL; - } - return NULL; -} - -NodeResolveConfig default_node_resolve_config() { - return (NodeResolveConfig) { - .enter_loads = true, - .allow_incompatible_types = false, - .assume_globals_immutability = false, - }; +const char* shd_get_node_tag_string(NodeTag tag) { + return node_tags[tag]; } -const Node* resolve_node_to_definition(const Node* node, NodeResolveConfig config) { - while (node) { - switch (node->tag) { - case Constant_TAG: - node = node->payload.constant.instruction; - continue; - case RefDecl_TAG: - node = node->payload.ref_decl.decl; - continue; - case Variable_TAG: { - if (node->payload.var.pindex != 0) - break; - const Node* abs = node->payload.var.abs; - if (!abs || abs->tag != Case_TAG) - break; - const Node* user = abs->payload.case_.structured_construct; - if (user->tag != Let_TAG) - break; - node = user->payload.let.instruction; - continue; - } - case PrimOp_TAG: { - switch (node->payload.prim_op.op) { - case quote_op: { - node = first(node->payload.prim_op.operands);; - continue; - } - case load_op: { - if (config.enter_loads) { - const Node* source = first(node->payload.prim_op.operands); - const Node* result = resolve_ptr_to_value(source, config); - if (!result) - break; - node = result; - continue; - } - } - case reinterpret_op: { - if (config.allow_incompatible_types) { - node = first(node->payload.prim_op.operands); - continue; - } - } - default: break; - } - break; - } - default: break; - } - break; - } - return node; -} - -const IntLiteral* resolve_to_int_literal(const Node* node) { - node = resolve_node_to_definition(node, default_node_resolve_config()); - if (!node) - return NULL; - if (node->tag == IntLiteral_TAG) - return &node->payload.int_literal; - return NULL; -} - -const FloatLiteral* resolve_to_float_literal(const Node* node) { - node = resolve_node_to_definition(node, default_node_resolve_config()); - if (!node) - return NULL; - if (node->tag == FloatLiteral_TAG) - return &node->payload.float_literal; - return NULL; -} - -static bool is_zero(const Node* node) { - const IntLiteral* lit = resolve_to_int_literal(node); - if (lit && get_int_literal_value(*lit, false) == 0) - return true; - return false; -} - -const char* get_string_literal(IrArena* arena, const Node* node) { - if (!node) - return NULL; - switch (node->tag) { - case PrimOp_TAG: { - switch (node->payload.prim_op.op) { - case lea_op: { - Nodes ops = node->payload.prim_op.operands; - if (ops.count == 3 && is_zero(ops.nodes[1]) && is_zero(ops.nodes[2])) { - const Node* ref = first(ops); - if (ref->tag != RefDecl_TAG) - return NULL; - const Node* decl = ref->payload.ref_decl.decl; - if (decl->tag != GlobalVariable_TAG || !decl->payload.global_variable.init) - return NULL; - return get_string_literal(arena, decl->payload.global_variable.init); - } - break; - } - default: break; - } - return NULL; - } - case StringLiteral_TAG: return node->payload.string_lit.string; - case Composite_TAG: { - Nodes contents = node->payload.composite.contents; - LARRAY(char, chars, contents.count); - for (size_t i = 0; i < contents.count; i++) { - const Node* value = contents.nodes[i]; - assert(value->tag == IntLiteral_TAG && value->payload.int_literal.width == IntTy8); - chars[i] = (unsigned char) get_int_literal_value(*resolve_to_int_literal(value), false); - } - assert(chars[contents.count - 1] == 0); - return string(arena, chars); - } - default: return NULL; // error("This is not a string literal and it doesn't look like one either"); - } -} - -bool is_abstraction(const Node* node) { - NodeTag tag = node->tag; - return tag == Function_TAG || tag == BasicBlock_TAG || tag == Case_TAG; -} - -String get_abstraction_name(const Node* abs) { - assert(is_abstraction(abs)); - switch (abs->tag) { - case Function_TAG: return abs->payload.fun.name; - case BasicBlock_TAG: return abs->payload.basic_block.name; - case Case_TAG: return "case"; - default: assert(false); - } -} - -const Node* get_abstraction_body(const Node* abs) { - assert(is_abstraction(abs)); - switch (abs->tag) { - case Function_TAG: return abs->payload.fun.body; - case BasicBlock_TAG: return abs->payload.basic_block.body; - case Case_TAG: return abs->payload.case_.body; - default: assert(false); - } -} - -void set_abstraction_body(Node* abs, const Node* body) { - assert(is_abstraction(abs)); - assert(!body || is_terminator(body)); - switch (abs->tag) { - case Function_TAG: abs->payload.fun.body = body; break; - case BasicBlock_TAG: abs->payload.basic_block.body = body; break; - case Case_TAG: abs->payload.case_.body = body; break; - default: assert(false); - } -} - -Nodes get_abstraction_params(const Node* abs) { - assert(is_abstraction(abs)); - switch (abs->tag) { - case Function_TAG: return abs->payload.fun.params; - case BasicBlock_TAG: return abs->payload.basic_block.params; - case Case_TAG: return abs->payload.case_.params; - default: assert(false); - } -} - -const Node* get_let_instruction(const Node* let) { - switch (let->tag) { - case Let_TAG: return let->payload.let.instruction; - case LetMut_TAG: return let->payload.let_mut.instruction; - default: assert(false); - } -} - -const Node* get_let_tail(const Node* let) { - switch (let->tag) { - case Let_TAG: return let->payload.let.tail; - case LetMut_TAG: return let->payload.let_mut.tail; - default: assert(false); - } -} +const bool node_type_has_payload[]; -KeyHash hash_node_payload(const Node* node); +KeyHash _shd_hash_node_payload(const Node* node); -KeyHash hash_node(Node** pnode) { +KeyHash shd_hash_node(Node** pnode) { const Node* node = *pnode; KeyHash combined; - if (is_nominal(node)) { + if (shd_is_node_nominal(node)) { size_t ptr = (size_t) node; uint32_t upper = ptr >> 32; uint32_t lower = ptr; @@ -313,11 +26,11 @@ KeyHash hash_node(Node** pnode) { goto end; } - KeyHash tag_hash = hash_murmur(&node->tag, sizeof(NodeTag)); + KeyHash tag_hash = shd_hash(&node->tag, sizeof(NodeTag)); KeyHash payload_hash = 0; if (node_type_has_payload[node->tag]) { - payload_hash = hash_node_payload(node); + payload_hash = _shd_hash_node_payload(node); } combined = tag_hash ^ payload_hash; @@ -325,11 +38,11 @@ KeyHash hash_node(Node** pnode) { return combined; } -bool compare_node_payload(const Node*, const Node*); +bool _shd_compare_node_payload(const Node*, const Node*); -bool compare_node(Node** pa, Node** pb) { +bool shd_compare_node(Node** pa, Node** pb) { if ((*pa)->tag != (*pb)->tag) return false; - if (is_nominal((*pa))) + if (shd_is_node_nominal((*pa))) return *pa == *pb; const Node* a = *pa; @@ -339,7 +52,7 @@ bool compare_node(Node** pa, Node** pb) { #define field(w) eq &= memcmp(&a->payload.w, &b->payload.w, sizeof(a->payload.w)) == 0; if (node_type_has_payload[a->tag]) { - return compare_node_payload(a, b); + return _shd_compare_node_payload(a, b); } else return true; } diff --git a/src/shady/passes/CMakeLists.txt b/src/shady/passes/CMakeLists.txt new file mode 100644 index 000000000..34cb94de5 --- /dev/null +++ b/src/shady/passes/CMakeLists.txt @@ -0,0 +1,45 @@ +target_sources(shady PRIVATE + import.c + cleanup.c + lower_cf_instrs.c + lift_indirect_targets.c + lower_callf.c + lower_alloca.c + lower_stack.c + lower_lea.c + lower_physical_ptrs.c + lower_generic_ptrs.c + lower_memory_layout.c + lower_memcpy.c + lower_decay_ptrs.c + lower_tailcalls.c + lower_mask.c + lower_fill.c + lower_nullptr.c + lower_switch_btree.c + setup_stack_frames.c + eliminate_constants.c + normalize_builtins.c + lower_subgroup_ops.c + lower_subgroup_vars.c + lower_int64.c + lower_vec_arr.c + lower_workgroups.c + lower_generic_globals.c + mark_leaf_functions.c + opt_inline.c + restructure.c + opt_demote_alloca.c + opt_mem2reg.c + specialize_entry_point.c + specialize_execution_model.c + lower_logical_pointers.c + lower_entrypoint_args.c + scope2control.c + lift_everything.c + scope_heuristic.c + reconvergence_heuristics.c + lcssa.c + remove_critical_edges.c + lower_inclusive_scan.c +) diff --git a/src/shady/passes/bind.c b/src/shady/passes/bind.c deleted file mode 100644 index 04809176e..000000000 --- a/src/shady/passes/bind.c +++ /dev/null @@ -1,336 +0,0 @@ -#include "passes.h" - -#include "list.h" -#include "log.h" -#include "portability.h" - -#include "../ir_private.h" -#include "../rewrite.h" - -#include -#include - -typedef struct NamedBindEntry_ NamedBindEntry; -struct NamedBindEntry_ { - const char* name; - bool is_var; - Node* node; - NamedBindEntry* next; -}; - -typedef struct { - Rewriter rewriter; - - const Node* current_function; - NamedBindEntry* local_variables; -} Context; - -typedef struct { - bool is_var; - const Node* node; -} Resolved; - -static Resolved resolve_using_name(Context* ctx, const char* name) { - for (NamedBindEntry* entry = ctx->local_variables; entry != NULL; entry = entry->next) { - if (strcmp(entry->name, name) == 0) { - return (Resolved) { - .is_var = entry->is_var, - .node = entry->node - }; - } - } - - Nodes new_decls = get_module_declarations(ctx->rewriter.dst_module); - for (size_t i = 0; i < new_decls.count; i++) { - const Node* decl = new_decls.nodes[i]; - if (strcmp(get_decl_name(decl), name) == 0) { - return (Resolved) { - .is_var = decl->tag == GlobalVariable_TAG, - .node = decl - }; - } - } - - Nodes old_decls = get_module_declarations(ctx->rewriter.src_module); - for (size_t i = 0; i < old_decls.count; i++) { - const Node* old_decl = old_decls.nodes[i]; - if (strcmp(get_decl_name(old_decl), name) == 0) { - Context top_ctx = *ctx; - top_ctx.current_function = NULL; - top_ctx.local_variables = NULL; - const Node* decl = rewrite_node(&top_ctx.rewriter, old_decl); - return (Resolved) { - .is_var = decl->tag == GlobalVariable_TAG, - .node = decl - }; - } - } - - error("could not resolve node %s", name) -} - -static void add_binding(Context* ctx, bool is_var, String name, const Node* node) { - NamedBindEntry* entry = arena_alloc(ctx->rewriter.dst_arena->arena, sizeof(NamedBindEntry)); - *entry = (NamedBindEntry) { - .name = string(ctx->rewriter.dst_arena, name), - .is_var = is_var, - .node = (Node*) node, - .next = NULL - }; - entry->next = ctx->local_variables; - ctx->local_variables = entry; -} - -static const Node* get_node_address(Context* ctx, const Node* node); -static const Node* get_node_address_safe(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - switch (node->tag) { - case Unbound_TAG: { - Resolved entry = resolve_using_name(ctx, node->payload.unbound.name); - // can't take the address if it's not a var! - if (!entry.is_var) - return NULL; - return entry.node; - } - case PrimOp_TAG: { - if (node->tag == PrimOp_TAG && node->payload.prim_op.op == subscript_op) { - const Node* src_ptr = get_node_address_safe(ctx, node->payload.prim_op.operands.nodes[0]); - if (src_ptr == NULL) - return NULL; - const Node* index = rewrite_node(&ctx->rewriter, node->payload.prim_op.operands.nodes[1]); - return prim_op(a, (PrimOp) { - .op = lea_op, - .operands = nodes(a, 3, (const Node* []) {src_ptr, int32_literal(a, 0), index }) - }); - } else if (node->tag == PrimOp_TAG && node->payload.prim_op.op == deref_op) { - return rewrite_node(&ctx->rewriter, first(node->payload.prim_op.operands)); - } - } - default: break; - } - return NULL; -} - -static const Node* get_node_address(Context* ctx, const Node* node) { - const Node* got = get_node_address_safe(ctx, node); - if (!got) - error("This doesn't really look like a place expression...") - return got; -} - -static const Node* desugar_let_mut(Context* ctx, const Node* node) { - assert(node->tag == LetMut_TAG); - IrArena* a = ctx->rewriter.dst_arena; - Context body_infer_ctx = *ctx; - const Node* ninstruction = rewrite_node(&ctx->rewriter, node->payload.let.instruction); - - const Node* old_lam = node->payload.let.tail; - assert(old_lam && is_case(old_lam)); - - BodyBuilder* bb = begin_body(a); - - Nodes initial_values = bind_instruction_outputs_count(bb, ninstruction, old_lam->payload.case_.params.count, NULL, false); - Nodes old_params = old_lam->payload.case_.params; - for (size_t i = 0; i < old_params.count; i++) { - const Node* oparam = old_params.nodes[i]; - const Type* type_annotation = oparam->payload.var.type; - assert(type_annotation); - const Node* alloca = prim_op(a, (PrimOp) { - .op = alloca_op, - .type_arguments = nodes(a, 1, (const Node* []){rewrite_node(&ctx->rewriter, type_annotation) }), - .operands = nodes(a, 0, NULL) - }); - const Node* ptr = bind_instruction_outputs_count(bb, alloca, 1, &oparam->payload.var.name, false).nodes[0]; - const Node* store = prim_op(a, (PrimOp) { - .op = store_op, - .operands = nodes(a, 2, (const Node* []) {ptr, initial_values.nodes[0] }) - }); - bind_instruction_outputs_count(bb, store, 0, NULL, false); - - add_binding(&body_infer_ctx, true, oparam->payload.var.name, ptr); - debugv_print("Lowered mutable variable %s\n", get_value_name_safe(oparam)); - } - - const Node* terminator = rewrite_node(&body_infer_ctx.rewriter, old_lam->payload.case_.body); - return finish_body(bb, terminator); -} - -static const Node* rewrite_decl(Context* ctx, const Node* decl) { - assert(is_declaration(decl)); - switch (decl->tag) { - case GlobalVariable_TAG: { - const GlobalVariable* ogvar = &decl->payload.global_variable; - Node* bound = global_var(ctx->rewriter.dst_module, rewrite_nodes(&ctx->rewriter, ogvar->annotations), rewrite_node(&ctx->rewriter, ogvar->type), ogvar->name, ogvar->address_space); - register_processed(&ctx->rewriter, decl, bound); - bound->payload.global_variable.init = rewrite_node(&ctx->rewriter, decl->payload.global_variable.init); - return bound; - } - case Constant_TAG: { - const Constant* cnst = &decl->payload.constant; - Node* bound = constant(ctx->rewriter.dst_module, rewrite_nodes(&ctx->rewriter, cnst->annotations), rewrite_node(&ctx->rewriter, decl->payload.constant.type_hint), cnst->name); - register_processed(&ctx->rewriter, decl, bound); - bound->payload.constant.instruction = rewrite_node(&ctx->rewriter, decl->payload.constant.instruction); - return bound; - } - case Function_TAG: { - Nodes new_fn_params = recreate_variables(&ctx->rewriter, decl->payload.fun.params); - Node* bound = function(ctx->rewriter.dst_module, new_fn_params, decl->payload.fun.name, rewrite_nodes(&ctx->rewriter, decl->payload.fun.annotations), rewrite_nodes(&ctx->rewriter, decl->payload.fun.return_types)); - register_processed(&ctx->rewriter, decl, bound); - Context fn_ctx = *ctx; - for (size_t i = 0; i < new_fn_params.count; i++) { - add_binding(&fn_ctx, false, decl->payload.fun.params.nodes[i]->payload.var.name, new_fn_params.nodes[i]); - } - register_processed_list(&ctx->rewriter, decl->payload.fun.params, new_fn_params); - - fn_ctx.current_function = bound; - bound->payload.fun.body = rewrite_node(&fn_ctx.rewriter, decl->payload.fun.body); - return bound; - } - case NominalType_TAG: { - Node* bound = nominal_type(ctx->rewriter.dst_module, rewrite_nodes(&ctx->rewriter, decl->payload.nom_type.annotations), decl->payload.nom_type.name); - register_processed(&ctx->rewriter, decl, bound); - bound->payload.nom_type.body = rewrite_node(&ctx->rewriter, decl->payload.nom_type.body); - return bound; - } - default: error("unknown declaration kind"); - } - - error("unreachable") - //register_processed(&ctx->rewriter, decl, bound); - //return bound; -} - -static const Node* bind_node(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - if (node == NULL) - return NULL; - - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - // in case the node is an l-value, we load it - const Node* lhs = get_node_address_safe(ctx, node); - if (lhs) { - return prim_op(a, (PrimOp) { - .op = load_op, - .operands = singleton(lhs) - }); - } - - switch (node->tag) { - case Function_TAG: - case Constant_TAG: - case GlobalVariable_TAG: - case NominalType_TAG: { - assert(is_declaration(node)); - return rewrite_decl(ctx, node); - } - case Variable_TAG: error("the binders should be handled such that this node is never reached"); - case Unbound_TAG: { - Resolved entry = resolve_using_name(ctx, node->payload.unbound.name); - assert(!entry.is_var); - return entry.node; - } - case UnboundBBs_TAG: { - Nodes unbound_blocks = node->payload.unbound_bbs.children_blocks; - LARRAY(Node*, new_bbs, unbound_blocks.count); - - // First create stubs - for (size_t i = 0; i < unbound_blocks.count; i++) { - const Node* old_bb = unbound_blocks.nodes[i]; - assert(is_basic_block(old_bb)); - Nodes new_bb_params = recreate_variables(&ctx->rewriter, old_bb->payload.basic_block.params); - Node* new_bb = basic_block(a, (Node*) ctx->current_function, new_bb_params, old_bb->payload.basic_block.name); - new_bbs[i] = new_bb; - add_binding(ctx, false, old_bb->payload.basic_block.name, new_bb); - register_processed(&ctx->rewriter, old_bb, new_bb); - debugv_print("Bound (stub) basic block %s\n", old_bb->payload.basic_block.name); - } - - const Node* bound_body = rewrite_node(&ctx->rewriter, node->payload.unbound_bbs.body); - - // Rebuild the basic blocks now - for (size_t i = 0; i < unbound_blocks.count; i++) { - const Node* old_bb = unbound_blocks.nodes[i]; - Node* new_bb = new_bbs[i]; - - Context bb_ctx = *ctx; - Nodes new_bb_params = get_abstraction_params(new_bb); - for (size_t j = 0; j < new_bb_params.count; j++) - add_binding(&bb_ctx, false, new_bb->payload.basic_block.params.nodes[j]->payload.var.name, new_bb_params.nodes[j]); - - new_bb->payload.basic_block.body = rewrite_node(&bb_ctx.rewriter, old_bb->payload.basic_block.body); - debugv_print("Bound basic block %s\n", new_bb->payload.basic_block.name); - } - - return bound_body; - } - case BasicBlock_TAG: { - assert(is_basic_block(node)); - Nodes new_bb_params = recreate_variables(&ctx->rewriter, node->payload.basic_block.params); - Node* new_bb = basic_block(a, (Node*) ctx->current_function, new_bb_params, node->payload.basic_block.name); - Context bb_ctx = *ctx; - ctx = &bb_ctx; - add_binding(ctx, false, node->payload.basic_block.name, new_bb); - register_processed(&ctx->rewriter, node, new_bb); - register_processed_list(&ctx->rewriter, node->payload.basic_block.params, new_bb_params); - new_bb->payload.basic_block.body = rewrite_node(&ctx->rewriter, node->payload.basic_block.body); - return new_bb; - } - case Case_TAG: { - Nodes old_params = node->payload.case_.params; - Nodes new_params = recreate_variables(&ctx->rewriter, old_params); - for (size_t i = 0; i < new_params.count; i++) - add_binding(ctx, false, old_params.nodes[i]->payload.var.name, new_params.nodes[i]); - register_processed_list(&ctx->rewriter, old_params, new_params); - const Node* new_body = rewrite_node(&ctx->rewriter, node->payload.case_.body); - return case_(a, new_params, new_body); - } - case LetMut_TAG: return desugar_let_mut(ctx, node); - case Return_TAG: { - assert(ctx->current_function); - return fn_ret(a, (Return) { - .fn = ctx->current_function, - .args = rewrite_nodes(&ctx->rewriter, node->payload.fn_ret.args) - }); - } - default: { - if (node->tag == PrimOp_TAG && node->payload.prim_op.op == assign_op) { - const Node* target_ptr = get_node_address(ctx, node->payload.prim_op.operands.nodes[0]); - assert(target_ptr); - const Node* value = rewrite_node(&ctx->rewriter, node->payload.prim_op.operands.nodes[1]); - return prim_op(a, (PrimOp) { - .op = store_op, - .operands = nodes(a, 2, (const Node* []) {target_ptr, value }) - }); - } else if (node->tag == PrimOp_TAG && node->payload.prim_op.op == subscript_op) { - return prim_op(a, (PrimOp) { - .op = extract_op, - .operands = mk_nodes(a, rewrite_node(&ctx->rewriter, node->payload.prim_op.operands.nodes[0]), rewrite_node(&ctx->rewriter, node->payload.prim_op.operands.nodes[1])) - }); - } else if (node->tag == PrimOp_TAG && node->payload.prim_op.op == addrof_op) { - const Node* target_ptr = get_node_address(ctx, node->payload.prim_op.operands.nodes[0]); - return target_ptr; - } - return recreate_node_identity(&ctx->rewriter, node); - } - } -} - -Module* bind_program(SHADY_UNUSED const CompilerConfig* compiler_config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - assert(!src->arena->config.name_bound); - aconfig.name_bound = true; - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) bind_node), - .local_variables = NULL, - .current_function = NULL, - }; - - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - return dst; -} diff --git a/src/shady/passes/cleanup.c b/src/shady/passes/cleanup.c index 03e1bbc79..466e20b83 100644 --- a/src/shady/passes/cleanup.c +++ b/src/shady/passes/cleanup.c @@ -1,10 +1,13 @@ -#include "passes.h" +#include "shady/pass.h" + +#include "../analysis/uses.h" +#include "../analysis/leak.h" +#include "../ir_private.h" #include "portability.h" #include "log.h" -#include "../rewrite.h" -#include "../analysis/uses.h" +#pragma GCC diagnostic error "-Wswitch" typedef struct { Rewriter rewriter; @@ -12,69 +15,147 @@ typedef struct { bool* todo; } Context; -const Node* process(Context* ctx, const Node* old) { +static size_t count_calls(const UsesMap* map, const Node* bb) { + size_t count = 0; + const Use* use = shd_get_first_use(map, bb); + for (; use; use = use->next_use) { + if (use->user->tag == Jump_TAG) { + const Use* jump_use = shd_get_first_use(map, use->user); + for (; jump_use; jump_use = jump_use->next_use) { + if (jump_use->operand_class == NcJump) + return SIZE_MAX; // you can never inline conditional jumps + count++; + } + } else if (use->operand_class == NcBasic_block) + return SIZE_MAX; // you can never inline basic blocks used for other purposes + } + return count; +} + +static bool is_used_as_value(const UsesMap* map, const Node* instr) { + const Use* use = shd_get_first_use(map, instr); + for (; use; use = use->next_use) { + if (use->operand_class == NcValue) + return true; + } + return false; +} + +static const Node* process(Context* ctx, const Node* old) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; if (old->tag == Function_TAG || old->tag == Constant_TAG) { Context c = *ctx; - c.map = create_uses_map(old, NcType | NcDeclaration); - const Node* new = recreate_node_identity(&c.rewriter, old); - destroy_uses_map(c.map); + c.map = shd_new_uses_map_fn(old, NcType | NcDeclaration); + const Node* new = shd_recreate_node(&c.rewriter, old); + shd_destroy_uses_map(c.map); return new; } switch (old->tag) { - case Let_TAG: { - Let payload = old->payload.let; - bool side_effects = true; - if (payload.instruction->tag == PrimOp_TAG) - side_effects = has_primop_got_side_effects(payload.instruction->payload.prim_op.op); - bool consumed = false; - const Node* tail_case = payload.tail; - Nodes tail_params = get_abstraction_params(tail_case); - for (size_t i = 0; i < tail_params.count; i++) { - const Use* use = get_first_use(ctx->map, tail_params.nodes[i]); - assert(use); - for (;use; use = use->next_use) { - if (use->user == tail_case) - continue; - consumed = true; - break; + case BasicBlock_TAG: { + size_t uses = count_calls(ctx->map, old); + if (uses <= 1 && a->config.optimisations.inline_single_use_bbs) { + shd_log_fmt(DEBUGVV, "Eliminating basic block '%s' since it's used only %d times.\n", shd_get_abstraction_name_safe(old), uses); + *ctx->todo = true; + return NULL; + } + break; + } + case Jump_TAG: { + const Node* otarget = old->payload.jump.target; + const Node* ntarget = shd_rewrite_node(r, otarget); + if (!ntarget) { + // it's been inlined away! just steal the body + Nodes nargs = shd_rewrite_nodes(r, old->payload.jump.args); + shd_register_processed_list(r, get_abstraction_params(otarget), nargs); + shd_register_processed(r, shd_get_abstraction_mem(otarget), shd_rewrite_node(r, old->payload.jump.mem)); + return shd_rewrite_node(r, get_abstraction_body(otarget)); + } + break; + } + case Control_TAG: { + Control payload = old->payload.control; + if (shd_is_control_static(ctx->map, old)) { + const Node* control_inside = payload.inside; + const Node* term = get_abstraction_body(control_inside); + if (term->tag == Join_TAG) { + Join payload_join = term->payload.join; + if (payload_join.join_point == shd_first(get_abstraction_params(control_inside))) { + // if we immediately consume the join point and it's never leaked, this control block does nothing and can be eliminated + shd_register_processed(r, shd_get_abstraction_mem(control_inside), shd_rewrite_node(r, payload.mem)); + shd_register_processed(r, control_inside, NULL); + *ctx->todo = true; + return shd_rewrite_node(r, term); + } } - if (consumed) - break; } - if (!consumed && !side_effects && ctx->rewriter.dst_arena) { - debug_print("Cleanup: found an unused instruction: "); - log_node(DEBUG, payload.instruction); - debug_print("\n"); - *ctx->todo = true; - return rewrite_node(&ctx->rewriter, get_abstraction_body(tail_case)); + break; + } + case Join_TAG: { + Join payload = old->payload.join; + const Node* control = shd_get_control_for_jp(ctx->map, payload.join_point); + if (control) { + Control old_control_payload = control->payload.control; + // there was a control but now there is not anymore - jump to the tail! + if (shd_rewrite_node(r, old_control_payload.inside) == NULL) { + return jump_helper(a, shd_rewrite_node(r, payload.mem), shd_rewrite_node(r, old_control_payload.tail), + shd_rewrite_nodes(r, payload.args)); + } } break; } + case Load_TAG: { + if (!is_used_as_value(ctx->map, old)) + return shd_rewrite_node(r, old->payload.load.mem); + break; + } default: break; } - return recreate_node_identity(&ctx->rewriter, old);; + return shd_recreate_node(&ctx->rewriter, old); +} + +OptPass shd_opt_simplify; + +bool shd_opt_simplify(SHADY_UNUSED const CompilerConfig* config, Module** m) { + Module* src = *m; + + IrArena* a = shd_module_get_arena(src); + *m = shd_new_module(a, shd_module_get_name(*m)); + bool todo = false; + Context ctx = { .todo = &todo }; + ctx.rewriter = shd_create_node_rewriter(src, *m, (RewriteNodeFn) process); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return todo; } -Module* cleanup(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); +OptPass shd_opt_demote_alloca; +OptPass shd_opt_mem2reg; +RewritePass shd_import; + +Module* shd_cleanup(const CompilerConfig* config, Module* const src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); if (!aconfig.check_types) return src; - IrArena* a = new_ir_arena(aconfig); bool todo; - Context ctx = { .todo = &todo }; size_t r = 0; - Module* m; + Module* m = src; + bool changed_at_all = false; do { - debug_print("Cleanup round %d\n", r); todo = false; - m = new_module(a, get_module_name(src)); - ctx.rewriter = create_rewriter(src, m, (RewriteNodeFn) process), - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - src = m; + shd_debugv_print("Cleanup round %d\n", r); + + APPLY_OPT(shd_opt_demote_alloca); + APPLY_OPT(shd_opt_mem2reg); + APPLY_OPT(shd_opt_simplify); + + changed_at_all |= todo; + r++; } while (todo); - return m; + if (changed_at_all) + shd_debugv_print("After %d rounds of cleanup:\n", r); + return shd_import(config, m); } diff --git a/src/shady/passes/eliminate_constants.c b/src/shady/passes/eliminate_constants.c index 2d20a79c4..a3b6301ff 100644 --- a/src/shady/passes/eliminate_constants.c +++ b/src/shady/passes/eliminate_constants.c @@ -1,64 +1,56 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/annotation.h" -#include "../rewrite.h" #include "portability.h" #include "log.h" +#include "dict.h" typedef struct { Rewriter rewriter; - BodyBuilder* bb; + bool all; } Context; static const Node* process(Context* ctx, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; IrArena* a = ctx->rewriter.dst_arena; - BodyBuilder* abs_bb = NULL; - Context c = *ctx; - ctx = &c; - if (is_abstraction(node)) { - c.bb = abs_bb = begin_body(a); - } - switch (node->tag) { - case Constant_TAG: return NULL; + case Constant_TAG: + if (!node->payload.constant.value) + break; + if (!ctx->all && !shd_lookup_annotation(node, "Inline")) + break; + return NULL; case RefDecl_TAG: { const Node* decl = node->payload.ref_decl.decl; - if (decl->tag == Constant_TAG) { - const Node* value = get_quoted_value(decl->payload.constant.instruction); - if (value) - return rewrite_node(&ctx->rewriter, value); - assert(ctx->bb); - // TODO: actually _copy_ the instruction so we can duplicate the code safely! - return first(bind_instruction(ctx->bb, rewrite_node(&ctx->rewriter, decl->payload.constant.instruction))); + if (decl->tag == Constant_TAG && decl->payload.constant.value) { + return shd_rewrite_node(&ctx->rewriter, decl->payload.constant.value); } break; } default: break; } - Node* new = (Node*) recreate_node_identity(&ctx->rewriter, node); - if (abs_bb) { - assert(is_abstraction(new)); - if (get_abstraction_body(new)) - set_abstraction_body(new, finish_body(abs_bb, get_abstraction_body(new))); - else - cancel_body(abs_bb); - } - return new; + return shd_recreate_node(&ctx->rewriter, node); } -Module* eliminate_constants(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +static Module* eliminate_constants_(SHADY_UNUSED const CompilerConfig* config, Module* src, bool all) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process) + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .all = all, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } + +Module* shd_pass_eliminate_constants(const CompilerConfig* config, Module* src) { + return eliminate_constants_(config, src, true); +} + +Module* shd_pass_eliminate_inlineable_constants(const CompilerConfig* config, Module* src) { + return eliminate_constants_(config, src, false); +} diff --git a/src/shady/passes/import.c b/src/shady/passes/import.c index 6bb7c770e..02dd87872 100644 --- a/src/shady/passes/import.c +++ b/src/shady/passes/import.c @@ -1,22 +1,82 @@ -#include "passes.h" +#include "shady/pass.h" -#include "portability.h" +#include "../ir_private.h" -#include "../rewrite.h" +#include "portability.h" +#include "log.h" typedef struct { Rewriter rewriter; } Context; -Module* import(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +static void replace_or_compare(const Node** dst, const Node* with) { + if (!*dst) + *dst = with; + else { + assert(*dst == with && "conflicting definitions"); + } +} + +static const Node* import_node(Rewriter* r, const Node* node) { + if (is_declaration(node)) { + Node* existing = shd_module_get_declaration(r->dst_module, get_declaration_name(node)); + if (existing) { + const Node* imported_t = shd_rewrite_node(r, node->type); + if (imported_t != existing->type) { + shd_error_print("Incompatible types for to-be-merged declaration: %s ", get_declaration_name(node)); + shd_log_node(ERROR, existing->type); + shd_error_print(" vs "); + shd_log_node(ERROR, imported_t); + shd_error_print(".\n"); + shd_error_die(); + } + if (node->tag != existing->tag) { + shd_error_print("Incompatible node types for to-be-merged declaration: %s ", get_declaration_name(node)); + shd_error_print("%s", shd_get_node_tag_string(existing->tag)); + shd_error_print(" vs "); + shd_error_print("%s", shd_get_node_tag_string(node->tag)); + shd_error_print(".\n"); + shd_error_die(); + } + switch (is_declaration(node)) { + case NotADeclaration: assert(false); + case Declaration_Function_TAG: + replace_or_compare(&existing->payload.fun.body, shd_rewrite_node(r, node->payload.fun.body)); + break; + case Declaration_Constant_TAG: + replace_or_compare(&existing->payload.constant.value, shd_rewrite_node(r, node->payload.constant.value)); + break; + case Declaration_GlobalVariable_TAG: + replace_or_compare(&existing->payload.global_variable.init, shd_rewrite_node(r, node->payload.global_variable.init)); + break; + case Declaration_NominalType_TAG: + replace_or_compare(&existing->payload.nom_type.body, shd_rewrite_node(r, node->payload.nom_type.body)); + break; + } + return existing; + } + } + + return shd_recreate_node(r, node); +} + +Module* shd_import(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) recreate_node_identity), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) shd_recreate_node), }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } + +void shd_module_link(Module* dst, Module* src) { + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) import_node), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); +} diff --git a/src/shady/passes/infer.c b/src/shady/passes/infer.c deleted file mode 100644 index b9a952f89..000000000 --- a/src/shady/passes/infer.c +++ /dev/null @@ -1,860 +0,0 @@ -#include "passes.h" - -#include "log.h" -#include "portability.h" - -#include "../type.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" - -#include -#include - -#pragma GCC diagnostic error "-Wswitch" - -static Nodes annotate_all_types(IrArena* a, Nodes types, bool uniform_by_default) { - LARRAY(const Type*, ntypes, types.count); - for (size_t i = 0; i < types.count; i++) { - if (is_data_type(types.nodes[i])) - ntypes[i] = qualified_type(a, (QualifiedType) { - .type = types.nodes[i], - .is_uniform = uniform_by_default, - }); - else - ntypes[i] = types.nodes[i]; - } - return nodes(a, types.count, ntypes); -} - -typedef struct { - Rewriter rewriter; - - const Type* expected_type; - - const Nodes* merge_types; - const Nodes* break_types; - const Nodes* continue_types; -} Context; - -static const Node* infer(Context* ctx, const Node* node, const Type* expect) { - Context ctx2 = *ctx; - ctx2.expected_type = expect; - return rewrite_node(&ctx2.rewriter, node); -} - -static Nodes infer_nodes(Context* ctx, Nodes nodes) { - Context ctx2 = *ctx; - ctx2.expected_type = NULL; - return rewrite_nodes(&ctx->rewriter, nodes); -} - -#define rewrite_node error("don't use this directly, use the 'infer' and 'infer_node' helpers") -#define rewrite_nodes rewrite_node - -static const Node* _infer_annotation(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - assert(is_annotation(node)); - switch (node->tag) { - case Annotation_TAG: return annotation(a, (Annotation) { .name = node->payload.annotation.name }); - case AnnotationValue_TAG: return annotation_value(a, (AnnotationValue) { .name = node->payload.annotation_value.name, .value = infer(ctx, node->payload.annotation_value.value, NULL) }); - case AnnotationValues_TAG: return annotation_values(a, (AnnotationValues) { .name = node->payload.annotation_values.name, .values = infer_nodes(ctx, node->payload.annotation_values.values) }); - case AnnotationCompound_TAG: return annotation_compound(a, (AnnotationCompound) { .name = node->payload.annotation_compound.name, .entries = infer_nodes(ctx, node->payload.annotation_compound.entries) }); - default: error("Not an annotation"); - } -} - -static const Node* _infer_type(Context* ctx, const Type* type) { - IrArena* a = ctx->rewriter.dst_arena; - switch (type->tag) { - case ArrType_TAG: { - const Node* size = infer(ctx, type->payload.arr_type.size, NULL); - return arr_type(a, (ArrType) { - .size = size, - .element_type = infer(ctx, type->payload.arr_type.element_type, NULL) - }); - } - case PtrType_TAG: { - const Node* element_type = infer(ctx, type->payload.ptr_type.pointed_type, NULL); - if (!element_type) - element_type = unit_type(a); - return ptr_type(a, (PtrType) { .pointed_type = element_type, .address_space = type->payload.ptr_type.address_space }); - } - default: return recreate_node_identity(&ctx->rewriter, type); - } -} - -static const Node* _infer_decl(Context* ctx, const Node* node) { - assert(is_declaration(node)); - const Node* already_done = search_processed(&ctx->rewriter, node); - if (already_done) - return already_done; - - if (lookup_annotation(node, "SkipOnInfer")) - return NULL; - - IrArena* a = ctx->rewriter.dst_arena; - switch (is_declaration(node)) { - case Function_TAG: { - Context body_context = *ctx; - - LARRAY(const Node*, nparams, node->payload.fun.params.count); - for (size_t i = 0; i < node->payload.fun.params.count; i++) { - const Variable* old_param = &node->payload.fun.params.nodes[i]->payload.var; - const Type* imported_param_type = infer(ctx, old_param->type, NULL); - nparams[i] = var(a, imported_param_type, old_param->name); - register_processed(&body_context.rewriter, node->payload.fun.params.nodes[i], nparams[i]); - } - - Nodes nret_types = annotate_all_types(a, infer_nodes(ctx, node->payload.fun.return_types), false); - Node* fun = function(ctx->rewriter.dst_module, nodes(a, node->payload.fun.params.count, nparams), string(a, node->payload.fun.name), infer_nodes(ctx, node->payload.fun.annotations), nret_types); - register_processed(&ctx->rewriter, node, fun); - fun->payload.fun.body = infer(&body_context, node->payload.fun.body, NULL); - return fun; - } - case Constant_TAG: { - const Constant* oconstant = &node->payload.constant; - const Type* imported_hint = infer(ctx, oconstant->type_hint, NULL); - const Node* instruction; - if (imported_hint) { - assert(is_data_type(imported_hint)); - instruction = infer(ctx, oconstant->instruction, qualified_type_helper(imported_hint, true)); - } else { - instruction = infer(ctx, oconstant->instruction, NULL); - } - imported_hint = get_unqualified_type(instruction->type); - - Node* nconstant = constant(ctx->rewriter.dst_module, infer_nodes(ctx, oconstant->annotations), imported_hint, oconstant->name); - register_processed(&ctx->rewriter, node, nconstant); - nconstant->payload.constant.instruction = instruction; - - return nconstant; - } - case GlobalVariable_TAG: { - const GlobalVariable* old_gvar = &node->payload.global_variable; - const Type* imported_ty = infer(ctx, old_gvar->type, NULL); - Node* ngvar = global_var(ctx->rewriter.dst_module, infer_nodes(ctx, old_gvar->annotations), imported_ty, old_gvar->name, old_gvar->address_space); - register_processed(&ctx->rewriter, node, ngvar); - - ngvar->payload.global_variable.init = infer(ctx, old_gvar->init, qualified_type_helper(imported_ty, true)); - return ngvar; - } - case NominalType_TAG: { - const NominalType* onom_type = &node->payload.nom_type; - Node* nnominal_type = nominal_type(ctx->rewriter.dst_module, infer_nodes(ctx, onom_type->annotations), onom_type->name); - register_processed(&ctx->rewriter, node, nnominal_type); - nnominal_type->payload.nom_type.body = infer(ctx, onom_type->body, NULL); - return nnominal_type; - } - case NotADeclaration: error("not a decl"); - } -} - -/// Like get_unqualified_type but won't error out if type wasn't qualified to begin with -static const Type* remove_uniformity_qualifier(const Node* type) { - if (is_value_type(type)) - return get_unqualified_type(type); - return type; -} - -static const Node* _infer_value(Context* ctx, const Node* node, const Type* expected_type) { - if (!node) return NULL; - - if (expected_type) { - assert(is_value_type(expected_type)); - } - - IrArena* a = ctx->rewriter.dst_arena; - switch (is_value(node)) { - case NotAValue: error(""); - case Variable_TAG: return find_processed(&ctx->rewriter, node); - case Value_ConstrainedValue_TAG: { - const Type* type = infer(ctx, node->payload.constrained.type, NULL); - bool expect_uniform = false; - if (expected_type) { - expect_uniform = deconstruct_qualified_type(&expected_type); - assert(is_subtype(expected_type, type)); - } - return infer(ctx, node->payload.constrained.value, qualified_type_helper(type, expect_uniform)); - } - case IntLiteral_TAG: { - if (expected_type) { - expected_type = remove_uniformity_qualifier(expected_type); - assert(expected_type->tag == Int_TAG); - assert(expected_type->payload.int_type.width == node->payload.int_literal.width); - } - return int_literal(a, (IntLiteral) { - .width = node->payload.int_literal.width, - .is_signed = node->payload.int_literal.is_signed, - .value = node->payload.int_literal.value}); - } - case UntypedNumber_TAG: { - char* endptr; - int64_t i = strtoll(node->payload.untyped_number.plaintext, &endptr, 10); - if (!expected_type) { - bool valid_int = *endptr == '\0'; - expected_type = valid_int ? int32_type(a) : fp32_type(a); - } - expected_type = remove_uniformity_qualifier(expected_type); - if (expected_type->tag == Int_TAG) { - // TODO chop off extra bits based on width ? - return int_literal(a, (IntLiteral) { - .width = expected_type->payload.int_type.width, - .is_signed = expected_type->payload.int_literal.is_signed, - .value = i - }); - } else if (expected_type->tag == Float_TAG) { - uint64_t v; - switch (expected_type->payload.float_type.width) { - case FloatTy16: - error("TODO: implement fp16 parsing"); - case FloatTy32: - assert(sizeof(float) == sizeof(uint32_t)); - float f = strtof(node->payload.untyped_number.plaintext, NULL); - memcpy(&v, &f, sizeof(uint32_t)); - break; - case FloatTy64: - assert(sizeof(double) == sizeof(uint64_t)); - double d = strtod(node->payload.untyped_number.plaintext, NULL); - memcpy(&v, &d, sizeof(uint64_t)); - break; - } - return float_literal(a, (FloatLiteral) {.value = v, .width = expected_type->payload.float_type.width}); - } - } - case FloatLiteral_TAG: { - if (expected_type) { - expected_type = remove_uniformity_qualifier(expected_type); - assert(expected_type->tag == Float_TAG); - assert(expected_type->payload.float_type.width == node->payload.float_literal.width); - } - return float_literal(a, (FloatLiteral) { .width = node->payload.float_literal.width, .value = node->payload.float_literal.value }); - } - case True_TAG: return true_lit(a); - case False_TAG: return false_lit(a); - case StringLiteral_TAG: return string_lit(a, (StringLiteral) { .string = string(a, node->payload.string_lit.string )}); - case RefDecl_TAG: - case FnAddr_TAG: return recreate_node_identity(&ctx->rewriter, node); - case Value_Undef_TAG: return recreate_node_identity(&ctx->rewriter, node); - case Value_Composite_TAG: { - const Node* elem_type = infer(ctx, node->payload.composite.type, NULL); - bool uniform = false; - if (elem_type && expected_type) { - assert(is_subtype(get_unqualified_type(expected_type), elem_type)); - } else if (expected_type) { - uniform = deconstruct_qualified_type(&elem_type); - elem_type = expected_type; - } - - Nodes omembers = node->payload.composite.contents; - LARRAY(const Node*, inferred, omembers.count); - if (elem_type) { - Nodes expected_members = get_composite_type_element_types(elem_type); - for (size_t i = 0; i < omembers.count; i++) - inferred[i] = infer(ctx, omembers.nodes[i], qualified_type(a, (QualifiedType) { .is_uniform = uniform, .type = expected_members.nodes[i] })); - } else { - for (size_t i = 0; i < omembers.count; i++) - inferred[i] = infer(ctx, omembers.nodes[i], NULL); - } - Nodes nmembers = nodes(a, omembers.count, inferred); - - // Composites are tuples by default - if (!elem_type) - elem_type = record_type(a, (RecordType) { .members = strip_qualifiers(a, get_values_types(a, nmembers)) }); - - return composite_helper(a, elem_type, nmembers); - } - case Value_Fill_TAG: { - const Node* composite_t = infer(ctx, node->payload.fill.type, NULL); - assert(composite_t); - bool uniform = false; - if (composite_t && expected_type) { - assert(is_subtype(get_unqualified_type(expected_type), composite_t)); - } else if (expected_type) { - uniform = deconstruct_qualified_type(&composite_t); - composite_t = expected_type; - } - assert(composite_t); - const Node* element_t = get_fill_type_element_type(composite_t); - const Node* value = infer(ctx, node->payload.fill.value, qualified_type(a, (QualifiedType) { .is_uniform = uniform, .type = element_t })); - return fill(a, (Fill) { .type = composite_t, .value = value }); - } - case Value_NullPtr_TAG: return recreate_node_identity(&ctx->rewriter, node); - } - SHADY_UNREACHABLE; -} - -static const Node* _infer_case(Context* ctx, const Node* node, const Node* expected) { - IrArena* a = ctx->rewriter.dst_arena; - assert(is_case(node)); - assert(expected); - Nodes inferred_arg_type = unwrap_multiple_yield_types(a, expected); - assert(inferred_arg_type.count == node->payload.case_.params.count || node->payload.case_.params.count == 0); - - Context body_context = *ctx; - LARRAY(const Node*, nparams, inferred_arg_type.count); - for (size_t i = 0; i < inferred_arg_type.count; i++) { - if (node->payload.case_.params.count == 0) { - // syntax sugar: make up a parameter if there was none - nparams[i] = var(a, inferred_arg_type.nodes[i], unique_name(a, "_")); - } else { - const Variable* old_param = &node->payload.case_.params.nodes[i]->payload.var; - // for the param type: use the inferred one if none is already provided - // if one is provided, check the inferred argument type is a subtype of the param type - const Type* param_type = infer(ctx, old_param->type, NULL); - // and do not use the provided param type if it is an untyped ptr - if (!param_type || param_type->tag != PtrType_TAG || param_type->payload.ptr_type.pointed_type) - param_type = inferred_arg_type.nodes[i]; - assert(is_subtype(param_type, inferred_arg_type.nodes[i])); - nparams[i] = var(a, param_type, old_param->name); - register_processed(&body_context.rewriter, node->payload.case_.params.nodes[i], nparams[i]); - } - } - - const Node* new_body = infer(&body_context, node->payload.case_.body, NULL); - return case_(a, nodes(a, inferred_arg_type.count, nparams), new_body); -} - -static const Node* _infer_basic_block(Context* ctx, const Node* node) { - assert(is_basic_block(node)); - IrArena* a = ctx->rewriter.dst_arena; - - Context body_context = *ctx; - LARRAY(const Node*, nparams, node->payload.basic_block.params.count); - for (size_t i = 0; i < node->payload.basic_block.params.count; i++) { - const Variable* old_param = &node->payload.basic_block.params.nodes[i]->payload.var; - // for the param type: use the inferred one if none is already provided - // if one is provided, check the inferred argument type is a subtype of the param type - const Type* param_type = infer(ctx, old_param->type, NULL); - assert(param_type); - nparams[i] = var(a, param_type, old_param->name); - register_processed(&body_context.rewriter, node->payload.basic_block.params.nodes[i], nparams[i]); - } - - Node* fn = (Node*) infer(ctx, node->payload.basic_block.fn, NULL); - Node* bb = basic_block(a, fn, nodes(a, node->payload.basic_block.params.count, nparams), node->payload.basic_block.name); - assert(bb); - register_processed(&ctx->rewriter, node, bb); - - bb->payload.basic_block.body = infer(&body_context, node->payload.basic_block.body, NULL); - return bb; -} - -static const Type* type_untyped_ptr(const Type* untyped_ptr_t, const Type* element_type) { - assert(element_type); - IrArena* a = untyped_ptr_t->arena; - assert(untyped_ptr_t->tag == PtrType_TAG); - const Type* typed_ptr_t = ptr_type(a, (PtrType) { .pointed_type = element_type, .address_space = untyped_ptr_t->payload.ptr_type.address_space }); - return typed_ptr_t; -} - -static const Node* reinterpret_cast_helper(BodyBuilder* bb, const Node* ptr, const Type* typed_ptr_t) { - IrArena* a = ptr->arena; - ptr = gen_reinterpret_cast(bb, typed_ptr_t, ptr); - return ptr; -} - -static void fix_source_pointer(BodyBuilder* bb, const Node** operand, const Type* element_type) { - IrArena* a = element_type->arena; - const Type* original_operand_t = get_unqualified_type((*operand)->type); - assert(original_operand_t->tag == PtrType_TAG); - if (is_physical_ptr_type(original_operand_t)) { - // typed loads - normalise to typed ptrs instead by generating an extra cast! - const Type *ptr_type = original_operand_t; - ptr_type = type_untyped_ptr(ptr_type, element_type); - *operand = reinterpret_cast_helper(bb, *operand, ptr_type); - } else { - // we can't insert a cast but maybe we can make this work - do { - const Type* pointee = get_pointer_type_element(get_unqualified_type((*operand)->type)); - if (pointee == element_type) - return; - pointee = get_maybe_nominal_type_body(pointee); - if (pointee->tag == RecordType_TAG) { - *operand = gen_lea(bb, *operand, int32_literal(a, 0), singleton(int32_literal(a, 0))); - continue; - } - // TODO arrays - assert(false); - } while(true); - } -} - -static const Node* _infer_primop(Context* ctx, const Node* node, const Type* expected_type) { - assert(node->tag == PrimOp_TAG); - IrArena* a = ctx->rewriter.dst_arena; - - for (size_t i = 0; i < node->payload.prim_op.type_arguments.count; i++) - assert(node->payload.prim_op.type_arguments.nodes[i] && is_type(node->payload.prim_op.type_arguments.nodes[i])); - for (size_t i = 0; i < node->payload.prim_op.operands.count; i++) - assert(node->payload.prim_op.operands.nodes[i] && is_value(node->payload.prim_op.operands.nodes[i])); - - Nodes old_type_args = node->payload.prim_op.type_arguments; - Nodes type_args = infer_nodes(ctx, old_type_args); - Nodes old_operands = node->payload.prim_op.operands; - - BodyBuilder* bb = begin_body(a); - Op op = node->payload.prim_op.op; - LARRAY(const Node*, new_operands, old_operands.count); - Nodes input_types = empty(a); - switch (node->payload.prim_op.op) { - case push_stack_op: { - assert(old_operands.count == 1); - assert(type_args.count == 1); - const Type* element_type = type_args.nodes[0]; - assert(is_data_type(element_type)); - new_operands[0] = infer(ctx, old_operands.nodes[0], qualified_type_helper(element_type, false)); - goto rebuild; - } - case pop_stack_op: { - assert(old_operands.count == 0); - assert(type_args.count == 1); - const Type* element_type = type_args.nodes[0]; - assert(is_data_type(element_type)); - //new_inputs_scratch[0] = element_type; - goto rebuild; - } - case load_op: { - assert(old_operands.count == 1); - assert(type_args.count <= 1); - new_operands[0] = infer(ctx, old_operands.nodes[0], NULL); - if (type_args.count == 1) { - fix_source_pointer(bb, &new_operands[0], first(type_args)); - type_args = empty(a); - } - goto rebuild; - } - case store_op: { - assert(old_operands.count == 2); - assert(type_args.count <= 1); - new_operands[0] = infer(ctx, old_operands.nodes[0], NULL); - if (type_args.count == 1) { - fix_source_pointer(bb, &new_operands[0], first(type_args)); - type_args = empty(a); - } - const Type* ptr_type = get_unqualified_type(new_operands[0]->type); - assert(ptr_type->tag == PtrType_TAG); - const Type* element_t = ptr_type->payload.ptr_type.pointed_type; - assert(element_t); - new_operands[1] = infer(ctx, old_operands.nodes[1], qualified_type_helper(element_t, false)); - goto rebuild; - } - case alloca_op: { - assert(type_args.count == 1); - assert(old_operands.count == 0); - const Type* element_type = type_args.nodes[0]; - assert(is_type(element_type)); - assert(is_data_type(element_type)); - goto rebuild; - } - case reinterpret_op: - case convert_op: { - new_operands[0] = infer(ctx, old_operands.nodes[0], NULL); - const Type* src_pointer_type = get_unqualified_type(new_operands[0]->type); - const Type* old_dst_pointer_type = first(old_type_args); - const Type* dst_pointer_type = first(type_args); - - if (is_generic_ptr_type(src_pointer_type) != is_generic_ptr_type(dst_pointer_type)) - op = convert_op; - - if (old_dst_pointer_type->tag == PtrType_TAG && !old_dst_pointer_type->payload.ptr_type.pointed_type) { - const Type* element_type = uint8_type(a); - if (src_pointer_type->tag == PtrType_TAG && src_pointer_type->payload.ptr_type.pointed_type) { - // element_type = infer(ctx, old_src_pointer_type->payload.ptr_type.pointed_type, NULL); - element_type = src_pointer_type->payload.ptr_type.pointed_type; - } - dst_pointer_type = type_untyped_ptr(dst_pointer_type, element_type); - type_args = change_node_at_index(a, type_args, 0, dst_pointer_type); - } - - goto rebuild; - } - case lea_op: { - assert(old_operands.count >= 2); - assert(type_args.count <= 1); - new_operands[0] = infer(ctx, old_operands.nodes[0], NULL); - new_operands[1] = infer(ctx, old_operands.nodes[1], NULL); - for (size_t i = 2; i < old_operands.count; i++) { - new_operands[i] = infer(ctx, old_operands.nodes[i], NULL); - } - - const Type* src_ptr = remove_uniformity_qualifier(new_operands[0]->type); - const Type* base_datatype = src_ptr; - assert(base_datatype->tag == PtrType_TAG); - AddressSpace as = get_pointer_type_address_space(base_datatype); - bool was_untyped = false; - if (type_args.count == 1) { - was_untyped = true; - base_datatype = type_untyped_ptr(base_datatype, first(type_args)); - new_operands[0] = reinterpret_cast_helper(bb, new_operands[0], base_datatype); - type_args = empty(a); - } - - Nodes new_ops = nodes(a, old_operands.count, new_operands); - - const Node* offset = new_operands[1]; - const IntLiteral* offset_lit = resolve_to_int_literal(offset); - if ((!offset_lit || offset_lit->value) != 0 && base_datatype->tag != ArrType_TAG) { - warn_print("LEA used on a pointer to a non-array type!\n"); - const Type* arrayed_src_t = ptr_type(a, (PtrType) { - .address_space = as, - .pointed_type = arr_type(a, (ArrType) { - .element_type = get_pointer_type_element(base_datatype), - .size = NULL - }), - }); - const Node* cast_base = gen_reinterpret_cast(bb, arrayed_src_t, first(new_ops)); - Nodes final_lea_ops = mk_nodes(a, cast_base, offset, int32_literal(a, 0)); - final_lea_ops = concat_nodes(a, final_lea_ops, nodes(a, old_operands.count - 2, new_operands + 2)); - new_ops = final_lea_ops; - } - - const Node* result = first(bind_instruction(bb, prim_op(a, (PrimOp) { - .op = lea_op, - .type_arguments = empty(a), - .operands = new_ops - }))); - - if (was_untyped && is_physical_as(get_pointer_type_address_space(src_ptr))) { - const Type* result_t = type_untyped_ptr(base_datatype, unit_type(a)); - result = gen_reinterpret_cast(bb, result_t, result); - } - - return yield_values_and_wrap_in_block(bb, singleton(result)); - } - case empty_mask_op: - case subgroup_active_mask_op: - case subgroup_elect_first_op: - input_types = nodes(a, 0, NULL); - break; - case subgroup_broadcast_first_op: - new_operands[0] = infer(ctx, old_operands.nodes[0], NULL); - goto rebuild; - case subgroup_ballot_op: - input_types = singleton(qualified_type_helper(bool_type(a), false)); - break; - case mask_is_thread_active_op: { - input_types = mk_nodes(a, qualified_type_helper(mask_type(a), false), qualified_type_helper(uint32_type(a), false)); - break; - } - case debug_printf_op: { - String lit = get_string_literal(a, old_operands.nodes[0]); - assert(lit && "debug_printf requires a string literal"); - new_operands[0] = string_lit_helper(a, lit); - for (size_t i = 1; i < old_operands.count; i++) - new_operands[i] = infer(ctx, old_operands.nodes[i], NULL); - goto rebuild; - } - default: { - for (size_t i = 0; i < old_operands.count; i++) { - new_operands[i] = old_operands.nodes[i] ? infer(ctx, old_operands.nodes[i], NULL) : NULL; - } - goto rebuild; - } - } - - assert(input_types.count == old_operands.count); - for (size_t i = 0; i < input_types.count; i++) - new_operands[i] = infer(ctx, old_operands.nodes[i], input_types.nodes[i]); - - rebuild: { - const Node* new_instruction = prim_op(a, (PrimOp) { - .op = op, - .type_arguments = type_args, - .operands = nodes(a, old_operands.count, new_operands) - }); - return bind_last_instruction_and_wrap_in_block(bb, new_instruction); - } -} - -static const Node* _infer_indirect_call(Context* ctx, const Node* node, const Type* expected_type) { - assert(node->tag == Call_TAG); - IrArena* a = ctx->rewriter.dst_arena; - - const Node* new_callee = infer(ctx, node->payload.call.callee, NULL); - assert(is_value(new_callee)); - LARRAY(const Node*, new_args, node->payload.call.args.count); - - const Type* callee_type = get_unqualified_type(new_callee->type); - if (callee_type->tag != PtrType_TAG) - error("functions are called through function pointers"); - callee_type = callee_type->payload.ptr_type.pointed_type; - - if (callee_type->tag != FnType_TAG) - error("Callees must have a function type"); - if (callee_type->payload.fn_type.param_types.count != node->payload.call.args.count) - error("Mismatched argument counts"); - for (size_t i = 0; i < node->payload.call.args.count; i++) { - const Node* arg = node->payload.call.args.nodes[i]; - assert(arg); - new_args[i] = infer(ctx, node->payload.call.args.nodes[i], callee_type->payload.fn_type.param_types.nodes[i]); - assert(new_args[i]->type); - } - - return call(a, (Call) { - .callee = new_callee, - .args = nodes(a, node->payload.call.args.count, new_args) - }); -} - -static const Node* _infer_if(Context* ctx, const Node* node, const Type* expected_type) { - assert(node->tag == If_TAG); - IrArena* a = ctx->rewriter.dst_arena; - const Node* condition = infer(ctx, node->payload.if_instr.condition, bool_type(a)); - - Nodes join_types = infer_nodes(ctx, node->payload.if_instr.yield_types); - Context infer_if_body_ctx = *ctx; - // When we infer the types of the arguments to a call to merge(), they are expected to be varying - Nodes expected_join_types = annotate_all_types(a, join_types, false); - infer_if_body_ctx.merge_types = &expected_join_types; - - const Node* true_body = infer(&infer_if_body_ctx, node->payload.if_instr.if_true, wrap_multiple_yield_types(a, nodes(a, 0, NULL))); - // don't allow seeing the variables made available in the true branch - infer_if_body_ctx.rewriter = ctx->rewriter; - const Node* false_body = node->payload.if_instr.if_false ? infer(&infer_if_body_ctx, node->payload.if_instr.if_false, wrap_multiple_yield_types(a, nodes(a, 0, NULL))) : NULL; - - return if_instr(a, (If) { - .yield_types = join_types, - .condition = condition, - .if_true = true_body, - .if_false = false_body, - }); -} - -static const Node* _infer_loop(Context* ctx, const Node* node, const Type* expected_type) { - assert(node->tag == Loop_TAG); - IrArena* a = ctx->rewriter.dst_arena; - Context loop_body_ctx = *ctx; - const Node* old_body = node->payload.loop_instr.body; - - Nodes old_params = get_abstraction_params(old_body); - Nodes old_params_types = get_variables_types(a, old_params); - Nodes new_params_types = infer_nodes(ctx, old_params_types); - - Nodes old_initial_args = node->payload.loop_instr.initial_args; - LARRAY(const Node*, new_initial_args, old_params.count); - for (size_t i = 0; i < old_params.count; i++) - new_initial_args[i] = infer(ctx, old_initial_args.nodes[i], new_params_types.nodes[i]); - - Nodes loop_yield_types = infer_nodes(ctx, node->payload.loop_instr.yield_types); - - loop_body_ctx.merge_types = NULL; - loop_body_ctx.break_types = &loop_yield_types; - loop_body_ctx.continue_types = &new_params_types; - - const Node* nbody = infer(&loop_body_ctx, old_body, wrap_multiple_yield_types(a, new_params_types)); - // TODO check new body params match continue types - - return loop_instr(a, (Loop) { - .yield_types = loop_yield_types, - .initial_args = nodes(a, old_params.count, new_initial_args), - .body = nbody, - }); -} - -static const Node* _infer_control(Context* ctx, const Node* node, const Type* expected_type) { - assert(node->tag == Control_TAG); - IrArena* a = ctx->rewriter.dst_arena; - - Nodes yield_types = infer_nodes(ctx, node->payload.control.yield_types); - - const Node* olam = node->payload.control.inside; - const Node* ojp = first(get_abstraction_params(olam)); - - Context joinable_ctx = *ctx; - const Type* jpt = join_point_type(a, (JoinPointType) { - .yield_types = yield_types - }); - jpt = qualified_type(a, (QualifiedType) { .is_uniform = true, .type = jpt }); - const Node* jp = var(a, jpt, ojp->payload.var.name); - register_processed(&ctx->rewriter, ojp, jp); - - const Node* nlam = case_(a, singleton(jp), infer(&joinable_ctx, get_abstraction_body(olam), NULL)); - - return control(a, (Control) { - .yield_types = yield_types, - .inside = nlam - }); -} - -static const Node* _infer_block(Context* ctx, const Node* node, const Type* expected_type) { - assert(node->tag == Block_TAG); - IrArena* a = ctx->rewriter.dst_arena; - - Context block_inside_ctx = *ctx; - Nodes nyield_types = infer_nodes(ctx, node->payload.block.yield_types); - block_inside_ctx.merge_types = &nyield_types; - const Node* olam = node->payload.block.inside; - const Node* nlam = case_(a, empty(a), infer(&block_inside_ctx, get_abstraction_body(olam), NULL)); - - return block(a, (Block) { - .yield_types = nyield_types, - .inside = nlam - }); -} - -static const Node* _infer_instruction(Context* ctx, const Node* node, const Type* expected_type) { - switch (is_instruction(node)) { - case PrimOp_TAG: return _infer_primop(ctx, node, expected_type); - case Call_TAG: return _infer_indirect_call(ctx, node, expected_type); - case If_TAG: return _infer_if (ctx, node, expected_type); - case Loop_TAG: return _infer_loop (ctx, node, expected_type); - case Match_TAG: error("TODO") - case Control_TAG: return _infer_control(ctx, node, expected_type); - case Block_TAG: return _infer_block (ctx, node, expected_type); - case Instruction_Comment_TAG: return recreate_node_identity(&ctx->rewriter, node); - case NotAnInstruction: error("not an instruction"); - } - SHADY_UNREACHABLE; -} - -static const Node* _infer_terminator(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - switch (is_terminator(node)) { - case Terminator_LetMut_TAG: - case NotATerminator: assert(false); - case Let_TAG: { - const Node* otail = node->payload.let.tail; - Nodes annotated_types = get_variables_types(a, otail->payload.case_.params); - const Node* inferred_instruction = infer(ctx, node->payload.let.instruction, wrap_multiple_yield_types(a, annotated_types)); - Nodes inferred_yield_types = unwrap_multiple_yield_types(a, inferred_instruction->type); - for (size_t i = 0; i < inferred_yield_types.count; i++) { - assert(is_value_type(inferred_yield_types.nodes[i])); - } - const Node* inferred_tail = infer(ctx, otail, wrap_multiple_yield_types(a, inferred_yield_types)); - return let(a, inferred_instruction, inferred_tail); - } - case Return_TAG: { - const Node* imported_fn = infer(ctx, node->payload.fn_ret.fn, NULL); - Nodes return_types = imported_fn->payload.fun.return_types; - - const Nodes* old_values = &node->payload.fn_ret.args; - LARRAY(const Node*, nvalues, old_values->count); - for (size_t i = 0; i < old_values->count; i++) - nvalues[i] = infer(ctx, old_values->nodes[i], return_types.nodes[i]); - return fn_ret(a, (Return) { - .args = nodes(a, old_values->count, nvalues), - .fn = NULL - }); - } - case Jump_TAG: { - assert(is_basic_block(node->payload.jump.target)); - const Node* ntarget = infer(ctx, node->payload.jump.target, NULL); - Nodes param_types = get_variables_types(a, get_abstraction_params(ntarget)); - - LARRAY(const Node*, tmp, node->payload.jump.args.count); - for (size_t i = 0; i < node->payload.jump.args.count; i++) - tmp[i] = infer(ctx, node->payload.jump.args.nodes[i], param_types.nodes[i]); - - Nodes new_args = nodes(a, node->payload.jump.args.count, tmp); - - return jump(a, (Jump) { - .target = ntarget, - .args = new_args - }); - } - case Branch_TAG: - case Terminator_Switch_TAG: break; - case Terminator_TailCall_TAG: break; - case Terminator_Yield_TAG: { - const Nodes* expected_types = ctx->merge_types; - // TODO: block nodes should set merge types - assert(expected_types && "Merge terminator found but we're not within a suitable if instruction !"); - const Nodes* old_args = &node->payload.yield.args; - assert(expected_types->count == old_args->count); - LARRAY(const Node*, new_args, old_args->count); - for (size_t i = 0; i < old_args->count; i++) - new_args[i] = infer(ctx, old_args->nodes[i], (*expected_types).nodes[i]); - return yield(a, (Yield) { - .args = nodes(a, old_args->count, new_args) - }); - } - case MergeContinue_TAG: { - const Nodes* expected_types = ctx->continue_types; - assert(expected_types && "Merge terminator found but we're not within a suitable loop instruction !"); - const Nodes* old_args = &node->payload.merge_continue.args; - assert(expected_types->count == old_args->count); - LARRAY(const Node*, new_args, old_args->count); - for (size_t i = 0; i < old_args->count; i++) - new_args[i] = infer(ctx, old_args->nodes[i], (*expected_types).nodes[i]); - return merge_continue(a, (MergeContinue) { - .args = nodes(a, old_args->count, new_args) - }); - } - case MergeBreak_TAG: { - const Nodes* expected_types = ctx->break_types; - assert(expected_types && "Merge terminator found but we're not within a suitable loop instruction !"); - const Nodes* old_args = &node->payload.merge_break.args; - assert(expected_types->count == old_args->count); - LARRAY(const Node*, new_args, old_args->count); - for (size_t i = 0; i < old_args->count; i++) - new_args[i] = infer(ctx, old_args->nodes[i], (*expected_types).nodes[i]); - return merge_break(a, (MergeBreak) { - .args = nodes(a, old_args->count, new_args) - }); - } - case Unreachable_TAG: return unreachable(a); - case Terminator_Join_TAG: return join(a, (Join) { - .join_point = infer(ctx, node->payload.join.join_point, NULL), - .args = infer_nodes(ctx, node->payload.join.args), - }); - } - return recreate_node_identity(&ctx->rewriter, node); -} - -static const Node* process(Context* src_ctx, const Node* node) { - const Type* expect = src_ctx->expected_type; - Context ctx = *src_ctx; - ctx.expected_type = NULL; - - const Node* found = search_processed(&src_ctx->rewriter, node); - if (found) { - //if (expect) - // assert(is_subtype(expect, found->type)); - return found; - } - - if (is_type(node)) { - assert(expect == NULL); - return _infer_type(&ctx, node); - } else if (is_value(node)) { - const Node* value = _infer_value(&ctx, node, expect); - assert(is_value_type(value->type)); - return value; - }else if (is_instruction(node)) - return _infer_instruction(&ctx, node, expect); - else if (is_terminator(node)) { - assert(expect == NULL); - return _infer_terminator(&ctx, node); - } else if (is_declaration(node)) { - return _infer_decl(&ctx, node); - } else if (is_annotation(node)) { - assert(expect == NULL); - return _infer_annotation(&ctx, node); - } else if (is_case(node)) { - assert(expect != NULL); - return _infer_case(&ctx, node, expect); - } else if (is_basic_block(node)) { - return _infer_basic_block(&ctx, node); - } - assert(false); -} - -Module* infer_program(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - assert(!aconfig.check_types); - aconfig.check_types = true; - aconfig.untyped_ptrs = false; - aconfig.allow_fold = true; // TODO was moved here because a refactor, does this cause issues ? - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - }; - ctx.rewriter.config.search_map = false; - ctx.rewriter.config.write_map = false; - ctx.rewriter.config.rebind_let = true; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - return dst; -} diff --git a/src/shady/passes/join_point_ops.h b/src/shady/passes/join_point_ops.h new file mode 100644 index 000000000..338cab249 --- /dev/null +++ b/src/shady/passes/join_point_ops.h @@ -0,0 +1,9 @@ +#ifndef SHADY_JOIN_POINT_OPS_H +#define SHADY_JOIN_POINT_OPS_H + +typedef enum { + ShadyOpDefaultJoinPoint, + ShadyOpCreateJoinPoint, +} ShadyJoinPointOpcodes; + +#endif diff --git a/src/shady/passes/lcssa.c b/src/shady/passes/lcssa.c index 4a5b7f968..fb4e25ffa 100644 --- a/src/shady/passes/lcssa.c +++ b/src/shady/passes/lcssa.c @@ -1,11 +1,14 @@ #include "shady/ir.h" +#include "shady/pass.h" -#include "../rewrite.h" -#include "../analysis/scope.h" +#include "shady/rewrite.h" +#include "../ir_private.h" +#include "../analysis/cfg.h" +#include "../analysis/scheduler.h" #include "../analysis/looptree.h" #include "../analysis/uses.h" #include "../analysis/leak.h" -#include "../analysis/free_variables.h" +#include "../analysis/free_frontier.h" #include "portability.h" #include "log.h" @@ -15,8 +18,8 @@ typedef struct Context_ { Rewriter rewriter; const CompilerConfig* config; const Node* current_fn; - Scope* scope; - const UsesMap* scope_uses; + CFG* cfg; + Scheduler* scheduler; LoopTree* loop_tree; struct Dict* lifted_arguments; } Context; @@ -39,64 +42,63 @@ static const LTNode* get_loop(const LTNode* n) { } static String loop_name(const LTNode* n) { - if (n && n->type == LF_HEAD && entries_count_list(n->cf_nodes) > 0) { - return get_abstraction_name(read_list(CFNode*, n->cf_nodes)[0]->node); + if (n && n->type == LF_HEAD && shd_list_count(n->cf_nodes) > 0) { + return shd_get_abstraction_name(shd_read_list(CFNode*, n->cf_nodes)[0]->node); } return ""; } -void find_liftable_loop_values(Context* ctx, const Node* old, Nodes* nparams, Nodes* lparams, Nodes* nargs) { +static void find_liftable_loop_values(Context* ctx, const Node* old, Nodes* nparams, Nodes* lparams, Nodes* nargs) { IrArena* a = ctx->rewriter.dst_arena; assert(old->tag == BasicBlock_TAG); - const LTNode* bb_loop = get_loop(looptree_lookup(ctx->loop_tree, old)); + const LTNode* bb_loop = get_loop(shd_loop_tree_lookup(ctx->loop_tree, old)); - *nparams = empty(a); - *lparams = empty(a); - *nargs = empty(a); + *nparams = shd_empty(a); + *lparams = shd_empty(a); + *nargs = shd_empty(a); - struct List* fvs = compute_free_variables(ctx->scope, old); - for (size_t i = 0; i < entries_count_list(fvs); i++) { - const Node* fv = read_list(const Node*, fvs)[i]; - const Node* defining_abs = get_binding_abstraction(ctx->scope_uses, fv); - const CFNode* defining_cf_node = scope_lookup(ctx->scope, defining_abs); + struct Dict* fvs = shd_free_frontier(ctx->scheduler, ctx->cfg, old); + const Node* fv; + for (size_t i = 0; shd_dict_iter(fvs, &i, &fv, NULL);) { + const CFNode* defining_cf_node = shd_schedule_instruction(ctx->scheduler, fv); assert(defining_cf_node); - const LTNode* defining_loop = get_loop(looptree_lookup(ctx->loop_tree, defining_cf_node->node)); + const LTNode* defining_loop = get_loop(shd_loop_tree_lookup(ctx->loop_tree, defining_cf_node->node)); if (!is_child(defining_loop, bb_loop)) { // that's it, that variable is leaking ! - debug_print("lcssa: %s~%d is used outside of the loop that defines it %s %s\n", get_value_name_safe(fv), fv->payload.var.id, loop_name(defining_loop), loop_name(bb_loop)); - const Node* narg = rewrite_node(&ctx->rewriter, fv); - const Node* nparam = var(a, narg->type, "lcssa_phi"); - *nparams = append_nodes(a, *nparams, nparam); - *lparams = append_nodes(a, *lparams, fv); - *nargs = append_nodes(a, *nargs, narg); + shd_log_fmt(DEBUGV, "lcssa: "); + shd_log_node(DEBUGV, fv); + shd_log_fmt(DEBUGV, " (%%%d) is used outside of the loop that defines it %s %s\n", fv->id, loop_name(defining_loop), loop_name(bb_loop)); + const Node* narg = shd_rewrite_node(&ctx->rewriter, fv); + const Node* nparam = param(a, narg->type, "lcssa_phi"); + *nparams = shd_nodes_append(a, *nparams, nparam); + *lparams = shd_nodes_append(a, *lparams, fv); + *nargs = shd_nodes_append(a, *nargs, narg); } } - destroy_list(fvs); + shd_destroy_dict(fvs); if (nparams->count > 0) - insert_dict(const Node*, Nodes, ctx->lifted_arguments, old, *nparams); + shd_dict_insert(const Node*, Nodes, ctx->lifted_arguments, old, *nparams); } -const Node* process_abstraction_body(Context* ctx, const Node* old, const Node* body) { +static const Node* process_abstraction_body(Context* ctx, const Node* old, const Node* body) { IrArena* a = ctx->rewriter.dst_arena; Context ctx2 = *ctx; ctx = &ctx2; - Node* nfn = (Node*) rewrite_node(&ctx->rewriter, ctx->current_fn); - - if (!ctx->scope) { - error_print("LCSSA: Trying to process an abstraction that's not part of a function ('%s')!", get_abstraction_name(old)); - log_module(ERROR, ctx->config, ctx->rewriter.src_module); - error_die(); + if (!ctx->cfg) { + shd_error_print("LCSSA: Trying to process an abstraction that's not part of a function ('%s')!", shd_get_abstraction_name(old)); + shd_log_module(ERROR, ctx->config, ctx->rewriter.src_module); + shd_error_die(); } - const CFNode* n = scope_lookup(ctx->scope, old); + const CFNode* n = shd_cfg_lookup(ctx->cfg, old); size_t children_count = 0; - LARRAY(const Node*, old_children, entries_count_list(n->dominates)); - for (size_t i = 0; i < entries_count_list(n->dominates); i++) { - CFNode* c = read_list(CFNode*, n->dominates)[i]; - if (is_case(c->node)) + LARRAY(const Node*, old_children, shd_list_count(n->dominates)); + for (size_t i = 0; i < shd_list_count(n->dominates); i++) { + CFNode* c = shd_read_list(CFNode*, n->dominates)[i]; + if (shd_cfg_is_node_structural_target(c)) continue; old_children[children_count++] = c->node; } @@ -107,32 +109,33 @@ const Node* process_abstraction_body(Context* ctx, const Node* old, const Node* for (size_t i = 0; i < children_count; i++) { Nodes nargs; find_liftable_loop_values(ctx, old_children[i], &new_params[i], &lifted_params[i], &nargs); - Nodes nparams = recreate_variables(&ctx->rewriter, get_abstraction_params(old_children[i])); - new_children[i] = basic_block(a, nfn, concat_nodes(a, nparams, new_params[i]), get_abstraction_name(old_children[i])); - register_processed(&ctx->rewriter, old_children[i], new_children[i]); - register_processed_list(&ctx->rewriter, get_abstraction_params(old_children[i]), nparams); - insert_dict(const Node*, Nodes, ctx->lifted_arguments, old_children[i], nargs); + Nodes nparams = shd_recreate_params(&ctx->rewriter, get_abstraction_params(old_children[i])); + new_children[i] = basic_block(a, shd_concat_nodes(a, nparams, new_params[i]), shd_get_abstraction_name(old_children[i])); + shd_register_processed(&ctx->rewriter, old_children[i], new_children[i]); + shd_register_processed_list(&ctx->rewriter, get_abstraction_params(old_children[i]), nparams); + shd_dict_insert(const Node*, Nodes, ctx->lifted_arguments, old_children[i], nargs); } - const Node* new = rewrite_node(&ctx->rewriter, body); + const Node* new = shd_rewrite_node(&ctx->rewriter, body); - ctx->rewriter.map = clone_dict(ctx->rewriter.map); + ctx->rewriter.map = shd_clone_dict(ctx->rewriter.map); for (size_t i = 0; i < children_count; i++) { for (size_t j = 0; j < lifted_params[i].count; j++) { - remove_dict(const Node*, ctx->rewriter.map, lifted_params[i].nodes[j]); + shd_dict_remove(const Node*, ctx->rewriter.map, lifted_params[i].nodes[j]); } - register_processed_list(&ctx->rewriter, lifted_params[i], new_params[i]); + shd_register_processed_list(&ctx->rewriter, lifted_params[i], new_params[i]); new_children[i]->payload.basic_block.body = process_abstraction_body(ctx, old_children[i], get_abstraction_body(old_children[i])); } - destroy_dict(ctx->rewriter.map); + shd_destroy_dict(ctx->rewriter.map); return new; } -const Node* process_node(Context* ctx, const Node* old) { - IrArena* a = ctx->rewriter.dst_arena; +static const Node* process_node(Context* ctx, const Node* old) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (old->tag) { case NominalType_TAG: @@ -140,72 +143,65 @@ const Node* process_node(Context* ctx, const Node* old) { case Constant_TAG: { Context not_a_fn_ctx = *ctx; ctx = ¬_a_fn_ctx; - ctx->scope = NULL; - return recreate_node_identity(&ctx->rewriter, old); + ctx->cfg = NULL; + return shd_recreate_node(&ctx->rewriter, old); } case Function_TAG: { Context fn_ctx = *ctx; ctx = &fn_ctx; ctx->current_fn = old; - ctx->scope = new_scope(old); - ctx->scope_uses = create_uses_map(old, (NcDeclaration | NcType)); - ctx->loop_tree = build_loop_tree(ctx->scope); + ctx->cfg = build_fn_cfg(old); + ctx->scheduler = shd_new_scheduler(ctx->cfg); + ctx->loop_tree = shd_new_loop_tree(ctx->cfg); - Node* new = recreate_decl_header_identity(&ctx->rewriter, old); + Node* new = shd_recreate_node_head(&ctx->rewriter, old); new->payload.fun.body = process_abstraction_body(ctx, old, get_abstraction_body(old)); - destroy_loop_tree(ctx->loop_tree); - destroy_uses_map(ctx->scope_uses); - destroy_scope(ctx->scope); + shd_destroy_loop_tree(ctx->loop_tree); + shd_destroy_scheduler(ctx->scheduler); + shd_destroy_cfg(ctx->cfg); return new; } case Jump_TAG: { - Nodes nargs = rewrite_nodes(&ctx->rewriter, old->payload.jump.args); - Nodes* lifted_args = find_value_dict(const Node*, Nodes, ctx->lifted_arguments, old->payload.jump.target); + Jump payload = old->payload.jump; + Nodes nargs = shd_rewrite_nodes(&ctx->rewriter, old->payload.jump.args); + Nodes* lifted_args = shd_dict_find_value(const Node*, Nodes, ctx->lifted_arguments, old->payload.jump.target); if (lifted_args) { - nargs = concat_nodes(a, nargs, *lifted_args); + nargs = shd_concat_nodes(a, nargs, *lifted_args); } return jump(a, (Jump) { - .target = rewrite_node(&ctx->rewriter, old->payload.jump.target), - .args = nargs + .target = shd_rewrite_node(&ctx->rewriter, old->payload.jump.target), + .args = nargs, + .mem = shd_rewrite_node(r, payload.mem), }); } case BasicBlock_TAG: { assert(false); } - case Case_TAG: { - if (ctx->scope) { - Nodes nparams = recreate_variables(&ctx->rewriter, get_abstraction_params(old)); - register_processed_list(&ctx->rewriter, get_abstraction_params(old), nparams); - return case_(a, nparams, process_abstraction_body(ctx, old, get_abstraction_body(old))); - } - } default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); -Module* lcssa(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lcssa(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), .config = config, .current_fn = NULL, - .lifted_arguments = new_dict(const Node*, Nodes, (HashFn) hash_node, (CmpFn) compare_node) + .lifted_arguments = shd_new_dict(const Node*, Nodes, (HashFn) shd_hash_node, (CmpFn) shd_compare_node) }; - ctx.rewriter.config.fold_quote = false; - - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - destroy_dict(ctx.lifted_arguments); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_dict(ctx.lifted_arguments); return dst; } diff --git a/src/shady/passes/lift_everything.c b/src/shady/passes/lift_everything.c new file mode 100644 index 000000000..99ad3fc1c --- /dev/null +++ b/src/shady/passes/lift_everything.c @@ -0,0 +1,117 @@ +#include "shady/pass.h" + +#include "../ir_private.h" +#include "../analysis/cfg.h" +#include "../analysis/scheduler.h" +#include "../analysis/free_frontier.h" + +#include "dict.h" +#include "portability.h" + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +typedef struct { + Rewriter rewriter; + struct Dict* lift; + CFG* cfg; + Scheduler* scheduler; +} Context; + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (node->tag) { + case Function_TAG: { + Context fn_ctx = *ctx; + fn_ctx.cfg = build_fn_cfg(node); + fn_ctx.scheduler = shd_new_scheduler(fn_ctx.cfg); + + Node* new_fn = shd_recreate_node_head(r, node); + shd_recreate_node_body(&fn_ctx.rewriter, node, new_fn); + + shd_destroy_scheduler(fn_ctx.scheduler); + shd_destroy_cfg(fn_ctx.cfg); + return new_fn; + } + case BasicBlock_TAG: { + CFNode* n = shd_cfg_lookup(ctx->cfg, node); + if (shd_cfg_is_node_structural_target(n)) + break; + struct Dict* frontier = shd_free_frontier(ctx->scheduler, ctx->cfg, node); + // insert_dict(const Node*, Dict*, ctx->lift, node, frontier); + + Nodes additional_args = shd_empty(a); + Nodes new_params = shd_recreate_params(r, get_abstraction_params(node)); + shd_register_processed_list(r, get_abstraction_params(node), new_params); + size_t i = 0; + const Node* value; + + Context bb_ctx = *ctx; + bb_ctx.rewriter = shd_create_children_rewriter(&ctx->rewriter); + + while (shd_dict_iter(frontier, &i, &value, NULL)) { + if (is_value(value)) { + additional_args = shd_nodes_append(a, additional_args, value); + const Type* t = shd_rewrite_node(r, value->type); + const Node* p = param(a, t, NULL); + new_params = shd_nodes_append(a, new_params, p); + shd_register_processed(&bb_ctx.rewriter, value, p); + } + } + + shd_destroy_dict(frontier); + shd_dict_insert(const Node*, Nodes, ctx->lift, node, additional_args); + Node* new_bb = basic_block(a, new_params, shd_get_abstraction_name_unsafe(node)); + + Context* fn_ctx = ctx; + while (fn_ctx->rewriter.parent) { + Context* parent_ctx = (Context*) fn_ctx->rewriter.parent; + if (parent_ctx->cfg) + fn_ctx = parent_ctx; + else + break; + } + + shd_register_processed(&fn_ctx->rewriter, node, new_bb); + shd_set_abstraction_body(new_bb, shd_rewrite_node(&bb_ctx.rewriter, get_abstraction_body(node))); + shd_destroy_rewriter(&bb_ctx.rewriter); + return new_bb; + } + case Jump_TAG: { + Jump payload = node->payload.jump; + shd_rewrite_node(r, payload.target); + + Nodes* additional_args = shd_dict_find_value(const Node*, Nodes, ctx->lift, payload.target); + assert(additional_args); + return jump(a, (Jump) { + .mem = shd_rewrite_node(r, payload.mem), + .target = shd_rewrite_node(r, payload.target), + .args = shd_concat_nodes(a, shd_rewrite_nodes(r, payload.args), shd_rewrite_nodes(r, *additional_args)) + }); + } + default: break; + } + + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* shd_pass_lift_everything(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + bool todo = true; + Module* dst; + while (todo) { + todo = false; + dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .lift = shd_new_dict(const Node*, Nodes, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_dict(ctx.lift); + shd_destroy_rewriter(&ctx.rewriter); + src = dst; + } + return dst; +} diff --git a/src/shady/passes/lift_indirect_targets.c b/src/shady/passes/lift_indirect_targets.c index 138d9c0be..9fca55226 100644 --- a/src/shady/passes/lift_indirect_targets.c +++ b/src/shady/passes/lift_indirect_targets.c @@ -1,35 +1,40 @@ -#include "shady/ir.h" +#include "join_point_ops.h" -#include "log.h" -#include "portability.h" -#include "list.h" -#include "dict.h" -#include "util.h" +#include "shady/pass.h" +#include "shady/visit.h" +#include "shady/ir/stack.h" +#include "shady/ir/ext.h" -#include "../type.h" -#include "../rewrite.h" #include "../ir_private.h" -#include "../transform/ir_gen_helpers.h" -#include "../analysis/scope.h" -#include "../analysis/free_variables.h" +#include "../analysis/cfg.h" #include "../analysis/uses.h" #include "../analysis/leak.h" +#include "../analysis/verify.h" +#include "../analysis/scheduler.h" +#include "../analysis/free_frontier.h" + +#include "log.h" +#include "portability.h" +#include "list.h" +#include "dict.h" +#include "util.h" #include -#include -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); typedef struct Context_ { Rewriter rewriter; - Scope* scope; - const UsesMap* scope_uses; + CFG* cfg; + const UsesMap* uses; struct Dict* lifted; bool disable_lowering; const CompilerConfig* config; + + bool* todo; } Context; static const Node* process_node(Context* ctx, const Node* node); @@ -37,187 +42,231 @@ static const Node* process_node(Context* ctx, const Node* node); typedef struct { const Node* old_cont; const Node* lifted_fn; - struct List* save_values; + Nodes save_values; } LiftedCont; #pragma GCC diagnostic error "-Wswitch" -static const Node* add_spill_instrs(Context* ctx, BodyBuilder* builder, struct List* spilled_vars) { - IrArena* a = ctx->rewriter.dst_arena; - - size_t recover_context_size = entries_count_list(spilled_vars); - for (size_t i = 0; i < recover_context_size; i++) { - const Node* ovar = read_list(const Node*, spilled_vars)[i]; - const Node* nvar = rewrite_node(&ctx->rewriter, ovar); +static const Node* add_spill_instrs(Context* ctx, BodyBuilder* builder, Nodes spilled_vars) { + for (size_t i = 0; i < spilled_vars.count; i++) { + const Node* ovar = spilled_vars.nodes[i]; + const Node* nvar = shd_rewrite_node(&ctx->rewriter, ovar); const Type* t = nvar->type; - deconstruct_qualified_type(&t); - assert(t->tag != PtrType_TAG || is_physical_as(t->payload.ptr_type.address_space)); - const Node* save_instruction = prim_op(a, (PrimOp) { - .op = push_stack_op, - .type_arguments = singleton(get_unqualified_type(nvar->type)), - .operands = singleton(nvar), - }); - bind_instruction(builder, save_instruction); + shd_deconstruct_qualified_type(&t); + assert(t->tag != PtrType_TAG || !t->payload.ptr_type.is_reference && "References cannot be spilled"); + shd_bld_stack_push_value(builder, nvar); } - const Node* sp = gen_primop_ce(builder, get_stack_pointer_op, 0, NULL); + return shd_bld_get_stack_size(builder); +} - return sp; +static Nodes set2nodes(IrArena* a, struct Dict* set) { + size_t count = shd_dict_count(set); + LARRAY(const Node*, tmp, count); + size_t i = 0, j = 0; + const Node* key; + while (shd_dict_iter(set, &i, &key, NULL)) { + tmp[j++] = key; + } + assert(j == count); + return shd_nodes(a, count, tmp); } -static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name) { - assert(is_basic_block(cont) || is_case(cont)); - LiftedCont** found = find_value_dict(const Node*, LiftedCont*, ctx->lifted, cont); +static LiftedCont* lambda_lift(Context* ctx, CFG* cfg, const Node* liftee) { + assert(is_basic_block(liftee)); + LiftedCont** found = shd_dict_find_value(const Node*, LiftedCont*, ctx->lifted, liftee); if (found) return *found; IrArena* a = ctx->rewriter.dst_arena; - Nodes oparams = get_abstraction_params(cont); - const Node* obody = get_abstraction_body(cont); + const Node* obody = get_abstraction_body(liftee); + String name = shd_get_abstraction_name_safe(liftee); + + Scheduler* scheduler = shd_new_scheduler(cfg); + struct Dict* frontier_set = shd_free_frontier(scheduler, cfg, liftee); + Nodes frontier = set2nodes(a, frontier_set); + shd_destroy_dict(frontier_set); + + size_t recover_context_size = frontier.count; - String name = is_basic_block(cont) ? format_string_arena(a->arena, "%s_%s", get_abstraction_name(cont->payload.basic_block.fn), get_abstraction_name(cont)) : unique_name(a, given_name); + shd_destroy_scheduler(scheduler); - // Compute the live stuff we'll need - Scope* scope = new_scope(cont); - struct List* recover_context = compute_free_variables(scope, cont); - size_t recover_context_size = entries_count_list(recover_context); - destroy_scope(scope); + Context lifting_ctx = *ctx; + lifting_ctx.rewriter = shd_create_decl_rewriter(&ctx->rewriter); + Rewriter* r = &lifting_ctx.rewriter; - debugv_print("free (spilled) variables at '%s': ", name); + Nodes ovariables = get_abstraction_params(liftee); + shd_debugv_print("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): ", shd_get_abstraction_name_safe(liftee), recover_context_size); for (size_t i = 0; i < recover_context_size; i++) { - const Node* item = read_list(const Node*, recover_context)[i]; - debugv_print(get_value_name_safe(item)); + const Node* item = frontier.nodes[i]; + if (!is_value(item)) { + //lambda_lift() + continue; + } + shd_debugv_print("%%%d", item->id); if (i + 1 < recover_context_size) - debugv_print(", "); + shd_debugv_print(", "); } - debugv_print("\n"); + shd_debugv_print("\n"); // Create and register new parameters for the lifted continuation - Nodes new_params = recreate_variables(&ctx->rewriter, oparams); + LARRAY(const Node*, new_params_arr, ovariables.count); + for (size_t i = 0; i < ovariables.count; i++) + new_params_arr[i] = param(a, shd_rewrite_node(&ctx->rewriter, ovariables.nodes[i]->type), shd_get_value_name_unsafe(ovariables.nodes[i])); + Nodes new_params = shd_nodes(a, ovariables.count, new_params_arr); LiftedCont* lifted_cont = calloc(sizeof(LiftedCont), 1); - lifted_cont->old_cont = cont; - lifted_cont->save_values = recover_context; - insert_dict(const Node*, LiftedCont*, ctx->lifted, cont, lifted_cont); + lifted_cont->old_cont = liftee; + lifted_cont->save_values = frontier; + shd_dict_insert(const Node*, LiftedCont*, ctx->lifted, liftee, lifted_cont); - Context lifting_ctx = *ctx; - lifting_ctx.rewriter = create_rewriter(ctx->rewriter.src_module, ctx->rewriter.dst_module, (RewriteNodeFn) process_node); - register_processed_list(&lifting_ctx.rewriter, oparams, new_params); + shd_register_processed_list(r, ovariables, new_params); - const Node* payload = var(a, qualified_type_helper(uint32_type(a), false), "sp"); + const Node* payload = param(a, shd_as_qualified_type(shd_uint32_type(a), false), "sp"); // Keep annotations the same - Nodes annotations = nodes(a, 0, NULL); - new_params = prepend_nodes(a, new_params, payload); - Node* new_fn = function(ctx->rewriter.dst_module, new_params, name, annotations, nodes(a, 0, NULL)); + Nodes annotations = shd_singleton(annotation(a, (Annotation) { .name = "Exported" })); + new_params = shd_nodes_prepend(a, new_params, payload); + Node* new_fn = function(ctx->rewriter.dst_module, new_params, name, annotations, shd_nodes(a, 0, NULL)); lifted_cont->lifted_fn = new_fn; // Recover that stuff inside the new body - BodyBuilder* bb = begin_body(a); - gen_primop(bb, set_stack_pointer_op, empty(a), singleton(payload)); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new_fn)); + shd_bld_set_stack_size(bb, payload); for (size_t i = recover_context_size - 1; i < recover_context_size; i--) { - const Node* ovar = read_list(const Node*, recover_context)[i]; - assert(ovar->tag == Variable_TAG); + const Node* ovar = frontier.nodes[i]; + // assert(ovar->tag == Variable_TAG); - const Type* value_type = rewrite_node(&ctx->rewriter, ovar->type); + const Type* value_type = shd_rewrite_node(r, ovar->type); - const Node* recovered_value = first(bind_instruction_named(bb, prim_op(a, (PrimOp) { - .op = pop_stack_op, - .type_arguments = singleton(get_unqualified_type(value_type)) - }), &ovar->payload.var.name)); + //String param_name = get_value_name_unsafe(ovar); + const Node* recovered_value = shd_bld_stack_pop_value(bb, shd_get_unqualified_type(value_type)); + //if (param_name) + // set_value_name(recovered_value, param_name); - if (is_qualified_type_uniform(ovar->type)) - recovered_value = first(bind_instruction_named(bb, prim_op(a, (PrimOp) { .op = subgroup_broadcast_first_op, .operands = singleton(recovered_value) }), &ovar->payload.var.name)); + if (shd_is_qualified_type_uniform(ovar->type)) + recovered_value = prim_op(a, (PrimOp) { .op = subgroup_assume_uniform_op, .operands = shd_singleton(recovered_value) }); - register_processed(&lifting_ctx.rewriter, ovar, recovered_value); + shd_register_processed(r, ovar, recovered_value); } - const Node* substituted = rewrite_node(&lifting_ctx.rewriter, obody); - //destroy_dict(lifting_ctx.rewriter.processed); - destroy_rewriter(&lifting_ctx.rewriter); + shd_register_processed(r, shd_get_abstraction_mem(liftee), shd_bb_mem(bb)); + shd_register_processed(r, liftee, new_fn); + const Node* substituted = shd_rewrite_node(r, obody); + shd_destroy_rewriter(r); assert(is_terminator(substituted)); - new_fn->payload.fun.body = finish_body(bb, substituted); + shd_set_abstraction_body(new_fn, shd_bld_finish(bb, substituted)); return lifted_cont; } static const Node* process_node(Context* ctx, const Node* node) { - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - // TODO: share this code - if (is_declaration(node)) { - String name = get_decl_name(node); - Nodes decls = get_module_declarations(ctx->rewriter.dst_module); - for (size_t i = 0; i < decls.count; i++) { - if (strcmp(get_decl_name(decls.nodes[i]), name) == 0) - return decls.nodes[i]; - } - } - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; - if (ctx->disable_lowering) - return recreate_node_identity(&ctx->rewriter, node); - - switch (node->tag) { + switch (is_declaration(node)) { case Function_TAG: { + while (ctx->rewriter.parent) + ctx = (Context*) ctx->rewriter.parent; + Context fn_ctx = *ctx; - fn_ctx.scope = new_scope(node); - fn_ctx.scope_uses = create_uses_map(node, (NcDeclaration | NcType)); + fn_ctx.cfg = build_fn_cfg(node); + fn_ctx.uses = shd_new_uses_map_fn(node, (NcDeclaration | NcType)); + fn_ctx.disable_lowering = shd_lookup_annotation(node, "Internal"); ctx = &fn_ctx; - Node* new = recreate_decl_header_identity(&ctx->rewriter, node); - recreate_decl_body_identity(&ctx->rewriter, node, new); + Node* new = shd_recreate_node_head(&ctx->rewriter, node); + shd_recreate_node_body(&ctx->rewriter, node, new); - destroy_uses_map(ctx->scope_uses); - destroy_scope(ctx->scope); + shd_destroy_uses_map(ctx->uses); + shd_destroy_cfg(ctx->cfg); return new; } - case Let_TAG: { - const Node* oinstruction = get_let_instruction(node); - if (oinstruction->tag == Control_TAG) { - const Node* oinside = oinstruction->payload.control.inside; - assert(is_case(oinside)); - if (!is_control_static(ctx->scope_uses, oinstruction) || ctx->config->hacks.force_join_point_lifting) { - const Node* otail = get_let_tail(node); - BodyBuilder* bb = begin_body(a); - LiftedCont* lifted_tail = lambda_lift(ctx, otail, unique_name(a, format_string_arena(a->arena, "post_control_%s", get_abstraction_name(ctx->scope->entry->node)))); - const Node* sp = add_spill_instrs(ctx, bb, lifted_tail->save_values); - const Node* tail_ptr = fn_addr_helper(a, lifted_tail->lifted_fn); - - const Node* jp = gen_primop_e(bb, create_joint_point_op, rewrite_nodes(&ctx->rewriter, oinstruction->payload.control.yield_types), mk_nodes(a, tail_ptr, sp)); - - return finish_body(bb, let(a, quote_helper(a, singleton(jp)), rewrite_node(&ctx->rewriter, oinside))); - } - } + default: + break; + } - return recreate_node_identity(&ctx->rewriter, node); + if (ctx->disable_lowering) + return shd_recreate_node(&ctx->rewriter, node); + + switch (node->tag) { + case Control_TAG: { + const Node* oinside = node->payload.control.inside; + if (!shd_is_control_static(ctx->uses, node) || ctx->config->hacks.force_join_point_lifting) { + *ctx->todo = true; + + const Node* otail = get_structured_construct_tail(node); + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, node->payload.control.mem)); + LiftedCont* lifted_tail = lambda_lift(ctx, ctx->cfg, otail); + const Node* sp = add_spill_instrs(ctx, bb, lifted_tail->save_values); + const Node* tail_ptr = fn_addr_helper(a, lifted_tail->lifted_fn); + + const Type* jp_type = join_point_type(a, (JoinPointType) { + .yield_types = shd_rewrite_nodes(&ctx->rewriter, node->payload.control.yield_types), + }); + const Node* jp = shd_bld_ext_instruction(bb, "shady.internal", ShadyOpCreateJoinPoint, + shd_as_qualified_type(jp_type, true), mk_nodes(a, tail_ptr, sp)); + // dumbass hack + jp = prim_op_helper(a, subgroup_assume_uniform_op, shd_empty(a), shd_singleton(jp)); + + shd_register_processed(r, shd_first(get_abstraction_params(oinside)), jp); + shd_register_processed(r, shd_get_abstraction_mem(oinside), shd_bb_mem(bb)); + shd_register_processed(r, oinside, NULL); + return shd_bld_finish(bb, shd_rewrite_node(&ctx->rewriter, get_abstraction_body(oinside))); + } + break; } - default: return recreate_node_identity(&ctx->rewriter, node); + default: break; } + return shd_recreate_node(&ctx->rewriter, node); } -Module* lift_indirect_targets(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node), - .lifted = new_dict(const Node*, LiftedCont*, (HashFn) hash_node, (CmpFn) compare_node), - .config = config, - }; - - rewrite_module(&ctx.rewriter); - - size_t iter = 0; - LiftedCont* lifted_cont; - while (dict_iter(ctx.lifted, &iter, NULL, &lifted_cont)) { - destroy_list(lifted_cont->save_values); - free(lifted_cont); +Module* shd_pass_lift_indirect_targets(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = NULL; + Module* dst; + + int round = 0; + while (true) { + shd_debugv_print("lift_indirect_target: round %d\n", round++); + IrArena* oa = a; + a = shd_new_ir_arena(&aconfig); + dst = shd_new_module(a, shd_module_get_name(src)); + bool todo = false; + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), + .lifted = shd_new_dict(const Node*, LiftedCont*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .config = config, + + .todo = &todo + }; + + shd_rewrite_module(&ctx.rewriter); + + size_t iter = 0; + LiftedCont* lifted_cont; + while (shd_dict_iter(ctx.lifted, &iter, NULL, &lifted_cont)) { + free(lifted_cont); + } + shd_destroy_dict(ctx.lifted); + shd_destroy_rewriter(&ctx.rewriter); + shd_verify_module(config, dst); + src = dst; + if (oa) + shd_destroy_ir_arena(oa); + if (!todo) { + break; + } } - destroy_dict(ctx.lifted); - destroy_rewriter(&ctx.rewriter); + + // this will be safe now since we won't lift any more code after this pass + aconfig.optimisations.weaken_non_leaking_allocas = true; + IrArena* a2 = shd_new_ir_arena(&aconfig); + dst = shd_new_module(a2, shd_module_get_name(src)); + Rewriter r = shd_create_importer(src, dst); + shd_rewrite_module(&r); + shd_destroy_rewriter(&r); + shd_destroy_ir_arena(a); return dst; } diff --git a/src/shady/passes/lower_alloca.c b/src/shady/passes/lower_alloca.c new file mode 100644 index 000000000..df5ab2206 --- /dev/null +++ b/src/shady/passes/lower_alloca.c @@ -0,0 +1,178 @@ +#include "shady/pass.h" +#include "shady/visit.h" +#include "shady/ir/stack.h" +#include "shady/ir/cast.h" + +#include "../ir_private.h" + +#include "log.h" +#include "portability.h" +#include "list.h" +#include "dict.h" +#include "util.h" + +#include + +typedef struct Context_ { + Rewriter rewriter; + bool disable_lowering; + + const CompilerConfig* config; + struct Dict* prepared_offsets; + const Node* base_stack_addr_on_entry; + const Node* stack_size_on_entry; + size_t num_slots; + const Node* frame_size; + + const Type* stack_ptr_t; +} Context; + +typedef struct { + Visitor visitor; + Context* context; + BodyBuilder* bb; + Node* nom_t; + size_t num_slots; + struct List* members; + struct Dict* prepared_offsets; +} VContext; + +typedef struct { + size_t i; + const Node* offset; + const Type* type; + AddressSpace as; +} StackSlot; + +static void search_operand_for_alloca(VContext* vctx, const Node* node) { + IrArena* a = vctx->context->rewriter.dst_arena; + switch (node->tag) { + case StackAlloc_TAG: { + StackSlot* found = shd_dict_find_value(const Node*, StackSlot, vctx->prepared_offsets, node); + if (found) + break; + + const Type* element_type = shd_rewrite_node(&vctx->context->rewriter, node->payload.stack_alloc.type); + assert(shd_is_data_type(element_type)); + const Node* slot_offset = prim_op_helper(a, offset_of_op, shd_singleton(type_decl_ref_helper(a, vctx->nom_t)), shd_singleton(shd_int32_literal(a, shd_list_count(vctx->members)))); + shd_list_append(const Type*, vctx->members, element_type); + + StackSlot slot = { vctx->num_slots, slot_offset, element_type, AsPrivate }; + shd_dict_insert(const Node*, StackSlot, vctx->prepared_offsets, node, slot); + + vctx->num_slots++; + break; + } + default: break; + } + + shd_visit_node_operands(&vctx->visitor, ~NcMem, node); +} + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + Module* m = r->dst_module; + switch (node->tag) { + case Function_TAG: { + Node* fun = shd_recreate_node_head(&ctx->rewriter, node); + if (!node->payload.fun.body) + return fun; + + Context ctx2 = *ctx; + ctx2.disable_lowering = shd_lookup_annotation_with_string_payload(node, "DisablePass", "setup_stack_frames") || ctx->config->per_thread_stack_size == 0; + if (ctx2.disable_lowering) { + shd_set_abstraction_body(fun, shd_rewrite_node(&ctx2.rewriter, node->payload.fun.body)); + return fun; + } + + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); + ctx2.prepared_offsets = shd_new_dict(const Node*, StackSlot, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); + ctx2.base_stack_addr_on_entry = shd_bld_get_stack_base_addr(bb); + ctx2.stack_size_on_entry = shd_bld_get_stack_size(bb); + shd_set_value_name((Node*) ctx2.stack_size_on_entry, "stack_size_before_alloca"); + + Node* nom_t = nominal_type(m, shd_empty(a), shd_format_string_arena(a->arena, "%s_stack_frame", shd_get_abstraction_name(node))); + VContext vctx = { + .visitor = { + .visit_node_fn = (VisitNodeFn) search_operand_for_alloca, + }, + .context = &ctx2, + .bb = bb, + .nom_t = nom_t, + .num_slots = 0, + .members = shd_new_list(const Node*), + .prepared_offsets = ctx2.prepared_offsets, + }; + shd_visit_function_bodies_rpo(&vctx.visitor, node); + + vctx.nom_t->payload.nom_type.body = record_type(a, (RecordType) { + .members = shd_nodes(a, vctx.num_slots, shd_read_list(const Node*, vctx.members)), + .names = shd_strings(a, 0, NULL), + .special = 0 + }); + shd_destroy_list(vctx.members); + ctx2.num_slots = vctx.num_slots; + ctx2.frame_size = prim_op_helper(a, size_of_op, shd_singleton(type_decl_ref_helper(a, vctx.nom_t)), shd_empty(a)); + ctx2.frame_size = shd_bld_convert_int_extend_according_to_src_t(bb, ctx->stack_ptr_t, ctx2.frame_size); + + // make sure to use the new mem from then on + shd_register_processed(r, shd_get_abstraction_mem(node), shd_bb_mem(bb)); + shd_set_abstraction_body(fun, shd_bld_finish(bb, shd_rewrite_node(&ctx2.rewriter, get_abstraction_body(node)))); + + shd_destroy_dict(ctx2.prepared_offsets); + return fun; + } + case StackAlloc_TAG: { + if (!ctx->disable_lowering) { + StackSlot* found_slot = shd_dict_find_value(const Node*, StackSlot, ctx->prepared_offsets, node); + if (!found_slot) { + shd_error_print("lower_alloca: failed to find a stack offset for "); + shd_log_node(ERROR, node); + shd_error_print(", most likely this means this alloca was not found in the shd_first block of a function.\n"); + shd_log_module(DEBUG, ctx->config, ctx->rewriter.src_module); + shd_error_die(); + } + + BodyBuilder* bb = shd_bld_begin_pseudo_instr(a, shd_rewrite_node(r, node->payload.stack_alloc.mem)); + if (!ctx->stack_size_on_entry) { + //String tmp_name = format_string_arena(a->arena, "stack_ptr_before_alloca_%s", get_abstraction_name(fun)); + assert(false); + } + + //const Node* lea_instr = prim_op_helper(a, lea_op, empty(a), mk_nodes(a, rewrite_node(&ctx->rewriter, first(node->payload.prim_op.operands)), found_slot->offset)); + const Node* converted_offset = shd_bld_convert_int_extend_according_to_dst_t(bb, ctx->stack_ptr_t, found_slot->offset); + const Node* slot = ptr_array_element_offset(a, (PtrArrayElementOffset) { .ptr = ctx->base_stack_addr_on_entry, .offset = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, ctx->stack_size_on_entry, converted_offset)) }); + const Node* ptr_t = ptr_type(a, (PtrType) { .pointed_type = found_slot->type, .address_space = found_slot->as }); + slot = shd_bld_reinterpret_cast(bb, ptr_t, slot); + //bool last = found_slot->i == ctx->num_slots - 1; + //if (last) { + const Node* updated_stack_ptr = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, ctx->stack_size_on_entry, ctx->frame_size)); + shd_bld_set_stack_size(bb, updated_stack_ptr); + //} + + return shd_bld_to_instr_yield_values(bb, shd_singleton(slot)); + } + break; + } + default: break; + } + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* shd_pass_lower_alloca(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config, + .stack_ptr_t = int_type(a, (Int) { .is_signed = false, .width = IntTy32 }), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/shady/passes/lower_callf.c b/src/shady/passes/lower_callf.c index 0a6c005b0..2abed36fe 100644 --- a/src/shady/passes/lower_callf.c +++ b/src/shady/passes/lower_callf.c @@ -1,12 +1,13 @@ -#include "passes.h" +#include "join_point_ops.h" + +#include "shady/pass.h" +#include "shady/ir/ext.h" +#include "shady/ir/annotation.h" +#include "shady/ir/function.h" -#include "../rewrite.h" -#include "../type.h" #include "log.h" #include "portability.h" -#include "../transform/ir_gen_helpers.h" - #include typedef uint32_t FnPtr; @@ -15,69 +16,87 @@ typedef struct Context_ { Rewriter rewriter; bool disable_lowering; - Node* self; const Node* return_jp; } Context; static const Node* lower_callf_process(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; IrArena* a = ctx->rewriter.dst_arena; Module* m = ctx->rewriter.dst_module; + Rewriter* r = &ctx->rewriter; if (old->tag == Function_TAG) { Context ctx2 = *ctx; - ctx2.disable_lowering = lookup_annotation(old, "Leaf"); + ctx2.disable_lowering = shd_lookup_annotation(old, "Leaf"); ctx2.return_jp = NULL; - Node* fun = NULL; - BodyBuilder* bb = begin_body(a); - if (!ctx2.disable_lowering) { + if (!ctx2.disable_lowering && get_abstraction_body(old)) { Nodes oparams = get_abstraction_params(old); - Nodes nparams = recreate_variables(&ctx->rewriter, oparams); - register_processed_list(&ctx->rewriter, oparams, nparams); + Nodes nparams = shd_recreate_params(&ctx->rewriter, oparams); + shd_register_processed_list(&ctx->rewriter, oparams, nparams); + + Nodes nannots = shd_rewrite_nodes(&ctx->rewriter, old->payload.fun.annotations); + + Node* prelude = case_(a, shd_empty(a)); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(prelude)); // Supplement an additional parameter for the join point const Type* jp_type = join_point_type(a, (JoinPointType) { - .yield_types = strip_qualifiers(a, rewrite_nodes(&ctx->rewriter, old->payload.fun.return_types)) + .yield_types = shd_strip_qualifiers(a, shd_rewrite_nodes(&ctx->rewriter, old->payload.fun.return_types)) }); - if (lookup_annotation_list(old->payload.fun.annotations, "EntryPoint")) { - ctx2.return_jp = gen_primop_e(bb, default_join_point_op, empty(a), empty(a)); + if (shd_lookup_annotation_list(old->payload.fun.annotations, "EntryPoint")) { + ctx2.return_jp = shd_bld_ext_instruction(bb, "shady.internal", ShadyOpDefaultJoinPoint, + shd_as_qualified_type(jp_type, true), shd_empty(a)); } else { - const Node* jp_variable = var(a, qualified_type_helper(jp_type, false), "return_jp"); - nparams = append_nodes(a, nparams, jp_variable); + const Node* jp_variable = param(a, shd_as_qualified_type(jp_type, false), "return_jp"); + nparams = shd_nodes_append(a, nparams, jp_variable); ctx2.return_jp = jp_variable; } - Nodes nannots = rewrite_nodes(&ctx->rewriter, old->payload.fun.annotations); - fun = function(ctx->rewriter.dst_module, nparams, get_abstraction_name(old), nannots, empty(a)); - ctx2.self = fun; - register_processed(&ctx->rewriter, old, fun); - } else - fun = recreate_decl_header_identity(&ctx->rewriter, old); + Node* fun = function(ctx->rewriter.dst_module, nparams, shd_get_abstraction_name(old), nannots, shd_empty(a)); + shd_register_processed(&ctx->rewriter, old, fun); + + shd_register_processed(&ctx2.rewriter, shd_get_abstraction_mem(old), shd_bb_mem(bb)); + shd_set_abstraction_body(prelude, shd_bld_finish(bb, shd_rewrite_node(&ctx2.rewriter, old->payload.fun.body))); + shd_set_abstraction_body(fun, jump_helper(a, shd_get_abstraction_mem(fun), prelude, shd_empty(a))); + return fun; + } + + Node* fun = shd_recreate_node_head(&ctx->rewriter, old); if (old->payload.fun.body) - fun->payload.fun.body = finish_body(bb, rewrite_node(&ctx2.rewriter, old->payload.fun.body)); - else - cancel_body(bb); + shd_set_abstraction_body(fun, shd_rewrite_node(&ctx2.rewriter, old->payload.fun.body)); return fun; } if (ctx->disable_lowering) - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); switch (old->tag) { + case FnType_TAG: { + Nodes param_types = shd_rewrite_nodes(r, old->payload.fn_type.param_types); + Nodes returned_types = shd_rewrite_nodes(&ctx->rewriter, old->payload.fn_type.return_types); + const Type* jp_type = qualified_type(a, (QualifiedType) { + .type = join_point_type(a, (JoinPointType) { .yield_types = shd_strip_qualifiers(a, returned_types) }), + .is_uniform = false + }); + param_types = shd_nodes_append(a, param_types, jp_type); + return fn_type(a, (FnType) { + .param_types = param_types, + .return_types = shd_empty(a), + }); + } case Return_TAG: { - Nodes nargs = rewrite_nodes(&ctx->rewriter, old->payload.fn_ret.args); + Nodes nargs = shd_rewrite_nodes(&ctx->rewriter, old->payload.fn_ret.args); const Node* return_jp = ctx->return_jp; if (return_jp) { - BodyBuilder* bb = begin_body(a); - return_jp = gen_primop_ce(bb, subgroup_broadcast_first_op, 1, (const Node* []) {return_jp}); + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, old->payload.fn_ret.mem)); + return_jp = prim_op_helper(a, subgroup_assume_uniform_op, shd_empty(a), shd_singleton(return_jp)); // Join up at the return address instead of returning - return finish_body(bb, join(a, (Join) { - .join_point = return_jp, - .args = nargs, + return shd_bld_finish(bb, join(a, (Join) { + .join_point = return_jp, + .args = nargs, + .mem = shd_bb_mem(bb), })); } else { assert(false); @@ -86,53 +105,57 @@ static const Node* lower_callf_process(Context* ctx, const Node* old) { // we convert calls to tail-calls within a control - only if the // call_indirect(...) to control(jp => save(jp); tailcall(...)) case Call_TAG: { - const Node* ocallee = old->payload.call.callee; + Call payload = old->payload.call; + const Node* ocallee = payload.callee; // if we know the callee and it's a leaf - then we don't change the call - if (ocallee->tag == FnAddr_TAG && lookup_annotation(ocallee->payload.fn_addr.fn, "Leaf")) + if (ocallee->tag == FnAddr_TAG && shd_lookup_annotation(ocallee->payload.fn_addr.fn, "Leaf")) break; const Type* ocallee_type = ocallee->type; - bool callee_uniform = deconstruct_qualified_type(&ocallee_type); - ocallee_type = get_pointee_type(a, ocallee_type); + bool callee_uniform = shd_deconstruct_qualified_type(&ocallee_type); + ocallee_type = shd_get_pointee_type(a, ocallee_type); assert(ocallee_type->tag == FnType_TAG); - Nodes returned_types = rewrite_nodes(&ctx->rewriter, ocallee_type->payload.fn_type.return_types); + Nodes returned_types = shd_rewrite_nodes(&ctx->rewriter, ocallee_type->payload.fn_type.return_types); // Rewrite the callee and its arguments - const Node* ncallee = rewrite_node(&ctx->rewriter, ocallee); - Nodes nargs = rewrite_nodes(&ctx->rewriter, old->payload.call.args); + const Node* ncallee = shd_rewrite_node(&ctx->rewriter, ocallee); + Nodes nargs = shd_rewrite_nodes(&ctx->rewriter, payload.args); // Create the body of the control that receives the appropriately typed join point const Type* jp_type = qualified_type(a, (QualifiedType) { - .type = join_point_type(a, (JoinPointType) { .yield_types = strip_qualifiers(a, returned_types) }), + .type = join_point_type(a, (JoinPointType) { .yield_types = shd_strip_qualifiers(a, returned_types) }), .is_uniform = false }); - const Node* jp = var(a, jp_type, "fn_return_point"); + const Node* jp = param(a, jp_type, "fn_return_point"); // Add that join point as the last argument to the newly made function - nargs = append_nodes(a, nargs, jp); + nargs = shd_nodes_append(a, nargs, jp); // the body of the control is just an immediate tail-call + Node* control_case = case_(a, shd_singleton(jp)); const Node* control_body = tail_call(a, (TailCall) { - .target = ncallee, + .callee = ncallee, .args = nargs, + .mem = shd_get_abstraction_mem(control_case), }); - const Node* control_lam = case_(a, nodes(a, 1, (const Node* []) {jp}), control_body); - return control(a, (Control) { .yield_types = strip_qualifiers(a, returned_types), .inside = control_lam }); + shd_set_abstraction_body(control_case, control_body); + BodyBuilder* bb = shd_bld_begin_pseudo_instr(a, shd_rewrite_node(r, payload.mem)); + return shd_bld_to_instr_yield_values(bb, shd_bld_control(bb, shd_strip_qualifiers(a, returned_types), control_case)); } default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -Module* lower_callf(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_callf(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) lower_callf_process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) lower_callf_process), .disable_lowering = false, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_cf_instrs.c b/src/shady/passes/lower_cf_instrs.c index 4326627ca..86def4543 100644 --- a/src/shady/passes/lower_cf_instrs.c +++ b/src/shady/passes/lower_cf_instrs.c @@ -1,13 +1,11 @@ -#include "passes.h" +#include "shady/pass.h" + +#include "../analysis/cfg.h" #include "log.h" #include "portability.h" #include "dict.h" -#include "../type.h" -#include "../rewrite.h" -#include "../analysis/scope.h" - #include typedef struct Context_ { @@ -16,66 +14,78 @@ typedef struct Context_ { Node* current_fn; struct Dict* structured_join_tokens; - Scope* scope; - const Node* abs; + CFG* cfg; } Context; -static const Node* process_node(Context* ctx, const Node* node); +static const Node* process_node(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; -static const Node* process_let(Context* ctx, const Node* node) { - assert(node->tag == Let_TAG); - IrArena* a = ctx->rewriter.dst_arena; + Context sub_ctx = *ctx; + if (node->tag == Function_TAG) { + Node* fun = shd_recreate_node_head(&ctx->rewriter, node); + sub_ctx.disable_lowering = shd_lookup_annotation(fun, "Structured"); + sub_ctx.current_fn = fun; + sub_ctx.cfg = build_fn_cfg(node); + shd_set_abstraction_body(fun, shd_rewrite_node(&sub_ctx.rewriter, node->payload.fun.body)); + shd_destroy_cfg(sub_ctx.cfg); + return fun; + } else if (node->tag == Constant_TAG) { + sub_ctx.cfg = NULL; + sub_ctx.current_fn = NULL; + ctx = &sub_ctx; + } - const Node* old_instruction = node->payload.let.instruction; - const Node* new_instruction = NULL; - const Node* old_tail = node->payload.let.tail; - const Node* new_tail = NULL; + if (ctx->disable_lowering) + return shd_recreate_node(&ctx->rewriter, node); - switch (old_instruction->tag) { + switch (node->tag) { case If_TAG: { - bool has_false_branch = old_instruction->payload.if_instr.if_false; - Nodes yield_types = rewrite_nodes(&ctx->rewriter, old_instruction->payload.if_instr.yield_types); + If payload = node->payload.if_instr; + bool has_false_branch = payload.if_false; + Nodes yield_types = shd_rewrite_nodes(&ctx->rewriter, node->payload.if_instr.yield_types); + const Node* nmem = shd_rewrite_node(r, node->payload.if_instr.mem); const Type* jp_type = qualified_type(a, (QualifiedType) { .type = join_point_type(a, (JoinPointType) { .yield_types = yield_types }), .is_uniform = false, }); - const Node* join_point = var(a, jp_type, "if_join"); - Context join_context = *ctx; - Nodes jps = singleton(join_point); - insert_dict(const Node*, Nodes, ctx->structured_join_tokens, old_instruction, jps); + const Node* jp = param(a, jp_type, "if_join"); + Nodes jps = shd_singleton(jp); + shd_dict_insert(const Node*, Nodes, ctx->structured_join_tokens, node, jps); - Node* true_block = basic_block(a, ctx->current_fn, nodes(a, 0, NULL), unique_name(a, "if_true")); - join_context.abs = old_instruction->payload.if_instr.if_true; - true_block->payload.basic_block.body = rewrite_node(&join_context.rewriter, old_instruction->payload.if_instr.if_true->payload.case_.body); + const Node* true_block = shd_rewrite_node(r, payload.if_true); - Node* flse_block = basic_block(a, ctx->current_fn, nodes(a, 0, NULL), unique_name(a, "if_false")); + const Node* false_block; if (has_false_branch) { - join_context.abs = old_instruction->payload.if_instr.if_false; - flse_block->payload.basic_block.body = rewrite_node(&join_context.rewriter, old_instruction->payload.if_instr.if_false->payload.case_.body); + false_block = shd_rewrite_node(r, payload.if_false); } else { assert(yield_types.count == 0); - flse_block->payload.basic_block.body = join(a, (Join) { .join_point = join_point, .args = nodes(a, 0, NULL) }); + false_block = basic_block(a, shd_nodes(a, 0, NULL), shd_make_unique_name(a, "if_false")); + shd_set_abstraction_body((Node*) false_block, join(a, (Join) { .join_point = jp, .args = shd_nodes(a, 0, NULL), .mem = shd_get_abstraction_mem(false_block) })); } + Node* control_case = basic_block(a, shd_singleton(jp), NULL); const Node* control_body = branch(a, (Branch) { - .branch_condition = rewrite_node(&ctx->rewriter, old_instruction->payload.if_instr.condition), - .true_jump = jump_helper(a, true_block, empty(a)), - .false_jump = jump_helper(a, flse_block, empty(a)), + .condition = shd_rewrite_node(r, node->payload.if_instr.condition), + .true_jump = jump_helper(a, shd_get_abstraction_mem(control_case), true_block, shd_empty(a)), + .false_jump = jump_helper(a, shd_get_abstraction_mem(control_case), false_block, shd_empty(a)), + .mem = shd_get_abstraction_mem(control_case), }); + shd_set_abstraction_body(control_case, control_body); - const Node* control_lam = case_(a, nodes(a, 1, (const Node* []) {join_point}), control_body); - new_instruction = control(a, (Control) { .yield_types = yield_types, .inside = control_lam }); - break; + BodyBuilder* bb = shd_bld_begin(a, nmem); + Nodes results = shd_bld_control(bb, yield_types, control_case); + return shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), shd_rewrite_node(r, payload.tail), results)); } // TODO: match case Loop_TAG: { - const Node* old_loop_body = old_instruction->payload.loop_instr.body; - assert(is_case(old_loop_body)); + Loop payload = node->payload.loop_instr; + const Node* old_loop_block = payload.body; - Nodes yield_types = rewrite_nodes(&ctx->rewriter, old_instruction->payload.loop_instr.yield_types); - Nodes param_types = rewrite_nodes(&ctx->rewriter, get_variables_types(a, old_loop_body->payload.case_.params)); - param_types = strip_qualifiers(a, param_types); + Nodes yield_types = shd_rewrite_nodes(&ctx->rewriter, node->payload.loop_instr.yield_types); + Nodes param_types = shd_rewrite_nodes(&ctx->rewriter, shd_get_param_types(a, get_abstraction_params(old_loop_block))); + param_types = shd_strip_qualifiers(a, param_types); const Type* break_jp_type = qualified_type(a, (QualifiedType) { .type = join_point_type(a, (JoinPointType) { .yield_types = yield_types }), @@ -85,196 +95,155 @@ static const Node* process_let(Context* ctx, const Node* node) { .type = join_point_type(a, (JoinPointType) { .yield_types = param_types }), .is_uniform = false, }); - const Node* break_point = var(a, break_jp_type, "loop_break_point"); - const Node* continue_point = var(a, continue_jp_type, "loop_continue_point"); - Context join_context = *ctx; + const Node* break_point = param(a, break_jp_type, "loop_break_point"); + const Node* continue_point = param(a, continue_jp_type, "loop_continue_point"); Nodes jps = mk_nodes(a, break_point, continue_point); - insert_dict(const Node*, Nodes, ctx->structured_join_tokens, old_instruction, jps); - - Nodes new_params = recreate_variables(&ctx->rewriter, old_loop_body->payload.case_.params); - Node* loop_body = basic_block(a, ctx->current_fn, new_params, unique_name(a, "loop_body")); - register_processed_list(&join_context.rewriter, old_loop_body->payload.case_.params, loop_body->payload.basic_block.params); + shd_dict_insert(const Node*, Nodes, ctx->structured_join_tokens, node, jps); - join_context.abs = old_loop_body; - const Node* inner_control_body = rewrite_node(&join_context.rewriter, old_loop_body->payload.case_.body); - const Node* inner_control_lam = case_(a, nodes(a, 1, (const Node* []) {continue_point}), inner_control_body); + Nodes new_params = shd_recreate_params(&ctx->rewriter, get_abstraction_params(old_loop_block)); + Node* loop_header_block = basic_block(a, new_params, shd_make_unique_name(a, "loop_header")); - BodyBuilder* bb = begin_body(a); - const Node* inner_control = control(a, (Control) { - .yield_types = param_types, - .inside = inner_control_lam, - }); - Nodes args = bind_instruction(bb, inner_control); + BodyBuilder* inner_bb = shd_bld_begin(a, shd_get_abstraction_mem(loop_header_block)); + Node* inner_control_case = case_(a, shd_singleton(continue_point)); + shd_set_abstraction_body(inner_control_case, jump_helper(a, shd_get_abstraction_mem(inner_control_case), + shd_rewrite_node(r, old_loop_block), new_params)); + Nodes args = shd_bld_control(inner_bb, param_types, inner_control_case); - // TODO let_in_block or use a Jump ! - loop_body->payload.basic_block.body = finish_body(bb, jump(a, (Jump) { .target = loop_body, .args = args })); + shd_set_abstraction_body(loop_header_block, shd_bld_finish(inner_bb, jump(a, (Jump) { .target = loop_header_block, .args = args, .mem = shd_bb_mem(inner_bb) }))); - const Node* initial_jump = jump(a, (Jump) { - .target = loop_body, - .args = rewrite_nodes(&ctx->rewriter, old_instruction->payload.loop_instr.initial_args), + Node* outer_control_case = case_(a, shd_singleton(break_point)); + const Node* first_iteration_jump = jump(a, (Jump) { + .target = loop_header_block, + .args = shd_rewrite_nodes(r, payload.initial_args), + .mem = shd_get_abstraction_mem(outer_control_case), }); - const Node* outer_body = case_(a, nodes(a, 1, (const Node* []) {break_point}), initial_jump); - new_instruction = control(a, (Control) { .yield_types = yield_types, .inside = outer_body }); - break; - } - default: - new_instruction = rewrite_node(&ctx->rewriter, old_instruction); - break; - } - - if (!new_tail) - new_tail = rewrite_node(&ctx->rewriter, old_tail); - - assert(new_instruction && new_tail); - return let(a, new_instruction, new_tail); -} - -static const Node* process_node(Context* ctx, const Node* node) { - const Node* already_done = search_processed(&ctx->rewriter, node); - if (already_done) - return already_done; - - IrArena* a = ctx->rewriter.dst_arena; - - Context sub_ctx = *ctx; - if (node->tag == Function_TAG) { - Node* fun = recreate_decl_header_identity(&ctx->rewriter, node); - sub_ctx.disable_lowering = lookup_annotation(fun, "Structured"); - sub_ctx.current_fn = fun; - sub_ctx.scope = new_scope(node); - sub_ctx.abs = node; - fun->payload.fun.body = rewrite_node(&sub_ctx.rewriter, node->payload.fun.body); - destroy_scope(sub_ctx.scope); - return fun; - } + shd_set_abstraction_body(outer_control_case, first_iteration_jump); - if (is_abstraction(node)) { - sub_ctx.abs = node; - ctx = &sub_ctx; - } - - if (ctx->disable_lowering) - return recreate_node_identity(&ctx->rewriter, node); - - CFNode* cfnode = ctx->scope ? scope_lookup(ctx->scope, ctx->abs) : NULL; - switch (node->tag) { - case Let_TAG: return process_let(ctx, node); - case Yield_TAG: { - if (!cfnode) - break; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + Nodes results = shd_bld_control(bb, yield_types, outer_control_case); + return shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), shd_rewrite_node(r, payload.tail), results)); + } + case MergeSelection_TAG: { + MergeSelection payload = node->payload.merge_selection; + const Node* root_mem = shd_get_original_mem(payload.mem); + assert(root_mem->tag == AbsMem_TAG); + CFNode* cfnode = shd_cfg_lookup(ctx->cfg, root_mem->payload.abs_mem.abs); CFNode* dom = cfnode->idom; const Node* selection_instr = NULL; while (dom) { const Node* body = get_abstraction_body(dom->node); - if (body->tag == Let_TAG) { - const Node* instr = get_let_instruction(body); - if (instr->tag == If_TAG || instr->tag == Match_TAG) { - selection_instr = instr; - break; - } + if(body->tag == If_TAG || body->tag == Match_TAG) { + selection_instr = body; + break; } dom = dom->idom; } if (!selection_instr) { - error_print("Scoping error: Failed to find a dominating selection construct for "); - log_node(ERROR, node); - error_print(".\n"); - error_die(); + shd_error_print("Scoping error: Failed to find a dominating selection construct for "); + shd_log_node(ERROR, node); + shd_error_print(".\n"); + shd_error_die(); } - Nodes* jps = find_value_dict(const Node*, Nodes, ctx->structured_join_tokens, selection_instr); + Nodes* jps = shd_dict_find_value(const Node*, Nodes, ctx->structured_join_tokens, selection_instr); assert(jps && jps->count == 1); - const Node* jp = first(*jps); + const Node* jp = shd_first(*jps); assert(jp); + const Node* nmem = shd_rewrite_node(r, payload.mem); return join(a, (Join) { .join_point = jp, - .args = rewrite_nodes(&ctx->rewriter, node->payload.yield.args), + .args = shd_rewrite_nodes(&ctx->rewriter, payload.args), + .mem = nmem }); } case MergeContinue_TAG: { - assert(cfnode); + MergeContinue payload = node->payload.merge_continue; + const Node* root_mem = shd_get_original_mem(payload.mem); + assert(root_mem->tag == AbsMem_TAG); + CFNode* cfnode = shd_cfg_lookup(ctx->cfg, root_mem->payload.abs_mem.abs); CFNode* dom = cfnode->idom; - const Node* selection_instr = NULL; + const Node* loop_start = NULL; while (dom) { const Node* body = get_abstraction_body(dom->node); - if (body->tag == Let_TAG) { - const Node* instr = get_let_instruction(body); - if (instr->tag == Loop_TAG) { - selection_instr = instr; - break; - } + if (body->tag == Loop_TAG) { + loop_start = body; + break; } dom = dom->idom; } - if (!selection_instr) { - error_print("Scoping error: Failed to find a dominating selection construct for "); - log_node(ERROR, node); - error_print(".\n"); - error_die(); + if (!loop_start) { + shd_error_print("Scoping error: Failed to find a dominating loop construct for "); + shd_log_node(ERROR, node); + shd_error_print(".\n"); + shd_error_die(); } - Nodes* jps = find_value_dict(const Node*, Nodes, ctx->structured_join_tokens, selection_instr); + Nodes* jps = shd_dict_find_value(const Node*, Nodes, ctx->structured_join_tokens, loop_start); assert(jps && jps->count == 2); const Node* jp = jps->nodes[1]; assert(jp); + const Node* nmem = shd_rewrite_node(r, payload.mem); return join(a, (Join) { .join_point = jp, - .args = rewrite_nodes(&ctx->rewriter, node->payload.merge_continue.args), + .args = shd_rewrite_nodes(&ctx->rewriter, payload.args), + .mem = nmem, }); } case MergeBreak_TAG: { - assert(cfnode); + MergeBreak payload = node->payload.merge_break; + const Node* root_mem = shd_get_original_mem(payload.mem); + assert(root_mem->tag == AbsMem_TAG); + CFNode* cfnode = shd_cfg_lookup(ctx->cfg, root_mem->payload.abs_mem.abs); CFNode* dom = cfnode->idom; - const Node* selection_instr = NULL; + const Node* loop_start = NULL; while (dom) { const Node* body = get_abstraction_body(dom->node); - if (body->tag == Let_TAG) { - const Node* instr = get_let_instruction(body); - if (instr->tag == Loop_TAG) { - selection_instr = instr; - break; - } + if (body->tag == Loop_TAG) { + loop_start = body; + break; } dom = dom->idom; } - if (!selection_instr) { - error_print("Scoping error: Failed to find a dominating selection construct for "); - log_node(ERROR, node); - error_print(".\n"); - error_die(); + if (!loop_start) { + shd_error_print("Scoping error: Failed to find a dominating loop construct for "); + shd_log_node(ERROR, node); + shd_error_print(".\n"); + shd_error_die(); } - Nodes* jps = find_value_dict(const Node*, Nodes, ctx->structured_join_tokens, selection_instr); + Nodes* jps = shd_dict_find_value(const Node*, Nodes, ctx->structured_join_tokens, loop_start); assert(jps && jps->count == 2); - const Node* jp = first(*jps); + const Node* jp = shd_first(*jps); assert(jp); + const Node* nmem = shd_rewrite_node(r, payload.mem); return join(a, (Join) { .join_point = jp, - .args = rewrite_nodes(&ctx->rewriter, node->payload.merge_break.args), + .args = shd_rewrite_nodes(&ctx->rewriter, payload.args), + .mem = nmem, }); } - default: return recreate_node_identity(&ctx->rewriter, node); + default: break; } + return shd_recreate_node(&ctx->rewriter, node); } -KeyHash hash_node(const Node**); -bool compare_node(const Node**, const Node**); +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node**, const Node**); -Module* lower_cf_instrs(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_cf_instrs(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node), - .structured_join_tokens = new_dict(const Node*, Nodes, (HashFn) hash_node, (CmpFn) compare_node), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), + .structured_join_tokens = shd_new_dict(const Node*, Nodes, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), }; - ctx.rewriter.config.fold_quote = false; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - destroy_dict(ctx.structured_join_tokens); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_dict(ctx.structured_join_tokens); return dst; } diff --git a/src/shady/passes/lower_decay_ptrs.c b/src/shady/passes/lower_decay_ptrs.c index 4f1b185c8..512bc3c3f 100644 --- a/src/shady/passes/lower_decay_ptrs.c +++ b/src/shady/passes/lower_decay_ptrs.c @@ -1,12 +1,4 @@ -#include "passes.h" - -#include "log.h" -#include "portability.h" - -#include "../ir_private.h" -#include "../type.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" +#include "shady/pass.h" typedef struct { Rewriter rewriter; @@ -14,9 +6,6 @@ typedef struct { } Context; static const Node* process(Context* ctx, const Node* node) { - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - IrArena* arena = ctx->rewriter.dst_arena; switch (node->tag) { @@ -24,7 +13,7 @@ static const Node* process(Context* ctx, const Node* node) { const Node* arr_t = node->payload.ptr_type.pointed_type; if (arr_t->tag == ArrType_TAG && !arr_t->payload.arr_type.size) { return ptr_type(arena, (PtrType) { - .pointed_type = rewrite_node(&ctx->rewriter, arr_t->payload.arr_type.element_type), + .pointed_type = shd_rewrite_node(&ctx->rewriter, arr_t->payload.arr_type.element_type), .address_space = node->payload.ptr_type.address_space, }); } @@ -34,18 +23,18 @@ static const Node* process(Context* ctx, const Node* node) { } rebuild: - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_decay_ptrs(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_decay_ptrs(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_entrypoint_args.c b/src/shady/passes/lower_entrypoint_args.c index 591bdce64..745237eab 100644 --- a/src/shady/passes/lower_entrypoint_args.c +++ b/src/shady/passes/lower_entrypoint_args.c @@ -1,14 +1,15 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/memory_layout.h" +#include "shady/ir/function.h" +#include "shady/ir/debug.h" +#include "shady/ir/decl.h" +#include "shady/ir/annotation.h" +#include "shady/ir/mem.h" #include "portability.h" #include "log.h" #include "util.h" -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" -#include "../transform/memory_layout.h" - typedef struct { Rewriter rewriter; const CompilerConfig* config; @@ -17,10 +18,10 @@ typedef struct { static Node* rewrite_entry_point_fun(Context* ctx, const Node* node) { IrArena* a = ctx->rewriter.dst_arena; - Nodes annotations = rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); - Node* fun = function(ctx->rewriter.dst_module, empty(a), node->payload.fun.name, annotations, empty(a)); + Nodes annotations = shd_rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); + Node* fun = function(ctx->rewriter.dst_module, shd_empty(a), node->payload.fun.name, annotations, shd_empty(a)); - register_processed(&ctx->rewriter, node, fun); + shd_register_processed(&ctx->rewriter, node, fun); return fun; } @@ -32,18 +33,18 @@ static const Node* generate_arg_struct_type(Rewriter* rewriter, Nodes params) { LARRAY(String, names, params.count); for (int i = 0; i < params.count; ++i) { - const Type* type = rewrite_node(rewriter, params.nodes[i]->type); + const Type* type = shd_rewrite_node(rewriter, params.nodes[i]->type); - if (!deconstruct_qualified_type(&type)) - error("EntryPoint parameters must be uniform"); + if (!shd_deconstruct_qualified_type(&type)) + shd_error("EntryPoint parameters must be uniform"); types[i] = type; - names[i] = get_value_name_safe(params.nodes[i]); + names[i] = shd_get_value_name_safe(params.nodes[i]); } return record_type(a, (RecordType) { - .members = nodes(a, params.count, types), - .names = strings(a, params.count, names) + .members = shd_nodes(a, params.count, types), + .names = shd_strings(a, params.count, names) }); } @@ -52,57 +53,54 @@ static const Node* generate_arg_struct(Rewriter* rewriter, const Node* old_entry Nodes annotations = mk_nodes(a, annotation_value(a, (AnnotationValue) { .name = "EntryPointArgs", .value = fn_addr_helper(a, new_entry_point) })); const Node* type = generate_arg_struct_type(rewriter, old_entry_point->payload.fun.params); - String name = format_string_arena(a->arena, "__%s_args", old_entry_point->payload.fun.name); + String name = shd_fmt_string_irarena(a, "__%s_args", old_entry_point->payload.fun.name); Node* var = global_var(rewriter->dst_module, annotations, type, name, AsExternal); return ref_decl_helper(a, var); } -static const Node* rewrite_body(Context* ctx, const Node* old_entry_point, const Node* arg_struct) { +static const Node* rewrite_body(Context* ctx, const Node* old_entry_point, const Node* new, const Node* arg_struct) { IrArena* a = ctx->rewriter.dst_arena; - BodyBuilder* bb = begin_body(a); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new)); Nodes params = old_entry_point->payload.fun.params; for (int i = 0; i < params.count; ++i) { - const Node* addr = gen_lea(bb, arg_struct, int32_literal(a, 0), singleton(int32_literal(a, i))); - const Node* val = gen_load(bb, addr); - register_processed(&ctx->rewriter, params.nodes[i], val); + const Node* addr = lea_helper(a, arg_struct, shd_int32_literal(a, 0), shd_singleton(shd_int32_literal(a, i))); + const Node* val = shd_bld_load(bb, addr); + shd_register_processed(&ctx->rewriter, params.nodes[i], val); } - return finish_body(bb, rewrite_node(&ctx->rewriter, old_entry_point->payload.fun.body)); + shd_register_processed(&ctx->rewriter, shd_get_abstraction_mem(old_entry_point), shd_bb_mem(bb)); + return shd_bld_finish(bb, shd_rewrite_node(&ctx->rewriter, old_entry_point->payload.fun.body)); } static const Node* process(Context* ctx, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - switch (node->tag) { case Function_TAG: - if (lookup_annotation(node, "EntryPoint") && node->payload.fun.params.count > 0) { + if (shd_lookup_annotation(node, "EntryPoint") && node->payload.fun.params.count > 0) { Node* new_entry_point = rewrite_entry_point_fun(ctx, node); const Node* arg_struct = generate_arg_struct(&ctx->rewriter, node, new_entry_point); - new_entry_point->payload.fun.body = rewrite_body(ctx, node, arg_struct); + shd_set_abstraction_body(new_entry_point, rewrite_body(ctx, node, new_entry_point, arg_struct)); return new_entry_point; } break; default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_entrypoint_args(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_entrypoint_args(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_fill.c b/src/shady/passes/lower_fill.c index eaab886e0..f3492c792 100644 --- a/src/shady/passes/lower_fill.c +++ b/src/shady/passes/lower_fill.c @@ -1,44 +1,42 @@ -#include "passes.h" +#include "shady/pass.h" + +#include "../ir_private.h" #include "log.h" #include "portability.h" -#include "../ir_private.h" -#include "../type.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" - typedef struct { Rewriter rewriter; } Context; static const Node* process(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (node->tag) { case Fill_TAG: { - const Type* composite_t = rewrite_node(&ctx->rewriter, node->payload.fill.type); - size_t actual_size = get_int_literal_value(*resolve_to_int_literal(get_fill_type_size(composite_t)), false); - const Node* value = rewrite_node(&ctx->rewriter, node->payload.fill.value); + const Type* composite_t = shd_rewrite_node(r, node->payload.fill.type); + size_t actual_size = shd_get_int_literal_value(*shd_resolve_to_int_literal(shd_get_fill_type_size(composite_t)), false); + const Node* value = shd_rewrite_node(r, node->payload.fill.value); LARRAY(const Node*, copies, actual_size); for (size_t i = 0; i < actual_size; i++) { copies[i] = value; } - return composite_helper(a, composite_t, nodes(a, actual_size, copies)); + return composite_helper(a, composite_t, shd_nodes(a, actual_size, copies)); } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(r, node); } -Module* lower_fill(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_fill(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_generic_globals.c b/src/shady/passes/lower_generic_globals.c index 9fdc0efff..a8b5856f6 100644 --- a/src/shady/passes/lower_generic_globals.c +++ b/src/shady/passes/lower_generic_globals.c @@ -1,51 +1,53 @@ -#include "passes.h" +#include "shady/pass.h" + +#include "../ir_private.h" #include "log.h" #include "portability.h" -#include "../ir_private.h" -#include "../type.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" - typedef struct { Rewriter rewriter; } Context; static const Node* process(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (node->tag) { + case RefDecl_TAG: { + // make sure we rewrite the decl first, and then look if it rewrote the ref to it! + shd_rewrite_node(r, node->payload.ref_decl.decl); + const Node** f = shd_search_processed(r, node); + if (f) return *f; + break; + } case GlobalVariable_TAG: { if (node->payload.global_variable.address_space == AsGeneric) { - AddressSpace dst_as = AsGlobalPhysical; - const Type* t = rewrite_node(&ctx->rewriter, node->payload.global_variable.type); - Node* new_global = global_var(ctx->rewriter.dst_module, rewrite_nodes(&ctx->rewriter, node->payload.global_variable.annotations), t, node->payload.global_variable.name, dst_as); + AddressSpace dst_as = AsGlobal; + const Type* t = shd_rewrite_node(&ctx->rewriter, node->payload.global_variable.type); + Node* new_global = global_var(ctx->rewriter.dst_module, shd_rewrite_nodes(&ctx->rewriter, node->payload.global_variable.annotations), t, node->payload.global_variable.name, dst_as); + shd_register_processed(&ctx->rewriter, node, new_global); const Type* dst_t = ptr_type(a, (PtrType) { .pointed_type = t, .address_space = AsGeneric }); - Nodes decl_annotations = singleton(annotation(a, (Annotation) { .name = "Generated" })); - Node* constant_decl = constant(ctx->rewriter.dst_module, decl_annotations, dst_t, - format_string_interned(a, "%s_generic", get_decl_name(node))); - const Node* result = constant_decl; - constant_decl->payload.constant.instruction = prim_op_helper(a, convert_op, singleton(dst_t), singleton(ref_decl_helper(a, new_global))); - register_processed(&ctx->rewriter, node, result); - new_global->payload.global_variable.init = rewrite_node(&ctx->rewriter, node->payload.global_variable.init); - return result; + const Node* converted = prim_op_helper(a, convert_op, shd_singleton(dst_t), shd_singleton(ref_decl_helper(a, new_global))); + shd_register_processed(&ctx->rewriter, ref_decl_helper(node->arena, node), converted); + return new_global; } + break; } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_generic_globals(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_generic_globals(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_generic_ptrs.c b/src/shady/passes/lower_generic_ptrs.c index 27aab375a..bfd3121f3 100644 --- a/src/shady/passes/lower_generic_ptrs.c +++ b/src/shady/passes/lower_generic_ptrs.c @@ -1,16 +1,14 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/cast.h" +#include "shady/ir/memory_layout.h" + +#include "../ir_private.h" #include "log.h" #include "portability.h" #include "util.h" #include "dict.h" -#include "../rewrite.h" -#include "../type.h" -#include "../ir_private.h" -#include "../transform/ir_gen_helpers.h" -#include "../transform/memory_layout.h" - #include typedef struct { @@ -20,9 +18,15 @@ typedef struct { const CompilerConfig* config; } Context; -static AddressSpace generic_ptr_tags[4] = { AsGlobalPhysical, AsSharedPhysical, AsSubgroupPhysical, AsPrivatePhysical }; +static AddressSpace generic_ptr_tags[8] = { + [0x0] = AsGlobal, + [0x1] = AsShared, + [0x2] = AsSubgroup, + [0x3] = AsPrivate, + [0x7] = AsGlobal +}; -static size_t generic_ptr_tag_bitwidth = 2; +static size_t generic_ptr_tag_bitwidth = 3; static AddressSpace get_addr_space_from_tag(size_t tag) { size_t max_tag = sizeof(generic_ptr_tags) / sizeof(generic_ptr_tags[0]); @@ -36,7 +40,7 @@ static uint64_t get_tag_for_addr_space(AddressSpace as) { if (generic_ptr_tags[i] == as) return (uint64_t) i; } - error("this address space can't be converted to generic"); + shd_error("address space '%s' can't be converted to generic", shd_get_address_space_name(as)); } static const Node* recover_full_pointer(Context* ctx, BodyBuilder* bb, uint64_t tag, const Node* nptr, const Type* element_type) { @@ -45,28 +49,28 @@ static const Node* recover_full_pointer(Context* ctx, BodyBuilder* bb, uint64_t const Node* generic_ptr_type = int_type(a, (Int) {.width = a->config.memory.ptr_size, .is_signed = false}); // first_non_tag_bit = nptr >> (64 - 2 - 1) - const Node* first_non_tag_bit = gen_primop_e(bb, rshift_logical_op, empty(a), mk_nodes(a, nptr, size_t_literal(a, get_type_bitwidth(generic_ptr_type) - generic_ptr_tag_bitwidth - 1))); + const Node* first_non_tag_bit = prim_op_helper(a, rshift_logical_op, shd_empty(a), mk_nodes(a, nptr, size_t_literal(a, shd_get_type_bitwidth(generic_ptr_type) - generic_ptr_tag_bitwidth - 1))); // first_non_tag_bit &= 1 - first_non_tag_bit = gen_primop_e(bb, and_op, empty(a), mk_nodes(a, first_non_tag_bit, size_t_literal(a, 1))); + first_non_tag_bit = prim_op_helper(a, and_op, shd_empty(a), mk_nodes(a, first_non_tag_bit, size_t_literal(a, 1))); // needs_sign_extension = first_non_tag_bit == 1 - const Node* needs_sign_extension = gen_primop_e(bb, eq_op, empty(a), mk_nodes(a, first_non_tag_bit, size_t_literal(a, 1))); + const Node* needs_sign_extension = prim_op_helper(a, eq_op, shd_empty(a), mk_nodes(a, first_non_tag_bit, size_t_literal(a, 1))); // sign_extension_patch = needs_sign_extension ? ((1 << 2) - 1) << (64 - 2) : 0 - const Node* sign_extension_patch = gen_primop_e(bb, select_op, empty(a), mk_nodes(a, needs_sign_extension, size_t_literal(a, ((size_t) ((1 << max_tag) - 1)) << (get_type_bitwidth(generic_ptr_type) - generic_ptr_tag_bitwidth)), size_t_literal(a, 0))); + const Node* sign_extension_patch = prim_op_helper(a, select_op, shd_empty(a), mk_nodes(a, needs_sign_extension, size_t_literal(a, ((size_t) ((1 << max_tag) - 1)) << (shd_get_type_bitwidth(generic_ptr_type) - generic_ptr_tag_bitwidth)), size_t_literal(a, 0))); // patched_ptr = nptr & 0b00111 ... 111 - const Node* patched_ptr = gen_primop_e(bb, and_op, empty(a), mk_nodes(a, nptr, size_t_literal(a, SIZE_MAX >> generic_ptr_tag_bitwidth))); + const Node* patched_ptr = prim_op_helper(a, and_op, shd_empty(a), mk_nodes(a, nptr, size_t_literal(a, SIZE_MAX >> generic_ptr_tag_bitwidth))); // patched_ptr = patched_ptr | sign_extension_patch - patched_ptr = gen_primop_e(bb, or_op, empty(a), mk_nodes(a, patched_ptr, sign_extension_patch)); + patched_ptr = prim_op_helper(a, or_op, shd_empty(a), mk_nodes(a, patched_ptr, sign_extension_patch)); const Type* dst_ptr_t = ptr_type(a, (PtrType) { .pointed_type = element_type, .address_space = get_addr_space_from_tag(tag) }); - const Node* reinterpreted_ptr = gen_reinterpret_cast(bb, dst_ptr_t, patched_ptr); + const Node* reinterpreted_ptr = shd_bld_reinterpret_cast(bb, dst_ptr_t, patched_ptr); return reinterpreted_ptr; } static bool allowed(Context* ctx, AddressSpace as) { - if (as == AsGlobalPhysical && ctx->config->hacks.no_physical_global_ptrs) - return false; - if (as == AsSharedPhysical && !ctx->rewriter.dst_arena->config.allow_shared_memory) + // some tags aren't in use + if (as == AsGeneric) return false; - if (as == AsSubgroupPhysical && !ctx->rewriter.dst_arena->config.allow_subgroup_memory) + // if an address space is logical-only, or isn't allowed at all in the module, we can skip emitting a case for it. + if (!ctx->rewriter.dst_arena->config.address_spaces[as].physical || !ctx->rewriter.dst_arena->config.address_spaces[as].allowed) return false; return true; } @@ -76,94 +80,105 @@ static const Node* get_or_make_access_fn(Context* ctx, WhichFn which, bool unifo IrArena* a = ctx->rewriter.dst_arena; String name; switch (which) { - case LoadFn: name = format_string_interned(a, "generated_load_Generic_%s", name_type_safe(a, t)); break; - case StoreFn: name = format_string_interned(a, "generated_store_Generic_%s", name_type_safe(a, t)); break; + case LoadFn: name = shd_fmt_string_irarena(a, "generated_load_Generic_%s%s", shd_get_type_name(a, t), uniform_ptr ? "_uniform" : ""); break; + case StoreFn: name = shd_fmt_string_irarena(a, "generated_store_Generic_%s", shd_get_type_name(a, t)); break; } - const Node** found = find_value_dict(String, const Node*, ctx->fns, name); + const Node** found = shd_dict_find_value(String, const Node*, ctx->fns, name); if (found) return *found; - const Node* ptr_param = var(a, qualified_type_helper(ctx->generic_ptr_type, false), "ptr"); + const Node* ptr_param = param(a, shd_as_qualified_type(ctx->generic_ptr_type, uniform_ptr), "ptr"); const Node* value_param; - Nodes params = singleton(ptr_param); - Nodes return_ts = empty(a); + Nodes params = shd_singleton(ptr_param); + Nodes return_ts = shd_empty(a); switch (which) { case LoadFn: - return_ts = singleton(qualified_type_helper(t, false)); + return_ts = shd_singleton(shd_as_qualified_type(t, uniform_ptr)); break; case StoreFn: - value_param = var(a, qualified_type_helper(t, false), "value"); - params = append_nodes(a, params, value_param); + value_param = param(a, shd_as_qualified_type(t, false), "value"); + params = shd_nodes_append(a, params, value_param); break; } - Node* new_fn = function(ctx->rewriter.dst_module, params, name, singleton(annotation(a, (Annotation) { .name = "Generated" })), return_ts); - insert_dict(String, const Node*, ctx->fns, name, new_fn); + Node* new_fn = function(ctx->rewriter.dst_module, params, name, mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }), annotation(a, (Annotation) { .name = "Leaf" })), return_ts); + shd_dict_insert(String, const Node*, ctx->fns, name, new_fn); size_t max_tag = sizeof(generic_ptr_tags) / sizeof(generic_ptr_tags[0]); switch (which) { case LoadFn: { + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new_fn)); + shd_bld_comment(bb, "Generated generic ptr store"); + begin_control_t r = shd_bld_begin_control(bb, shd_singleton(t)); + const Node* final_loaded_value = shd_first(r.results); + LARRAY(const Node*, literals, max_tag); - LARRAY(const Node*, cases, max_tag); + LARRAY(const Node*, jumps, max_tag); for (size_t tag = 0; tag < max_tag; tag++) { literals[tag] = size_t_literal(a, tag); if (!allowed(ctx, generic_ptr_tags[tag])) { - cases[tag] = case_(a, empty(a), unreachable(a)); + Node* tag_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(tag_case, unreachable(a, (Unreachable) { .mem = shd_get_abstraction_mem(tag_case) })); + jumps[tag] = jump_helper(a, shd_get_abstraction_mem(r.case_), tag_case, shd_empty(a)); continue; } - BodyBuilder* case_bb = begin_body(a); + Node* tag_case = case_(a, shd_empty(a)); + BodyBuilder* case_bb = shd_bld_begin(a, shd_get_abstraction_mem(tag_case)); const Node* reinterpreted_ptr = recover_full_pointer(ctx, case_bb, tag, ptr_param, t); - const Node* loaded_value = gen_load(case_bb, reinterpreted_ptr); - cases[tag] = case_(a, empty(a), finish_body(case_bb, yield(a, (Yield) { - .args = singleton(loaded_value), - }))); + const Node* loaded_value = shd_bld_load(case_bb, reinterpreted_ptr); + shd_set_abstraction_body(tag_case, shd_bld_join(case_bb, r.jp, shd_singleton(loaded_value))); + jumps[tag] = jump_helper(a, shd_get_abstraction_mem(r.case_), tag_case, shd_empty(a)); } - - BodyBuilder* bb = begin_body(a); - gen_comment(bb, "Generated generic ptr store"); // extracted_tag = nptr >> (64 - 2), for example - const Node* extracted_tag = gen_primop_e(bb, rshift_logical_op, empty(a), mk_nodes(a, ptr_param, size_t_literal(a, get_type_bitwidth(ctx->generic_ptr_type) - generic_ptr_tag_bitwidth))); + const Node* extracted_tag = prim_op_helper(a, rshift_logical_op, shd_empty(a), mk_nodes(a, ptr_param, size_t_literal(a, shd_get_type_bitwidth(ctx->generic_ptr_type) - generic_ptr_tag_bitwidth))); - const Node* loaded_value = first(bind_instruction(bb, match_instr(a, (Match) { - .inspect = extracted_tag, - .yield_types = singleton(t), - .literals = nodes(a, max_tag, literals), - .cases = nodes(a, max_tag, cases), - .default_case = case_(a, empty(a), unreachable(a)), - }))); - new_fn->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .args = singleton(loaded_value), .fn = new_fn })); + Node* default_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(default_case, unreachable(a, (Unreachable) { .mem = shd_get_abstraction_mem(default_case) })); + shd_set_abstraction_body(r.case_, br_switch(a, (Switch) { + .mem = shd_get_abstraction_mem(r.case_), + .switch_value = extracted_tag, + .case_values = shd_nodes(a, max_tag, literals), + .case_jumps = shd_nodes(a, max_tag, jumps), + .default_jump = jump_helper(a, shd_get_abstraction_mem(r.case_), default_case, shd_empty(a)) + })); + shd_set_abstraction_body(new_fn, shd_bld_finish(bb, fn_ret(a, (Return) { .args = shd_singleton(final_loaded_value), .mem = shd_bb_mem(bb) }))); break; } case StoreFn: { + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new_fn)); + shd_bld_comment(bb, "Generated generic ptr store"); + begin_control_t r = shd_bld_begin_control(bb, shd_empty(a)); + LARRAY(const Node*, literals, max_tag); - LARRAY(const Node*, cases, max_tag); + LARRAY(const Node*, jumps, max_tag); for (size_t tag = 0; tag < max_tag; tag++) { literals[tag] = size_t_literal(a, tag); if (!allowed(ctx, generic_ptr_tags[tag])) { - cases[tag] = case_(a, empty(a), unreachable(a)); + Node* tag_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(tag_case, unreachable(a, (Unreachable) { .mem = shd_get_abstraction_mem(tag_case) })); + jumps[tag] = jump_helper(a, shd_get_abstraction_mem(r.case_), tag_case, shd_empty(a)); continue; } - BodyBuilder* case_bb = begin_body(a); + Node* tag_case = case_(a, shd_empty(a)); + BodyBuilder* case_bb = shd_bld_begin(a, shd_get_abstraction_mem(tag_case)); const Node* reinterpreted_ptr = recover_full_pointer(ctx, case_bb, tag, ptr_param, t); - gen_store(case_bb, reinterpreted_ptr, value_param); - cases[tag] = case_(a, empty(a), finish_body(case_bb, yield(a, (Yield) { - .args = empty(a), - }))); + shd_bld_store(case_bb, reinterpreted_ptr, value_param); + shd_set_abstraction_body(tag_case, shd_bld_join(case_bb, r.jp, shd_empty(a))); + jumps[tag] = jump_helper(a, shd_get_abstraction_mem(r.case_), tag_case, shd_empty(a)); } - - BodyBuilder* bb = begin_body(a); - gen_comment(bb, "Generated generic ptr store"); // extracted_tag = nptr >> (64 - 2), for example - const Node* extracted_tag = gen_primop_e(bb, rshift_logical_op, empty(a), mk_nodes(a, ptr_param, size_t_literal(a, get_type_bitwidth(ctx->generic_ptr_type) - generic_ptr_tag_bitwidth))); + const Node* extracted_tag = prim_op_helper(a, rshift_logical_op, shd_empty(a), mk_nodes(a, ptr_param, size_t_literal(a, shd_get_type_bitwidth(ctx->generic_ptr_type) - generic_ptr_tag_bitwidth))); - bind_instruction(bb, match_instr(a, (Match) { - .inspect = extracted_tag, - .yield_types = empty(a), - .literals = nodes(a, max_tag, literals), - .cases = nodes(a, max_tag, cases), - .default_case = case_(a, empty(a), unreachable(a)), + Node* default_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(default_case, unreachable(a, (Unreachable) { .mem = shd_get_abstraction_mem(default_case) })); + shd_set_abstraction_body(r.case_, br_switch(a, (Switch) { + .mem = shd_get_abstraction_mem(r.case_), + .switch_value = extracted_tag, + .case_values = shd_nodes(a, max_tag, literals), + .case_jumps = shd_nodes(a, max_tag, jumps), + .default_jump = jump_helper(a, shd_get_abstraction_mem(r.case_), default_case, shd_empty(a)) })); - new_fn->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .args = empty(a), .fn = new_fn })); + shd_set_abstraction_body(new_fn, shd_bld_finish(bb, fn_ret(a, (Return) { .args = shd_empty(a), .mem = shd_bb_mem(bb) }))); break; } } @@ -171,10 +186,8 @@ static const Node* get_or_make_access_fn(Context* ctx, WhichFn which, bool unifo } static const Node* process(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; Module* m = ctx->rewriter.dst_module; size_t max_tag = sizeof(generic_ptr_tags) / sizeof(generic_ptr_tags[0]); @@ -190,54 +203,60 @@ static const Node* process(Context* ctx, const Node* old) { return size_t_literal(a, 0); break; } + case Load_TAG: { + Load payload = old->payload.load; + const Type* old_ptr_t = payload.ptr->type; + bool u = shd_deconstruct_qualified_type(&old_ptr_t); + u &= shd_is_addr_space_uniform(a, old_ptr_t->payload.ptr_type.address_space); + if (old_ptr_t->payload.ptr_type.address_space == AsGeneric) { + return call(a, (Call) { + .callee = fn_addr_helper(a, get_or_make_access_fn(ctx, LoadFn, u, shd_rewrite_node(r, old_ptr_t->payload.ptr_type.pointed_type))), + .args = shd_singleton(shd_rewrite_node(&ctx->rewriter, payload.ptr)), + .mem = shd_rewrite_node(r, payload.mem) + }); + } + break; + } + case Store_TAG: { + Store payload = old->payload.store; + const Type* old_ptr_t = payload.ptr->type; + shd_deconstruct_qualified_type(&old_ptr_t); + if (old_ptr_t->payload.ptr_type.address_space == AsGeneric) { + return call(a, (Call) { + .callee = fn_addr_helper(a, get_or_make_access_fn(ctx, StoreFn, false, shd_rewrite_node(r, old_ptr_t->payload.ptr_type.pointed_type))), + .args = mk_nodes(a, shd_rewrite_node(r, payload.ptr), shd_rewrite_node(r, payload.value)), + .mem = shd_rewrite_node(r, payload.mem), + }); + } + break; + } case PrimOp_TAG: { switch (old->payload.prim_op.op) { case convert_op: { - const Node* old_src = first(old->payload.prim_op.operands); + const Node* old_src = shd_first(old->payload.prim_op.operands); const Type* old_src_t = old_src->type; - deconstruct_qualified_type(&old_src_t); - const Type* old_dst_t = first(old->payload.prim_op.type_arguments); + shd_deconstruct_qualified_type(&old_src_t); + const Type* old_dst_t = shd_first(old->payload.prim_op.type_arguments); if (old_dst_t->tag == PtrType_TAG && old_dst_t->payload.ptr_type.address_space == AsGeneric) { // cast _into_ generic AddressSpace src_as = old_src_t->payload.ptr_type.address_space; size_t tag = get_tag_for_addr_space(src_as); - BodyBuilder* bb = begin_body(a); - String x = format_string_arena(a->arena, "Generated generic ptr convert src %d tag %d", src_as, tag); - gen_comment(bb, x); - const Node* src_ptr = rewrite_node(&ctx->rewriter, old_src); - const Node* generic_ptr = gen_reinterpret_cast(bb, ctx->generic_ptr_type, src_ptr); + BodyBuilder* bb = shd_bld_begin_pure(a); + // TODO: find another way to annotate this ? + // String x = format_string_arena(a->arena, "Generated generic ptr convert src %d tag %d", src_as, tag); + // gen_comment(bb, x); + const Node* src_ptr = shd_rewrite_node(&ctx->rewriter, old_src); + const Node* generic_ptr = shd_bld_reinterpret_cast(bb, ctx->generic_ptr_type, src_ptr); const Node* ptr_mask = size_t_literal(a, (UINT64_MAX >> (uint64_t) (generic_ptr_tag_bitwidth))); // generic_ptr = generic_ptr & 0x001111 ... 111 - generic_ptr = gen_primop_e(bb, and_op, empty(a), mk_nodes(a, generic_ptr, ptr_mask)); - const Node* shifted_tag = size_t_literal(a, (tag << (uint64_t) (get_type_bitwidth(ctx->generic_ptr_type) - generic_ptr_tag_bitwidth))); + generic_ptr = prim_op_helper(a, and_op, shd_empty(a), mk_nodes(a, generic_ptr, ptr_mask)); + const Node* shifted_tag = size_t_literal(a, (tag << (uint64_t) (shd_get_type_bitwidth(ctx->generic_ptr_type) - generic_ptr_tag_bitwidth))); // generic_ptr = generic_ptr | 01000000 ... 000 - generic_ptr = gen_primop_e(bb, or_op, empty(a), mk_nodes(a, generic_ptr, shifted_tag)); - return yield_values_and_wrap_in_block(bb, singleton(generic_ptr)); + generic_ptr = prim_op_helper(a, or_op, shd_empty(a), mk_nodes(a, generic_ptr, shifted_tag)); + return shd_bld_to_instr_yield_values(bb, shd_singleton(generic_ptr)); } else if (old_src_t->tag == PtrType_TAG && old_src_t->payload.ptr_type.address_space == AsGeneric) { // cast _from_ generic - error("TODO"); - } - break; - } - case load_op: { - const Type* old_ptr_t = first(old->payload.prim_op.operands)->type; - deconstruct_qualified_type(&old_ptr_t); - if (old_ptr_t->payload.ptr_type.address_space == AsGeneric) { - return call(a, (Call) { - .callee = fn_addr_helper(a, get_or_make_access_fn(ctx, LoadFn, false, rewrite_node(&ctx->rewriter, old_ptr_t->payload.ptr_type.pointed_type))), - .args = singleton(rewrite_node(&ctx->rewriter, first(old->payload.prim_op.operands))), - }); - } - break; - } - case store_op: { - const Type* old_ptr_t = first(old->payload.prim_op.operands)->type; - deconstruct_qualified_type(&old_ptr_t); - if (old_ptr_t->payload.ptr_type.address_space == AsGeneric) { - return call(a, (Call) { - .callee = fn_addr_helper(a, get_or_make_access_fn(ctx, StoreFn, false, rewrite_node(&ctx->rewriter, old_ptr_t->payload.ptr_type.pointed_type))), - .args = rewrite_nodes(&ctx->rewriter, old->payload.prim_op.operands), - }); + shd_error("TODO"); } break; } @@ -247,24 +266,24 @@ static const Node* process(Context* ctx, const Node* old) { default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -KeyHash hash_string(const char** string); -bool compare_string(const char** a, const char** b); +KeyHash shd_hash_string(const char** string); +bool shd_compare_string(const char** a, const char** b); -Module* lower_generic_ptrs(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_generic_ptrs(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .fns = new_dict(String, const Node*, (HashFn) hash_string, (CmpFn) compare_string), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .fns = shd_new_dict(String, const Node*, (HashFn) shd_hash_string, (CmpFn) shd_compare_string), .generic_ptr_type = int_type(a, (Int) {.width = a->config.memory.ptr_size, .is_signed = false}), .config = config, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - destroy_dict(ctx.fns); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_dict(ctx.fns); return dst; } diff --git a/src/shady/passes/lower_inclusive_scan.c b/src/shady/passes/lower_inclusive_scan.c new file mode 100644 index 000000000..39b26fb30 --- /dev/null +++ b/src/shady/passes/lower_inclusive_scan.c @@ -0,0 +1,135 @@ +#include "shady/pass.h" +#include "shady/ir/type.h" + +#include "log.h" +#include "portability.h" + +#include + +#include + +typedef struct { + Rewriter rewriter; +} Context; + +typedef struct { + SpvOp spv_op; + Op scalar; + const Node* (*I)(IrArena*, const Type* t); +} GroupOp; + +static const Node* zero(IrArena* a, const Type* t) { + t = shd_get_unqualified_type(t); + assert(t->tag == Int_TAG); + Int t_payload = t->payload.int_type; + IntLiteral lit = { + .width = t_payload.width, + .is_signed = t_payload.is_signed, + .value = 0 + }; + return int_literal(a, lit); +} + +static const Node* one(IrArena* a, const Type* t) { + t = shd_get_unqualified_type(t); + assert(t->tag == Int_TAG); + Int t_payload = t->payload.int_type; + IntLiteral lit = { + .width = t_payload.width, + .is_signed = t_payload.is_signed, + .value = 1 + }; + return int_literal(a, lit); +} + +static GroupOp group_operations[] = { + { SpvOpGroupIAdd, add_op }, + { SpvOpGroupFAdd, add_op }, + { SpvOpGroupFMin, min_op }, + { SpvOpGroupUMin, min_op }, + { SpvOpGroupSMin, min_op }, + { SpvOpGroupFMax, max_op, }, + { SpvOpGroupUMax, max_op }, + { SpvOpGroupSMax, max_op }, + { SpvOpGroupNonUniformBallotBitCount, /* todo */ }, + { SpvOpGroupNonUniformIAdd, add_op }, + { SpvOpGroupNonUniformFAdd, add_op }, + { SpvOpGroupNonUniformIMul, mul_op }, + { SpvOpGroupNonUniformFMul, mul_op }, + { SpvOpGroupNonUniformSMin, min_op }, + { SpvOpGroupNonUniformUMin, min_op }, + { SpvOpGroupNonUniformFMin, min_op }, + { SpvOpGroupNonUniformSMax, max_op }, + { SpvOpGroupNonUniformUMax, max_op }, + { SpvOpGroupNonUniformFMax, max_op }, + { SpvOpGroupNonUniformBitwiseAnd, and_op }, + { SpvOpGroupNonUniformBitwiseOr, or_op }, + { SpvOpGroupNonUniformBitwiseXor, xor_op }, + { SpvOpGroupNonUniformLogicalAnd, and_op }, + { SpvOpGroupNonUniformLogicalOr, or_op }, + { SpvOpGroupNonUniformLogicalXor, xor_op }, + { SpvOpGroupIAddNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupFAddNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupFMinNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupUMinNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupSMinNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupFMaxNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupUMaxNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupSMaxNonUniformAMD, /* todo: map to std */ }, + { SpvOpGroupIMulKHR, /* todo: map to std */ }, + { SpvOpGroupFMulKHR, /* todo: map to std */ }, + { SpvOpGroupBitwiseAndKHR, /* todo: map to std */ }, + { SpvOpGroupBitwiseOrKHR, /* todo: map to std */ }, + { SpvOpGroupBitwiseXorKHR, /* todo: map to std */ }, + { SpvOpGroupLogicalAndKHR, /* todo: map to std */ }, + { SpvOpGroupLogicalOrKHR, /* todo: map to std */ }, + { SpvOpGroupLogicalXorKHR, /* todo: map to std */ }, +}; + +enum { + NumGroupOps = sizeof(group_operations) / sizeof(group_operations[0]) +}; + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (node->tag) { + case ExtInstr_TAG: { + ExtInstr payload = node->payload.ext_instr; + if (strcmp(payload.set, "spirv.core") == 0) { + for (size_t i = 0; i < NumGroupOps; i++) { + if (payload.opcode == group_operations[i].spv_op) { + if (shd_get_int_value(payload.operands.nodes[1], false) == SpvGroupOperationInclusiveScan) { + //assert(group_operations[i].I); + IrArena* oa = node->arena; + payload.operands = shd_change_node_at_index(oa, payload.operands, 1, shd_uint32_literal(a, SpvGroupOperationReduce)); + const Node* new = shd_recreate_node(r, ext_instr(oa, payload)); + // new = prim_op_helper(a, group_operations[i].scalar, shd_empty(a), mk_nodes(a, new, group_operations[i].I(a, new->type) )); + new = prim_op_helper(a, group_operations[i].scalar, shd_empty(a), mk_nodes(a, new, shd_recreate_node(r, payload.operands.nodes[2]) )); + return new; + } + } + } + } + } + default: break; + } + + return shd_recreate_node(r, node); +} + +/// Transforms +/// SpvOpGroupXXX(Scope, 'GroupOperationInclusiveScan', v) +/// into +/// SpvOpGroupXXX(Scope, 'GroupOperationExclusiveScan', v) op v +Module* shd_pass_lower_inclusive_scan(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/shady/passes/lower_int64.c b/src/shady/passes/lower_int64.c index 025703dbd..754d5e248 100644 --- a/src/shady/passes/lower_int64.c +++ b/src/shady/passes/lower_int64.c @@ -1,57 +1,51 @@ -#include "passes.h" +#include "shady/pass.h" + +#include "../ir_private.h" #include "log.h" #include "portability.h" -#include "../ir_private.h" -#include "../type.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" - typedef struct { Rewriter rewriter; const CompilerConfig* config; } Context; static bool should_convert(Context* ctx, const Type* t) { - t = get_unqualified_type(t); + t = shd_get_unqualified_type(t); return t->tag == Int_TAG && t->payload.int_type.width == IntTy64 && ctx->config->lower.int64; } -static void extract_low_hi_halves(BodyBuilder* bb, const Node* src, const Node** lo, const Node** hi) { - *lo = first(bind_instruction(bb, prim_op(bb->arena, - (PrimOp) { .op = extract_op, .operands = mk_nodes(bb->arena, src, int32_literal(bb->arena, 0))}))); - *hi = first(bind_instruction(bb, prim_op(bb->arena, - (PrimOp) { .op = extract_op, .operands = mk_nodes(bb->arena, src, int32_literal(bb->arena, 1))}))); +static void extract_low_hi_halves(IrArena* a, BodyBuilder* bb, const Node* src, const Node** lo, const Node** hi) { + *lo = shd_first(shd_bld_add_instruction_extract(bb, prim_op(a, + (PrimOp) { .op = extract_op, .operands = mk_nodes(a, src, shd_int32_literal(a, 0)) }))); + *hi = shd_first(shd_bld_add_instruction_extract(bb, prim_op(a, + (PrimOp) { .op = extract_op, .operands = mk_nodes(a, src, shd_int32_literal(a, 1)) }))); } -static void extract_low_hi_halves_list(BodyBuilder* bb, Nodes src, const Node** lows, const Node** his) { +static void extract_low_hi_halves_list(IrArena* a, BodyBuilder* bb, Nodes src, const Node** lows, const Node** his) { for (size_t i = 0; i < src.count; i++) { - extract_low_hi_halves(bb, src.nodes[i], lows, his); + extract_low_hi_halves(a, bb, src.nodes[i], lows, his); lows++; his++; } } static const Node* process(Context* ctx, const Node* node) { - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - IrArena* a = ctx->rewriter.dst_arena; switch (node->tag) { case Int_TAG: if (node->payload.int_type.width == IntTy64 && ctx->config->lower.int64) return record_type(a, (RecordType) { - .members = mk_nodes(a, int32_type(a), int32_type(a)) + .members = mk_nodes(a, shd_int32_type(a), shd_int32_type(a)) }); break; case IntLiteral_TAG: if (node->payload.int_literal.width == IntTy64 && ctx->config->lower.int64) { uint64_t raw = node->payload.int_literal.value; - const Node* lower = uint32_literal(a, (uint32_t) raw); - const Node* upper = uint32_literal(a, (uint32_t) (raw >> 32)); - return tuple_helper(a, mk_nodes(a, lower, upper)); + const Node* lower = shd_uint32_literal(a, (uint32_t) raw); + const Node* upper = shd_uint32_literal(a, (uint32_t) (raw >> 32)); + return shd_tuple_helper(a, mk_nodes(a, lower, upper)); } break; case PrimOp_TAG: { @@ -60,17 +54,17 @@ static const Node* process(Context* ctx, const Node* node) { LARRAY(const Node*, lows, old_nodes.count); LARRAY(const Node*, his, old_nodes.count); switch(op) { - case add_op: if (should_convert(ctx, first(old_nodes)->type)) { - Nodes new_nodes = rewrite_nodes(&ctx->rewriter, old_nodes); + case add_op: if (should_convert(ctx, shd_first(old_nodes)->type)) { + Nodes new_nodes = shd_rewrite_nodes(&ctx->rewriter, old_nodes); // TODO: convert into and then out of unsigned - BodyBuilder* bb = begin_body(a); - extract_low_hi_halves_list(bb, new_nodes, lows, his); - Nodes low_and_carry = bind_instruction(bb, prim_op(a, (PrimOp) { .op = add_carry_op, .operands = nodes(a, 2, lows)})); - const Node* lo = first(low_and_carry); + BodyBuilder* bb = shd_bld_begin_pure(a); + extract_low_hi_halves_list(a, bb, new_nodes, lows, his); + Nodes low_and_carry = shd_bld_add_instruction_extract(bb, prim_op(a, (PrimOp) { .op = add_carry_op, .operands = shd_nodes(a, 2, lows) })); + const Node* lo = shd_first(low_and_carry); // compute the high side, without forgetting the carry bit - const Node* hi = first(bind_instruction(bb, prim_op(a, (PrimOp) { .op = add_op, .operands = nodes(a, 2, his)}))); - hi = first(bind_instruction(bb, prim_op(a, (PrimOp) { .op = add_op, .operands = mk_nodes(a, hi, low_and_carry.nodes[1])}))); - return yield_values_and_wrap_in_block(bb, singleton(tuple_helper(a, mk_nodes(a, lo, hi)))); + const Node* hi = shd_first(shd_bld_add_instruction_extract(bb, prim_op(a, (PrimOp) { .op = add_op, .operands = shd_nodes(a, 2, his) }))); + hi = shd_first(shd_bld_add_instruction_extract(bb, prim_op(a, (PrimOp) { .op = add_op, .operands = mk_nodes(a, hi, low_and_carry.nodes[1]) }))); + return shd_bld_to_instr_yield_values(bb, shd_singleton(shd_tuple_helper(a, mk_nodes(a, lo, hi)))); } break; default: break; } @@ -80,18 +74,18 @@ static const Node* process(Context* ctx, const Node* node) { } rebuild: - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_int(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_int(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_lea.c b/src/shady/passes/lower_lea.c index 49a3ad562..fc154f423 100644 --- a/src/shady/passes/lower_lea.c +++ b/src/shady/passes/lower_lea.c @@ -1,9 +1,7 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/cast.h" -#include "../rewrite.h" -#include "../type.h" #include "../ir_private.h" -#include "../transform/ir_gen_helpers.h" #include "log.h" #include "portability.h" @@ -12,129 +10,168 @@ typedef struct { Rewriter rewriter; + const CompilerConfig* config; } Context; -static const Node* lower_ptr_arithm(Context* ctx, BodyBuilder* bb, const Type* pointer_type, const Node* base, const Node* offset, size_t n_indices, const Node** indices) { +// TODO: make this configuration-dependant +static bool is_as_emulated(SHADY_UNUSED Context* ctx, AddressSpace as) { + switch (as) { + case AsPrivate: return true; // TODO have a config option to do this with swizzled global memory + case AsSubgroup: return true; + case AsShared: return true; + case AsGlobal: return true; // TODO have a config option to do this with SSBOs + default: return false; + } +} + +static const Node* lower_ptr_index(Context* ctx, BodyBuilder* bb, const Type* pointer_type, const Node* base, const Node* index) { IrArena* a = ctx->rewriter.dst_arena; const Type* emulated_ptr_t = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false }); assert(pointer_type->tag == PtrType_TAG); const Node* ptr = base; - const IntLiteral* offset_value = resolve_to_int_literal(offset); - bool offset_is_zero = offset_value && offset_value->value == 0; - if (!offset_is_zero) { - const Type* arr_type = pointer_type->payload.ptr_type.pointed_type; - assert(arr_type->tag == ArrType_TAG); - const Type* element_type = arr_type->payload.arr_type.element_type; + assert(pointer_type->tag == PtrType_TAG); + const Type* pointed_type = pointer_type->payload.ptr_type.pointed_type; + switch (pointed_type->tag) { + case ArrType_TAG: { + const Type* element_type = pointed_type->payload.arr_type.element_type; + + const Node* element_t_size = prim_op_helper(a, size_of_op, shd_singleton(element_type), shd_empty(a)); - const Node* element_t_size = gen_primop_e(bb, size_of_op, singleton(element_type), empty(a)); + const Node* new_index = shd_bld_convert_int_extend_according_to_src_t(bb, emulated_ptr_t, index); + const Node* physical_offset = prim_op_helper(a, mul_op, shd_empty(a), mk_nodes(a, new_index, element_t_size)); - const Node* new_offset = convert_int_extend_according_to_src_t(bb, emulated_ptr_t, offset); - const Node* physical_offset = gen_primop_ce(bb, mul_op, 2, (const Node* []) { new_offset, element_t_size}); + ptr = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, ptr, physical_offset)); - ptr = gen_primop_ce(bb, add_op, 2, (const Node* []) { ptr, physical_offset}); + pointer_type = ptr_type(a, (PtrType) { + .pointed_type = element_type, + .address_space = pointer_type->payload.ptr_type.address_space + }); + break; + } + case TypeDeclRef_TAG: { + const Node* nom_decl = pointed_type->payload.type_decl_ref.decl; + assert(nom_decl && nom_decl->tag == NominalType_TAG); + pointed_type = nom_decl->payload.nom_type.body; + SHADY_FALLTHROUGH + } + case RecordType_TAG: { + Nodes member_types = pointed_type->payload.record_type.members; + + const IntLiteral* selector_value = shd_resolve_to_int_literal(index); + assert(selector_value && "selector value must be known for LEA into a record"); + size_t n = selector_value->value; + assert(n < member_types.count); + + const Node* offset_of = prim_op_helper(a, offset_of_op, shd_singleton(pointed_type), shd_singleton(shd_uint64_literal(a, n))); + ptr = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, ptr, offset_of)); + + pointer_type = ptr_type(a, (PtrType) { + .pointed_type = member_types.nodes[n], + .address_space = pointer_type->payload.ptr_type.address_space + }); + break; + } + default: shd_error("cannot index into this") } - for (size_t i = 0; i < n_indices; i++) { - assert(pointer_type->tag == PtrType_TAG); - const Type* pointed_type = pointer_type->payload.ptr_type.pointed_type; - switch (pointed_type->tag) { - case ArrType_TAG: { - const Type* element_type = pointed_type->payload.arr_type.element_type; + return ptr; +} + +static const Node* lower_ptr_offset(Context* ctx, BodyBuilder* bb, const Type* pointer_type, const Node* base, const Node* offset) { + IrArena* a = ctx->rewriter.dst_arena; + const Type* emulated_ptr_t = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false }); + assert(pointer_type->tag == PtrType_TAG); - const Node* element_t_size = gen_primop_e(bb, size_of_op, singleton(element_type), empty(a)); + const Node* ptr = base; - const Node* new_index = convert_int_extend_according_to_src_t(bb, emulated_ptr_t, indices[i]); - const Node* physical_offset = gen_primop_ce(bb, mul_op, 2, (const Node* []) {new_index, element_t_size}); + const IntLiteral* offset_value = shd_resolve_to_int_literal(offset); + bool offset_is_zero = offset_value && offset_value->value == 0; + if (!offset_is_zero) { + const Type* element_type = pointer_type->payload.ptr_type.pointed_type; + // assert(arr_type->tag == ArrType_TAG); + // const Type* element_type = arr_type->payload.arr_type.element_type; - ptr = gen_primop_ce(bb, add_op, 2, (const Node* []) { ptr, physical_offset }); + const Node* element_t_size = prim_op_helper(a, size_of_op, shd_singleton(element_type), shd_empty(a)); - pointer_type = ptr_type(a, (PtrType) { - .pointed_type = element_type, - .address_space = pointer_type->payload.ptr_type.address_space - }); - break; - } - case TypeDeclRef_TAG: { - const Node* nom_decl = pointed_type->payload.type_decl_ref.decl; - assert(nom_decl && nom_decl->tag == NominalType_TAG); - pointed_type = nom_decl->payload.nom_type.body; - SHADY_FALLTHROUGH - } - case RecordType_TAG: { - Nodes member_types = pointed_type->payload.record_type.members; - - const IntLiteral* selector_value = resolve_to_int_literal(indices[i]); - assert(selector_value && "selector value must be known for LEA into a record"); - size_t n = selector_value->value; - assert(n < member_types.count); - - const Node* offset_of = gen_primop_e(bb, offset_of_op, singleton(pointed_type), singleton(uint64_literal(a, n))); - ptr = gen_primop_ce(bb, add_op, 2, (const Node* []) { ptr, offset_of }); - - pointer_type = ptr_type(a, (PtrType) { - .pointed_type = member_types.nodes[n], - .address_space = pointer_type->payload.ptr_type.address_space - }); - break; - } - default: error("cannot index into this") - } + const Node* new_offset = shd_bld_convert_int_extend_according_to_src_t(bb, emulated_ptr_t, offset); + const Node* physical_offset = prim_op_helper(a, mul_op, shd_empty(a), mk_nodes(a, new_offset, element_t_size)); + + ptr = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, ptr, physical_offset)); } return ptr; } static const Node* process(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; const Type* emulated_ptr_t = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false }); switch (old->tag) { - case PrimOp_TAG: { - switch (old->payload.prim_op.op) { - case lea_op: { - Nodes old_ops = old->payload.prim_op.operands; - const Node* old_base = first(old_ops); - const Type* old_base_ptr_t = old_base->type; - deconstruct_qualified_type(&old_base_ptr_t); - assert(old_base_ptr_t->tag == PtrType_TAG); - const Node* old_result_t = old->type; - deconstruct_qualified_type(&old_result_t); - // Leave logical ptrs alone - if (!is_physical_as(old_base_ptr_t->payload.ptr_type.address_space)) - break; - BodyBuilder* bb = begin_body(a); - Nodes new_ops = rewrite_nodes(&ctx->rewriter, old_ops); - const Node* cast_base = gen_reinterpret_cast(bb, emulated_ptr_t, first(new_ops)); - const Type* new_base_t = rewrite_node(&ctx->rewriter, old_base_ptr_t); - const Node* result = lower_ptr_arithm(ctx, bb, new_base_t, cast_base, new_ops.nodes[1], new_ops.count - 2, &new_ops.nodes[2]); - const Type* new_ptr_t = rewrite_node(&ctx->rewriter, old_result_t); - const Node* cast_result = gen_reinterpret_cast(bb, new_ptr_t, result); - return yield_values_and_wrap_in_block(bb, singleton(cast_result)); - } - default: break; - } - break; + case PtrArrayElementOffset_TAG: { + PtrArrayElementOffset lea = old->payload.ptr_array_element_offset; + const Node* old_base = lea.ptr; + const Type* old_base_ptr_t = old_base->type; + shd_deconstruct_qualified_type(&old_base_ptr_t); + assert(old_base_ptr_t->tag == PtrType_TAG); + const Node* old_result_t = old->type; + shd_deconstruct_qualified_type(&old_result_t); + bool must_lower = false; + // we have to lower generic pointers if we emulate them using ints + must_lower |= ctx->config->lower.emulate_generic_ptrs && old_base_ptr_t->payload.ptr_type.address_space == AsGeneric; + must_lower |= ctx->config->lower.emulate_physical_memory && !old_base_ptr_t->payload.ptr_type.is_reference && is_as_emulated(ctx, old_base_ptr_t->payload.ptr_type.address_space); + if (!must_lower) + break; + BodyBuilder* bb = shd_bld_begin_pure(a); + // Nodes new_ops = rewrite_nodes(&ctx->rewriter, old_ops); + const Node* cast_base = shd_bld_reinterpret_cast(bb, emulated_ptr_t, shd_rewrite_node(r, lea.ptr)); + const Type* new_base_t = shd_rewrite_node(&ctx->rewriter, old_base_ptr_t); + const Node* result = lower_ptr_offset(ctx, bb, new_base_t, cast_base, shd_rewrite_node(r, lea.offset)); + const Type* new_ptr_t = shd_rewrite_node(&ctx->rewriter, old_result_t); + const Node* cast_result = shd_bld_reinterpret_cast(bb, new_ptr_t, result); + return shd_bld_to_instr_yield_values(bb, shd_singleton(cast_result)); + } + case PtrCompositeElement_TAG: { + PtrCompositeElement lea = old->payload.ptr_composite_element; + const Node* old_base = lea.ptr; + const Type* old_base_ptr_t = old_base->type; + shd_deconstruct_qualified_type(&old_base_ptr_t); + assert(old_base_ptr_t->tag == PtrType_TAG); + const Node* old_result_t = old->type; + shd_deconstruct_qualified_type(&old_result_t); + bool must_lower = false; + // we have to lower generic pointers if we emulate them using ints + must_lower |= ctx->config->lower.emulate_generic_ptrs && old_base_ptr_t->payload.ptr_type.address_space == AsGeneric; + must_lower |= ctx->config->lower.emulate_physical_memory && !old_base_ptr_t->payload.ptr_type.is_reference && is_as_emulated(ctx, old_base_ptr_t->payload.ptr_type.address_space); + if (!must_lower) + break; + BodyBuilder* bb = shd_bld_begin_pure(a); + // Nodes new_ops = rewrite_nodes(&ctx->rewriter, old_ops); + const Node* cast_base = shd_bld_reinterpret_cast(bb, emulated_ptr_t, shd_rewrite_node(r, lea.ptr)); + const Type* new_base_t = shd_rewrite_node(&ctx->rewriter, old_base_ptr_t); + const Node* result = lower_ptr_index(ctx, bb, new_base_t, cast_base, shd_rewrite_node(r, lea.index)); + const Type* new_ptr_t = shd_rewrite_node(&ctx->rewriter, old_result_t); + const Node* cast_result = shd_bld_reinterpret_cast(bb, new_ptr_t, result); + return shd_bld_to_instr_yield_values(bb, shd_singleton(cast_result)); } default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -Module* lower_lea(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_lea(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process) + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_logical_pointers.c b/src/shady/passes/lower_logical_pointers.c new file mode 100644 index 000000000..b25bd6ae0 --- /dev/null +++ b/src/shady/passes/lower_logical_pointers.c @@ -0,0 +1,153 @@ +#include "shady/pass.h" +#include "shady/ir/memory_layout.h" + +#include "../ir_private.h" + +#include "log.h" +#include "portability.h" +#include "shady/ir.h" + +#include + +typedef struct { + Rewriter rewriter; + const CompilerConfig* config; +} Context; + +static const Node* guess_pointer_casts(Context* ctx, BodyBuilder* bb, const Node* ptr, const Type* expected_type) { + IrArena* a = ctx->rewriter.dst_arena; + while (true) { + const Type* actual_type = shd_get_unqualified_type(ptr->type); + assert(actual_type->tag == PtrType_TAG); + actual_type = shd_get_pointer_type_element(actual_type); + if (expected_type == actual_type) + break; + + actual_type = shd_get_maybe_nominal_type_body(actual_type); + assert(expected_type != actual_type && "todo: rework this function if we change how nominal types are handled"); + + switch (actual_type->tag) { + case RecordType_TAG: + case ArrType_TAG: + case PackType_TAG: { + ptr = lea_helper(a, ptr, shd_int32_literal(a, 0), shd_singleton(shd_int32_literal(a, 0))); + continue; + } + default: break; + } + shd_error("Cannot fix pointer") + } + return ptr; +} + +static const Node* process(Context* ctx, const Node* old) { + IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + + switch (old->tag) { + case PtrType_TAG: { + PtrType payload = old->payload.ptr_type; + if (!shd_get_arena_config(a)->address_spaces[payload.address_space].physical) + payload.is_reference = true; + payload.pointed_type = shd_rewrite_node(r, payload.pointed_type); + return ptr_type(a, payload); + } + /*case PtrArrayElementOffset_TAG: { + Lea payload = old->payload.lea; + const Type* optr_t = payload.ptr->type; + deconstruct_qualified_type(&optr_t); + assert(optr_t->tag == PtrType_TAG); + const Type* expected_type = rewrite_node(r, optr_t); + const Node* ptr = rewrite_node(r, payload.ptr); + const Type* actual_type = get_unqualified_type(ptr->type); + BodyBuilder* bb = begin_block_pure(a); + if (expected_type != actual_type) + ptr = guess_pointer_casts(ctx, bb, ptr, get_pointer_type_element(expected_type)); + return bind_last_instruction_and_wrap_in_block(bb, lea(a, (Lea) { .ptr = ptr, .offset = rewrite_node(r, payload.offset), .indices = rewrite_nodes(r, payload.indices)})); + }*/ + // TODO: we actually want to match stuff that has a ptr as an input operand. + case PtrCompositeElement_TAG: { + PtrCompositeElement payload = old->payload.ptr_composite_element; + const Type* optr_t = payload.ptr->type; + shd_deconstruct_qualified_type(&optr_t); + assert(optr_t->tag == PtrType_TAG); + const Type* expected_type = shd_rewrite_node(r, optr_t); + const Node* ptr = shd_rewrite_node(r, payload.ptr); + const Type* actual_type = shd_get_unqualified_type(ptr->type); + BodyBuilder* bb = shd_bld_begin_pure(a); + if (expected_type != actual_type) + ptr = guess_pointer_casts(ctx, bb, ptr, shd_get_pointer_type_element(expected_type)); + return shd_bld_to_instr_with_last_instr(bb, ptr_composite_element(a, (PtrCompositeElement) { .ptr = ptr, .index = shd_rewrite_node(r, payload.index) })); + } + case PrimOp_TAG: { + PrimOp payload = old->payload.prim_op; + switch (payload.op) { + case reinterpret_op: { + const Node* osrc = shd_first(payload.operands); + const Type* osrc_t = osrc->type; + shd_deconstruct_qualified_type(&osrc_t); + if (osrc_t->tag == PtrType_TAG && !shd_get_arena_config(a)->address_spaces[osrc_t->payload.ptr_type.address_space].physical) + return shd_rewrite_node(r, osrc); + break; + } + default: break; + } + break; + } + case Load_TAG: { + Load payload = old->payload.load; + const Type* optr_t = payload.ptr->type; + shd_deconstruct_qualified_type(&optr_t); + assert(optr_t->tag == PtrType_TAG); + const Type* expected_type = shd_rewrite_node(r, optr_t); + const Node* ptr = shd_rewrite_node(r, payload.ptr); + const Type* actual_type = shd_get_unqualified_type(ptr->type); + BodyBuilder* bb = shd_bld_begin_pure(a); + if (expected_type != actual_type) + ptr = guess_pointer_casts(ctx, bb, ptr, shd_get_pointer_type_element(expected_type)); + return load(a, (Load) { .ptr = shd_bld_to_instr_yield_value(bb, ptr), .mem = shd_rewrite_node(r, payload.mem) }); + } + case Store_TAG: { + Store payload = old->payload.store; + const Type* optr_t = payload.ptr->type; + shd_deconstruct_qualified_type(&optr_t); + assert(optr_t->tag == PtrType_TAG); + const Type* expected_type = shd_rewrite_node(r, optr_t); + const Node* ptr = shd_rewrite_node(r, payload.ptr); + const Type* actual_type = shd_get_unqualified_type(ptr->type); + BodyBuilder* bb = shd_bld_begin_pure(a); + if (expected_type != actual_type) + ptr = guess_pointer_casts(ctx, bb, ptr, shd_get_pointer_type_element(expected_type)); + return shd_bld_to_instr_with_last_instr(bb, store(a, (Store) { .ptr = ptr, .value = shd_rewrite_node(r, payload.value), .mem = shd_rewrite_node(r, payload.mem) })); + } + case GlobalVariable_TAG: { + AddressSpace as = old->payload.global_variable.address_space; + if (shd_get_arena_config(a)->address_spaces[as].physical) + break; + Nodes annotations = shd_rewrite_nodes(r, old->payload.global_variable.annotations); + annotations = shd_nodes_append(a, annotations, annotation(a, (Annotation) { .name = "Logical" })); + Node* new = global_var(ctx->rewriter.dst_module, annotations, shd_rewrite_node(r, old->payload.global_variable.type), old->payload.global_variable.name, as); + shd_recreate_node_body(r, old, new); + return new; + } + default: break; + } + + return shd_recreate_node(&ctx->rewriter, old); +} + +Module* shd_pass_lower_logical_pointers(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + aconfig.address_spaces[AsInput].physical = false; + aconfig.address_spaces[AsOutput].physical = false; + aconfig.address_spaces[AsUniformConstant].physical = false; + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config, + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/shady/passes/lower_mask.c b/src/shady/passes/lower_mask.c index b89e0b206..54f95d7a5 100644 --- a/src/shady/passes/lower_mask.c +++ b/src/shady/passes/lower_mask.c @@ -1,13 +1,10 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/type.h" +#include "shady/ir/cast.h" #include "log.h" #include "portability.h" -#include "../ir_private.h" -#include "../type.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" - typedef struct { Rewriter rewriter; const Node* zero; @@ -15,34 +12,29 @@ typedef struct { } Context; static const Node* process(Context* ctx, const Node* node) { - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - IrArena* a = ctx->rewriter.dst_arena; switch (node->tag) { - case MaskType_TAG: return get_actual_mask_type(ctx->rewriter.dst_arena); + case MaskType_TAG: return shd_get_actual_mask_type(ctx->rewriter.dst_arena); case PrimOp_TAG: { Op op = node->payload.prim_op.op; Nodes old_nodes = node->payload.prim_op.operands; switch(op) { - case empty_mask_op: return quote_helper(a, singleton(ctx->zero)); - case subgroup_active_mask_op: // this is just ballot(true) - return prim_op(a, (PrimOp) { .op = subgroup_ballot_op, .type_arguments = empty(a), .operands = singleton(true_lit(ctx->rewriter.dst_arena)) }); + case empty_mask_op: return ctx->zero; // extract the relevant bit case mask_is_thread_active_op: { - BodyBuilder* bb = begin_body(a); - const Node* mask = rewrite_node(&ctx->rewriter, old_nodes.nodes[0]); - const Node* index = rewrite_node(&ctx->rewriter, old_nodes.nodes[1]); - index = gen_conversion(bb, get_actual_mask_type(ctx->rewriter.dst_arena), index); + BodyBuilder* bb = shd_bld_begin_pure(a); + const Node* mask = shd_rewrite_node(&ctx->rewriter, old_nodes.nodes[0]); + const Node* index = shd_rewrite_node(&ctx->rewriter, old_nodes.nodes[1]); + index = shd_bld_conversion(bb, shd_get_actual_mask_type(ctx->rewriter.dst_arena), index); const Node* acc = mask; // acc >>= index - acc = gen_primop_ce(bb, rshift_logical_op, 2, (const Node* []) { acc, index }); + acc = prim_op_helper(a, rshift_logical_op, shd_empty(a), mk_nodes(a, acc, index)); // acc &= 0x1 - acc = gen_primop_ce(bb, and_op, 2, (const Node* []) { acc, ctx->one }); + acc = prim_op_helper(a, and_op, shd_empty(a), mk_nodes(a, acc, ctx->one)); // acc == 1 - acc = gen_primop_ce(bb, eq_op, 2, (const Node* []) { acc, ctx->one }); - return yield_values_and_wrap_in_block(bb, singleton(acc)); + acc = prim_op_helper(a, eq_op, shd_empty(a), mk_nodes(a, acc, ctx->one)); + return shd_bld_to_instr_yield_value(bb, acc); } default: break; } @@ -51,24 +43,24 @@ static const Node* process(Context* ctx, const Node* node) { default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_mask(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); +Module* shd_pass_lower_mask(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); aconfig.specializations.subgroup_mask_representation = SubgroupMaskInt64; - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); - const Type* mask_type = get_actual_mask_type(a); + const Type* mask_type = shd_get_actual_mask_type(a); assert(mask_type->tag == Int_TAG); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .zero = int_literal(a, (IntLiteral) { .width = mask_type->payload.int_type.width, .value = 0 }), .one = int_literal(a, (IntLiteral) { .width = mask_type->payload.int_type.width, .value = 1 }), }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_memcpy.c b/src/shady/passes/lower_memcpy.c index d076c2a3f..12901a000 100644 --- a/src/shady/passes/lower_memcpy.c +++ b/src/shady/passes/lower_memcpy.c @@ -1,9 +1,7 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/cast.h" +#include "shady/ir/memory_layout.h" -#include "../transform/ir_gen_helpers.h" -#include "../transform/memory_layout.h" -#include "../rewrite.h" -#include "../type.h" #include "../ir_private.h" #include "log.h" @@ -17,123 +15,124 @@ typedef struct { } Context; static const Node* process(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; - Module* m = ctx->rewriter.dst_module; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + Module* m = r->dst_module; switch (old->tag) { - case PrimOp_TAG: { - switch (old->payload.prim_op.op) { - case memcpy_op: { - const Type* word_type = int_type(a, (Int) { .is_signed = false, .width = a->config.memory.word_size }); - - BodyBuilder* bb = begin_body(a); - Nodes old_ops = old->payload.prim_op.operands; - - const Node* dst_addr = rewrite_node(&ctx->rewriter, old_ops.nodes[0]); - const Type* dst_addr_type = dst_addr->type; - deconstruct_qualified_type(&dst_addr_type); - assert(dst_addr_type->tag == PtrType_TAG); - dst_addr_type = ptr_type(a, (PtrType) { - .address_space = dst_addr_type->payload.ptr_type.address_space, - .pointed_type = arr_type(a, (ArrType) { .element_type = word_type, .size = NULL }), - }); - dst_addr = gen_reinterpret_cast(bb, dst_addr_type, dst_addr); - - const Node* src_addr = rewrite_node(&ctx->rewriter, old_ops.nodes[1]); - const Type* src_addr_type = src_addr->type; - deconstruct_qualified_type(&src_addr_type); - assert(src_addr_type->tag == PtrType_TAG); - src_addr_type = ptr_type(a, (PtrType) { - .address_space = src_addr_type->payload.ptr_type.address_space, - .pointed_type = arr_type(a, (ArrType) { .element_type = word_type, .size = NULL }), - }); - src_addr = gen_reinterpret_cast(bb, src_addr_type, src_addr); - - const Node* num = rewrite_node(&ctx->rewriter, old_ops.nodes[2]); - const Node* num_in_bytes = gen_conversion(bb, uint32_type(a), bytes_to_words(bb, num)); - - const Node* index = var(a, qualified_type_helper(uint32_type(a), false), "memcpy_i"); - BodyBuilder* loop_bb = begin_body(a); - const Node* loaded_word = gen_load(loop_bb, gen_lea(loop_bb, src_addr, index, singleton(uint32_literal(a, 0)))); - gen_store(loop_bb, gen_lea(loop_bb, dst_addr, index, singleton(uint32_literal(a, 0))), loaded_word); - const Node* next_index = gen_primop_e(loop_bb, add_op, empty(a), mk_nodes(a, index, uint32_literal(a, 1))); - bind_instruction(loop_bb, if_instr(a, (If) { - .condition = gen_primop_e(loop_bb, lt_op, empty(a), mk_nodes(a, next_index, num_in_bytes)), - .yield_types = empty(a), - .if_true = case_(a, empty(a), merge_continue(a, (MergeContinue) {.args = singleton(next_index)})), - .if_false = case_(a, empty(a), merge_break(a, (MergeBreak) {.args = empty(a)})) - })); - - bind_instruction(bb, loop_instr(a, (Loop) { - .yield_types = empty(a), - .body = case_(a, singleton(index), finish_body(loop_bb, unreachable(a))), - .initial_args = singleton(uint32_literal(a, 0)) - })); - return yield_values_and_wrap_in_block(bb, empty(a)); - } - case memset_op: { - Nodes old_ops = old->payload.prim_op.operands; - const Node* src_value = rewrite_node(&ctx->rewriter, old_ops.nodes[1]); - const Type* src_type = src_value->type; - deconstruct_qualified_type(&src_type); - assert(src_type->tag == Int_TAG); - const Type* word_type = src_type;// int_type(a, (Int) { .is_signed = false, .width = a->config.memory.word_size }); - - BodyBuilder* bb = begin_body(a); - - const Node* dst_addr = rewrite_node(&ctx->rewriter, old_ops.nodes[0]); - const Type* dst_addr_type = dst_addr->type; - deconstruct_qualified_type(&dst_addr_type); - assert(dst_addr_type->tag == PtrType_TAG); - dst_addr_type = ptr_type(a, (PtrType) { - .address_space = dst_addr_type->payload.ptr_type.address_space, - .pointed_type = arr_type(a, (ArrType) { .element_type = word_type, .size = NULL }), - }); - dst_addr = gen_reinterpret_cast(bb, dst_addr_type, dst_addr); - - const Node* num = rewrite_node(&ctx->rewriter, old_ops.nodes[2]); - const Node* num_in_bytes = gen_conversion(bb, uint32_type(a), bytes_to_words(bb, num)); - - const Node* index = var(a, qualified_type_helper(uint32_type(a), false), "memset_i"); - BodyBuilder* loop_bb = begin_body(a); - gen_store(loop_bb, gen_lea(loop_bb, dst_addr, index, singleton(uint32_literal(a, 0))), src_value); - const Node* next_index = gen_primop_e(loop_bb, add_op, empty(a), mk_nodes(a, index, uint32_literal(a, 1))); - bind_instruction(loop_bb, if_instr(a, (If) { - .condition = gen_primop_e(loop_bb, lt_op, empty(a), mk_nodes(a, next_index, num_in_bytes)), - .yield_types = empty(a), - .if_true = case_(a, empty(a), merge_continue(a, (MergeContinue) {.args = singleton(next_index)})), - .if_false = case_(a, empty(a), merge_break(a, (MergeBreak) {.args = empty(a)})) - })); - - bind_instruction(bb, loop_instr(a, (Loop) { - .yield_types = empty(a), - .body = case_(a, singleton(index), finish_body(loop_bb, unreachable(a))), - .initial_args = singleton(uint32_literal(a, 0)) - })); - return yield_values_and_wrap_in_block(bb, empty(a)); - } - default: break; - } - break; + case CopyBytes_TAG: { + CopyBytes payload = old->payload.copy_bytes; + const Type* word_type = int_type(a, (Int) { .is_signed = false, .width = a->config.memory.word_size }); + + BodyBuilder* bb = shd_bld_begin_pseudo_instr(a, shd_rewrite_node(r, payload.mem)); + + const Node* dst_addr = shd_rewrite_node(&ctx->rewriter, payload.dst); + const Type* dst_addr_type = dst_addr->type; + shd_deconstruct_qualified_type(&dst_addr_type); + assert(dst_addr_type->tag == PtrType_TAG); + dst_addr_type = ptr_type(a, (PtrType) { + .address_space = dst_addr_type->payload.ptr_type.address_space, + .pointed_type = word_type, + }); + dst_addr = shd_bld_reinterpret_cast(bb, dst_addr_type, dst_addr); + + const Node* src_addr = shd_rewrite_node(&ctx->rewriter, payload.src); + const Type* src_addr_type = src_addr->type; + shd_deconstruct_qualified_type(&src_addr_type); + assert(src_addr_type->tag == PtrType_TAG); + src_addr_type = ptr_type(a, (PtrType) { + .address_space = src_addr_type->payload.ptr_type.address_space, + .pointed_type = word_type, + }); + src_addr = shd_bld_reinterpret_cast(bb, src_addr_type, src_addr); + + const Node* num_in_bytes = shd_bld_convert_int_extend_according_to_dst_t(bb, size_t_type(a), shd_rewrite_node(&ctx->rewriter, payload.count)); + const Node* num_in_words = shd_bld_conversion(bb, shd_uint32_type(a), shd_bytes_to_words(bb, num_in_bytes)); + + begin_loop_helper_t l = shd_bld_begin_loop_helper(bb, shd_empty(a), shd_singleton(shd_uint32_type(a)), shd_singleton(shd_uint32_literal(a, 0))); + + const Node* index = shd_first(l.params); + shd_set_value_name(index, "memcpy_i"); + Node* loop_case = l.loop_body; + BodyBuilder* loop_bb = shd_bld_begin(a, shd_get_abstraction_mem(loop_case)); + const Node* loaded_word = shd_bld_load(loop_bb, lea_helper(a, src_addr, index, shd_empty(a))); + shd_bld_store(loop_bb, lea_helper(a, dst_addr, index, shd_empty(a)), loaded_word); + const Node* next_index = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, index, shd_uint32_literal(a, 1))); + + Node* true_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(true_case, join(a, (Join) { .join_point = l.continue_jp, .mem = shd_get_abstraction_mem(true_case), .args = shd_singleton(next_index) })); + Node* false_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(false_case, join(a, (Join) { .join_point = l.break_jp, .mem = shd_get_abstraction_mem(false_case), .args = shd_empty(a) })); + + shd_set_abstraction_body(loop_case, shd_bld_finish(loop_bb, branch(a, (Branch) { + .mem = shd_bb_mem(loop_bb), + .condition = prim_op_helper(a, lt_op, shd_empty(a), mk_nodes(a, next_index, num_in_words)), + .true_jump = jump_helper(a, shd_bb_mem(loop_bb), true_case, shd_empty(a)), + .false_jump = jump_helper(a, shd_bb_mem(loop_bb), false_case, shd_empty(a)), + }))); + + return shd_bld_to_instr_yield_values(bb, shd_empty(a)); + } + case FillBytes_TAG: { + FillBytes payload = old->payload.fill_bytes; + const Node* src_value = shd_rewrite_node(&ctx->rewriter, payload.src); + const Type* src_type = src_value->type; + shd_deconstruct_qualified_type(&src_type); + assert(src_type->tag == Int_TAG); + const Type* word_type = src_type;// int_type(a, (Int) { .is_signed = false, .width = a->config.memory.word_size }); + + BodyBuilder* bb = shd_bld_begin_pseudo_instr(a, shd_rewrite_node(r, payload.mem)); + + const Node* dst_addr = shd_rewrite_node(&ctx->rewriter, payload.dst); + const Type* dst_addr_type = dst_addr->type; + shd_deconstruct_qualified_type(&dst_addr_type); + assert(dst_addr_type->tag == PtrType_TAG); + dst_addr_type = ptr_type(a, (PtrType) { + .address_space = dst_addr_type->payload.ptr_type.address_space, + .pointed_type = word_type, + }); + dst_addr = shd_bld_reinterpret_cast(bb, dst_addr_type, dst_addr); + + const Node* num = shd_rewrite_node(&ctx->rewriter, payload.count); + const Node* num_in_words = shd_bld_conversion(bb, shd_uint32_type(a), shd_bytes_to_words(bb, num)); + + begin_loop_helper_t l = shd_bld_begin_loop_helper(bb, shd_empty(a), shd_singleton(shd_uint32_type(a)), shd_singleton(shd_uint32_literal(a, 0))); + + const Node* index = shd_first(l.params); + shd_set_value_name(index, "memset_i"); + Node* loop_case = l.loop_body; + BodyBuilder* loop_bb = shd_bld_begin(a, shd_get_abstraction_mem(loop_case)); + shd_bld_store(loop_bb, lea_helper(a, dst_addr, index, shd_empty(a)), src_value); + const Node* next_index = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, index, shd_uint32_literal(a, 1))); + + Node* true_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(true_case, join(a, (Join) { .join_point = l.continue_jp, .mem = shd_get_abstraction_mem(true_case), .args = shd_singleton(next_index) })); + Node* false_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(false_case, join(a, (Join) { .join_point = l.break_jp, .mem = shd_get_abstraction_mem(false_case), .args = shd_empty(a) })); + + shd_set_abstraction_body(loop_case, shd_bld_finish(loop_bb, branch(a, (Branch) { + .mem = shd_bb_mem(loop_bb), + .condition = prim_op_helper(a, lt_op, shd_empty(a), mk_nodes(a, next_index, num_in_words)), + .true_jump = jump_helper(a, shd_bb_mem(loop_bb), true_case, shd_empty(a)), + .false_jump = jump_helper(a, shd_bb_mem(loop_bb), false_case, shd_empty(a)), + }))); + return shd_bld_to_instr_yield_values(bb, shd_empty(a)); } default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -Module* lower_memcpy(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_memcpy(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process) + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process) }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_memory_layout.c b/src/shady/passes/lower_memory_layout.c index 856ba78a1..b93bd04a4 100644 --- a/src/shady/passes/lower_memory_layout.c +++ b/src/shady/passes/lower_memory_layout.c @@ -1,11 +1,10 @@ -#include "passes.h" - -#include "../transform/memory_layout.h" -#include "../rewrite.h" -#include "../type.h" +#include "shady/pass.h" +#include "shady/ir/memory_layout.h" +#include "shady/ir/type.h" #include "log.h" #include "portability.h" + #include typedef struct { @@ -13,33 +12,30 @@ typedef struct { } Context; static const Node* process(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; - IrArena* a = ctx->rewriter.dst_arena; switch (old->tag) { case PrimOp_TAG: { switch (old->payload.prim_op.op) { case size_of_op: { - const Type* t = rewrite_node(&ctx->rewriter, first(old->payload.prim_op.type_arguments)); - TypeMemLayout layout = get_mem_layout(a, t); - return quote_helper(a, singleton(int_literal(a, (IntLiteral) {.width = a->config.memory.ptr_size, .is_signed = false, .value = layout.size_in_bytes}))); + const Type* t = shd_rewrite_node(&ctx->rewriter, shd_first(old->payload.prim_op.type_arguments)); + TypeMemLayout layout = shd_get_mem_layout(a, t); + return int_literal(a, (IntLiteral) {.width = shd_get_arena_config(a)->memory.ptr_size, .is_signed = false, .value = layout.size_in_bytes}); } case align_of_op: { - const Type* t = rewrite_node(&ctx->rewriter, first(old->payload.prim_op.type_arguments)); - TypeMemLayout layout = get_mem_layout(a, t); - return quote_helper(a, singleton(int_literal(a, (IntLiteral) {.width = a->config.memory.ptr_size, .is_signed = false, .value = layout.alignment_in_bytes}))); + const Type* t = shd_rewrite_node(&ctx->rewriter, shd_first(old->payload.prim_op.type_arguments)); + TypeMemLayout layout = shd_get_mem_layout(a, t); + return int_literal(a, (IntLiteral) {.width = shd_get_arena_config(a)->memory.ptr_size, .is_signed = false, .value = layout.alignment_in_bytes}); } case offset_of_op: { - const Type* t = rewrite_node(&ctx->rewriter, first(old->payload.prim_op.type_arguments)); - const Node* n = rewrite_node(&ctx->rewriter, first(old->payload.prim_op.operands)); - const IntLiteral* literal = resolve_to_int_literal(n); + const Type* t = shd_rewrite_node(&ctx->rewriter, shd_first(old->payload.prim_op.type_arguments)); + const Node* n = shd_rewrite_node(&ctx->rewriter, shd_first(old->payload.prim_op.operands)); + const IntLiteral* literal = shd_resolve_to_int_literal(n); assert(literal); - t = get_maybe_nominal_type_body(t); - uint64_t offset_in_bytes = (uint64_t) get_record_field_offset_in_bytes(a, t, literal->value); - const Node* offset_literal = int_literal(a, (IntLiteral) { .width = a->config.memory.ptr_size, .is_signed = false, .value = offset_in_bytes }); - return quote_helper(a, singleton(offset_literal)); + t = shd_get_maybe_nominal_type_body(t); + uint64_t offset_in_bytes = (uint64_t) shd_get_record_field_offset_in_bytes(a, t, literal->value); + const Node* offset_literal = int_literal(a, (IntLiteral) { .width = shd_get_arena_config(a)->memory.ptr_size, .is_signed = false, .value = offset_in_bytes }); + return offset_literal; } default: break; } @@ -48,19 +44,18 @@ static const Node* process(Context* ctx, const Node* old) { default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -Module* lower_memory_layout(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_memory_layout(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process) + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process) }; - ctx.rewriter.config.rebind_let = true; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_nullptr.c b/src/shady/passes/lower_nullptr.c new file mode 100644 index 000000000..e0c7e471f --- /dev/null +++ b/src/shady/passes/lower_nullptr.c @@ -0,0 +1,62 @@ +#include "shady/pass.h" +#include "shady/ir/cast.h" + +#include "../ir_private.h" + +#include "log.h" +#include "portability.h" +#include "dict.h" + +typedef struct { + Rewriter rewriter; + struct Dict* map; +} Context; + +static const Node* make_nullptr(Context* ctx, const Type* t) { + IrArena* a = ctx->rewriter.dst_arena; + const Node** found = shd_dict_find_value(const Type*, const Node*, ctx->map, t); + if (found) + return *found; + + BodyBuilder* bb = shd_bld_begin_pure(a); + const Node* nul = shd_bld_reinterpret_cast(bb, t, shd_uint64_literal(a, 0)); + Node* decl = constant(ctx->rewriter.dst_module, shd_singleton(annotation(a, (Annotation) { + .name = "Generated", + })), t, shd_fmt_string_irarena(a, "nullptr_%s", shd_get_type_name(a, t))); + decl->payload.constant.value = shd_bld_to_instr_pure_with_values(bb, shd_singleton(nul)); + const Node* ref = ref_decl_helper(a, decl); + shd_dict_insert(const Type*, const Node*, ctx->map, t, ref); + return ref; +} + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (node->tag) { + case NullPtr_TAG: { + const Type* t = shd_rewrite_node(r, node->payload.null_ptr.ptr_type); + assert(t->tag == PtrType_TAG); + return make_nullptr(ctx, t); + } + default: break; + } + + return shd_recreate_node(&ctx->rewriter, node); +} + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +Module* shd_pass_lower_nullptr(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .map = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_dict(ctx.map); + return dst; +} diff --git a/src/shady/passes/lower_physical_ptrs.c b/src/shady/passes/lower_physical_ptrs.c index 3c68e6254..cb187f2e0 100644 --- a/src/shady/passes/lower_physical_ptrs.c +++ b/src/shady/passes/lower_physical_ptrs.c @@ -1,16 +1,12 @@ -#include "passes.h" - -#include "../transform/ir_gen_helpers.h" -#include "../transform/memory_layout.h" +#include "shady/pass.h" +#include "shady/ir/cast.h" +#include "shady/ir/memory_layout.h" #include "../ir_private.h" -#include "../rewrite.h" -#include "../type.h" #include "log.h" #include "portability.h" #include "util.h" - #include "list.h" #include "dict.h" @@ -39,142 +35,142 @@ static void store_init_data(Context* ctx, AddressSpace as, Nodes collected, Body // TODO: make this configuration-dependant static bool is_as_emulated(SHADY_UNUSED Context* ctx, AddressSpace as) { switch (as) { - case AsPrivatePhysical: return true; // TODO have a config option to do this with swizzled global memory - case AsSubgroupPhysical: return true; - case AsSharedPhysical: return true; - case AsGlobalPhysical: return false; // TODO have a config option to do this with SSBOs + case AsPrivate: return true; // TODO have a config option to do this with swizzled global memory + case AsSubgroup: return true; + case AsShared: return true; + case AsGlobal: return false; // TODO have a config option to do this with SSBOs default: return false; } } static const Node** get_emulated_as_word_array(Context* ctx, AddressSpace as) { switch (as) { - case AsPrivatePhysical: return &ctx->fake_private_memory; - case AsSubgroupPhysical: return &ctx->fake_subgroup_memory; - case AsSharedPhysical: return &ctx->fake_shared_memory; - default: error("Emulation of this AS is not supported"); + case AsPrivate: return &ctx->fake_private_memory; + case AsSubgroup: return &ctx->fake_subgroup_memory; + case AsShared: return &ctx->fake_shared_memory; + default: shd_error("Emulation of this AS is not supported"); } } -static const Node* gen_deserialisation(Context* ctx, BodyBuilder* bb, const Type* element_type, const Node* arr, const Node* base_offset) { +static const Node* gen_deserialisation(Context* ctx, BodyBuilder* bb, const Type* element_type, const Node* arr, const Node* address) { IrArena* a = ctx->rewriter.dst_arena; const CompilerConfig* config = ctx->config; const Node* zero = size_t_literal(a, 0); switch (element_type->tag) { case Bool_TAG: { - const Node* logical_ptr = gen_primop_ce(bb, lea_op, 3, (const Node* []) { arr, zero, base_offset }); - const Node* value = gen_load(bb, logical_ptr); - return gen_primop_ce(bb, neq_op, 2, (const Node*[]) {value, int_literal(a, (IntLiteral) { .value = 0, .width = a->config.memory.word_size })}); + const Node* logical_ptr = lea_helper(a, arr, zero, shd_singleton(address)); + const Node* value = shd_bld_load(bb, logical_ptr); + return prim_op_helper(a, neq_op, shd_empty(a), mk_nodes(a, value, int_literal(a, (IntLiteral) { .value = 0, .width = a->config.memory.word_size }))); } case PtrType_TAG: switch (element_type->payload.ptr_type.address_space) { - case AsGlobalPhysical: { + case AsGlobal: { + // TODO: add a per-as size configuration const Type* ptr_int_t = int_type(a, (Int) {.width = a->config.memory.ptr_size, .is_signed = false }); - const Node* unsigned_int = gen_deserialisation(ctx, bb, ptr_int_t, arr, base_offset); - return gen_reinterpret_cast(bb, element_type, unsigned_int); + const Node* unsigned_int = gen_deserialisation(ctx, bb, ptr_int_t, arr, address); + return shd_bld_reinterpret_cast(bb, element_type, unsigned_int); } - default: error("TODO") + default: shd_error("TODO") } case Int_TAG: ser_int: { assert(element_type->tag == Int_TAG); const Node* acc = int_literal(a, (IntLiteral) { .width = element_type->payload.int_type.width, .is_signed = false, .value = 0 }); size_t length_in_bytes = int_size_in_bytes(element_type->payload.int_type.width); size_t word_size_in_bytes = int_size_in_bytes(a->config.memory.word_size); - const Node* offset = base_offset; + const Node* offset = shd_bytes_to_words(bb, address); const Node* shift = int_literal(a, (IntLiteral) { .width = element_type->payload.int_type.width, .is_signed = false, .value = 0 }); const Node* word_bitwidth = int_literal(a, (IntLiteral) { .width = element_type->payload.int_type.width, .is_signed = false, .value = word_size_in_bytes * 8 }); for (size_t byte = 0; byte < length_in_bytes; byte += word_size_in_bytes) { - const Node* word = gen_load(bb, gen_primop_ce(bb, lea_op, 3, (const Node* []) {arr, zero, offset})); - word = gen_conversion(bb, int_type(a, (Int) { .width = element_type->payload.int_type.width, .is_signed = false }), word); // widen/truncate the word we just loaded - word = first(gen_primop(bb, lshift_op, empty(a), mk_nodes(a, word, shift))); // shift it - acc = gen_primop_e(bb, or_op, empty(a), mk_nodes(a, acc, word)); + const Node* word = shd_bld_load(bb, lea_helper(a, arr, zero, shd_singleton(offset))); + word = shd_bld_conversion(bb, int_type(a, (Int) { .width = element_type->payload.int_type.width, .is_signed = false }), word); // widen/truncate the word we just loaded + word = prim_op_helper(a, lshift_op, shd_empty(a), mk_nodes(a, word, shift)); // shift it + acc = prim_op_helper(a, or_op, shd_empty(a), mk_nodes(a, acc, word)); - offset = first(gen_primop(bb, add_op, empty(a), mk_nodes(a, offset, size_t_literal(a, 1)))); - shift = first(gen_primop(bb, add_op, empty(a), mk_nodes(a, shift, word_bitwidth))); + offset = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, offset, size_t_literal(a, 1))); + shift = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, shift, word_bitwidth)); } if (config->printf_trace.memory_accesses) { - AddressSpace as = get_unqualified_type(arr->type)->payload.ptr_type.address_space; - String template = format_string_interned(a, "loaded %s at %s:%s\n", element_type->payload.int_type.width == IntTy64 ? "%lu" : "%u", get_address_space_name(as), "%lx"); + AddressSpace as = shd_get_unqualified_type(arr->type)->payload.ptr_type.address_space; + String template = shd_fmt_string_irarena(a, "loaded %s at %s:0x%s\n", element_type->payload.int_type.width == IntTy64 ? "%lu" : "%u", shd_get_address_space_name(as), "%lx"); const Node* widened = acc; if (element_type->payload.int_type.width < IntTy32) - widened = gen_conversion(bb, uint32_type(a), acc); - bind_instruction(bb, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = template }), widened, base_offset) })); + widened = shd_bld_conversion(bb, shd_uint32_type(a), acc); + shd_bld_debug_printf(bb, template, mk_nodes(a, widened, address)); } - acc = gen_reinterpret_cast(bb, int_type(a, (Int) { .width = element_type->payload.int_type.width, .is_signed = element_type->payload.int_type.is_signed }), acc);\ + acc = shd_bld_reinterpret_cast(bb, int_type(a, (Int) { .width = element_type->payload.int_type.width, .is_signed = element_type->payload.int_type.is_signed }), acc);\ return acc; } case Float_TAG: { - const Type* unsigned_int_t = int_type(a, (Int) {.width = float_to_int_width(element_type->payload.float_type.width), .is_signed = false }); - const Node* unsigned_int = gen_deserialisation(ctx, bb, unsigned_int_t, arr, base_offset); - return gen_reinterpret_cast(bb, element_type, unsigned_int); + const Type* unsigned_int_t = int_type(a, (Int) {.width = shd_float_to_int_width(element_type->payload.float_type.width), .is_signed = false }); + const Node* unsigned_int = gen_deserialisation(ctx, bb, unsigned_int_t, arr, address); + return shd_bld_reinterpret_cast(bb, element_type, unsigned_int); } case TypeDeclRef_TAG: case RecordType_TAG: { const Type* compound_type = element_type; - compound_type = get_maybe_nominal_type_body(compound_type); + compound_type = shd_get_maybe_nominal_type_body(compound_type); Nodes member_types = compound_type->payload.record_type.members; LARRAY(const Node*, loaded, member_types.count); for (size_t i = 0; i < member_types.count; i++) { - const Node* field_offset = gen_primop_e(bb, offset_of_op, singleton(element_type), singleton(size_t_literal(a, i))); - field_offset = bytes_to_words(bb, field_offset); - const Node* adjusted_offset = gen_primop_e(bb, add_op, empty(a), mk_nodes(a, base_offset, field_offset)); + const Node* field_offset = prim_op_helper(a, offset_of_op, shd_singleton(element_type), shd_singleton(size_t_literal(a, i))); + const Node* adjusted_offset = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, address, field_offset)); loaded[i] = gen_deserialisation(ctx, bb, member_types.nodes[i], arr, adjusted_offset); } - return composite_helper(a, element_type, nodes(a, member_types.count, loaded)); + return composite_helper(a, element_type, shd_nodes(a, member_types.count, loaded)); } case ArrType_TAG: case PackType_TAG: { - const Node* size = get_fill_type_size(element_type); + const Node* size = shd_get_fill_type_size(element_type); if (size->tag != IntLiteral_TAG) { - error_print("Size of type "); - log_node(ERROR, element_type); - error_print(" is not known a compile-time!\n"); + shd_error_print("Size of type "); + shd_log_node(ERROR, element_type); + shd_error_print(" is not known a compile-time!\n"); } - size_t components_count = get_int_literal_value(*resolve_to_int_literal(size), 0); - const Type* component_type = get_fill_type_element_type(element_type); + size_t components_count = shd_get_int_literal_value(*shd_resolve_to_int_literal(size), 0); + const Type* component_type = shd_get_fill_type_element_type(element_type); LARRAY(const Node*, components, components_count); - const Node* offset = base_offset; + const Node* offset = address; for (size_t i = 0; i < components_count; i++) { components[i] = gen_deserialisation(ctx, bb, component_type, arr, offset); - offset = gen_primop_e(bb, add_op, empty(a), mk_nodes(a, offset, gen_primop_e(bb, size_of_op, singleton(component_type), empty(a)))); + offset = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, offset, prim_op_helper(a, size_of_op, shd_singleton(component_type), shd_empty(a)))); } - return composite_helper(a, element_type, nodes(a, components_count, components)); + return composite_helper(a, element_type, shd_nodes(a, components_count, components)); } - default: error("TODO"); + default: shd_error("TODO"); } } -static void gen_serialisation(Context* ctx, BodyBuilder* bb, const Type* element_type, const Node* arr, const Node* base_offset, const Node* value) { +static void gen_serialisation(Context* ctx, BodyBuilder* bb, const Type* element_type, const Node* arr, const Node* address, const Node* value) { IrArena* a = ctx->rewriter.dst_arena; const CompilerConfig* config = ctx->config; const Node* zero = size_t_literal(a, 0); switch (element_type->tag) { case Bool_TAG: { - const Node* logical_ptr = gen_primop_ce(bb, lea_op, 3, (const Node* []) { arr, zero, base_offset }); + const Node* logical_ptr = lea_helper(a, arr, zero, shd_singleton(address)); const Node* zero_b = int_literal(a, (IntLiteral) { .value = 1, .width = a->config.memory.word_size }); const Node* one_b = int_literal(a, (IntLiteral) { .value = 0, .width = a->config.memory.word_size }); - const Node* int_value = gen_primop_ce(bb, select_op, 3, (const Node*[]) { value, one_b, zero_b }); - gen_store(bb, logical_ptr, int_value); + const Node* int_value = prim_op_helper(a, select_op, shd_empty(a), mk_nodes(a, value, one_b, zero_b)); + shd_bld_store(bb, logical_ptr, int_value); return; } case PtrType_TAG: switch (element_type->payload.ptr_type.address_space) { - case AsGlobalPhysical: { + case AsGlobal: { const Type* ptr_int_t = int_type(a, (Int) {.width = a->config.memory.ptr_size, .is_signed = false }); - const Node* unsigned_value = gen_primop_e(bb, reinterpret_op, singleton(ptr_int_t), singleton(value)); - return gen_serialisation(ctx, bb, ptr_int_t, arr, base_offset, unsigned_value); + const Node* unsigned_value = prim_op_helper(a, reinterpret_op, shd_singleton(ptr_int_t), shd_singleton(value)); + return gen_serialisation(ctx, bb, ptr_int_t, arr, address, unsigned_value); } - default: error("TODO") + default: shd_error("TODO") } case Int_TAG: des_int: { assert(element_type->tag == Int_TAG); // First bitcast to unsigned so we always get zero-extension and not sign-extension afterwards const Type* element_t_unsigned = int_type(a, (Int) { .width = element_type->payload.int_type.width, .is_signed = false}); - value = convert_int_extend_according_to_src_t(bb, element_t_unsigned, value); + value = shd_bld_convert_int_extend_according_to_src_t(bb, element_t_unsigned, value); // const Node* acc = int_literal(a, (IntLiteral) { .width = element_type->payload.int_type.width, .is_signed = false, .value = 0 }); size_t length_in_bytes = int_size_in_bytes(element_type->payload.int_type.width); size_t word_size_in_bytes = int_size_in_bytes(a->config.memory.word_size); - const Node* offset = base_offset; + const Node* offset = shd_bytes_to_words(bb, address); const Node* shift = int_literal(a, (IntLiteral) { .width = element_type->payload.int_type.width, .is_signed = false, .value = 0 }); const Node* word_bitwidth = int_literal(a, (IntLiteral) { .width = element_type->payload.int_type.width, .is_signed = false, .value = word_size_in_bytes * 8 }); for (size_t byte = 0; byte < length_in_bytes; byte += word_size_in_bytes) { @@ -182,41 +178,40 @@ static void gen_serialisation(Context* ctx, BodyBuilder* bb, const Type* element /*bool needs_patch = is_last_word && word_size_in_bytes < length_in_bytes; const Node* original_word = NULL; if (needs_patch) { - original_word = gen_load(bb, gen_primop_ce(bb, lea_op, 3, (const Node* []) {arr, zero, offset})); - error_print("TODO"); - error_die(); + original_word = gen_load(bb, gen_lea(bb, arr, zero, singleton(base_offset))); + shd_error_print("TODO"); + shd_error_die(); // word = gen_conversion(bb, int_type(a, (Int) { .width = element_type->payload.int_type.width, .is_signed = false }), word); // widen/truncate the word we just loaded }*/ const Node* word = value; - word = first(gen_primop(bb, rshift_logical_op, empty(a), mk_nodes(a, word, shift))); // shift it - word = gen_conversion(bb, int_type(a, (Int) { .width = a->config.memory.word_size, .is_signed = false }), word); // widen/truncate the word we want to store - gen_store(bb, gen_primop_ce(bb, lea_op, 3, (const Node* []) {arr, zero, offset}), word); + word = (prim_op_helper(a, rshift_logical_op, shd_empty(a), mk_nodes(a, word, shift))); // shift it + word = shd_bld_conversion(bb, int_type(a, (Int) { .width = a->config.memory.word_size, .is_signed = false }), word); // widen/truncate the word we want to store + shd_bld_store(bb, lea_helper(a, arr, zero, shd_singleton(offset)), word); - offset = first(gen_primop(bb, add_op, empty(a), mk_nodes(a, offset, size_t_literal(a, 1)))); - shift = first(gen_primop(bb, add_op, empty(a), mk_nodes(a, shift, word_bitwidth))); + offset = (prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, offset, size_t_literal(a, 1)))); + shift = (prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, shift, word_bitwidth))); } if (config->printf_trace.memory_accesses) { - AddressSpace as = get_unqualified_type(arr->type)->payload.ptr_type.address_space; - String template = format_string_interned(a, "stored %s at %s:%s\n", element_type->payload.int_type.width == IntTy64 ? "%lu" : "%u", get_address_space_name(as), "%lx"); + AddressSpace as = shd_get_unqualified_type(arr->type)->payload.ptr_type.address_space; + String template = shd_fmt_string_irarena(a, "stored %s at %s:0x%s\n", element_type->payload.int_type.width == IntTy64 ? "%lu" : "%u", shd_get_address_space_name(as), "%lx"); const Node* widened = value; if (element_type->payload.int_type.width < IntTy32) - widened = gen_conversion(bb, uint32_type(a), value); - bind_instruction(bb, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = template }), widened, base_offset) })); + widened = shd_bld_conversion(bb, shd_uint32_type(a), value); + shd_bld_debug_printf(bb, template, mk_nodes(a, widened, address)); } return; } case Float_TAG: { - const Type* unsigned_int_t = int_type(a, (Int) {.width = float_to_int_width(element_type->payload.float_type.width), .is_signed = false }); - const Node* unsigned_value = gen_primop_e(bb, reinterpret_op, singleton(unsigned_int_t), singleton(value)); - return gen_serialisation(ctx, bb, unsigned_int_t, arr, base_offset, unsigned_value); + const Type* unsigned_int_t = int_type(a, (Int) {.width = shd_float_to_int_width(element_type->payload.float_type.width), .is_signed = false }); + const Node* unsigned_value = prim_op_helper(a, reinterpret_op, shd_singleton(unsigned_int_t), shd_singleton(value)); + return gen_serialisation(ctx, bb, unsigned_int_t, arr, address, unsigned_value); } case RecordType_TAG: { Nodes member_types = element_type->payload.record_type.members; for (size_t i = 0; i < member_types.count; i++) { - const Node* extracted_value = first(bind_instruction(bb, prim_op(a, (PrimOp) { .op = extract_op, .operands = mk_nodes(a, value, int32_literal(a, i)), .type_arguments = empty(a) }))); - const Node* field_offset = gen_primop_e(bb, offset_of_op, singleton(element_type), singleton(size_t_literal(a, i))); - field_offset = bytes_to_words(bb, field_offset); - const Node* adjusted_offset = gen_primop_e(bb, add_op, empty(a), mk_nodes(a, base_offset, field_offset)); + const Node* extracted_value = prim_op(a, (PrimOp) { .op = extract_op, .operands = mk_nodes(a, value, shd_int32_literal(a, i)), .type_arguments = shd_empty(a) }); + const Node* field_offset = prim_op_helper(a, offset_of_op, shd_singleton(element_type), shd_singleton(size_t_literal(a, i))); + const Node* adjusted_offset = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, address, field_offset)); gen_serialisation(ctx, bb, member_types.nodes[i], arr, adjusted_offset, extracted_value); } return; @@ -224,27 +219,27 @@ static void gen_serialisation(Context* ctx, BodyBuilder* bb, const Type* element case TypeDeclRef_TAG: { const Node* nom = element_type->payload.type_decl_ref.decl; assert(nom && nom->tag == NominalType_TAG); - gen_serialisation(ctx, bb, nom->payload.nom_type.body, arr, base_offset, value); + gen_serialisation(ctx, bb, nom->payload.nom_type.body, arr, address, value); return; } case ArrType_TAG: case PackType_TAG: { - const Node* size = get_fill_type_size(element_type); + const Node* size = shd_get_fill_type_size(element_type); if (size->tag != IntLiteral_TAG) { - error_print("Size of type "); - log_node(ERROR, element_type); - error_print(" is not known a compile-time!\n"); + shd_error_print("Size of type "); + shd_log_node(ERROR, element_type); + shd_error_print(" is not known a compile-time!\n"); } - size_t components_count = get_int_literal_value(*resolve_to_int_literal(size), 0); - const Type* component_type = get_fill_type_element_type(element_type); - const Node* offset = base_offset; + size_t components_count = shd_get_int_literal_value(*shd_resolve_to_int_literal(size), 0); + const Type* component_type = shd_get_fill_type_element_type(element_type); + const Node* offset = address; for (size_t i = 0; i < components_count; i++) { - gen_serialisation(ctx, bb, component_type, arr, offset, gen_extract(bb, value, singleton(int32_literal(a, i)))); - offset = gen_primop_e(bb, add_op, empty(a), mk_nodes(a, offset, gen_primop_e(bb, size_of_op, singleton(component_type), empty(a)))); + gen_serialisation(ctx, bb, component_type, arr, offset, shd_extract_helper(a, value, shd_singleton(shd_int32_literal(a, i)))); + offset = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, offset, prim_op_helper(a, size_of_op, shd_singleton(component_type), shd_empty(a)))); } return; } - default: error("TODO"); + default: shd_error("TODO"); } } @@ -257,82 +252,78 @@ static const Node* gen_serdes_fn(Context* ctx, const Type* element_type, bool un else cache = ser ? ctx->serialisation_varying[as] : ctx->deserialisation_varying[as]; - const Node** found = find_value_dict(const Node*, const Node*, cache, element_type); + const Node** found = shd_dict_find_value(const Node*, const Node*, cache, element_type); if (found) return *found; IrArena* a = ctx->rewriter.dst_arena; const Type* emulated_ptr_type = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false }); - const Node* address_param = var(a, qualified_type(a, (QualifiedType) { .is_uniform = !a->config.is_simt || uniform_address, .type = emulated_ptr_type }), "ptr"); + const Node* address_param = param(a, qualified_type(a, (QualifiedType) { .is_uniform = !a->config.is_simt || uniform_address, .type = emulated_ptr_type }), "ptr"); - const Type* input_value_t = qualified_type(a, (QualifiedType) { .is_uniform = !a->config.is_simt || (uniform_address && is_addr_space_uniform(a, as) && false), .type = element_type }); - const Node* value_param = ser ? var(a, input_value_t, "value") : NULL; - Nodes params = ser ? mk_nodes(a, address_param, value_param) : singleton(address_param); + const Type* input_value_t = qualified_type(a, (QualifiedType) { .is_uniform = !a->config.is_simt || (uniform_address && shd_is_addr_space_uniform(a, as) && false), .type = element_type }); + const Node* value_param = ser ? param(a, input_value_t, "value") : NULL; + Nodes params = ser ? mk_nodes(a, address_param, value_param) : shd_singleton(address_param); - const Type* return_value_t = qualified_type(a, (QualifiedType) { .is_uniform = !a->config.is_simt || (uniform_address && is_addr_space_uniform(a, as)), .type = element_type }); - Nodes return_ts = ser ? empty(a) : singleton(return_value_t); + const Type* return_value_t = qualified_type(a, (QualifiedType) { .is_uniform = !a->config.is_simt || (uniform_address && shd_is_addr_space_uniform(a, as)), .type = element_type }); + Nodes return_ts = ser ? shd_empty(a) : shd_singleton(return_value_t); - String name = format_string_arena(a->arena, "generated_%s_%s_%s_%s", ser ? "store" : "load", get_address_space_name(as), uniform_address ? "uniform" : "varying", name_type_safe(a, element_type)); - Node* fun = function(ctx->rewriter.dst_module, params, name, singleton(annotation(a, (Annotation) { .name = "Generated" })), return_ts); - insert_dict(const Node*, Node*, cache, element_type, fun); + String name = shd_format_string_arena(a->arena, "generated_%s_%s_%s_%s", ser ? "store" : "load", shd_get_address_space_name(as), uniform_address ? "uniform" : "varying", shd_get_type_name(a, element_type)); + Node* fun = function(ctx->rewriter.dst_module, params, name, mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }), annotation(a, (Annotation) { .name = "Leaf" })), return_ts); + shd_dict_insert(const Node*, Node*, cache, element_type, fun); - BodyBuilder* bb = begin_body(a); - const Node* address = bytes_to_words(bb, address_param); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); const Node* base = *get_emulated_as_word_array(ctx, as); if (ser) { - gen_serialisation(ctx, bb, element_type, base, address, value_param); - fun->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .fn = fun, .args = empty(a) })); + gen_serialisation(ctx, bb, element_type, base, address_param, value_param); + shd_set_abstraction_body(fun, shd_bld_return(bb, shd_empty(a))); } else { - const Node* loaded_value = gen_deserialisation(ctx, bb, element_type, base, address); + const Node* loaded_value = gen_deserialisation(ctx, bb, element_type, base, address_param); assert(loaded_value); - fun->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .fn = fun, .args = singleton(loaded_value) })); + shd_set_abstraction_body(fun, shd_bld_return(bb, shd_singleton(loaded_value))); } return fun; } static const Node* process_node(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (old->tag) { - case PrimOp_TAG: { - const PrimOp* oprim_op = &old->payload.prim_op; - switch (oprim_op->op) { - case alloca_subgroup_op: - case alloca_op: error("This needs to be lowered (see setup_stack_frames.c)") - // lowering for either kind of memory accesses is similar - case load_op: - case store_op: { - const Node* old_ptr = oprim_op->operands.nodes[0]; - const Type* ptr_type = old_ptr->type; - bool uniform_ptr = deconstruct_qualified_type(&ptr_type); - assert(ptr_type->tag == PtrType_TAG); - if (!is_as_emulated(ctx, ptr_type->payload.ptr_type.address_space)) - break; - BodyBuilder* bb = begin_body(a); - - const Type* element_type = rewrite_node(&ctx->rewriter, ptr_type->payload.ptr_type.pointed_type); - const Node* pointer_as_offset = rewrite_node(&ctx->rewriter, old_ptr); - const Node* fn = gen_serdes_fn(ctx, element_type, uniform_ptr, oprim_op->op == store_op, ptr_type->payload.ptr_type.address_space); - - if (oprim_op->op == load_op) { - Nodes r = bind_instruction(bb, call(a, (Call) {.callee = fn_addr_helper(a, fn), .args = singleton(pointer_as_offset)})); - return yield_values_and_wrap_in_block(bb, r); - } else { - const Node* value = rewrite_node(&ctx->rewriter, oprim_op->operands.nodes[1]); - bind_instruction(bb, call(a, (Call) { .callee = fn_addr_helper(a, fn), .args = mk_nodes(a, pointer_as_offset, value) })); - return yield_values_and_wrap_in_block(bb, empty(a)); - } - } - default: break; - } - break; + case Load_TAG: { + Load payload = old->payload.load; + const Type* ptr_type = payload.ptr->type; + bool uniform_ptr = shd_deconstruct_qualified_type(&ptr_type); + assert(ptr_type->tag == PtrType_TAG); + if (ptr_type->payload.ptr_type.is_reference || !is_as_emulated(ctx, ptr_type->payload.ptr_type.address_space)) + break; + BodyBuilder* bb = shd_bld_begin_pseudo_instr(a, shd_rewrite_node(r, payload.mem)); + const Type* element_type = shd_rewrite_node(&ctx->rewriter, ptr_type->payload.ptr_type.pointed_type); + const Node* pointer_as_offset = shd_rewrite_node(&ctx->rewriter, payload.ptr); + const Node* fn = gen_serdes_fn(ctx, element_type, uniform_ptr, false, ptr_type->payload.ptr_type.address_space); + Nodes results = shd_bld_call(bb, fn_addr_helper(a, fn), shd_singleton(pointer_as_offset)); + return shd_bld_to_instr_yield_values(bb, results); + } + case Store_TAG: { + Store payload = old->payload.store; + const Type* ptr_type = payload.ptr->type; + bool uniform_ptr = shd_deconstruct_qualified_type(&ptr_type); + assert(ptr_type->tag == PtrType_TAG); + if (ptr_type->payload.ptr_type.is_reference || !is_as_emulated(ctx, ptr_type->payload.ptr_type.address_space)) + break; + BodyBuilder* bb = shd_bld_begin_pseudo_instr(a, shd_rewrite_node(r, payload.mem)); + + const Type* element_type = shd_rewrite_node(&ctx->rewriter, ptr_type->payload.ptr_type.pointed_type); + const Node* pointer_as_offset = shd_rewrite_node(&ctx->rewriter, payload.ptr); + const Node* fn = gen_serdes_fn(ctx, element_type, uniform_ptr, true, ptr_type->payload.ptr_type.address_space); + + const Node* value = shd_rewrite_node(&ctx->rewriter, payload.value); + shd_bld_call(bb, fn_addr_helper(a, fn), mk_nodes(a, pointer_as_offset, value)); + return shd_bld_to_instr_yield_values(bb, shd_empty(a)); } + case StackAlloc_TAG: shd_error("This needs to be lowered (see setup_stack_frames.c)") case PtrType_TAG: { - if (is_as_emulated(ctx, old->payload.ptr_type.address_space)) + if (!old->payload.ptr_type.is_reference && is_as_emulated(ctx, old->payload.ptr_type.address_space)) return int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false }); break; } @@ -344,21 +335,21 @@ static const Node* process_node(Context* ctx, const Node* old) { case GlobalVariable_TAG: { const GlobalVariable* old_gvar = &old->payload.global_variable; // Global variables into emulated address spaces become integer constants (to index into arrays used for emulation of said address space) - if (is_as_emulated(ctx, old_gvar->address_space)) { + if (!shd_lookup_annotation(old, "Logical") && is_as_emulated(ctx, old_gvar->address_space)) { assert(false); } break; } case Function_TAG: { - if (strcmp(get_abstraction_name(old), "generated_init") == 0) { - Node *new = recreate_decl_header_identity(&ctx->rewriter, old); - BodyBuilder *bb = begin_body(a); - + if (strcmp(shd_get_abstraction_name(old), "generated_init") == 0) { + Node* new = shd_recreate_node_head(&ctx->rewriter, old); + BodyBuilder *bb = shd_bld_begin(a, shd_get_abstraction_mem(new)); for (AddressSpace as = 0; as < NumAddressSpaces; as++) { if (is_as_emulated(ctx, as)) store_init_data(ctx, as, ctx->collected[as], bb); } - new->payload.fun.body = finish_body(bb, rewrite_node(&ctx->rewriter, old->payload.fun.body)); + shd_register_processed(&ctx->rewriter, shd_get_abstraction_mem(old), shd_bb_mem(bb)); + shd_set_abstraction_body(new, shd_bld_finish(bb, shd_rewrite_node(&ctx->rewriter, old->payload.fun.body))); return new; } break; @@ -366,15 +357,15 @@ static const Node* process_node(Context* ctx, const Node* old) { default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); static Nodes collect_globals(Context* ctx, AddressSpace as) { IrArena* a = ctx->rewriter.dst_arena; - Nodes old_decls = get_module_declarations(ctx->rewriter.src_module); + Nodes old_decls = shd_module_get_declarations(ctx->rewriter.src_module); LARRAY(const Type*, collected, old_decls.count); size_t members_count = 0; @@ -382,11 +373,12 @@ static Nodes collect_globals(Context* ctx, AddressSpace as) { const Node* decl = old_decls.nodes[i]; if (decl->tag != GlobalVariable_TAG) continue; if (decl->payload.global_variable.address_space != as) continue; + if (shd_lookup_annotation(decl, "Logical")) continue; collected[members_count] = decl; members_count++; } - return nodes(a, members_count, collected); + return shd_nodes(a, members_count, collected); } /// Collects all global variables in a specific AS, and creates a record type for them. @@ -394,8 +386,8 @@ static const Node* make_record_type(Context* ctx, AddressSpace as, Nodes collect IrArena* a = ctx->rewriter.dst_arena; Module* m = ctx->rewriter.dst_module; - String as_name = get_address_space_name(as); - Node* global_struct_t = nominal_type(m, singleton(annotation(a, (Annotation) { .name = "Generated" })), format_string_arena(a->arena, "globals_physical_%s_t", as_name)); + String as_name = shd_get_address_space_name(as); + Node* global_struct_t = nominal_type(m, shd_singleton(annotation(a, (Annotation) { .name = "Generated" })), shd_format_string_arena(a->arena, "globals_physical_%s_t", as_name)); LARRAY(String, member_names, collected.count); LARRAY(const Type*, member_tys, collected.count); @@ -404,27 +396,26 @@ static const Node* make_record_type(Context* ctx, AddressSpace as, Nodes collect const Node* decl = collected.nodes[i]; const Type* type = decl->payload.global_variable.type; - member_tys[i] = rewrite_node(&ctx->rewriter, type); + member_tys[i] = shd_rewrite_node(&ctx->rewriter, type); member_names[i] = decl->payload.global_variable.name; // Turn the old global variable into a pointer (which are also now integers) const Type* emulated_ptr_type = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false }); - Nodes annotations = rewrite_nodes(&ctx->rewriter, decl->payload.global_variable.annotations); + Nodes annotations = shd_rewrite_nodes(&ctx->rewriter, decl->payload.global_variable.annotations); Node* new_address = constant(ctx->rewriter.dst_module, annotations, emulated_ptr_type, decl->payload.global_variable.name); // we need to compute the actual pointer by getting the offset and dividing it // after lower_memory_layout, optimisations will eliminate this and resolve to a value - BodyBuilder* bb = begin_body(a); - const Node* offset = gen_primop_e(bb, offset_of_op, singleton(type_decl_ref(a, (TypeDeclRef) { .decl = global_struct_t })), singleton(size_t_literal(a, i))); - // const Node* offset_in_words = bytes_to_words(bb, offset); - new_address->payload.constant.instruction = yield_values_and_wrap_in_block(bb, singleton(offset)); + BodyBuilder* bb = shd_bld_begin_pure(a); + const Node* offset = prim_op_helper(a, offset_of_op, shd_singleton(type_decl_ref(a, (TypeDeclRef) { .decl = global_struct_t })), shd_singleton(size_t_literal(a, i))); + new_address->payload.constant.value = shd_bld_to_instr_pure_with_values(bb, shd_singleton(offset)); - register_processed(&ctx->rewriter, decl, new_address); + shd_register_processed(&ctx->rewriter, decl, new_address); } const Type* record_t = record_type(a, (RecordType) { - .members = nodes(a, collected.count, member_tys), - .names = strings(a, collected.count, member_names) + .members = shd_nodes(a, collected.count, member_tys), + .names = shd_strings(a, collected.count, member_names) }); //return record_t; @@ -433,23 +424,25 @@ static const Node* make_record_type(Context* ctx, AddressSpace as, Nodes collect } static void store_init_data(Context* ctx, AddressSpace as, Nodes collected, BodyBuilder* bb) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; IrArena* oa = ctx->rewriter.src_arena; - IrArena* a = ctx->rewriter.dst_arena; for (size_t i = 0; i < collected.count; i++) { const Node* old_decl = collected.nodes[i]; assert(old_decl->tag == GlobalVariable_TAG); const Node* old_init = old_decl->payload.global_variable.init; if (old_init) { - const Node* old_store = prim_op_helper(oa, store_op, empty(oa), mk_nodes(oa, ref_decl_helper(oa, old_decl), old_init)); - bind_instruction(bb, rewrite_node(&ctx->rewriter, old_store)); + const Node* value = shd_rewrite_node(r, old_init); + const Node* fn = gen_serdes_fn(ctx, shd_get_unqualified_type(value->type), false, true, old_decl->payload.global_variable.address_space); + shd_bld_call(bb, fn_addr_helper(a, fn), mk_nodes(a, shd_rewrite_node(r, ref_decl_helper(oa, old_decl)), value)); } } } -static void construct_emulated_memory_array(Context* ctx, AddressSpace as, AddressSpace logical_as) { +static void construct_emulated_memory_array(Context* ctx, AddressSpace as) { IrArena* a = ctx->rewriter.dst_arena; Module* m = ctx->rewriter.dst_module; - String as_name = get_address_space_name(as); + String as_name = shd_get_address_space_name(as); const Type* word_type = int_type(a, (Int) { .width = a->config.memory.word_size, .is_signed = false }); const Type* ptr_size_type = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false }); @@ -460,66 +453,70 @@ static void construct_emulated_memory_array(Context* ctx, AddressSpace as, Addre .element_type = word_type, .size = NULL }); - *get_emulated_as_word_array(ctx, as) = undef(a, (Undef) { .type = ptr_type(a, (PtrType) { .address_space = logical_as, .pointed_type = words_array_type }) }); + *get_emulated_as_word_array(ctx, as) = undef(a, (Undef) { .type = ptr_type(a, (PtrType) { .address_space = as, .pointed_type = words_array_type }) }); return; } const Node* global_struct_t = make_record_type(ctx, as, ctx->collected[as]); - Nodes annotations = singleton(annotation(a, (Annotation) { .name = "Generated" })); + Nodes annotations = shd_singleton(annotation(a, (Annotation) { .name = "Generated" })); // compute the size - BodyBuilder* bb = begin_body(a); - const Node* size_of = gen_primop_e(bb, size_of_op, singleton(type_decl_ref(a, (TypeDeclRef) { .decl = global_struct_t })), empty(a)); - const Node* size_in_words = bytes_to_words(bb, size_of); + BodyBuilder* bb = shd_bld_begin_pure(a); + const Node* size_of = prim_op_helper(a, size_of_op, shd_singleton(type_decl_ref(a, (TypeDeclRef) { .decl = global_struct_t })), shd_empty(a)); + const Node* size_in_words = shd_bytes_to_words(bb, size_of); - Node* constant_decl = constant(m, annotations, ptr_size_type, format_string_interned(a, "globals_physical_%s_size", as_name)); - constant_decl->payload.constant.instruction = yield_values_and_wrap_in_block(bb, singleton(size_in_words)); + Node* constant_decl = constant(m, annotations, ptr_size_type, shd_fmt_string_irarena(a, "memory_%s_size", as_name)); + constant_decl->payload.constant.value = shd_bld_to_instr_pure_with_values(bb, shd_singleton(size_in_words)); const Type* words_array_type = arr_type(a, (ArrType) { .element_type = word_type, .size = ref_decl_helper(a, constant_decl) }); - Node* words_array = global_var(m, annotations, words_array_type, format_string_arena(a->arena, "addressable_word_memory_%s", as_name), logical_as); + Node* words_array = global_var(m, shd_nodes_append(a, annotations, annotation(a, (Annotation) { .name = "Logical" })), words_array_type, shd_format_string_arena(a->arena, "memory_%s", as_name), as); *get_emulated_as_word_array(ctx, as) = ref_decl_helper(a, words_array); } -Module* lower_physical_ptrs(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_physical_ptrs(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + aconfig.address_spaces[AsPrivate].physical = false; + aconfig.address_spaces[AsShared].physical = false; + aconfig.address_spaces[AsSubgroup].physical = false; + + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), .config = config, }; - construct_emulated_memory_array(&ctx, AsPrivatePhysical, AsPrivateLogical); - if (dst->arena->config.allow_subgroup_memory) - construct_emulated_memory_array(&ctx, AsSubgroupPhysical, AsSubgroupLogical); - if (dst->arena->config.allow_shared_memory) - construct_emulated_memory_array(&ctx, AsSharedPhysical, AsSharedLogical); + construct_emulated_memory_array(&ctx, AsPrivate); + if (dst->arena->config.address_spaces[AsSubgroup].allowed) + construct_emulated_memory_array(&ctx, AsSubgroup); + if (dst->arena->config.address_spaces[AsShared].allowed) + construct_emulated_memory_array(&ctx, AsShared); for (size_t i = 0; i < NumAddressSpaces; i++) { if (is_as_emulated(&ctx, i)) { - ctx.serialisation_varying[i] = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node); - ctx.deserialisation_varying[i] = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node); - ctx.serialisation_uniform[i] = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node); - ctx.deserialisation_uniform[i] = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node); + ctx.serialisation_varying[i] = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); + ctx.deserialisation_varying[i] = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); + ctx.serialisation_uniform[i] = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); + ctx.deserialisation_uniform[i] = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); } } - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); for (size_t i = 0; i < NumAddressSpaces; i++) { if (is_as_emulated(&ctx, i)) { - destroy_dict(ctx.serialisation_varying[i]); - destroy_dict(ctx.deserialisation_varying[i]); - destroy_dict(ctx.serialisation_uniform[i]); - destroy_dict(ctx.deserialisation_uniform[i]); + shd_destroy_dict(ctx.serialisation_varying[i]); + shd_destroy_dict(ctx.deserialisation_varying[i]); + shd_destroy_dict(ctx.serialisation_uniform[i]); + shd_destroy_dict(ctx.deserialisation_uniform[i]); } } diff --git a/src/shady/passes/lower_stack.c b/src/shady/passes/lower_stack.c index 13cec9a93..1be80d300 100644 --- a/src/shady/passes/lower_stack.c +++ b/src/shady/passes/lower_stack.c @@ -1,4 +1,8 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/cast.h" +#include "shady/ir/memory_layout.h" + +#include "../ir_private.h" #include "log.h" #include "portability.h" @@ -6,12 +10,6 @@ #include "dict.h" #include "util.h" -#include "../rewrite.h" -#include "../type.h" -#include "../ir_private.h" - -#include "../transform/ir_gen_helpers.h" - #include #include @@ -31,188 +29,189 @@ typedef struct Context_ { static const Node* gen_fn(Context* ctx, const Type* element_type, bool push) { struct Dict* cache = push ? ctx->push : ctx->pop; - const Node** found = find_value_dict(const Node*, const Node*, cache, element_type); + const Node** found = shd_dict_find_value(const Node*, const Node*, cache, element_type); if (found) return *found; IrArena* a = ctx->rewriter.dst_arena; const Type* qualified_t = qualified_type(a, (QualifiedType) { .is_uniform = false, .type = element_type }); - const Node* param = push ? var(a, qualified_t, "value") : NULL; - Nodes params = push ? singleton(param) : empty(a); - Nodes return_ts = push ? empty(a) : singleton(qualified_t); - String name = format_string_arena(a->arena, "generated_%s_%s", push ? "push" : "pop", name_type_safe(a, element_type)); - Node* fun = function(ctx->rewriter.dst_module, params, name, singleton(annotation(a, (Annotation) { .name = "Generated" })), return_ts); - insert_dict(const Node*, Node*, cache, element_type, fun); + const Node* value_param = push ? param(a, qualified_t, "value") : NULL; + Nodes params = push ? shd_singleton(value_param) : shd_empty(a); + Nodes return_ts = push ? shd_empty(a) : shd_singleton(qualified_t); + String name = shd_format_string_arena(a->arena, "generated_%s_%s", push ? "push" : "pop", shd_get_type_name(a, element_type)); + Node* fun = function(ctx->rewriter.dst_module, params, name, mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }), annotation(a, (Annotation) { .name = "Leaf" })), return_ts); + shd_dict_insert(const Node*, Node*, cache, element_type, fun); - BodyBuilder* bb = begin_body(a); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); - const Node* element_size = gen_primop_e(bb, size_of_op, singleton(element_type), empty(a)); - element_size = gen_conversion(bb, uint32_type(a), element_size); + const Node* element_size = prim_op_helper(a, size_of_op, shd_singleton(element_type), shd_empty(a)); + element_size = shd_bld_conversion(bb, shd_uint32_type(a), element_size); // TODO somehow annotate the uniform guys as uniform const Node* stack_pointer = ctx->stack_pointer; const Node* stack = ctx->stack; - const Node* stack_size = gen_load(bb, stack_pointer); + const Node* stack_size = shd_bld_load(bb, stack_pointer); if (!push) // for pop, we decrease the stack size first - stack_size = gen_primop_ce(bb, sub_op, 2, (const Node* []) { stack_size, element_size}); + stack_size = prim_op_helper(a, sub_op, shd_empty(a), mk_nodes(a, stack_size, element_size)); - const Node* addr = gen_lea(bb, stack, stack_size, nodes(a, 1, (const Node* []) {uint32_literal(a, 0) })); - assert(get_unqualified_type(addr->type)->tag == PtrType_TAG); - AddressSpace addr_space = get_unqualified_type(addr->type)->payload.ptr_type.address_space; + const Node* addr = lea_helper(a, ctx->stack, shd_int32_literal(a, 0), shd_singleton(stack_size)); + assert(shd_get_unqualified_type(addr->type)->tag == PtrType_TAG); + AddressSpace addr_space = shd_get_unqualified_type(addr->type)->payload.ptr_type.address_space; - addr = gen_reinterpret_cast(bb, ptr_type(a, (PtrType) {.address_space = addr_space, .pointed_type = element_type}), addr); + addr = shd_bld_reinterpret_cast(bb, ptr_type(a, (PtrType) { .address_space = addr_space, .pointed_type = element_type }), addr); const Node* popped_value = NULL; if (push) - gen_store(bb, addr, param); + shd_bld_store(bb, addr, value_param); else - popped_value = gen_primop_ce(bb, load_op, 1, (const Node* []) { addr }); + popped_value = shd_bld_load(bb, addr); if (push) // for push, we increase the stack size after the store - stack_size = gen_primop_ce(bb, add_op, 2, (const Node* []) { stack_size, element_size}); + stack_size = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, stack_size, element_size)); // store updated stack size - gen_store(bb, stack_pointer, stack_size); + shd_bld_store(bb, stack_pointer, stack_size); if (ctx->config->printf_trace.stack_size) { - bind_instruction(bb, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = name })) })); - bind_instruction(bb, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = "stack size after: %d\n" }), stack_size) })); + shd_bld_debug_printf(bb, name, shd_empty(a)); + shd_bld_debug_printf(bb, "stack size after: %d\n", shd_singleton(stack_size)); } if (push) { - fun->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .fn = fun, .args = empty(a) })); + shd_set_abstraction_body(fun, shd_bld_return(bb, shd_empty(a))); } else { assert(popped_value); - fun->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .fn = fun, .args = singleton(popped_value) })); + shd_set_abstraction_body(fun, shd_bld_return(bb, shd_singleton(popped_value))); } return fun; } -static const Node* process_let(Context* ctx, const Node* node) { - assert(node->tag == Let_TAG); - IrArena* a = ctx->rewriter.dst_arena; - - const Node* old_instruction = node->payload.let.instruction; - const Node* tail = rewrite_node(&ctx->rewriter, node->payload.let.tail); - - if (old_instruction->tag == PrimOp_TAG) { - const PrimOp* oprim_op = &old_instruction->payload.prim_op; - switch (oprim_op->op) { - case get_stack_pointer_op: { - BodyBuilder* bb = begin_body(a); - const Node* sp = gen_load(bb, ctx->stack_pointer); - return finish_body(bb, let(a, quote_helper(a, singleton(sp)), tail)); - } - case set_stack_pointer_op: { - BodyBuilder* bb = begin_body(a); - const Node* val = rewrite_node(&ctx->rewriter, oprim_op->operands.nodes[0]); - gen_store(bb, ctx->stack_pointer, val); - return finish_body(bb, let(a, quote_helper(a, empty(a)), tail)); - } - case get_stack_base_op: { - BodyBuilder* bb = begin_body(a); - const Node* stack_pointer = ctx->stack_pointer; - const Node* stack_size = gen_load(bb, stack_pointer); - const Node* stack_base_ptr = gen_lea(bb, ctx->stack, stack_size, empty(a)); - if (ctx->config->printf_trace.stack_size) { - if (oprim_op->op == get_stack_base_op) - bind_instruction(bb, prim_op(a, (PrimOp) {.op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) {.string = "trace: stack_size=%d\n"}), stack_size)})); - else - bind_instruction(bb, prim_op(a, (PrimOp) {.op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) {.string = "trace: uniform stack_size=%d\n"}), stack_size)})); - } - return finish_body(bb, let(a, quote_helper(a, singleton(stack_base_ptr)), tail)); - } - case push_stack_op: - case pop_stack_op: { - BodyBuilder* bb = begin_body(a); - const Type* element_type = rewrite_node(&ctx->rewriter, first(oprim_op->type_arguments)); - - bool push = oprim_op->op == push_stack_op; +static const Node* process_node(Context* ctx, const Node* old) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; - const Node* fn = gen_fn(ctx, element_type, push); - Nodes args = push ? singleton(rewrite_node(&ctx->rewriter, first(oprim_op->operands))) : empty(a); - Nodes results = bind_instruction(bb, call(a, (Call) { .callee = fn_addr_helper(a, fn), .args = args})); + if (old->tag == Function_TAG && strcmp(shd_get_abstraction_name(old), "generated_init") == 0) { + Node* new = shd_recreate_node_head(&ctx->rewriter, old); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new)); - if (push) - return finish_body(bb, let(a, quote_helper(a, empty(a)), tail)); + // Make sure to zero-init the stack pointers + // TODO isn't this redundant with thoose things having an initial value already ? + // is this an old forgotten workaround ? + if (ctx->stack) { + const Node* stack_pointer = ctx->stack_pointer; + shd_bld_store(bb, stack_pointer, shd_uint32_literal(a, 0)); + } + shd_register_processed(r, shd_get_abstraction_mem(old), shd_bb_mem(bb)); + shd_set_abstraction_body(new, shd_bld_finish(bb, shd_rewrite_node(&ctx->rewriter, old->payload.fun.body))); + return new; + } - assert(results.count == 1); - return finish_body(bb, let(a, quote_helper(a, results), tail)); + switch (old->tag) { + case GetStackSize_TAG: { + assert(ctx->stack); + GetStackSize payload = old->payload.get_stack_size; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Node* sp = shd_bld_load(bb, ctx->stack_pointer); + return shd_bld_to_instr_yield_values(bb, shd_singleton(sp)); + } + case SetStackSize_TAG: { + assert(ctx->stack); + SetStackSize payload = old->payload.set_stack_size; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Node* val = shd_rewrite_node(r, old->payload.set_stack_size.value); + shd_bld_store(bb, ctx->stack_pointer, val); + return shd_bld_to_instr_yield_values(bb, shd_empty(a)); + } + case GetStackBaseAddr_TAG: { + assert(ctx->stack); + GetStackBaseAddr payload = old->payload.get_stack_base_addr; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Node* stack_pointer = ctx->stack_pointer; + const Node* stack_size = shd_bld_load(bb, stack_pointer); + const Node* stack_base_ptr = lea_helper(a, ctx->stack, shd_int32_literal(a, 0), shd_singleton(stack_size)); + if (ctx->config->printf_trace.stack_size) { + shd_bld_debug_printf(bb, "trace: stack_size=%d\n", shd_singleton(stack_size)); } - default: break; + return shd_bld_to_instr_yield_values(bb, shd_singleton(stack_base_ptr)); } - } + case PushStack_TAG:{ + assert(ctx->stack); + PushStack payload = old->payload.push_stack; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Type* element_type = shd_rewrite_node(&ctx->rewriter, shd_get_unqualified_type(old->payload.push_stack.value->type)); - return let(a, rewrite_node(&ctx->rewriter, old_instruction), tail); -} + bool push = true; -static const Node* process_node(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; + const Node* fn = gen_fn(ctx, element_type, push); + Nodes args = shd_singleton(shd_rewrite_node(&ctx->rewriter, old->payload.push_stack.value)); + shd_bld_call(bb, fn_addr_helper(a, fn), args); - IrArena* a = ctx->rewriter.dst_arena; + return shd_bld_to_instr_yield_values(bb, shd_empty(a)); + } + case PopStack_TAG: { + assert(ctx->stack); + PopStack payload = old->payload.pop_stack; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Type* element_type = shd_rewrite_node(&ctx->rewriter, old->payload.pop_stack.type); - if (old->tag == Function_TAG && strcmp(get_abstraction_name(old), "generated_init") == 0) { - Node* new = recreate_decl_header_identity(&ctx->rewriter, old); - BodyBuilder* bb = begin_body(a); + bool push = false; - // Make sure to zero-init the stack pointers - // TODO isn't this redundant with thoose things having an initial value already ? - // is this an old forgotten workaround ? - const Node* stack_pointer = ctx->stack_pointer; - gen_store(bb, stack_pointer, uint32_literal(a, 0)); - new->payload.fun.body = finish_body(bb, rewrite_node(&ctx->rewriter, old->payload.fun.body)); - return new; - } + const Node* fn = gen_fn(ctx, element_type, push); + Nodes results = shd_bld_call(bb, fn_addr_helper(a, fn), shd_empty(a)); - switch (old->tag) { - case Let_TAG: return process_let(ctx, old); - default: return recreate_node_identity(&ctx->rewriter, old); + assert(results.count == 1); + return shd_bld_to_instr_yield_values(bb, results); + } + default: break; } + + return shd_recreate_node(&ctx->rewriter, old); } -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); -Module* lower_stack(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_stack(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); - const Type* stack_base_element = uint8_type(a); - const Type* stack_arr_type = arr_type(a, (ArrType) { - .element_type = stack_base_element, - .size = uint32_literal(a, config->per_thread_stack_size), - }); - const Type* stack_counter_t = uint32_type(a); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), - Nodes annotations = mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" })); + .config = config, - // Arrays for the stacks - Node* stack_decl = global_var(dst, annotations, stack_arr_type, "stack", AsPrivatePhysical); + .push = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .pop = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; - // Pointers into those arrays - Node* stack_ptr_decl = global_var(dst, annotations, stack_counter_t, "stack_ptr", AsPrivateLogical); - stack_ptr_decl->payload.global_variable.init = uint32_literal(a, 0); + if (config->per_thread_stack_size > 0) { + const Type* stack_base_element = shd_uint8_type(a); + const Type* stack_arr_type = arr_type(a, (ArrType) { + .element_type = stack_base_element, + .size = shd_uint32_literal(a, config->per_thread_stack_size), + }); + const Type* stack_counter_t = shd_uint32_type(a); - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node), + Nodes annotations = mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" })); - .config = config, + // Arrays for the stacks + Node* stack_decl = global_var(dst, annotations, stack_arr_type, "stack", AsPrivate); - .push = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node), - .pop = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node), + // Pointers into those arrays + Node* stack_ptr_decl = global_var(dst, shd_nodes_append(a, annotations, annotation(a, (Annotation) { .name = "Logical" })), stack_counter_t, "stack_ptr", AsPrivate); + stack_ptr_decl->payload.global_variable.init = shd_uint32_literal(a, 0); - .stack = ref_decl_helper(a, stack_decl), - .stack_pointer = ref_decl_helper(a, stack_ptr_decl), - }; + ctx.stack = ref_decl_helper(a, stack_decl); + ctx.stack_pointer = ref_decl_helper(a, stack_ptr_decl); + } - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); - destroy_dict(ctx.push); - destroy_dict(ctx.pop); + shd_destroy_dict(ctx.push); + shd_destroy_dict(ctx.pop); return dst; } diff --git a/src/shady/passes/lower_subgroup_ops.c b/src/shady/passes/lower_subgroup_ops.c index 61ca943fe..190cc5810 100644 --- a/src/shady/passes/lower_subgroup_ops.c +++ b/src/shady/passes/lower_subgroup_ops.c @@ -1,16 +1,23 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/cast.h" +#include "shady/ir/memory_layout.h" +#include "shady/ir/ext.h" +#include "shady/ir/type.h" +#include "shady/ir/composite.h" +#include "shady/ir/function.h" #include "portability.h" #include "log.h" +#include "dict.h" -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" -#include "../transform/memory_layout.h" +#include + +#include typedef struct { Rewriter rewriter; const CompilerConfig* config; + struct Dict* fns; } Context; static bool is_extended_type(SHADY_UNUSED IrArena* a, const Type* t, bool allow_vectors) { @@ -26,82 +33,135 @@ static bool is_extended_type(SHADY_UNUSED IrArena* a, const Type* t, bool allow_ } } -static const Node* process_let(Context* ctx, const Node* old) { - assert(old->tag == Let_TAG); +static bool is_supported_natively(Context* ctx, const Type* element_type) { + IrArena* a = ctx->rewriter.dst_arena; + if (element_type->tag == Int_TAG && element_type->payload.int_type.width == IntTy32) { + return true; + } else if (!ctx->config->lower.emulate_subgroup_ops_extended_types && is_extended_type(a, element_type, true)) { + return true; + } + + return false; +} + +static const Node* build_subgroup_first(Context* ctx, BodyBuilder* bb, const Node* scope, const Node* src); + +static const Node* generate(Context* ctx, BodyBuilder* bb, const Node* scope, const Node* t, const Node* param) { IrArena* a = ctx->rewriter.dst_arena; - const Node* tail = rewrite_node(&ctx->rewriter, old->payload.let.tail); - const Node* old_instruction = old->payload.let.instruction; - - if (old_instruction->tag == PrimOp_TAG) { - PrimOp payload = old_instruction->payload.prim_op; - switch (payload.op) { - case subgroup_broadcast_first_op: { - BodyBuilder* builder = begin_body(a); - const Node* varying_value = rewrite_node(&ctx->rewriter, payload.operands.nodes[0]); - const Type* element_type = get_unqualified_type(varying_value->type); - - if (element_type->tag == Int_TAG && element_type->payload.int_type.width == IntTy32) { - cancel_body(builder); - break; - } else if (is_extended_type(a, element_type, true) && !ctx->config->lower.emulate_subgroup_ops_extended_types) { - cancel_body(builder); - break; - } - - TypeMemLayout layout = get_mem_layout(a, element_type); - - const Type* local_arr_ty = arr_type(a, (ArrType) { .element_type = int32_type(a), .size = NULL }); - - const Node* varying_top_of_stack = gen_primop_e(builder, get_stack_base_op, empty(a), empty(a)); - const Type* varying_raw_ptr_t = ptr_type(a, (PtrType) { .address_space = AsPrivatePhysical, .pointed_type = local_arr_ty }); - const Node* varying_raw_ptr = gen_reinterpret_cast(builder, varying_raw_ptr_t, varying_top_of_stack); - const Type* varying_typed_ptr_t = ptr_type(a, (PtrType) { .address_space = AsPrivatePhysical, .pointed_type = element_type }); - const Node* varying_typed_ptr = gen_reinterpret_cast(builder, varying_typed_ptr_t, varying_top_of_stack); - - gen_store(builder, varying_typed_ptr, varying_value); - for (int32_t j = 0; j < bytes_to_words_static(a, layout.size_in_bytes); j++) { - const Node* varying_logical_addr = gen_lea(builder, varying_raw_ptr, int32_literal(a, 0), nodes(a, 1, (const Node* []) {int32_literal(a, j) })); - const Node* input = gen_load(builder, varying_logical_addr); - - const Node* partial_result = gen_primop_ce(builder, subgroup_broadcast_first_op, 1, (const Node* []) { input }); - - if (ctx->config->printf_trace.subgroup_ops) - gen_primop(builder, debug_printf_op, empty(a), mk_nodes(a, string_lit(a, (StringLiteral) { .string = "partial_result %d"}), partial_result)); - - gen_store(builder, varying_logical_addr, partial_result); - } - const Node* result = gen_load(builder, varying_typed_ptr); - result = first(gen_primop(builder, subgroup_assume_uniform_op, empty(a), singleton(result))); - return finish_body(builder, let(a, quote_helper(a, singleton(result)), tail)); + const Type* original_t = t; + t = shd_get_maybe_nominal_type_body(t); + switch (is_type(t)) { + case Type_ArrType_TAG: + case Type_RecordType_TAG: { + assert(t->payload.record_type.special == 0); + Nodes element_types = shd_get_composite_type_element_types(t); + LARRAY(const Node*, elements, element_types.count); + for (size_t i = 0; i < element_types.count; i++) { + const Node* e = shd_extract_helper(a, param, shd_singleton(shd_uint32_literal(a, i))); + elements[i] = build_subgroup_first(ctx, bb, scope, e); + } + return composite_helper(a, original_t, shd_nodes(a, element_types.count, elements)); + } + case Type_Int_TAG: { + if (t->payload.int_type.width == IntTy64) { + const Node* hi = prim_op_helper(a, rshift_logical_op, shd_empty(a), mk_nodes(a, param, shd_int32_literal(a, 32))); + hi = shd_bld_convert_int_zero_extend(bb, shd_int32_type(a), hi); + const Node* lo = shd_bld_convert_int_zero_extend(bb, shd_int32_type(a), param); + hi = build_subgroup_first(ctx, bb, scope, hi); + lo = build_subgroup_first(ctx, bb, scope, lo); + const Node* it = int_type(a, (Int) { .width = IntTy64, .is_signed = t->payload.int_type.is_signed }); + hi = shd_bld_convert_int_zero_extend(bb, it, hi); + lo = shd_bld_convert_int_zero_extend(bb, it, lo); + hi = prim_op_helper(a, lshift_op, shd_empty(a), mk_nodes(a, hi, shd_int32_literal(a, 32))); + return prim_op_helper(a, or_op, shd_empty(a), mk_nodes(a, lo, hi)); } - default: break; + break; } + case Type_PtrType_TAG: { + param = shd_bld_reinterpret_cast(bb, shd_uint64_type(a), param); + return shd_bld_reinterpret_cast(bb, t, generate(ctx, bb, scope, shd_uint64_type(a), param)); + } + default: break; + } + return NULL; +} + +static void build_fn_body(Context* ctx, Node* fn, const Node* scope, const Node* param, const Type* t) { + IrArena* a = ctx->rewriter.dst_arena; + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fn)); + const Node* result = generate(ctx, bb, scope, t, param); + if (result) { + shd_set_abstraction_body(fn, shd_bld_finish(bb, fn_ret(a, (Return) { + .args = shd_singleton(result), + .mem = shd_bb_mem(bb), + }))); + return; } - return let(a, rewrite_node(&ctx->rewriter, old_instruction), tail); + shd_log_fmt(ERROR, "subgroup_first emulation is not supported for "); + shd_log_node(ERROR, t); + shd_log_fmt(ERROR, ".\n"); + shd_error_die(); } -static const Node* process(Context* ctx, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; +static const Node* build_subgroup_first(Context* ctx, BodyBuilder* bb, const Node* scope, const Node* src) { + IrArena* a = ctx->rewriter.dst_arena; + Module* m = ctx->rewriter.dst_module; + const Node* t = shd_get_unqualified_type(src->type); + if (is_supported_natively(ctx, t)) + return shd_bld_ext_instruction(bb, "spirv.core", SpvOpGroupNonUniformBroadcastFirst, shd_as_qualified_type(t, true), mk_nodes(a, scope, src)); + + if (shd_resolve_to_int_literal(scope)->value != SpvScopeSubgroup) + shd_error("TODO") + + Node* fn = NULL; + Node** found = shd_dict_find_value(const Node*, Node*, ctx->fns, t); + if (found) + fn = *found; + else { + const Node* src_param = param(a, shd_as_qualified_type(t, false), "src"); + fn = function(m, shd_singleton(src_param), shd_fmt_string_irarena(a, "subgroup_first_%s", shd_get_type_name(a, t)), + mk_nodes(a, annotation(a, (Annotation) { .name = "Generated"}), annotation(a, (Annotation) { .name = "Leaf" })), shd_singleton( + shd_as_qualified_type(t, true))); + shd_dict_insert(const Node*, Node*, ctx->fns, t, fn); + build_fn_body(ctx, fn, scope, src_param, t); + } + + return shd_first(shd_bld_call(bb, fn_addr_helper(a, fn), shd_singleton(src))); +} +static const Node* process(Context* ctx, const Node* node) { + IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; switch (node->tag) { - case Let_TAG: return process_let(ctx, node); - default: return recreate_node_identity(&ctx->rewriter, node); + case ExtInstr_TAG: { + ExtInstr payload = node->payload.ext_instr; + if (strcmp(payload.set, "spirv.core") == 0 && payload.opcode == SpvOpGroupNonUniformBroadcastFirst) { + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + return shd_bld_to_instr_yield_values(bb, shd_singleton( + build_subgroup_first(ctx, bb, shd_rewrite_node(r, payload.operands.nodes[0]), shd_rewrite_node(r, payload.operands.nodes[1])))); + } + } + default: break; } + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_subgroup_ops(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +Module* shd_pass_lower_subgroup_ops(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); assert(!config->lower.emulate_subgroup_ops && "TODO"); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, + .fns = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node) }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_dict(ctx.fns); return dst; } diff --git a/src/shady/passes/lower_subgroup_vars.c b/src/shady/passes/lower_subgroup_vars.c index 4c357aa3e..38452a8d4 100644 --- a/src/shady/passes/lower_subgroup_vars.c +++ b/src/shady/passes/lower_subgroup_vars.c @@ -1,59 +1,76 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/memory_layout.h" +#include "shady/ir/function.h" +#include "shady/ir/builtin.h" +#include "shady/ir/annotation.h" +#include "shady/ir/decl.h" +#include "dict.h" #include "portability.h" #include "log.h" -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" - typedef struct { Rewriter rewriter; const CompilerConfig* config; - BodyBuilder* b; + BodyBuilder* bb; } Context; -static const Node* process(Context* ctx, NodeClass class, String op_name, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (node->tag) { + case Function_TAG: { + Node* newfun = shd_recreate_node_head(r, node); + if (get_abstraction_body(node)) { + Context functx = *ctx; + functx.rewriter.map = shd_clone_dict(functx.rewriter.map); + shd_dict_clear(functx.rewriter.map); + shd_register_processed_list(&functx.rewriter, get_abstraction_params(node), get_abstraction_params(newfun)); + functx.bb = shd_bld_begin(a, shd_get_abstraction_mem(newfun)); + Node* post_prelude = basic_block(a, shd_empty(a), "post-prelude"); + shd_register_processed(&functx.rewriter, shd_get_abstraction_mem(node), shd_get_abstraction_mem(post_prelude)); + shd_set_abstraction_body(post_prelude, shd_rewrite_node(&functx.rewriter, get_abstraction_body(node))); + shd_set_abstraction_body(newfun, shd_bld_finish(functx.bb, jump_helper(a, shd_bb_mem(functx.bb), post_prelude, + shd_empty(a)))); + shd_destroy_dict(functx.rewriter.map); + } + return newfun; + } case PtrType_TAG: { AddressSpace as = node->payload.ptr_type.address_space; - if (as == AsSubgroupLogical) { - return ptr_type(a, (PtrType) { .pointed_type = rewrite_op(&ctx->rewriter, NcType, "pointed_type", node->payload.ptr_type.pointed_type), .address_space = AsSharedLogical }); + if (as == AsSubgroup) { + return ptr_type(a, (PtrType) { .pointed_type = shd_rewrite_op(&ctx->rewriter, NcType, "pointed_type", node->payload.ptr_type.pointed_type), .address_space = AsShared, .is_reference = node->payload.ptr_type.is_reference }); } break; } case RefDecl_TAG: { const Node* odecl = node->payload.ref_decl.decl; - if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsSubgroupLogical) + if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsSubgroup) break; - assert(ctx->b); - const Node* ndecl = rewrite_op(&ctx->rewriter, NcDeclaration, "decl", odecl); - const Node* index = gen_builtin_load(ctx->rewriter.dst_module, ctx->b, BuiltinSubgroupId); - const Node* slice = gen_lea(ctx->b, ref_decl_helper(a, ndecl), int32_literal(a, 0), mk_nodes(a, index)); + const Node* ndecl = shd_rewrite_node(&ctx->rewriter, odecl); + assert(ctx->bb); + const Node* index = shd_bld_builtin_load(ctx->rewriter.dst_module, ctx->bb, BuiltinSubgroupId); + const Node* slice = lea_helper(a, ref_decl_helper(a, ndecl), shd_int32_literal(a, 0), mk_nodes(a, index)); return slice; } case GlobalVariable_TAG: { AddressSpace as = node->payload.global_variable.address_space; - if (as == AsSubgroupLogical) { - const Type* ntype = rewrite_op(&ctx->rewriter, NcType, "type", node->payload.global_variable.type); + if (as == AsSubgroup) { + const Type* ntype = shd_rewrite_node(&ctx->rewriter, node->payload.global_variable.type); const Type* atype = arr_type(a, (ArrType) { .element_type = ntype, - .size = ref_decl_helper(a, rewrite_op(&ctx->rewriter, NcDeclaration, "decl", get_declaration(ctx->rewriter.src_module, "SUBGROUPS_PER_WG"))) + .size = ref_decl_helper(a, shd_rewrite_node(&ctx->rewriter, shd_module_get_declaration(ctx->rewriter.src_module, "SUBGROUPS_PER_WG"))) }); - Node* new = global_var(ctx->rewriter.dst_module, rewrite_ops(&ctx->rewriter, NcAnnotation, "annotations", node->payload.global_variable.annotations), atype, node->payload.global_variable.name, AsSharedLogical); - register_processed(&ctx->rewriter, node, new); + assert(shd_lookup_annotation(node, "Logical") && "All subgroup variables should be logical by now!"); + Node* new = global_var(ctx->rewriter.dst_module, shd_rewrite_nodes(&ctx->rewriter, node->payload.global_variable.annotations), atype, node->payload.global_variable.name, AsShared); + shd_register_processed(&ctx->rewriter, node, new); if (node->payload.global_variable.init) { new->payload.global_variable.init = fill(a, (Fill) { .type = atype, - .value = rewrite_op(&ctx->rewriter, NcValue, "init", node->payload.global_variable.init) + .value = shd_rewrite_node(&ctx->rewriter, node->payload.global_variable.init) }); } return new; @@ -63,25 +80,24 @@ static const Node* process(Context* ctx, NodeClass class, String op_name, const default: break; } - if (class == NcTerminator) { - BodyBuilder* b = begin_body(a); - Context c = *ctx; - c.b = b; - return finish_body(b, recreate_node_identity(&c.rewriter, node)); + if (is_declaration(node)) { + Context declctx = *ctx; + declctx.bb = NULL; + return shd_recreate_node(&declctx.rewriter, node); } - return recreate_node_identity(&ctx->rewriter, node); + + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_subgroup_vars(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_subgroup_vars(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, NULL), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config }; - ctx.rewriter.rewrite_op_fn = (RewriteOpFn) process; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_switch_btree.c b/src/shady/passes/lower_switch_btree.c index adc96115a..a3d7d12f0 100644 --- a/src/shady/passes/lower_switch_btree.c +++ b/src/shady/passes/lower_switch_btree.c @@ -1,13 +1,10 @@ -#include "passes.h" +#include "shady/pass.h" + +#include "../ir_private.h" #include "log.h" #include "portability.h" -#include "../ir_private.h" -#include "../type.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" - typedef struct { Rewriter rewriter; @@ -77,24 +74,25 @@ TreeNode* insert(TreeNode* t, TreeNode* x) { return t; } -static const Node* generate_default_fallback_case(Context* ctx) { +/*static const Node* gen_yield(Context* ctx, bool in_if, const Node* mem, Nodes args) { + if (in_if) + return merge_selection(ctx->rewriter.dst_arena, (MergeSelection) { .args = args, .mem = mem }); + return block_yield(ctx->rewriter.dst_arena, (BlockYield) { args }); +} + +static const Node* generate_default_fallback_case(Context* ctx, bool in_if, const Node* mem) { IrArena* a = ctx->rewriter.dst_arena; - BodyBuilder* bb = begin_body(a); + BodyBuilder* bb = begin_body_with_mem(a, mem); gen_store(bb, ctx->run_default_case, true_lit(a)); LARRAY(const Node*, undefs, ctx->yield_types.count); for (size_t i = 0; i < ctx->yield_types.count; i++) undefs[i] = undef(a, (Undef) { .type = ctx->yield_types.nodes[i] }); - return case_(a, empty(a), finish_body(bb, yield(a, (Yield) {.args = nodes(a, ctx->yield_types.count, undefs)}))); -} - -static const Node* wrap_instr_in_lambda(const Node* instr) { - IrArena* a = instr->arena; - BodyBuilder* bb = begin_body(a); - Nodes values = bind_instruction(bb, instr); - return case_(a, empty(a), finish_body(bb, yield(a, (Yield) {.args = values}))); + Node* c = case_(a, empty(a)); + set_abstraction_body(c, finish_body(bb, gen_yield(ctx, in_if, shd_nodes(a, ctx->yield_types.count, undefs)))); + return c; } -static const Node* generate_decision_tree(Context* ctx, TreeNode* n, uint64_t min, uint64_t max) { +static const Node* generate_decision_tree(Context* ctx, TreeNode* n, bool in_if, uint64_t min, uint64_t max) { IrArena* a = ctx->rewriter.dst_arena; assert(n->key >= min && n->key <= max); assert(n->lam); @@ -110,91 +108,85 @@ static const Node* generate_decision_tree(Context* ctx, TreeNode* n, uint64_t mi if (min < n->key) { BodyBuilder* bb = begin_body(a); - const Node* instr = if_instr(a, (If) { - .yield_types = ctx->yield_types, - .condition = gen_primop_e(bb, lt_op, empty(a), mk_nodes(a, ctx->inspectee, pivot)), - .if_true = n->children[0] ? generate_decision_tree(ctx, n->children[0], min, n->key - 1) : generate_default_fallback_case(ctx), - .if_false = body, - }); - Nodes values = bind_instruction(bb, instr); - body = case_(a, empty(a), finish_body(bb, yield(a, (Yield) {.args = values}))); + const Node* true_branch = n->children[0] ? generate_decision_tree(ctx, n->children[0], true, min, n->key - 1) : generate_default_fallback_case(ctx, true); + Nodes values = gen_if(bb, ctx->yield_types, gen_primop_e(bb, lt_op, empty(a), mk_nodes(a, ctx->inspectee, pivot)), true_branch, body); + Node* c = case_(a, empty(a)); + set_abstraction_body(c, finish_body(bb, gen_yield(ctx, in_if || max > n->key, values))); + body = c; } if (max > n->key) { BodyBuilder* bb = begin_body(a); - const Node* instr = if_instr(a, (If) { - .yield_types = ctx->yield_types, - .condition = gen_primop_e(bb, gt_op, empty(a), mk_nodes(a, ctx->inspectee, pivot)), - .if_true = n->children[1] ? generate_decision_tree(ctx, n->children[1], n->key + 1, max) : generate_default_fallback_case(ctx), - .if_false = body, - }); - Nodes values = bind_instruction(bb, instr); - body = case_(a, empty(a), finish_body(bb, yield(a, (Yield) {.args = values}))); + const Node* true_branch = n->children[1] ? generate_decision_tree(ctx, n->children[1], true, n->key + 1, max) : generate_default_fallback_case(ctx, true); + Nodes values = gen_if(bb, ctx->yield_types, gen_primop_e(bb, gt_op, empty(a), mk_nodes(a, ctx->inspectee, pivot)), true_branch, body); + Node* c = case_(a, empty(a)); + set_abstraction_body(c, finish_body(bb, gen_yield(ctx, in_if, values))); + body = c; } return body; -} +}*/ static const Node* process(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (node->tag) { case Match_TAG: { - Nodes yield_types = rewrite_nodes(&ctx->rewriter, node->payload.match_instr.yield_types); - Nodes literals = rewrite_nodes(&ctx->rewriter, node->payload.match_instr.literals); - Nodes cases = rewrite_nodes(&ctx->rewriter, node->payload.match_instr.cases); + Match payload = node->payload.match_instr; + Nodes yield_types = shd_rewrite_nodes(&ctx->rewriter, node->payload.match_instr.yield_types); + Nodes literals = shd_rewrite_nodes(&ctx->rewriter, node->payload.match_instr.literals); + Nodes cases = shd_rewrite_nodes(&ctx->rewriter, node->payload.match_instr.cases); // TODO handle degenerate no-cases case ? // TODO or maybe do that in fold() assert(cases.count > 0); - Arena* arena = new_arena(); + Arena* arena = shd_new_arena(); TreeNode* root = NULL; for (size_t i = 0; i < literals.count; i++) { - TreeNode* t = arena_alloc(arena, sizeof(TreeNode)); - t->key = get_int_literal_value(*resolve_to_int_literal(literals.nodes[i]), false); + TreeNode* t = shd_arena_alloc(arena, sizeof(TreeNode)); + t->key = shd_get_int_literal_value(*shd_resolve_to_int_literal(literals.nodes[i]), false); t->lam = cases.nodes[i]; root = insert(root, t); } - BodyBuilder* bb = begin_body(a); - const Node* run_default_case = gen_primop_e(bb, alloca_logical_op, singleton(bool_type(a)), empty(a)); - gen_store(bb, run_default_case, false_lit(a)); + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Node* run_default_case = shd_bld_stack_alloc(bb, bool_type(a)); + shd_bld_store(bb, run_default_case, false_lit(a)); - Context ctx2 = *ctx; + /*Context ctx2 = *ctx; ctx2.run_default_case = run_default_case; ctx2.yield_types = yield_types; ctx2.inspectee = rewrite_node(&ctx->rewriter, node->payload.match_instr.inspect); - Nodes matched_results = bind_instruction(bb, block(a, (Block) { .yield_types = add_qualifiers(a, ctx2.yield_types, false), .inside = generate_decision_tree(&ctx2, root, 0, UINT64_MAX) })); + Nodes matched_results = bind_instruction(bb, block(a, (Block) { .yield_types = add_qualifiers(a, ctx2.yield_types, false), .inside = generate_decision_tree(&ctx2, root, false, 0, UINT64_MAX) })); // Check if we need to run the default case - Nodes final_results = bind_instruction(bb, if_instr(a, (If) { - .yield_types = ctx2.yield_types, - .condition = gen_load(bb, run_default_case), - .if_true = rewrite_node(&ctx->rewriter, node->payload.match_instr.default_case), - .if_false = case_(a, empty(a), yield(a, (Yield) { - .args = matched_results, - })) - })); + Node* yield_case = case_(a, empty(a)); + set_abstraction_body(yield_case, gen_yield(ctx, true, matched_results)); + Nodes final_results = gen_if(bb, ctx2.yield_types, gen_load(bb, run_default_case), rewrite_node(&ctx->rewriter, node->payload.match_instr.default_case), yield_case); + register_processed_list(r, get_abstraction_params(get_structured_construct_tail(node)), final_results); destroy_arena(arena); - return yield_values_and_wrap_in_block(bb, final_results); + return finish_body(bb, rewrite_node(r, get_abstraction_body(get_structured_construct_tail(node))));*/ + shd_error("TODO") + // return yield_values_and_wrap_in_block(bb, final_results); } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_switch_btree(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_switch_btree(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_tailcalls.c b/src/shady/passes/lower_tailcalls.c index 32b9719a6..c5e8c03ca 100644 --- a/src/shady/passes/lower_tailcalls.c +++ b/src/shady/passes/lower_tailcalls.c @@ -1,18 +1,18 @@ -#include "passes.h" +#include "join_point_ops.h" -#include "log.h" -#include "portability.h" -#include "util.h" +#include "shady/pass.h" +#include "shady/ir/stack.h" +#include "shady/ir/cast.h" +#include "shady/ir/builtin.h" -#include "../rewrite.h" -#include "../type.h" #include "../ir_private.h" - -#include "../analysis/scope.h" +#include "../analysis/cfg.h" #include "../analysis/uses.h" #include "../analysis/leak.h" -#include "../transform/ir_gen_helpers.h" +#include "log.h" +#include "portability.h" +#include "util.h" #include "list.h" #include "dict.h" @@ -28,8 +28,8 @@ typedef struct Context_ { struct Dict* assigned_fn_ptrs; FnPtr* next_fn_ptr; - Scope* scope; - const UsesMap* scope_uses; + CFG* cfg; + const UsesMap* uses; Node** top_dispatcher_fn; Node* init_fn; @@ -37,22 +37,40 @@ typedef struct Context_ { static const Node* process(Context* ctx, const Node* old); -static const Node* fn_ptr_as_value(IrArena* a, FnPtr ptr) { - return uint64_literal(a, ptr); +static const Node* get_fn(Rewriter* rewriter, const char* name) { + const Node* decl = shd_find_or_process_decl(rewriter, name); + return fn_addr_helper(rewriter->dst_arena, decl); } -static const Node* lower_fn_addr(Context* ctx, const Node* the_function) { +static const Type* lowered_fn_type(Context* ctx) { + IrArena* a = ctx->rewriter.dst_arena; + return shd_int_type_helper(a, false, ctx->config->target.memory.ptr_size); +} + +static const Node* fn_ptr_as_value(Context* ctx, FnPtr ptr) { IrArena* a = ctx->rewriter.dst_arena; + return int_literal(a, (IntLiteral) { + .is_signed = false, + .width = ctx->config->target.memory.ptr_size, + .value = ptr + }); +} + +static FnPtr get_fn_ptr(Context* ctx, const Node* the_function) { assert(the_function->arena == ctx->rewriter.src_arena); assert(the_function->tag == Function_TAG); - FnPtr* found = find_value_dict(const Node*, FnPtr, ctx->assigned_fn_ptrs, the_function); - if (found) return fn_ptr_as_value(a, *found); + FnPtr* found = shd_dict_find_value(const Node*, FnPtr, ctx->assigned_fn_ptrs, the_function); + if (found) return *found; FnPtr ptr = (*ctx->next_fn_ptr)++; - bool r = insert_dict_and_get_result(const Node*, FnPtr, ctx->assigned_fn_ptrs, the_function, ptr); + bool r = shd_dict_insert_get_result(const Node*, FnPtr, ctx->assigned_fn_ptrs, the_function, ptr); assert(r); - return fn_ptr_as_value(a, ptr); + return ptr; +} + +static const Node* lower_fn_addr(Context* ctx, const Node* the_function) { + return fn_ptr_as_value(ctx, get_fn_ptr(ctx, the_function)); } /// Turn a function into a top-level entry point, calling into the top dispatch function. @@ -61,376 +79,386 @@ static void lift_entry_point(Context* ctx, const Node* old, const Node* fun) { Context ctx2 = *ctx; IrArena* a = ctx->rewriter.dst_arena; // For the lifted entry point, we keep _all_ annotations - Nodes rewritten_params = recreate_variables(&ctx2.rewriter, old->payload.fun.params); - Node* new_entry_pt = function(ctx2.rewriter.dst_module, rewritten_params, old->payload.fun.name, rewrite_nodes(&ctx2.rewriter, old->payload.fun.annotations), nodes(a, 0, NULL)); + Nodes rewritten_params = shd_recreate_params(&ctx2.rewriter, old->payload.fun.params); + Node* new_entry_pt = function(ctx2.rewriter.dst_module, rewritten_params, old->payload.fun.name, shd_rewrite_nodes(&ctx2.rewriter, old->payload.fun.annotations), shd_nodes(a, 0, NULL)); - BodyBuilder* bb = begin_body(a); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new_entry_pt)); - bind_instruction(bb, call(a, (Call) { .callee = fn_addr_helper(a, ctx->init_fn), .args = empty(a) })); - bind_instruction(bb, call(a, (Call) { .callee = access_decl(&ctx->rewriter, "builtin_init_scheduler"), .args = empty(a) })); + shd_bld_call(bb, fn_addr_helper(a, ctx->init_fn), shd_empty(a)); + shd_bld_call(bb, get_fn(&ctx->rewriter, "builtin_init_scheduler"), shd_empty(a)); // shove the arguments on the stack for (size_t i = rewritten_params.count - 1; i < rewritten_params.count; i--) { - gen_push_value_stack(bb, rewritten_params.nodes[i]); + shd_bld_stack_push_value(bb, rewritten_params.nodes[i]); } // Initialise next_fn/next_mask to the entry function - const Node* jump_fn = access_decl(&ctx->rewriter, "builtin_fork"); - const Node* fn_addr = lower_fn_addr(ctx, old); - fn_addr = gen_conversion(bb, uint32_type(a), fn_addr); - bind_instruction(bb, call(a, (Call) { .callee = jump_fn, .args = singleton(fn_addr) })); + const Node* jump_fn = get_fn(&ctx->rewriter, "builtin_fork"); + const Node* fn_addr = shd_uint32_literal(a, get_fn_ptr(ctx, old)); + // fn_addr = gen_conversion(bb, lowered_fn_type(ctx), fn_addr); + shd_bld_call(bb, jump_fn, shd_singleton(fn_addr)); if (!*ctx->top_dispatcher_fn) { - *ctx->top_dispatcher_fn = function(ctx->rewriter.dst_module, nodes(a, 0, NULL), "top_dispatcher", mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }), annotation(a, (Annotation) { .name = "Leaf" }), annotation(a, (Annotation) { .name = "Structured" })), nodes(a, 0, NULL)); + *ctx->top_dispatcher_fn = function(ctx->rewriter.dst_module, shd_nodes(a, 0, NULL), "top_dispatcher", mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }), annotation(a, (Annotation) { .name = "Leaf" })), shd_nodes(a, 0, NULL)); } - bind_instruction(bb, call(a, (Call) { - .callee = fn_addr_helper(a, *ctx->top_dispatcher_fn), - .args = nodes(a, 0, NULL) - })); + shd_bld_call(bb, fn_addr_helper(a, *ctx->top_dispatcher_fn), shd_empty(a)); - new_entry_pt->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { - .fn = NULL, - .args = nodes(a, 0, NULL) - })); + shd_set_abstraction_body(new_entry_pt, shd_bld_finish(bb, fn_ret(a, (Return) { + .args = shd_nodes(a, 0, NULL), + .mem = shd_bb_mem(bb), + }))); } static const Node* process(Context* ctx, const Node* old) { - const Node* found = search_processed(&ctx->rewriter, old); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (old->tag) { case Function_TAG: { Context ctx2 = *ctx; - ctx2.scope = new_scope(old); - ctx2.scope_uses = create_uses_map(old, (NcDeclaration | NcType)); + ctx2.cfg = build_fn_cfg(old); + ctx2.uses = shd_new_uses_map_fn(old, (NcDeclaration | NcType)); ctx = &ctx2; - const Node* entry_point_annotation = lookup_annotation_list(old->payload.fun.annotations, "EntryPoint"); + const Node* entry_point_annotation = shd_lookup_annotation_list(old->payload.fun.annotations, "EntryPoint"); // Leave leaf-calls alone :) - ctx2.disable_lowering = lookup_annotation(old, "Leaf") || !old->payload.fun.body; + ctx2.disable_lowering = shd_lookup_annotation(old, "Leaf") || !old->payload.fun.body; if (ctx2.disable_lowering) { - Node* fun = recreate_decl_header_identity(&ctx2.rewriter, old); + Node* fun = shd_recreate_node_head(&ctx2.rewriter, old); if (old->payload.fun.body) { - const Node* nbody = rewrite_node(&ctx2.rewriter, old->payload.fun.body); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); if (entry_point_annotation) { - const Node* lam = case_(a, empty(a), nbody); - nbody = let(a, call(a, (Call) { .callee = fn_addr_helper(a, ctx2.init_fn), .args = empty(a)}), lam); + shd_bld_call(bb, fn_addr_helper(a, ctx2.init_fn), shd_empty(a)); } - fun->payload.fun.body = nbody; + shd_register_processed(&ctx2.rewriter, shd_get_abstraction_mem(old), shd_bb_mem(bb)); + shd_set_abstraction_body(fun, shd_bld_finish(bb, shd_rewrite_node(&ctx2.rewriter, get_abstraction_body(old)))); } - destroy_uses_map(ctx2.scope_uses); - destroy_scope(ctx2.scope); + shd_destroy_uses_map(ctx2.uses); + shd_destroy_cfg(ctx2.cfg); return fun; } assert(ctx->config->dynamic_scheduling && "Dynamic scheduling is disabled, but we encountered a non-leaf function"); - Nodes new_annotations = rewrite_nodes(&ctx->rewriter, old->payload.fun.annotations); - new_annotations = append_nodes(a, new_annotations, annotation_value(a, (AnnotationValue) { .name = "FnId", .value = lower_fn_addr(ctx, old) })); - new_annotations = append_nodes(a, new_annotations, annotation(a, (Annotation) { .name = "Leaf" })); + Nodes new_annotations = shd_rewrite_nodes(&ctx->rewriter, old->payload.fun.annotations); + new_annotations = shd_nodes_append(a, new_annotations, annotation_value(a, (AnnotationValue) { .name = "FnId", .value = lower_fn_addr(ctx, old) })); + new_annotations = shd_nodes_append(a, new_annotations, annotation(a, (Annotation) { .name = "Leaf" })); - String new_name = format_string_arena(a->arena, "%s_indirect", old->payload.fun.name); + String new_name = shd_format_string_arena(a->arena, "%s_indirect", old->payload.fun.name); - Node* fun = function(ctx->rewriter.dst_module, nodes(a, 0, NULL), new_name, filter_out_annotation(a, new_annotations, "EntryPoint"), nodes(a, 0, NULL)); - register_processed(&ctx->rewriter, old, fun); + Node* fun = function(ctx->rewriter.dst_module, shd_nodes(a, 0, NULL), new_name, shd_filter_out_annotation(a, new_annotations, "EntryPoint"), shd_nodes(a, 0, NULL)); + shd_register_processed(&ctx->rewriter, old, fun); if (entry_point_annotation) lift_entry_point(ctx, old, fun); - BodyBuilder* bb = begin_body(a); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); // Params become stack pops ! for (size_t i = 0; i < old->payload.fun.params.count; i++) { const Node* old_param = old->payload.fun.params.nodes[i]; - const Type* new_param_type = rewrite_node(&ctx->rewriter, get_unqualified_type(old_param->type)); - const Node* popped = first(bind_instruction_named(bb, prim_op(a, (PrimOp) { .op = pop_stack_op, .type_arguments = singleton(new_param_type), .operands = empty(a) }), &old_param->payload.var.name)); + const Type* new_param_type = shd_rewrite_node(&ctx->rewriter, shd_get_unqualified_type(old_param->type)); + const Node* popped = shd_bld_stack_pop_value(bb, new_param_type); // TODO use the uniform stack instead ? or no ? - if (is_qualified_type_uniform(old_param->type)) - popped = first(bind_instruction_named(bb, prim_op(a, (PrimOp) { .op = subgroup_broadcast_first_op, .type_arguments = empty(a), .operands =singleton(popped) }), &old_param->payload.var.name)); - register_processed(&ctx->rewriter, old_param, popped); + if (shd_is_qualified_type_uniform(old_param->type)) + popped = prim_op(a, (PrimOp) { .op = subgroup_assume_uniform_op, .type_arguments = shd_empty(a), .operands = shd_singleton(popped) }); + if (old_param->payload.param.name) + shd_set_value_name((Node*) popped, old_param->payload.param.name); + shd_register_processed(&ctx->rewriter, old_param, popped); } - fun->payload.fun.body = finish_body(bb, rewrite_node(&ctx2.rewriter, old->payload.fun.body)); - destroy_uses_map(ctx2.scope_uses); - destroy_scope(ctx2.scope); + shd_register_processed(&ctx2.rewriter, shd_get_abstraction_mem(old), shd_bb_mem(bb)); + shd_set_abstraction_body(fun, shd_bld_finish(bb, shd_rewrite_node(&ctx2.rewriter, get_abstraction_body(old)))); + shd_destroy_uses_map(ctx2.uses); + shd_destroy_cfg(ctx2.cfg); return fun; } case FnAddr_TAG: return lower_fn_addr(ctx, old->payload.fn_addr.fn); case Call_TAG: { - const Node* ocallee = old->payload.call.callee; - assert(ocallee->tag == FnAddr_TAG); + Call payload = old->payload.call; + assert(payload.callee->tag == FnAddr_TAG && "Only direct calls should survive this pass"); return call(a, (Call) { - .callee = fn_addr_helper(a, rewrite_node(&ctx->rewriter, ocallee->payload.fn_addr.fn)), - .args = rewrite_nodes(&ctx->rewriter, old->payload.call.args), + .callee = fn_addr_helper(a, shd_rewrite_node(&ctx->rewriter, payload.callee->payload.fn_addr.fn)), + .args = shd_rewrite_nodes(&ctx->rewriter, payload.args), + .mem = shd_rewrite_node(r, payload.mem) }); } case JoinPointType_TAG: return type_decl_ref(a, (TypeDeclRef) { - .decl = find_or_process_decl(&ctx->rewriter, "JoinPoint"), + .decl = shd_find_or_process_decl(&ctx->rewriter, "JoinPoint"), }); - case PrimOp_TAG: { - switch (old->payload.prim_op.op) { - case create_joint_point_op: { - BodyBuilder* bb = begin_body(a); - Nodes args = rewrite_nodes(&ctx->rewriter, old->payload.prim_op.operands); - assert(args.count == 2); - const Node* dst = first(args); - const Node* sp = args.nodes[1]; - dst = gen_conversion(bb, uint32_type(a), dst); - Nodes r = bind_instruction(bb, call(a, (Call) { - .callee = access_decl(&ctx->rewriter, "builtin_create_control_point"), - .args = mk_nodes(a, dst, sp), - })); - return yield_values_and_wrap_in_block(bb, r); + case ExtInstr_TAG: { + ExtInstr payload = old->payload.ext_instr; + if (strcmp(payload.set, "shady.internal") == 0) { + String callee_name = NULL; + Nodes args = shd_rewrite_nodes(r, payload.operands); + switch ((ShadyJoinPointOpcodes ) payload.opcode) { + case ShadyOpDefaultJoinPoint: + callee_name = "builtin_entry_join_point"; + break; + case ShadyOpCreateJoinPoint: + callee_name = "builtin_create_control_point"; + args = shd_change_node_at_index(a, args, 0, prim_op_helper(a, convert_op, shd_singleton(shd_uint32_type(a)), shd_singleton(args.nodes[0]))); + break; } - case default_join_point_op: { - BodyBuilder* bb = begin_body(a); - Nodes r = bind_instruction(bb, call(a, (Call) { - .callee = access_decl(&ctx->rewriter, "builtin_entry_join_point"), - .args = empty(a) - })); - return yield_values_and_wrap_in_block(bb, r); - } - default: return recreate_node_identity(&ctx->rewriter, old); + return call(a, (Call) { + .mem = shd_rewrite_node(r, payload.mem), + .callee = get_fn(r, callee_name), + .args = args, + }); } + break; } case TailCall_TAG: { //if (ctx->disable_lowering) // return recreate_node_identity(&ctx->rewriter, old); - BodyBuilder* bb = begin_body(a); - gen_push_values_stack(bb, rewrite_nodes(&ctx->rewriter, old->payload.tail_call.args)); - const Node* target = rewrite_node(&ctx->rewriter, old->payload.tail_call.target); - target = gen_conversion(bb, uint32_type(a), target); - - const Node* fork_call = call(a, (Call) { - .callee = access_decl(&ctx->rewriter, "builtin_fork"), - .args = nodes(a, 1, (const Node*[]) { target }) - }); - bind_instruction(bb, fork_call); - return finish_body(bb, fn_ret(a, (Return) { .fn = NULL, .args = nodes(a, 0, NULL) })); + TailCall payload = old->payload.tail_call; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + shd_bld_stack_push_values(bb, shd_rewrite_nodes(&ctx->rewriter, payload.args)); + const Node* target = shd_rewrite_node(&ctx->rewriter, payload.callee); + target = shd_bld_conversion(bb, shd_uint32_type(a), target); + + shd_bld_call(bb, get_fn(&ctx->rewriter, "builtin_fork"), shd_singleton(target)); + return shd_bld_finish(bb, fn_ret(a, (Return) { .args = shd_empty(a), .mem = shd_bb_mem(bb) })); } case Join_TAG: { + Join payload = old->payload.join; //if (ctx->disable_lowering) // return recreate_node_identity(&ctx->rewriter, old); - const Node* jp = rewrite_node(&ctx->rewriter, old->payload.join.join_point); + const Node* jp = shd_rewrite_node(&ctx->rewriter, old->payload.join.join_point); const Node* jp_type = jp->type; - deconstruct_qualified_type(&jp_type); + shd_deconstruct_qualified_type(&jp_type); if (jp_type->tag == JoinPointType_TAG) break; - BodyBuilder* bb = begin_body(a); - gen_push_values_stack(bb, rewrite_nodes(&ctx->rewriter, old->payload.join.args)); - const Node* payload = gen_primop_e(bb, extract_op, empty(a), mk_nodes(a, jp, int32_literal(a, 2))); - gen_push_value_stack(bb, payload); - const Node* dst = gen_primop_e(bb, extract_op, empty(a), mk_nodes(a, jp, int32_literal(a, 1))); - const Node* tree_node = gen_primop_e(bb, extract_op, empty(a), mk_nodes(a, jp, int32_literal(a, 0))); + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + shd_bld_stack_push_values(bb, shd_rewrite_nodes(&ctx->rewriter, old->payload.join.args)); + const Node* jp_payload = prim_op_helper(a, extract_op, shd_empty(a), mk_nodes(a, jp, shd_int32_literal(a, 2))); + shd_bld_stack_push_value(bb, jp_payload); + const Node* dst = prim_op_helper(a, extract_op, shd_empty(a), mk_nodes(a, jp, shd_int32_literal(a, 1))); + const Node* tree_node = prim_op_helper(a, extract_op, shd_empty(a), mk_nodes(a, jp, shd_int32_literal(a, 0))); - const Node* join_call = call(a, (Call) { - .callee = access_decl(&ctx->rewriter, "builtin_join"), - .args = mk_nodes(a, dst, tree_node) - }); - bind_instruction(bb, join_call); - return finish_body(bb, fn_ret(a, (Return) { .fn = NULL, .args = nodes(a, 0, NULL) })); + shd_bld_call(bb, get_fn(&ctx->rewriter, "builtin_join"), mk_nodes(a, dst, tree_node)); + return shd_bld_finish(bb, fn_ret(a, (Return) { .args = shd_empty(a), .mem = shd_bb_mem(bb) })); } case PtrType_TAG: { const Node* pointee = old->payload.ptr_type.pointed_type; if (pointee->tag == FnType_TAG) { - const Type* emulated_fn_ptr_type = uint64_type(a); + const Type* emulated_fn_ptr_type = shd_uint64_type(a); return emulated_fn_ptr_type; } break; } case Control_TAG: { - if (is_control_static(ctx->scope_uses, old)) { - const Node* old_inside = old->payload.control.inside; - const Node* old_jp = first(get_abstraction_params(old_inside)); - assert(old_jp->tag == Variable_TAG); + Control payload = old->payload.control; + if (shd_is_control_static(ctx->uses, old)) { + // const Node* old_inside = old->payload.control.inside; + const Node* old_jp = shd_first(get_abstraction_params(payload.inside)); + assert(old_jp->tag == Param_TAG); const Node* old_jp_type = old_jp->type; - deconstruct_qualified_type(&old_jp_type); + shd_deconstruct_qualified_type(&old_jp_type); assert(old_jp_type->tag == JoinPointType_TAG); const Node* new_jp_type = join_point_type(a, (JoinPointType) { - .yield_types = rewrite_nodes(&ctx->rewriter, old_jp_type->payload.join_point_type.yield_types), + .yield_types = shd_rewrite_nodes(&ctx->rewriter, old_jp_type->payload.join_point_type.yield_types), }); - const Node* new_jp = var(a, qualified_type_helper(new_jp_type, true), old_jp->payload.var.name); - register_processed(&ctx->rewriter, old_jp, new_jp); - const Node* new_body = case_(a, singleton(new_jp), rewrite_node(&ctx->rewriter, get_abstraction_body(old_inside))); + const Node* new_jp = param(a, shd_as_qualified_type(new_jp_type, true), old_jp->payload.param.name); + shd_register_processed(&ctx->rewriter, old_jp, new_jp); + Node* new_control_case = case_(a, shd_singleton(new_jp)); + shd_register_processed(r, payload.inside, new_control_case); + shd_set_abstraction_body(new_control_case, shd_rewrite_node(&ctx->rewriter, get_abstraction_body(payload.inside))); + // BodyBuilder* bb = begin_body_with_mem(a, rewrite_node(r, payload.mem)); + Nodes nyield_types = shd_rewrite_nodes(&ctx->rewriter, old->payload.control.yield_types); return control(a, (Control) { - .yield_types = rewrite_nodes(&ctx->rewriter, old->payload.control.yield_types), - .inside = new_body, + .yield_types = nyield_types, + .inside = new_control_case, + .tail = shd_rewrite_node(r, get_structured_construct_tail(old)), + .mem = shd_rewrite_node(r, payload.mem), }); + //return yield_values_and_wrap_in_block(bb, gen_control(bb, nyield_types, new_body)); } break; } default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); } -void generate_top_level_dispatch_fn(Context* ctx) { +static void generate_top_level_dispatch_fn(Context* ctx) { assert(ctx->config->dynamic_scheduling); assert(*ctx->top_dispatcher_fn); assert((*ctx->top_dispatcher_fn)->tag == Function_TAG); IrArena* a = ctx->rewriter.dst_arena; - BodyBuilder* loop_body_builder = begin_body(a); - - const Node* next_function = gen_load(loop_body_builder, access_decl(&ctx->rewriter, "next_fn")); - const Node* get_active_branch_fn = access_decl(&ctx->rewriter, "builtin_get_active_branch"); - const Node* next_mask = first(bind_instruction(loop_body_builder, call(a, (Call) { .callee = get_active_branch_fn, .args = empty(a) }))); - const Node* local_id = gen_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupLocalInvocationId); - const Node* should_run = gen_primop_e(loop_body_builder, mask_is_thread_active_op, empty(a), mk_nodes(a, next_mask, local_id)); + BodyBuilder* dispatcher_body_builder = shd_bld_begin(a, shd_get_abstraction_mem(*ctx->top_dispatcher_fn)); bool count_iterations = ctx->config->shader_diagnostics.max_top_iterations > 0; + const Node* iterations_count_param = NULL; + // if (count_iterations) + // iterations_count_param = param(a, qualified_type(a, (QualifiedType) { .type = int32_type(a), .is_uniform = true }), "iterations"); + + // Node* loop_inside_case = case_(a, count_iterations ? singleton(iterations_count_param) : shd_nodes(a, 0, NULL)); + // gen_loop(dispatcher_body_builder, empty(a), count_iterations ? singleton(int32_literal(a, 0)) : empty(a), loop_inside_case); + begin_loop_helper_t l = shd_bld_begin_loop_helper(dispatcher_body_builder, shd_empty(a), count_iterations ? shd_singleton(shd_int32_type(a)) : shd_empty(a), count_iterations ? shd_singleton(shd_int32_literal(a, 0)) : shd_empty(a)); + Node* loop_inside_case = l.loop_body; if (count_iterations) - iterations_count_param = var(a, qualified_type(a, (QualifiedType) { .type = int32_type(a), .is_uniform = true }), "iterations"); + iterations_count_param = shd_first(l.params); + BodyBuilder* loop_body_builder = shd_bld_begin(a, shd_get_abstraction_mem(loop_inside_case)); + + const Node* next_function = shd_bld_load(loop_body_builder, ref_decl_helper(a, shd_find_or_process_decl(&ctx->rewriter, "next_fn"))); + const Node* get_active_branch_fn = get_fn(&ctx->rewriter, "builtin_get_active_branch"); + const Node* next_mask = shd_first(shd_bld_call(loop_body_builder, get_active_branch_fn, shd_empty(a))); + const Node* local_id = shd_bld_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupLocalInvocationId); + const Node* should_run = prim_op_helper(a, mask_is_thread_active_op, shd_empty(a), mk_nodes(a, next_mask, local_id)); if (ctx->config->printf_trace.god_function) { - const Node* sid = gen_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupId); + const Node* sid = shd_bld_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupId); if (count_iterations) - bind_instruction(loop_body_builder, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = "trace: top loop, thread:%d:%d iteration=%d next_fn=%d next_mask=%lx\n" }), sid, local_id, iterations_count_param, next_function, next_mask) })); + shd_bld_debug_printf(loop_body_builder, "trace: top loop, thread:%d:%d iteration=%d next_fn=%d next_mask=%lx\n", mk_nodes(a, sid, local_id, iterations_count_param, next_function, next_mask)); else - bind_instruction(loop_body_builder, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = "trace: top loop, thread:%d:%d next_fn=%d next_mask=%lx\n" }), sid, local_id, next_function, next_mask) })); + shd_bld_debug_printf(loop_body_builder, "trace: top loop, thread:%d:%d next_fn=%d next_mask=%lx\n", mk_nodes(a, sid, local_id, next_function, next_mask)); } const Node* iteration_count_plus_one = NULL; if (count_iterations) - iteration_count_plus_one = gen_primop_e(loop_body_builder, add_op, empty(a), mk_nodes(a, iterations_count_param, int32_literal(a, 1))); - - const Node* break_terminator = merge_break(a, (MergeBreak) { .args = nodes(a, 0, NULL) }); - const Node* continue_terminator = merge_continue(a, (MergeContinue) { - .args = count_iterations ? singleton(iteration_count_plus_one) : nodes(a, 0, NULL), - }); + iteration_count_plus_one = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, iterations_count_param, shd_int32_literal(a, 1))); if (ctx->config->shader_diagnostics.max_top_iterations > 0) { - const Node* bail_condition = gen_primop_e(loop_body_builder, gt_op, empty(a), mk_nodes(a, iterations_count_param, int32_literal(a, ctx->config->shader_diagnostics.max_top_iterations))); - const Node* bail_true_lam = case_(a, empty(a), break_terminator); - const Node* bail_if = if_instr(a, (If) { + begin_control_t c = shd_bld_begin_control(loop_body_builder, shd_empty(a)); + const Node* bail_condition = prim_op_helper(a, gt_op, shd_empty(a), mk_nodes(a, iterations_count_param, shd_int32_literal(a, ctx->config->shader_diagnostics.max_top_iterations))); + Node* bail_case = case_(a, shd_empty(a)); + const Node* break_terminator = join(a, (Join) { .args = shd_empty(a), .join_point = l.break_jp, .mem = shd_get_abstraction_mem(bail_case) }); + shd_set_abstraction_body(bail_case, break_terminator); + Node* proceed_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(proceed_case, join(a, (Join) { + .join_point = c.jp, + .mem = shd_get_abstraction_mem(proceed_case), + .args = shd_empty(a), + })); + shd_set_abstraction_body(c.case_, branch(a, (Branch) { + .mem = shd_get_abstraction_mem(c.case_), .condition = bail_condition, - .if_true = bail_true_lam, - .if_false = NULL, - .yield_types = empty(a) - }); - bind_instruction(loop_body_builder, bail_if); + .true_jump = jump_helper(a, shd_get_abstraction_mem(c.case_), bail_case, shd_empty(a)), + .false_jump = jump_helper(a, shd_get_abstraction_mem(c.case_), proceed_case, shd_empty(a)), + })); + // gen_if(loop_body_builder, empty(a), bail_condition, bail_case, NULL); } - struct List* literals = new_list(const Node*); - struct List* cases = new_list(const Node*); + struct List* literals = shd_new_list(const Node*); + struct List* jumps = shd_new_list(const Node*); // Build 'zero' case (exits the program) - BodyBuilder* zero_case_builder = begin_body(a); - BodyBuilder* zero_if_case_builder = begin_body(a); + Node* zero_case_lam = case_(a, shd_nodes(a, 0, NULL)); + Node* zero_if_true_lam = case_(a, shd_empty(a)); + BodyBuilder* zero_if_case_builder = shd_bld_begin(a, shd_get_abstraction_mem(zero_if_true_lam)); if (ctx->config->printf_trace.god_function) { - const Node* sid = gen_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupId); - bind_instruction(zero_if_case_builder, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = "trace: kill thread %d:%d\n" }), sid, local_id) })); + const Node* sid = shd_bld_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupId); + shd_bld_debug_printf(zero_if_case_builder, "trace: kill thread %d:%d\n", mk_nodes(a, sid, local_id)); } - const Node* zero_if_true_lam = case_(a, empty(a), finish_body(zero_if_case_builder, break_terminator)); - const Node* zero_if_instruction = if_instr(a, (If) { - .condition = should_run, - .if_true = zero_if_true_lam, - .if_false = NULL, - .yield_types = empty(a), - }); - bind_instruction(zero_case_builder, zero_if_instruction); + shd_set_abstraction_body(zero_if_true_lam, shd_bld_join(zero_if_case_builder, l.break_jp, shd_empty(a))); + + Node* zero_if_false = case_(a, shd_empty(a)); + BodyBuilder* zero_false_builder = shd_bld_begin(a, shd_get_abstraction_mem(zero_if_false)); if (ctx->config->printf_trace.god_function) { - const Node* sid = gen_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupId); - bind_instruction(zero_case_builder, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = "trace: thread %d:%d escaped death!\n" }), sid, local_id) })); + const Node* sid = shd_bld_builtin_load(ctx->rewriter.dst_module, zero_false_builder, BuiltinSubgroupId); + shd_bld_debug_printf(zero_false_builder, "trace: thread %d:%d escaped death!\n", mk_nodes(a, sid, local_id)); } + shd_set_abstraction_body(zero_if_false, shd_bld_join(zero_false_builder, l.continue_jp, count_iterations ? shd_singleton(iteration_count_plus_one) : shd_empty(a))); - const Node* zero_case_lam = case_(a, nodes(a, 0, NULL), finish_body(zero_case_builder, continue_terminator)); - const Node* zero_lit = uint64_literal(a, 0); - append_list(const Node*, literals, zero_lit); - append_list(const Node*, cases, zero_case_lam); + shd_set_abstraction_body(zero_case_lam, branch(a, (Branch) { + .mem = shd_get_abstraction_mem(zero_case_lam), + .condition = should_run, + .true_jump = jump_helper(a, shd_get_abstraction_mem(zero_case_lam), zero_if_true_lam, shd_empty(a)), + .false_jump = jump_helper(a, shd_get_abstraction_mem(zero_case_lam), zero_if_false, shd_empty(a)), + })); + + const Node* zero_lit = shd_uint64_literal(a, 0); + shd_list_append(const Node*, literals, zero_lit); + const Node* zero_jump = jump_helper(a, shd_bb_mem(loop_body_builder), zero_case_lam, shd_empty(a)); + shd_list_append(const Node*, jumps, zero_jump); - Nodes old_decls = get_module_declarations(ctx->rewriter.src_module); + Nodes old_decls = shd_module_get_declarations(ctx->rewriter.src_module); for (size_t i = 0; i < old_decls.count; i++) { const Node* decl = old_decls.nodes[i]; if (decl->tag == Function_TAG) { - if (lookup_annotation(decl, "Leaf")) + if (shd_lookup_annotation(decl, "Leaf")) continue; - const Node* fn_lit = lower_fn_addr(ctx, decl); + const Node* fn_lit = shd_uint32_literal(a, get_fn_ptr(ctx, decl)); - BodyBuilder* if_builder = begin_body(a); + Node* if_true_case = case_(a, shd_empty(a)); + BodyBuilder* if_builder = shd_bld_begin(a, shd_get_abstraction_mem(if_true_case)); if (ctx->config->printf_trace.god_function) { - const Node* sid = gen_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupId); - bind_instruction(if_builder, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = "trace: thread %d:%d will run fn %d with mask = %lx\n" }), sid, local_id, fn_lit, next_mask) })); + const Node* sid = shd_bld_builtin_load(ctx->rewriter.dst_module, loop_body_builder, BuiltinSubgroupId); + shd_bld_debug_printf(if_builder, "trace: thread %d:%d will run fn %u with mask = %lx\n", mk_nodes(a, sid, local_id, fn_lit, next_mask)); } - bind_instruction(if_builder, call(a, (Call) { - .callee = fn_addr_helper(a, find_processed(&ctx->rewriter, decl)), - .args = nodes(a, 0, NULL) + shd_bld_call(if_builder, fn_addr_helper(a, shd_rewrite_node(&ctx->rewriter, decl)), shd_empty(a)); + shd_set_abstraction_body(if_true_case, shd_bld_join(if_builder, l.continue_jp, count_iterations ? shd_singleton(iteration_count_plus_one) : shd_empty(a))); + + Node* if_false = case_(a, shd_empty(a)); + shd_set_abstraction_body(if_false, join(a, (Join) { + .mem = shd_get_abstraction_mem(if_false), + .join_point = l.continue_jp, + .args = count_iterations ? shd_singleton(iteration_count_plus_one) : shd_empty(a) })); - const Node* if_true_lam = case_(a, empty(a), finish_body(if_builder, yield(a, (Yield) {.args = nodes(a, 0, NULL)}))); - const Node* if_instruction = if_instr(a, (If) { - .condition = should_run, - .if_true = if_true_lam, - .if_false = NULL, - .yield_types = empty(a), - }); - BodyBuilder* case_builder = begin_body(a); - bind_instruction(case_builder, if_instruction); - const Node* case_lam = case_(a, nodes(a, 0, NULL), finish_body(case_builder, continue_terminator)); + Node* fn_case = case_(a, shd_nodes(a, 0, NULL)); + shd_set_abstraction_body(fn_case, branch(a, (Branch) { + .mem = shd_get_abstraction_mem(fn_case), + .condition = should_run, + .true_jump = jump_helper(a, shd_get_abstraction_mem(fn_case), if_true_case, shd_empty(a)), + .false_jump = jump_helper(a, shd_get_abstraction_mem(fn_case), if_false, shd_empty(a)), + })); - append_list(const Node*, literals, fn_lit); - append_list(const Node*, cases, case_lam); + shd_list_append(const Node*, literals, fn_lit); + const Node* j = jump_helper(a, shd_bb_mem(loop_body_builder), fn_case, shd_empty(a)); + shd_list_append(const Node*, jumps, j); } } - const Node* default_case_lam = case_(a, nodes(a, 0, NULL), unreachable(a)); - - bind_instruction(loop_body_builder, match_instr(a, (Match) { - .yield_types = nodes(a, 0, NULL), - .inspect = next_function, - .literals = nodes(a, entries_count_list(literals), read_list(const Node*, literals)), - .cases = nodes(a, entries_count_list(cases), read_list(const Node*, cases)), - .default_case = default_case_lam, - })); - - destroy_list(literals); - destroy_list(cases); + Node* default_case = case_(a, shd_nodes(a, 0, NULL)); + shd_set_abstraction_body(default_case, unreachable(a, (Unreachable) { .mem = shd_get_abstraction_mem(default_case) })); - const Node* loop_inside_lam = case_(a, count_iterations ? singleton(iterations_count_param) : nodes(a, 0, NULL), finish_body(loop_body_builder, unreachable(a))); + shd_set_abstraction_body(loop_inside_case, shd_bld_finish(loop_body_builder, br_switch(a, (Switch) { + .mem = shd_bb_mem(loop_body_builder), + .switch_value = next_function, + .case_values = shd_nodes(a, shd_list_count(literals), shd_read_list(const Node*, literals)), + .case_jumps = shd_nodes(a, shd_list_count(jumps), shd_read_list(const Node*, jumps)), + .default_jump = jump_helper(a, shd_bb_mem(loop_body_builder), default_case, shd_empty(a)) + }))); - const Node* the_loop = loop_instr(a, (Loop) { - .yield_types = nodes(a, 0, NULL), - .initial_args = count_iterations ? singleton(int32_literal(a, 0)) : nodes(a, 0, NULL), - .body = loop_inside_lam - }); + shd_destroy_list(literals); + shd_destroy_list(jumps); - BodyBuilder* dispatcher_body_builder = begin_body(a); - bind_instruction(dispatcher_body_builder, the_loop); if (ctx->config->printf_trace.god_function) - bind_instruction(dispatcher_body_builder, prim_op(a, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(a, string_lit(a, (StringLiteral) { .string = "trace: end of top\n" })) })); + shd_bld_debug_printf(dispatcher_body_builder, "trace: end of top\n", shd_empty(a)); - (*ctx->top_dispatcher_fn)->payload.fun.body = finish_body(dispatcher_body_builder, fn_ret(a, (Return) { - .args = nodes(a, 0, NULL), - .fn = *ctx->top_dispatcher_fn, - })); + shd_set_abstraction_body(*ctx->top_dispatcher_fn, shd_bld_finish(dispatcher_body_builder, fn_ret(a, (Return) { + .args = shd_nodes(a, 0, NULL), + .mem = shd_bb_mem(dispatcher_body_builder), + }))); } -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); -Module* lower_tailcalls(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_tailcalls(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); - struct Dict* ptrs = new_dict(const Node*, FnPtr, (HashFn) hash_node, (CmpFn) compare_node); + struct Dict* ptrs = shd_new_dict(const Node*, FnPtr, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); - Node* init_fn = function(dst, nodes(a, 0, NULL), "generated_init", mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }), annotation(a, (Annotation) { .name = "Leaf" }), annotation(a, (Annotation) { .name = "Structured" })), nodes(a, 0, NULL)); - init_fn->payload.fun.body = fn_ret(a, (Return) { .fn = init_fn, .args = empty(a) }); + Node* init_fn = function(dst, shd_nodes(a, 0, NULL), "generated_init", mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }), annotation(a, (Annotation) { .name = "Leaf" })), shd_nodes(a, 0, NULL)); + shd_set_abstraction_body(init_fn, fn_ret(a, (Return) { .args = shd_empty(a), .mem = shd_get_abstraction_mem(init_fn) })); FnPtr next_fn_ptr = 1; Node* top_dispatcher_fn = NULL; Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, .disable_lowering = false, .assigned_fn_ptrs = ptrs, @@ -440,13 +468,13 @@ Module* lower_tailcalls(SHADY_UNUSED const CompilerConfig* config, Module* src) .init_fn = init_fn, }; - rewrite_module(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); // Generate the top dispatcher, but only if it is used for realsies if (*ctx.top_dispatcher_fn) generate_top_level_dispatch_fn(&ctx); - destroy_dict(ptrs); - destroy_rewriter(&ctx.rewriter); + shd_destroy_dict(ptrs); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_vec_arr.c b/src/shady/passes/lower_vec_arr.c index 78802fb36..eab656732 100644 --- a/src/shady/passes/lower_vec_arr.c +++ b/src/shady/passes/lower_vec_arr.c @@ -1,8 +1,7 @@ -#include "passes.h" - -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" +#include "shady/pass.h" +#include "shady/ir/type.h" +#include "shady/ir/composite.h" +#include "shady/ir/primop.h" #include "portability.h" @@ -14,24 +13,24 @@ typedef struct { static const Node* scalarify_primop(Context* ctx, const Node* old) { IrArena* a = ctx->rewriter.dst_arena; const Type* dst_type = old->type; - deconstruct_qualified_type(&dst_type); - size_t width = deconstruct_maybe_packed_type(&dst_type); + shd_deconstruct_qualified_type(&dst_type); + size_t width = shd_deconstruct_maybe_packed_type(&dst_type); if (width == 1) - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(&ctx->rewriter, old); LARRAY(const Node*, elements, width); - BodyBuilder* bb = begin_body(a); - Nodes noperands = rewrite_nodes(&ctx->rewriter, old->payload.prim_op.operands); + BodyBuilder* bb = shd_bld_begin_pure(a); + Nodes noperands = shd_rewrite_nodes(&ctx->rewriter, old->payload.prim_op.operands); for (size_t i = 0; i < width; i++) { LARRAY(const Node*, nops, noperands.count); for (size_t j = 0; j < noperands.count; j++) - nops[j] = gen_extract(bb, noperands.nodes[j], singleton(int32_literal(a, i))); - elements[i] = gen_primop_e(bb, old->payload.prim_op.op, empty(a), nodes(a, noperands.count, nops)); + nops[j] = shd_extract_helper(a, noperands.nodes[j], shd_singleton(shd_int32_literal(a, i))); + elements[i] = prim_op_helper(a, old->payload.prim_op.op, shd_empty(a), shd_nodes(a, noperands.count, nops)); } const Type* t = arr_type(a, (ArrType) { - .element_type = rewrite_node(&ctx->rewriter, dst_type), - .size = int32_literal(a, width) + .element_type = shd_rewrite_node(&ctx->rewriter, dst_type), + .size = shd_int32_literal(a, width) }); - return yield_values_and_wrap_in_block(bb, singleton(composite_helper(a, t, nodes(a, width, elements)))); + return shd_bld_to_instr_yield_values(bb, shd_singleton(composite_helper(a, t, shd_nodes(a, width, elements)))); } static const Node* process(Context* ctx, const Node* node) { @@ -40,30 +39,30 @@ static const Node* process(Context* ctx, const Node* node) { switch (node->tag) { case PackType_TAG: { return arr_type(a, (ArrType) { - .element_type = rewrite_node(&ctx->rewriter, node->payload.pack_type.element_type), - .size = int32_literal(a, node->payload.pack_type.width) + .element_type = shd_rewrite_node(&ctx->rewriter, node->payload.pack_type.element_type), + .size = shd_int32_literal(a, node->payload.pack_type.width) }); } case PrimOp_TAG: { - if (get_primop_class(node->payload.prim_op.op) & (OcArithmetic | OcLogic | OcCompare | OcShift | OcMath)) + if (shd_get_primop_class(node->payload.prim_op.op) & (OcArithmetic | OcLogic | OcCompare | OcShift | OcMath)) return scalarify_primop(ctx, node); } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_vec_arr(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); +Module* shd_pass_lower_vec_arr(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); aconfig.validate_builtin_types = false; // TODO: hacky - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/lower_workgroups.c b/src/shady/passes/lower_workgroups.c index 06cc7d982..ba9383c49 100644 --- a/src/shady/passes/lower_workgroups.c +++ b/src/shady/passes/lower_workgroups.c @@ -1,11 +1,9 @@ -#include "passes.h" - -#include "util.h" +#include "shady/pass.h" +#include "shady/ir/builtin.h" #include "../ir_private.h" -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" + +#include "util.h" #include #include @@ -18,171 +16,187 @@ typedef struct { bool is_entry_point; } Context; +static void add_bounds_check(IrArena* a, BodyBuilder* bb, const Node* i, const Node* max) { + Node* out_of_bounds_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(out_of_bounds_case, merge_break(a, (MergeBreak) { + .args = shd_empty(a), + .mem = shd_get_abstraction_mem(out_of_bounds_case) + })); + shd_bld_if(bb, shd_empty(a), prim_op_helper(a, gte_op, shd_empty(a), mk_nodes(a, i, max)), out_of_bounds_case, NULL); +} + static const Node* process(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - Module* m = ctx->rewriter.dst_module; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + Module* m = r->dst_module; switch (node->tag) { case GlobalVariable_TAG: { - const Node* ba = lookup_annotation(node, "Builtin"); + const Node* ba = shd_lookup_annotation(node, "Builtin"); if (ba) { - Nodes filtered_as = rewrite_nodes(&ctx->rewriter, filter_out_annotation(a, node->payload.global_variable.annotations, "Builtin")); - Builtin b = get_builtin_by_name(get_annotation_string_payload(ba)); + Nodes filtered_as = shd_rewrite_nodes(&ctx->rewriter, shd_filter_out_annotation(a, node->payload.global_variable.annotations, "Builtin")); + Builtin b = shd_get_builtin_by_name(shd_get_annotation_string_payload(ba)); switch (b) { case BuiltinSubgroupId: case BuiltinWorkgroupId: case BuiltinGlobalInvocationId: case BuiltinLocalInvocationId: - return global_var(m, filtered_as, rewrite_node(&ctx->rewriter, node->payload.global_variable.type), node->payload.global_variable.name, AsPrivateLogical); + return global_var(m, filtered_as, shd_rewrite_node(&ctx->rewriter, node->payload.global_variable.type), node->payload.global_variable.name, AsPrivate); case BuiltinNumWorkgroups: - return global_var(m, filtered_as, rewrite_node(&ctx->rewriter, node->payload.global_variable.type), node->payload.global_variable.name, AsExternal); - case BuiltinWorkgroupSize: - assert(false); + return global_var(m, filtered_as, shd_rewrite_node(&ctx->rewriter, node->payload.global_variable.type), node->payload.global_variable.name, AsExternal); default: break; } - return get_builtin(ctx->rewriter.dst_module, b, get_decl_name(node)); + return shd_get_or_create_builtin(ctx->rewriter.dst_module, b, get_declaration_name(node)); } break; } case Function_TAG: { Context ctx2 = *ctx; ctx2.is_entry_point = false; - const Node* epa = lookup_annotation(node, "EntryPoint"); - if (epa && strcmp(get_annotation_string_payload(epa), "Compute") == 0) { + const Node* epa = shd_lookup_annotation(node, "EntryPoint"); + if (epa && strcmp(shd_get_annotation_string_payload(epa), "Compute") == 0) { ctx2.is_entry_point = true; assert(node->payload.fun.return_types.count == 0 && "entry points do not return at this stage"); - Nodes wannotations = rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); - Nodes wparams = recreate_variables(&ctx->rewriter, node->payload.fun.params); - Node* wrapper = function(m, wparams, get_abstraction_name(node), wannotations, empty(a)); - register_processed(&ctx->rewriter, node, wrapper); + Nodes wannotations = shd_rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); + Nodes wparams = shd_recreate_params(&ctx->rewriter, node->payload.fun.params); + Node* wrapper = function(m, wparams, shd_get_abstraction_name(node), wannotations, shd_empty(a)); + shd_register_processed(&ctx->rewriter, node, wrapper); // recreate the old entry point, but this time it's not the entry point anymore - Nodes nannotations = filter_out_annotation(a, wannotations, "EntryPoint"); - Nodes nparams = recreate_variables(&ctx->rewriter, node->payload.fun.params); - Node* inner = function(m, nparams, format_string_arena(a->arena, "%s_wrapped", get_abstraction_name(node)), nannotations, empty(a)); - register_processed_list(&ctx->rewriter, node->payload.fun.params, nparams); - inner->payload.fun.body = recreate_node_identity(&ctx->rewriter, node->payload.fun.body); + Nodes nannotations = shd_filter_out_annotation(a, wannotations, "EntryPoint"); + Nodes nparams = shd_recreate_params(&ctx->rewriter, node->payload.fun.params); + Node* inner = function(m, nparams, shd_format_string_arena(a->arena, "%s_wrapped", shd_get_abstraction_name(node)), nannotations, shd_empty(a)); + shd_register_processed_list(&ctx->rewriter, node->payload.fun.params, nparams); + shd_register_processed(&ctx->rewriter, shd_get_abstraction_mem(node), shd_get_abstraction_mem(inner)); + shd_set_abstraction_body(inner, shd_recreate_node(&ctx->rewriter, node->payload.fun.body)); - BodyBuilder* bb = begin_body(a); - const Node* num_workgroups_var = rewrite_node(&ctx->rewriter, get_builtin(ctx->rewriter.src_module, BuiltinNumWorkgroups, NULL)); - const Node* workgroup_num_vec3 = gen_load(bb, ref_decl_helper(a, num_workgroups_var)); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(wrapper)); + const Node* num_workgroups_var = shd_rewrite_node(&ctx->rewriter, shd_get_or_create_builtin(ctx->rewriter.src_module, BuiltinNumWorkgroups, NULL)); + const Node* workgroup_num_vec3 = shd_bld_load(bb, ref_decl_helper(a, num_workgroups_var)); // prepare variables for iterating over workgroups String names[] = { "gx", "gy", "gz" }; const Node* workgroup_id[3]; const Node* num_workgroups[3]; for (int dim = 0; dim < 3; dim++) { - workgroup_id[dim] = var(a, qualified_type_helper(uint32_type(a), false), names[dim]); - num_workgroups[dim] = gen_extract(bb, workgroup_num_vec3, singleton(uint32_literal(a, dim))); + workgroup_id[dim] = param(a, shd_as_qualified_type(shd_uint32_type(a), false), names[dim]); + num_workgroups[dim] = shd_extract_helper(a, workgroup_num_vec3, shd_singleton(shd_uint32_literal(a, dim))); } // Prepare variables for iterating inside workgroups const Node* subgroup_id[3]; uint32_t num_subgroups[3]; const Node* num_subgroups_literals[3]; - assert(a->config.specializations.subgroup_size); + assert(ctx->config->specialization.subgroup_size); assert(a->config.specializations.workgroup_size[0] && a->config.specializations.workgroup_size[1] && a->config.specializations.workgroup_size[2]); - num_subgroups[0] = a->config.specializations.workgroup_size[0] / a->config.specializations.subgroup_size; + num_subgroups[0] = a->config.specializations.workgroup_size[0] / ctx->config->specialization.subgroup_size; num_subgroups[1] = a->config.specializations.workgroup_size[1]; num_subgroups[2] = a->config.specializations.workgroup_size[2]; String names2[] = { "sgx", "sgy", "sgz" }; for (int dim = 0; dim < 3; dim++) { - subgroup_id[dim] = var(a, qualified_type_helper(uint32_type(a), false), names2[dim]); - num_subgroups_literals[dim] = uint32_literal(a, num_subgroups[dim]); + subgroup_id[dim] = param(a, shd_as_qualified_type(shd_uint32_type(a), false), names2[dim]); + num_subgroups_literals[dim] = shd_uint32_literal(a, num_subgroups[dim]); } - BodyBuilder* bb2 = begin_body(a); + Node* cases[6]; + BodyBuilder* builders[6]; + for (int scope = 0; scope < 2; scope++) { + const Node** params; + const Node** maxes; + if (scope == 1) { + params = subgroup_id; + maxes = num_subgroups_literals; + } else if (scope == 0) { + params = workgroup_id; + maxes = num_workgroups; + } else + assert(false); + for (int dim = 0; dim < 3; dim++) { + Node* loop_body = case_(a, shd_singleton(params[dim])); + cases[scope * 3 + dim] = loop_body; + BodyBuilder* loop_bb = shd_bld_begin(a, shd_get_abstraction_mem(loop_body)); + builders[scope * 3 + dim] = loop_bb; + add_bounds_check(a, loop_bb, params[dim], maxes[dim]); + } + } + + // BodyBuilder* bb2 = begin_block_with_side_effects(a, bb_mem(builders[5])); + BodyBuilder* bb2 = builders[5]; // write the workgroup ID - gen_store(bb2, ref_decl_helper(a, rewrite_node(&ctx->rewriter, get_builtin(ctx->rewriter.src_module, BuiltinWorkgroupId, NULL))), composite_helper(a, pack_type(a, (PackType) { .element_type = uint32_type(a), .width = 3 }), mk_nodes(a, workgroup_id[0], workgroup_id[1], workgroup_id[2]))); + shd_bld_store(bb2, ref_decl_helper(a, shd_rewrite_node(&ctx->rewriter, shd_get_or_create_builtin(ctx->rewriter.src_module, BuiltinWorkgroupId, NULL))), composite_helper(a, pack_type(a, (PackType) { .element_type = shd_uint32_type(a), .width = 3 }), mk_nodes(a, workgroup_id[0], workgroup_id[1], workgroup_id[2]))); // write the local ID const Node* local_id[3]; // local_id[0] = SUBGROUP_SIZE * subgroup_id[0] + subgroup_local_id - local_id[0] = gen_primop_e(bb2, add_op, empty(a), mk_nodes(a, gen_primop_e(bb2, mul_op, empty(a), mk_nodes(a, uint32_literal(a, a->config.specializations.subgroup_size), subgroup_id[0])), gen_builtin_load(m, bb, BuiltinSubgroupLocalInvocationId))); + local_id[0] = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, prim_op_helper(a, mul_op, shd_empty(a), mk_nodes(a, shd_uint32_literal(a, ctx->config->specialization.subgroup_size), subgroup_id[0])), shd_bld_builtin_load(m, bb, BuiltinSubgroupLocalInvocationId))); local_id[1] = subgroup_id[1]; local_id[2] = subgroup_id[2]; - gen_store(bb2, ref_decl_helper(a, rewrite_node(&ctx->rewriter, get_builtin(ctx->rewriter.src_module, BuiltinLocalInvocationId, NULL))), composite_helper(a, pack_type(a, (PackType) { .element_type = uint32_type(a), .width = 3 }), mk_nodes(a, local_id[0], local_id[1], local_id[2]))); + shd_bld_store(bb2, ref_decl_helper(a, shd_rewrite_node(&ctx->rewriter, shd_get_or_create_builtin(ctx->rewriter.src_module, BuiltinLocalInvocationId, NULL))), composite_helper(a, pack_type(a, (PackType) { .element_type = shd_uint32_type(a), .width = 3 }), mk_nodes(a, local_id[0], local_id[1], local_id[2]))); // write the global ID const Node* global_id[3]; for (int dim = 0; dim < 3; dim++) - global_id[dim] = gen_primop_e(bb2, add_op, empty(a), mk_nodes(a, gen_primop_e(bb2, mul_op, empty(a), mk_nodes(a, uint32_literal(a, a->config.specializations.workgroup_size[dim]), workgroup_id[dim])), local_id[dim])); - gen_store(bb2, ref_decl_helper(a, rewrite_node(&ctx->rewriter, get_builtin(ctx->rewriter.src_module, BuiltinGlobalInvocationId, NULL))), composite_helper(a, pack_type(a, (PackType) { .element_type = uint32_type(a), .width = 3 }), mk_nodes(a, global_id[0], global_id[1], global_id[2]))); + global_id[dim] = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, prim_op_helper(a, mul_op, shd_empty(a), mk_nodes(a, shd_uint32_literal(a, a->config.specializations.workgroup_size[dim]), workgroup_id[dim])), local_id[dim])); + shd_bld_store(bb2, ref_decl_helper(a, shd_rewrite_node(&ctx->rewriter, shd_get_or_create_builtin(ctx->rewriter.src_module, BuiltinGlobalInvocationId, NULL))), composite_helper(a, pack_type(a, (PackType) { .element_type = shd_uint32_type(a), .width = 3 }), mk_nodes(a, global_id[0], global_id[1], global_id[2]))); // TODO: write the subgroup ID - - bind_instruction(bb2, call(a, (Call) { .callee = fn_addr_helper(a, inner), .args = wparams })); - const Node* instr = yield_values_and_wrap_in_block(bb2, empty(a)); + shd_bld_call(bb2, fn_addr_helper(a, inner), wparams); // Wrap in 3 loops for iterating over subgroups, then again for workgroups - for (int scope = 0; scope < 2; scope++) { + for (unsigned scope = 1; scope < 2; scope--) { const Node** params; - const Node** maxes; if (scope == 0) { - params = subgroup_id; - maxes = num_subgroups_literals; - } else if (scope == 1) { params = workgroup_id; - maxes = num_workgroups; + } else if (scope == 1) { + params = subgroup_id; } else assert(false); - for (int dim = 0; dim < 3; dim++) { - BodyBuilder* body_bb = begin_body(a); - bind_instruction(body_bb, if_instr(a, (If) { - .yield_types = empty(a), - .condition = gen_primop_e(body_bb, gte_op, empty(a), mk_nodes(a, params[dim], maxes[dim])), - .if_true = case_(a, empty(a), merge_break(a, (MergeBreak) {.args = empty(a)})), - .if_false = NULL - })); - bind_instruction(body_bb, instr); - const Node* loop = loop_instr(a, (Loop) { - .yield_types = empty(a), - .initial_args = singleton(uint32_literal(a, 0)), - .body = case_(a, singleton(params[dim]), finish_body(body_bb, merge_continue(a, (MergeContinue) {.args = singleton(gen_primop_e(body_bb, add_op, empty(a), mk_nodes(a, params[dim], uint32_literal(a, 1))))}))) - }); - instr = loop; + for (unsigned dim = 2; dim < 3; dim--) { + size_t depth = scope * 3 + dim; + Node* loop_body = cases[depth]; + BodyBuilder* body_bb = builders[depth]; + + shd_set_abstraction_body(loop_body, shd_bld_finish(body_bb, merge_continue(a, (MergeContinue) { + .args = shd_singleton(prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, params[dim], shd_uint32_literal(a, 1)))), + .mem = shd_bb_mem(body_bb) + }))); + shd_bld_loop(depth > 0 ? builders[depth - 1] : bb, shd_empty(a), shd_singleton(shd_uint32_literal(a, 0)), loop_body); } } - bind_instruction(bb, instr); - wrapper->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .fn = wrapper, .args = empty(a) })); + shd_set_abstraction_body(wrapper, shd_bld_finish(bb, fn_ret(a, (Return) { .args = shd_empty(a), .mem = shd_bb_mem(bb) }))); return wrapper; } - return recreate_node_identity(&ctx2.rewriter, node); + return shd_recreate_node(&ctx2.rewriter, node); } - case PrimOp_TAG: { - switch (node->payload.prim_op.op) { - case load_op: { - const Node* ptr = first(node->payload.prim_op.operands); - if (ptr->tag == RefDecl_TAG) - ptr = ptr->payload.ref_decl.decl; - if (ptr == get_builtin(ctx->rewriter.src_module, BuiltinSubgroupId, NULL)) { - BodyBuilder* bb = begin_body(a); - const Node* loaded = first(bind_instruction(bb, recreate_node_identity(&ctx->rewriter, node))); - const Node* uniformized = first(gen_primop(bb, subgroup_broadcast_first_op, empty(a), singleton(loaded))); - return yield_values_and_wrap_in_block(bb, singleton(uniformized)); - } - } - default: break; + case Load_TAG: { + Load payload = node->payload.load; + const Node* ptr = payload.ptr; + if (ptr->tag == RefDecl_TAG) + ptr = ptr->payload.ref_decl.decl; + if (ptr == shd_get_or_create_builtin(ctx->rewriter.src_module, BuiltinSubgroupId, NULL)) { + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Node* loaded = shd_first(shd_bld_add_instruction_extract(bb, shd_recreate_node(&ctx->rewriter, node))); + const Node* uniformized = prim_op_helper(a, subgroup_assume_uniform_op, shd_empty(a), shd_singleton(loaded)); + return shd_bld_to_instr_yield_values(bb, shd_singleton(uniformized)); } - break; } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* lower_workgroups(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_lower_workgroups(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, .globals = calloc(sizeof(Node*), PRIMOPS_COUNT), }; - rewrite_module(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); free(ctx.globals); - destroy_rewriter(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/mark_leaf_functions.c b/src/shady/passes/mark_leaf_functions.c index 101109359..2e5b84442 100644 --- a/src/shady/passes/mark_leaf_functions.c +++ b/src/shady/passes/mark_leaf_functions.c @@ -1,24 +1,22 @@ -#include "passes.h" - -#include "dict.h" -#include "portability.h" -#include "log.h" - -#include "../rewrite.h" +#include "shady/pass.h" #include "../analysis/callgraph.h" -#include "../analysis/scope.h" +#include "../analysis/cfg.h" #include "../analysis/uses.h" #include "../analysis/leak.h" +#include "dict.h" +#include "portability.h" +#include "log.h" + typedef struct { Rewriter rewriter; CallGraph* graph; struct Dict* fns; bool is_leaf; - Scope* scope; - const UsesMap* scope_uses; + CFG* cfg; + const UsesMap* uses; } Context; typedef struct { @@ -29,7 +27,7 @@ typedef struct { } FnInfo; static bool is_leaf_fn(Context* ctx, CGNode* fn_node) { - FnInfo* info = find_value_dict(const Node*, FnInfo, ctx->fns, fn_node->fn); + FnInfo* info = shd_dict_find_value(const Node*, FnInfo, ctx->fns, fn_node->fn); if (info) { // if we encounter a function before 'done' is set, it must be part of a recursive chain if (!info->done) { @@ -43,29 +41,47 @@ static bool is_leaf_fn(Context* ctx, CGNode* fn_node) { .node = fn_node, .done = false, }; - insert_dict(const Node*, FnInfo, ctx->fns, fn_node->fn, initial_info); - info = find_value_dict(const Node*, FnInfo, ctx->fns, fn_node->fn); + shd_dict_insert(const Node*, FnInfo, ctx->fns, fn_node->fn, initial_info); + info = shd_dict_find_value(const Node*, FnInfo, ctx->fns, fn_node->fn); assert(info); - if (fn_node->is_address_captured || fn_node->is_recursive) { + if (fn_node->is_address_captured || fn_node->is_recursive || fn_node->calls_indirect) { info->is_leaf = false; info->done = true; - debugv_print("Function %s can't be a leaf function because %s.\n", get_abstraction_name(fn_node->fn), fn_node->is_address_captured ? "its address is captured" : "it is recursive" ); + shd_debugv_print("Function %s can't be a leaf function because", shd_get_abstraction_name(fn_node->fn)); + bool and = false; + if (fn_node->is_address_captured) { + shd_debugv_print("its address is captured"); + and = true; + } + if (fn_node->is_recursive) { + if (and) + shd_debugv_print(" and "); + shd_debugv_print("it is recursive"); + and = true; + } + if (fn_node->calls_indirect) { + if (and) + shd_debugv_print(" and "); + shd_debugv_print("it makes indirect calls"); + and = true; + } + shd_debugv_print(".\n"); return false; } size_t iter = 0; CGEdge e; - while (dict_iter(fn_node->callees, &iter, &e, NULL)) { + while (shd_dict_iter(fn_node->callees, &iter, &e, NULL)) { if (!is_leaf_fn(ctx, e.dst_fn)) { - debugv_print("Function %s can't be a leaf function because its callee %s is not a leaf function.\n", get_abstraction_name(fn_node->fn), get_abstraction_name(e.dst_fn->fn)); + shd_debugv_print("Function %s can't be a leaf function because its callee %s is not a leaf function.\n", shd_get_abstraction_name(fn_node->fn), shd_get_abstraction_name(e.dst_fn->fn)); info->is_leaf = false; info->done = true; } } // by analysing the callees, the dict might have been regrown so we must refetch this to update the ptr if needed - info = find_value_dict(const Node*, FnInfo, ctx->fns, fn_node->fn); + info = shd_dict_find_value(const Node*, FnInfo, ctx->fns, fn_node->fn); if (!info->done) { info->is_leaf = true; @@ -77,88 +93,76 @@ static bool is_leaf_fn(Context* ctx, CGNode* fn_node) { static const Node* process(Context* ctx, const Node* node) { IrArena* a = ctx->rewriter.dst_arena; - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - switch (node->tag) { case Function_TAG: { Context fn_ctx = *ctx; - CGNode* fn_node = *find_value_dict(const Node*, CGNode*, ctx->graph->fn2cgn, node); + CGNode* fn_node = *shd_dict_find_value(const Node*, CGNode*, ctx->graph->fn2cgn, node); fn_ctx.is_leaf = is_leaf_fn(ctx, fn_node); - fn_ctx.scope = new_scope(node); - fn_ctx.scope_uses = create_uses_map(node, (NcDeclaration | NcType)); + fn_ctx.cfg = build_fn_cfg(node); + fn_ctx.uses = shd_new_uses_map_fn(node, (NcDeclaration | NcType)); ctx = &fn_ctx; - Nodes annotations = rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); - Node* new = function(ctx->rewriter.dst_module, recreate_variables(&ctx->rewriter, node->payload.fun.params), node->payload.fun.name, annotations, rewrite_nodes(&ctx->rewriter, node->payload.fun.return_types)); + Nodes annotations = shd_rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); + Node* new = function(ctx->rewriter.dst_module, shd_recreate_params(&ctx->rewriter, node->payload.fun.params), node->payload.fun.name, annotations, shd_rewrite_nodes(&ctx->rewriter, node->payload.fun.return_types)); for (size_t i = 0; i < new->payload.fun.params.count; i++) - register_processed(&ctx->rewriter, node->payload.fun.params.nodes[i], new->payload.fun.params.nodes[i]); - register_processed(&ctx->rewriter, node, new); - recreate_decl_body_identity(&ctx->rewriter, node, new); + shd_register_processed(&ctx->rewriter, node->payload.fun.params.nodes[i], new->payload.fun.params.nodes[i]); + shd_register_processed(&ctx->rewriter, node, new); + shd_recreate_node_body(&ctx->rewriter, node, new); if (fn_ctx.is_leaf) { - debugv_print("Function %s is a leaf function!\n", get_abstraction_name(node)); - new->payload.fun.annotations = append_nodes(a, annotations, annotation(a, (Annotation) { - .name = "Leaf", + shd_debugv_print("Function %s is a leaf function!\n", shd_get_abstraction_name(node)); + new->payload.fun.annotations = shd_nodes_append(a, annotations, annotation(a, (Annotation) { + .name = "Leaf", })); } - destroy_uses_map(fn_ctx.scope_uses); - destroy_scope(fn_ctx.scope); + shd_destroy_uses_map(fn_ctx.uses); + shd_destroy_cfg(fn_ctx.cfg); return new; } case Control_TAG: { - if (!is_control_static(ctx->scope_uses, node)) { - debugv_print("Function %s can't be a leaf function because the join point ", get_abstraction_name(ctx->scope->entry->node)); - log_node(DEBUGV, first(node->payload.control.inside->payload.case_.params)); - debugv_print("escapes its control block, preventing restructuring.\n"); + if (!shd_is_control_static(ctx->uses, node)) { + shd_debugv_print("Function %s can't be a leaf function because the join point ", shd_get_abstraction_name(ctx->cfg->entry->node)); + shd_log_node(DEBUGV, shd_first(get_abstraction_params(node->payload.control.inside))); + shd_debugv_print("escapes its control block, preventing restructuring.\n"); ctx->is_leaf = false; } break; } case Join_TAG: { const Node* old_jp = node->payload.join.join_point; - // is it associated with a control node ? - if (old_jp->tag == Variable_TAG) { - const Node* abs = old_jp->payload.var.abs; - assert(abs); - if (abs->tag == Case_TAG) { - const Node* structured = abs->payload.case_.structured_construct; - assert(structured); - // this join point is defined by a control - we can be a leaf :) - if (structured->tag == Control_TAG) - break; - } + if (old_jp->tag == Param_TAG) { + const Node* control = shd_get_control_for_jp(ctx->uses, old_jp); + if (control && shd_is_control_static(ctx->uses, control)) + break; } - debugv_print("Function %s can't be a leaf function because it joins with ", get_abstraction_name(ctx->scope->entry->node)); - log_node(DEBUGV, old_jp); - debugv_print("which is not bound by a control node within that function.\n"); - // we join with some random join point; we can't be a leaf :( + shd_debugv_print("Function %s can't be a leaf function because it joins with ", shd_get_abstraction_name(ctx->cfg->entry->node)); + shd_log_node(DEBUGV, old_jp); + shd_debugv_print("which is not bound by a control node within that function.\n"); ctx->is_leaf = false; break; } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); -Module* mark_leaf_functions(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_mark_leaf_functions(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .fns = new_dict(const Node*, FnInfo, (HashFn) hash_node, (CmpFn) compare_node), - .graph = new_callgraph(src) + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .fns = shd_new_dict(const Node*, FnInfo, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .graph = shd_new_callgraph(src) }; - rewrite_module(&ctx.rewriter); - destroy_dict(ctx.fns); - destroy_callgraph(ctx.graph); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_dict(ctx.fns); + shd_destroy_callgraph(ctx.graph); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/normalize.c b/src/shady/passes/normalize.c deleted file mode 100644 index c42d7a13d..000000000 --- a/src/shady/passes/normalize.c +++ /dev/null @@ -1,158 +0,0 @@ -#include "shady/ir.h" - -#include "log.h" -#include "portability.h" - -#include "../type.h" -#include "../rewrite.h" - -#include - -typedef struct Context_ { - Rewriter rewriter; - BodyBuilder* bb; -} Context; - -static const Node* process_node(Context* ctx, const Node* node); - -static const Node* force_to_be_value(Context* ctx, const Node* node) { - if (node == NULL) return NULL; - IrArena* a = ctx->rewriter.dst_arena; - - if (is_instruction(node)) { - const Node* let_bound; - let_bound = process_node(ctx, node); - return first(bind_instruction_outputs_count(ctx->bb, let_bound, 1, NULL, false)); - } - - switch (node->tag) { - // All decls map to refdecl/fnaddr - case Constant_TAG: - case GlobalVariable_TAG: { - return ref_decl_helper(a, process_node(ctx, node)); - } - case Function_TAG: { - return fn_addr_helper(a, process_node(ctx, node)); - } - case Variable_TAG: return find_processed(&ctx->rewriter, node); - default: - break; - } - - assert(is_value(node)); - const Node* value = process_node(ctx, node); - assert(is_value(value)); - return value; -} - -static const Node* process_op(Context* ctx, NodeClass op_class, SHADY_UNUSED String op_name, const Node* node) { - if (node == NULL) return NULL; - IrArena* a = ctx->rewriter.dst_arena; - switch (op_class) { - case NcType: { - switch (node->tag) { - case NominalType_TAG: { - return type_decl_ref(ctx->rewriter.dst_arena, (TypeDeclRef) { - .decl = process_node(ctx, node), - }); - } - default: break; - } - assert(is_type(node)); - const Node* type = process_node(ctx, node); - assert(is_type(type)); - return type; - } - case NcValue: - return force_to_be_value(ctx, node); - case NcVariable: - break; - case NcInstruction: { - if (is_instruction(node)) - return process_node(ctx, node); - const Node* val = force_to_be_value(ctx, node); - return quote_helper(a, singleton(val)); - } - case NcTerminator: - break; - case NcDeclaration: - break; - case NcCase: - break; - case NcBasic_block: - break; - case NcAnnotation: - break; - case NcJump: - break; - } - return process_node(ctx, node); -} - -static const Node* process_node(Context* ctx, const Node* node) { - if (node == NULL) return NULL; - - const Node* already_done = search_processed(&ctx->rewriter, node); - if (already_done) - return already_done; - - IrArena* a = ctx->rewriter.dst_arena; - - // add a builder to each abstraction... - switch (node->tag) { - case Function_TAG: { - Node* new = recreate_decl_header_identity(&ctx->rewriter, node); - BodyBuilder* bb = begin_body(a); - Context ctx2 = *ctx; - ctx2.bb = bb; - ctx2.rewriter.rewrite_fn = (RewriteNodeFn) process_node; - - new->payload.fun.body = finish_body(bb, rewrite_node(&ctx2.rewriter, node->payload.fun.body)); - return new; - } - case BasicBlock_TAG: { - Node* new = basic_block(a, (Node*) rewrite_node(&ctx->rewriter, node->payload.basic_block.fn), recreate_variables(&ctx->rewriter, node->payload.basic_block.params), node->payload.basic_block.name); - register_processed(&ctx->rewriter, node, new); - register_processed_list(&ctx->rewriter, node->payload.basic_block.params, new->payload.basic_block.params); - BodyBuilder* bb = begin_body(a); - Context ctx2 = *ctx; - ctx2.bb = bb; - ctx2.rewriter.rewrite_fn = (RewriteNodeFn) process_node; - new->payload.basic_block.body = finish_body(bb, rewrite_node(&ctx2.rewriter, node->payload.basic_block.body)); - return new; - } - case Case_TAG: { - Nodes new_params = recreate_variables(&ctx->rewriter, node->payload.case_.params); - register_processed_list(&ctx->rewriter, node->payload.case_.params, new_params); - BodyBuilder* bb = begin_body(a); - Context ctx2 = *ctx; - ctx2.bb = bb; - ctx2.rewriter.rewrite_fn = (RewriteNodeFn) process_node; - - const Node* new_body = finish_body(bb, rewrite_node(&ctx2.rewriter, node->payload.case_.body)); - return case_(a, new_params, new_body); - } - default: break; - } - - return recreate_node_identity(&ctx->rewriter, node); -} - -Module* normalize(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - aconfig.check_op_classes = true; - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) NULL), - .bb = NULL, - }; - - ctx.rewriter.config.search_map = false; - ctx.rewriter.config.write_map = false; - ctx.rewriter.rewrite_op_fn = (RewriteOpFn) process_op; - - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - return dst; -} diff --git a/src/shady/passes/normalize_builtins.c b/src/shady/passes/normalize_builtins.c index cd8e3edc9..9a20111e5 100644 --- a/src/shady/passes/normalize_builtins.c +++ b/src/shady/passes/normalize_builtins.c @@ -1,15 +1,14 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/builtin.h" + +#include "../ir_private.h" #include "log.h" #include "portability.h" -#include "../ir_private.h" -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" - typedef struct { Rewriter rewriter; + Node** builtins; } Context; static const Type* get_req_cast(Context* ctx, const Node* src) { @@ -18,45 +17,33 @@ static const Type* get_req_cast(Context* ctx, const Node* src) { switch (src->tag) { case GlobalVariable_TAG: { GlobalVariable global_variable = src->payload.global_variable; - const Node* ba = lookup_annotation_list(global_variable.annotations, "Builtin"); + const Node* ba = shd_lookup_annotation_list(global_variable.annotations, "Builtin"); if (ba) { - Builtin b = get_builtin_by_name(get_annotation_string_payload(ba)); + Builtin b = shd_get_builtin_by_name(shd_get_annotation_string_payload(ba)); assert(b != BuiltinsCount); - const Type* expected_t = get_builtin_type(a, b); - const Type* actual_t = rewrite_node(&ctx->rewriter, src)->payload.global_variable.type; + const Type* expected_t = shd_get_builtin_type(a, b); + const Type* actual_t = shd_rewrite_node(&ctx->rewriter, src)->payload.global_variable.type; if (expected_t != actual_t) { - log_string(INFO, "normalize_builtins: found builtin decl '%s' not matching expected type: '", global_variable.name); - log_node(INFO, expected_t); - log_string(INFO, "', got '"); - log_node(INFO, actual_t); - log_string(INFO, "'."); + shd_log_fmt(INFO, "normalize_builtins: found builtin decl '%s' not matching expected type: '", global_variable.name); + shd_log_node(INFO, expected_t); + shd_log_fmt(INFO, "', got '"); + shd_log_node(INFO, actual_t); + shd_log_fmt(INFO, "'."); return actual_t; } } break; } case RefDecl_TAG: return get_req_cast(ctx, src->payload.ref_decl.decl); - case Variable_TAG: { - const Node* abs = src->payload.var.abs; - if (abs) { - const Node* construct = abs->payload.case_.structured_construct; - if (construct && construct->tag == Let_TAG) { - return get_req_cast(ctx, construct->payload.let.instruction); - } + case PtrCompositeElement_TAG: { + const Type* src_req_cast = get_req_cast(ctx, src->payload.ptr_composite_element.ptr); + if (src_req_cast) { + bool u = shd_deconstruct_qualified_type(&src_req_cast); + shd_enter_composite_type(&src_req_cast, &u, src->payload.ptr_composite_element.index, false); + return src_req_cast; } break; } - case PrimOp_TAG: { - PrimOp prim_op = src->payload.prim_op; - if (prim_op.op == lea_op) { - const Type* src_req_cast = get_req_cast(ctx, first(prim_op.operands)); - if (src_req_cast) { - bool u = deconstruct_qualified_type(&src_req_cast); - enter_composite(&src_req_cast, &u, nodes(a, prim_op.operands.count - 2, &prim_op.operands.nodes[2]), false); - return src_req_cast; - } - } - } default: break; } @@ -64,58 +51,57 @@ static const Type* get_req_cast(Context* ctx, const Node* src) { } static const Node* process(Context* ctx, const Node* node) { - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; switch (node->tag) { case GlobalVariable_TAG: { GlobalVariable global_variable = node->payload.global_variable; - const Node* ba = lookup_annotation_list(global_variable.annotations, "Builtin"); + const Node* ba = shd_lookup_annotation_list(global_variable.annotations, "Builtin"); if (ba) { - Builtin b = get_builtin_by_name(get_annotation_string_payload(ba)); + Builtin b = shd_get_builtin_by_name(shd_get_annotation_string_payload(ba)); assert(b != BuiltinsCount); - const Type* t = get_builtin_type(a, b); - Node* ndecl = global_var(ctx->rewriter.dst_module, rewrite_nodes(&ctx->rewriter, global_variable.annotations), t, global_variable.name, global_variable.address_space); - register_processed(&ctx->rewriter, node, ndecl); + if (ctx->builtins[b]) + return ctx->builtins[b]; + const Type* t = shd_get_builtin_type(a, b); + Node* ndecl = global_var(r->dst_module, shd_rewrite_nodes(r, global_variable.annotations), t, global_variable.name, + shd_get_builtin_address_space(b)); + shd_register_processed(r, node, ndecl); // no 'init' for builtins, right ? assert(!global_variable.init); + ctx->builtins[b] = ndecl; return ndecl; } + break; } - case PrimOp_TAG: { - Op op = node->payload.prim_op.op; - switch (op) { - case load_op: { - const Type* req_cast = get_req_cast(ctx, first(node->payload.prim_op.operands)); - if (req_cast) { - assert(is_data_type(req_cast)); - BodyBuilder* bb = begin_body(a); - const Node* r = first(bind_instruction(bb, recreate_node_identity(&ctx->rewriter, node))); - const Node* r2 = first(gen_primop(bb, reinterpret_op, singleton(req_cast), singleton(r))); - return yield_values_and_wrap_in_block(bb, singleton(r2)); - } - } - default: break; + case Load_TAG: { + const Type* req_cast = get_req_cast(ctx, node->payload.load.ptr); + if (req_cast) { + assert(shd_is_data_type(req_cast)); + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, node->payload.load.mem)); + const Node* r1 = shd_bld_add_instruction(bb, shd_recreate_node(r, node)); + const Node* r2 = prim_op_helper(a, reinterpret_op, shd_singleton(req_cast), shd_singleton(r1)); + return shd_bld_to_instr_yield_values(bb, shd_singleton(r2)); } break; } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } -Module* normalize_builtins(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); +Module* shd_pass_normalize_builtins(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); aconfig.validate_builtin_types = true; - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .builtins = calloc(sizeof(Node*), BuiltinsCount) }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + free(ctx.builtins); return dst; } diff --git a/src/shady/passes/opt_demote_alloca.c b/src/shady/passes/opt_demote_alloca.c new file mode 100644 index 000000000..b70d0adf2 --- /dev/null +++ b/src/shady/passes/opt_demote_alloca.c @@ -0,0 +1,253 @@ +#include "shady/pass.h" +#include "shady/visit.h" +#include "shady/ir/cast.h" + +#include "../ir_private.h" +#include "../check.h" +#include "../analysis/uses.h" + +#include "log.h" +#include "portability.h" +#include "dict.h" +#include "util.h" + +#include + +typedef struct Context_ { + Rewriter rewriter; + bool disable_lowering; + + const UsesMap* uses; + const CompilerConfig* config; + Arena* arena; + struct Dict* alloca_info; + bool* todo; +} Context; + +typedef struct { + const Type* type; + /// Set when the alloca is used in a way the analysis cannot follow + /// Allocation must be left alone in such cases! + bool leaks; + /// Set when the alloca is read from. + bool read_from; + /// Set when the alloca is used in a manner forbidden by logical pointer rules + bool non_logical_use; + const Node* new; +} AllocaInfo; + +typedef struct { + AllocaInfo* src_alloca; +} PtrSourceKnowledge; + +static void visit_ptr_uses(const Node* ptr_value, const Type* slice_type, AllocaInfo* k, const UsesMap* map) { + const Type* ptr_type = ptr_value->type; + bool ptr_u = shd_deconstruct_qualified_type(&ptr_type); + assert(ptr_type->tag == PtrType_TAG); + + const Use* use = shd_get_first_use(map, ptr_value); + for (;use; use = use->next_use) { + if (is_abstraction(use->user) && use->operand_class == NcParam) + continue; + if (use->operand_class == NcMem) + continue; + else if (use->user->tag == Load_TAG) { + //if (get_pointer_type_element(ptr_type) != slice_type) + // k->reinterpreted = true; + k->read_from = true; + continue; // loads don't leak the address. + } else if (use->user->tag == Store_TAG) { + //if (get_pointer_type_element(ptr_type) != slice_type) + // k->reinterpreted = true; + // stores leak the value if it's stored + if (ptr_value == use->user->payload.store.value) + k->leaks = true; + continue; + } else if (use->user->tag == PrimOp_TAG) { + PrimOp payload = use->user->payload.prim_op; + switch (payload.op) { + case reinterpret_op: { + k->non_logical_use = true; + visit_ptr_uses(use->user, slice_type, k, map); + continue; + } + case convert_op: { + if (shd_first(payload.type_arguments)->tag == PtrType_TAG) { + k->non_logical_use = true; + visit_ptr_uses(use->user, slice_type, k, map); + } else { + k->leaks = true; + } + continue; + } + default: break; + } + if (shd_has_primop_got_side_effects(payload.op)) + k->leaks = true; + } /*else if (use->user->tag == Lea_TAG) { + // TODO: follow where those derived pointers are used and establish whether they leak themselves + // use slice_type to keep track of the expected type for the relevant sub-object + k->leaks = true; + continue; + } */else if (use->user->tag == Composite_TAG) { + // todo... + // note: if a composite literal containing our POI (pointer-of-interest) is extracted from, folding ops simplify this to the original POI + // so we don't need to be so clever here I think + k->leaks = true; + } else { + k->leaks = true; + } + } +} + +static PtrSourceKnowledge get_ptr_source_knowledge(Context* ctx, const Node* ptr) { + PtrSourceKnowledge k = { 0 }; + while (ptr) { + assert(is_value(ptr)); + const Node* instr = ptr; + switch (instr->tag) { + case StackAlloc_TAG: + case LocalAlloc_TAG: { + k.src_alloca = *shd_dict_find_value(const Node*, AllocaInfo*, ctx->alloca_info, instr); + return k; + } + case PrimOp_TAG: { + PrimOp payload = instr->payload.prim_op; + switch (payload.op) { + case convert_op: + case reinterpret_op: { + ptr = shd_first(payload.operands); + continue; + } + // TODO: lea and co + default: + break; + } + } + default: break; + } + + ptr = NULL; + } + return k; +} + +static const Node* handle_alloc(Context* ctx, const Node* old, const Type* old_type) { + IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + + const Node* nmem = shd_rewrite_node(r, old->tag == StackAlloc_TAG ? old->payload.stack_alloc.mem : old->payload.local_alloc.mem); + + AllocaInfo* k = shd_arena_alloc(ctx->arena, sizeof(AllocaInfo)); + *k = (AllocaInfo) { .type = shd_rewrite_node(r, old_type) }; + assert(ctx->uses); + visit_ptr_uses(old, old_type, k, ctx->uses); + shd_dict_insert(const Node*, AllocaInfo*, ctx->alloca_info, old, k); + // debugv_print("demote_alloca: uses analysis results for "); + // log_node(DEBUGV, old); + // debugv_print(": leaks=%d read_from=%d non_logical_use=%d\n", k->leaks, k->read_from, k->non_logical_use); + if (!k->leaks) { + if (!k->read_from/* this should include killing dead stores! */) { + *ctx->todo |= true; + const Node* new = undef(a, (Undef) { .type = shd_get_unqualified_type(shd_rewrite_node(r, old->type)) }); + new = mem_and_value(a, (MemAndValue) { .value = new, .mem = nmem }); + k->new = new; + return new; + } + if (!k->non_logical_use && shd_get_arena_config(a)->optimisations.weaken_non_leaking_allocas) { + *ctx->todo |= old->tag != LocalAlloc_TAG; + const Node* new = local_alloc(a, (LocalAlloc) { .type = shd_rewrite_node(r, old_type), .mem = nmem }); + k->new = new; + return new; + } + } + const Node* new = shd_recreate_node(r, old); + k->new = new; + return new; +} + +static const Node* process(Context* ctx, const Node* old) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + + switch (old->tag) { + case Function_TAG: { + Node* fun = shd_recreate_node_head(&ctx->rewriter, old); + Context fun_ctx = *ctx; + fun_ctx.uses = shd_new_uses_map_fn(old, (NcDeclaration | NcType)); + fun_ctx.disable_lowering = shd_lookup_annotation_with_string_payload(old, "DisableOpt", "demote_alloca"); + if (old->payload.fun.body) + shd_set_abstraction_body(fun, shd_rewrite_node(&fun_ctx.rewriter, old->payload.fun.body)); + shd_destroy_uses_map(fun_ctx.uses); + return fun; + } + case Constant_TAG: { + Context fun_ctx = *ctx; + fun_ctx.uses = NULL; + return shd_recreate_node(&fun_ctx.rewriter, old); + } + case Load_TAG: { + Load payload = old->payload.load; + shd_rewrite_node(r, payload.mem); + PtrSourceKnowledge k = get_ptr_source_knowledge(ctx, payload.ptr); + if (k.src_alloca) { + const Type* access_type = shd_get_pointer_type_element(shd_get_unqualified_type(shd_rewrite_node(r, payload.ptr->type))); + if (shd_is_reinterpret_cast_legal(access_type, k.src_alloca->type)) { + if (k.src_alloca->new == shd_rewrite_node(r, payload.ptr)) + break; + *ctx->todo |= true; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Node* data = shd_bld_load(bb, k.src_alloca->new); + data = shd_bld_reinterpret_cast(bb, access_type, data); + return shd_bld_to_instr_yield_value(bb, data); + } + } + break; + } + case Store_TAG: { + Store payload = old->payload.store; + shd_rewrite_node(r, payload.mem); + PtrSourceKnowledge k = get_ptr_source_knowledge(ctx, payload.ptr); + if (k.src_alloca) { + const Type* access_type = shd_get_pointer_type_element(shd_get_unqualified_type(shd_rewrite_node(r, payload.ptr->type))); + if (shd_is_reinterpret_cast_legal(access_type, k.src_alloca->type)) { + if (k.src_alloca->new == shd_rewrite_node(r, payload.ptr)) + break; + *ctx->todo |= true; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + const Node* data = shd_bld_reinterpret_cast(bb, access_type, shd_rewrite_node(r, payload.value)); + shd_bld_store(bb, k.src_alloca->new, data); + return shd_bld_to_instr_yield_values(bb, shd_empty(a)); + } + } + break; + } + case LocalAlloc_TAG: return handle_alloc(ctx, old, old->payload.local_alloc.type); + case StackAlloc_TAG: return handle_alloc(ctx, old, old->payload.stack_alloc.type); + default: break; + } + return shd_recreate_node(&ctx->rewriter, old); +} + +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node**, const Node**); + +bool shd_opt_demote_alloca(SHADY_UNUSED const CompilerConfig* config, Module** m) { + bool todo = false; + Module* src = *m; + IrArena* a = shd_module_get_arena(src); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config, + .arena = shd_new_arena(), + .alloca_info = shd_new_dict(const Node*, AllocaInfo*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .todo = &todo + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_dict(ctx.alloca_info); + shd_destroy_arena(ctx.arena); + *m = dst; + return todo; +} diff --git a/src/shady/passes/opt_inline.c b/src/shady/passes/opt_inline.c index 26fda811b..6cb4e099d 100644 --- a/src/shady/passes/opt_inline.c +++ b/src/shady/passes/opt_inline.c @@ -1,26 +1,25 @@ -#include "passes.h" +#include "shady/pass.h" + +#include "../ir_private.h" +#include "../analysis/callgraph.h" #include "dict.h" -#include "list.h" #include "portability.h" #include "util.h" #include "log.h" -#include "../rewrite.h" -#include "../type.h" -#include "../ir_private.h" - -#include "../analysis/scope.h" -#include "../analysis/callgraph.h" +typedef struct { + const Node* host_fn; + const Node* return_jp; +} InlinedCall; typedef struct { Rewriter rewriter; - Scope* scope; + const CompilerConfig* config; CallGraph* graph; const Node* old_fun; Node* fun; - bool allow_fn_inlining; - struct Dict* inlined_return_sites; + InlinedCall* inlined_call; } Context; static const Node* ignore_immediate_fn_addr(const Node* node) { @@ -31,9 +30,21 @@ static const Node* ignore_immediate_fn_addr(const Node* node) { } static bool is_call_potentially_inlineable(const Node* src_fn, const Node* dst_fn) { - if (lookup_annotation(src_fn, "Leaf")) + if (shd_lookup_annotation(src_fn, "Internal")) + return false; + if (shd_lookup_annotation(dst_fn, "NoInline")) + return false; + if (!dst_fn->payload.fun.body) return false; - if (lookup_annotation(dst_fn, "NoInline")) + return true; +} + +static bool is_call_safely_removable(const Node* fn) { + if (shd_lookup_annotation(fn, "Internal")) + return false; + if (shd_lookup_annotation(fn, "EntryPoint")) + return false; + if (shd_lookup_annotation(fn, "Exported")) return false; return true; } @@ -45,21 +56,19 @@ typedef struct { bool can_be_eliminated; } FnInliningCriteria; -static FnInliningCriteria get_inlining_heuristic(CGNode* fn_node) { +static FnInliningCriteria get_inlining_heuristic(const CompilerConfig* config, CGNode* fn_node) { FnInliningCriteria crit = { 0 }; CGEdge e; size_t i = 0; - while (dict_iter(fn_node->callers, &i, &e, NULL)) { + while (shd_dict_iter(fn_node->callers, &i, &e, NULL)) { crit.num_calls++; if (is_call_potentially_inlineable(e.src_fn->fn, e.dst_fn->fn)) crit.num_inlineable_calls++; } - debugv_print("%s has %d callers\n", get_abstraction_name(fn_node->fn), crit.num_calls); - // a function can be inlined if it has exactly one inlineable call... - if (crit.num_inlineable_calls == 1) + if (crit.num_inlineable_calls <= 1 || config->optimisations.inline_everything) crit.can_be_inlined = true; // avoid inlining recursive things for now @@ -74,29 +83,44 @@ static FnInliningCriteria get_inlining_heuristic(CGNode* fn_node) { if (fn_node->is_address_captured) crit.can_be_eliminated = false; + if (!is_call_safely_removable(fn_node->fn)) + crit.can_be_eliminated = false; + + shd_debugv_print("inlining heuristic for '%s': num_calls=%d num_inlineable_calls=%d safely_removable=%d address_leaks=%d recursive=%d inlineable=%d can_be_eliminated=%d\n", + shd_get_abstraction_name(fn_node->fn), + crit.num_calls, + crit.num_inlineable_calls, + is_call_safely_removable(fn_node->fn), + fn_node->is_address_captured, + fn_node->is_recursive, + crit.can_be_inlined, + crit.can_be_eliminated); + return crit; } /// inlines the abstraction with supplied arguments -static const Node* inline_call(Context* ctx, const Node* oabs, Nodes nargs, bool separate_scope) { - assert(is_abstraction(oabs)); +static const Node* inline_call(Context* ctx, const Node* ocallee, const Node* nmem, Nodes nargs, const Node* return_to) { + assert(is_abstraction(ocallee)); + shd_log_fmt(DEBUG, "Inlining '%s' inside '%s'\n", shd_get_abstraction_name(ocallee), shd_get_abstraction_name(ctx->fun)); Context inline_context = *ctx; - if (separate_scope) - inline_context.rewriter.map = clone_dict(inline_context.rewriter.map); - Nodes oparams = get_abstraction_params(oabs); - register_processed_list(&inline_context.rewriter, oparams, nargs); + inline_context.rewriter.map = shd_clone_dict(inline_context.rewriter.map); - if (oabs->tag == Function_TAG) - inline_context.scope = new_scope(oabs); + ctx = &inline_context; + InlinedCall inlined_call = { + .host_fn = ctx->fun, + .return_jp = return_to, + }; + inline_context.inlined_call = &inlined_call; - const Node* nbody = rewrite_node(&inline_context.rewriter, get_abstraction_body(oabs)); + Nodes oparams = get_abstraction_params(ocallee); + shd_register_processed_list(&inline_context.rewriter, oparams, nargs); + shd_register_processed(&inline_context.rewriter, shd_get_abstraction_mem(ocallee), nmem); - if (oabs->tag == Function_TAG) - destroy_scope(inline_context.scope); + const Node* nbody = shd_rewrite_node(&inline_context.rewriter, get_abstraction_body(ocallee)); - if (separate_scope) - destroy_dict(inline_context.rewriter.map); + shd_destroy_dict(inline_context.rewriter.map); assert(is_terminator(nbody)); return nbody; @@ -104,175 +128,122 @@ static const Node* inline_call(Context* ctx, const Node* oabs, Nodes nargs, bool static const Node* process(Context* ctx, const Node* node) { IrArena* a = ctx->rewriter.dst_arena; - if (!node) return NULL; + Rewriter* r = &ctx->rewriter; assert(a != node->arena); assert(node->arena == ctx->rewriter.src_arena); - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - switch (node->tag) { case Function_TAG: { if (ctx->graph) { - CGNode* fn_node = *find_value_dict(const Node*, CGNode*, ctx->graph->fn2cgn, node); - if (get_inlining_heuristic(fn_node).can_be_eliminated) { - debugv_print("Eliminating %s because it has exactly one caller\n", get_abstraction_name(fn_node->fn)); + CGNode* fn_node = *shd_dict_find_value(const Node*, CGNode*, ctx->graph->fn2cgn, node); + if (get_inlining_heuristic(ctx->config, fn_node).can_be_eliminated) { + shd_debugv_print("Eliminating %s because it has exactly one caller\n", shd_get_abstraction_name(fn_node->fn)); return NULL; } } - Nodes annotations = rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); - Node* new = function(ctx->rewriter.dst_module, recreate_variables(&ctx->rewriter, node->payload.fun.params), node->payload.fun.name, annotations, rewrite_nodes(&ctx->rewriter, node->payload.fun.return_types)); - for (size_t i = 0; i < new->payload.fun.params.count; i++) - register_processed(&ctx->rewriter, node->payload.fun.params.nodes[i], new->payload.fun.params.nodes[i]); - register_processed(&ctx->rewriter, node, new); + Nodes annotations = shd_rewrite_nodes(&ctx->rewriter, node->payload.fun.annotations); + Node* new = function(ctx->rewriter.dst_module, shd_recreate_params(&ctx->rewriter, node->payload.fun.params), node->payload.fun.name, annotations, shd_rewrite_nodes(&ctx->rewriter, node->payload.fun.return_types)); + shd_register_processed(r, node, new); Context fn_ctx = *ctx; - Scope* scope = new_scope(node); - fn_ctx.rewriter.map = clone_dict(fn_ctx.rewriter.map); - fn_ctx.scope = scope; + fn_ctx.rewriter.map = shd_clone_dict(fn_ctx.rewriter.map); fn_ctx.old_fun = node; fn_ctx.fun = new; - recreate_decl_body_identity(&fn_ctx.rewriter, node, new); - destroy_dict(fn_ctx.rewriter.map); - destroy_scope(scope); + fn_ctx.inlined_call = NULL; + for (size_t i = 0; i < new->payload.fun.params.count; i++) + shd_register_processed(&fn_ctx.rewriter, node->payload.fun.params.nodes[i], new->payload.fun.params.nodes[i]); + shd_recreate_node_body(&fn_ctx.rewriter, node, new); + shd_destroy_dict(fn_ctx.rewriter.map); return new; } - case Jump_TAG: { - const Node* otarget = node->payload.jump.target; - assert(otarget && otarget->tag == BasicBlock_TAG); - assert(otarget->payload.basic_block.fn == ctx->scope->entry->node); - CFNode* cfnode = scope_lookup(ctx->scope, otarget); - assert(cfnode); - size_t preds_count = entries_count_list(cfnode->pred_edges); - assert(preds_count > 0 && "this CFG looks broken"); - if (preds_count == 1) { - debugv_print("Inlining jump to %s inside function %s\n", get_abstraction_name(otarget), get_abstraction_name(ctx->old_fun)); - Nodes nargs = rewrite_nodes(&ctx->rewriter, node->payload.jump.args); - return inline_call(ctx, otarget, nargs, false); - } - break; - } - // do not inline jumps in branches - case Branch_TAG: { - return branch(a, (Branch) { - .branch_condition = rewrite_node(&ctx->rewriter, node->payload.branch.branch_condition), - .true_jump = recreate_node_identity(&ctx->rewriter, node->payload.branch.true_jump), - .false_jump = recreate_node_identity(&ctx->rewriter, node->payload.branch.false_jump), - }); - } - case Switch_TAG: { - return br_switch(a, (Switch) { - .switch_value = rewrite_node(&ctx->rewriter, node->payload.br_switch.switch_value), - .case_values = rewrite_nodes(&ctx->rewriter, node->payload.br_switch.case_values), - .case_jumps = rewrite_nodes_with_fn(&ctx->rewriter, node->payload.br_switch.case_jumps, recreate_node_identity), - .default_jump = recreate_node_identity(&ctx->rewriter, node->payload.br_switch.default_jump), - }); - } case Call_TAG: { if (!ctx->graph) break; - const Node* ocallee = node->payload.call.callee; - Nodes oargs = node->payload.call.args; + Call payload = node->payload.call; + const Node* ocallee = payload.callee; ocallee = ignore_immediate_fn_addr(ocallee); if (ocallee->tag == Function_TAG) { - CGNode* fn_node = *find_value_dict(const Node*, CGNode*, ctx->graph->fn2cgn, ocallee); - if (get_inlining_heuristic(fn_node).can_be_inlined && is_call_potentially_inlineable(ctx->old_fun, ocallee)) { - debugv_print("Inlining call to %s\n", get_abstraction_name(ocallee)); - Nodes nargs = rewrite_nodes(&ctx->rewriter, oargs); + CGNode* fn_node = *shd_dict_find_value(const Node*, CGNode*, ctx->graph->fn2cgn, ocallee); + if (get_inlining_heuristic(ctx->config, fn_node).can_be_inlined && is_call_potentially_inlineable(ctx->old_fun, ocallee)) { + shd_debugv_print("Inlining call to %s\n", shd_get_abstraction_name(ocallee)); + Nodes nargs = shd_rewrite_nodes(&ctx->rewriter, payload.args); // Prepare a join point to replace the old function return - Nodes nyield_types = strip_qualifiers(a, rewrite_nodes(&ctx->rewriter, ocallee->payload.fun.return_types)); + Nodes nyield_types = shd_strip_qualifiers(a, shd_rewrite_nodes(&ctx->rewriter, ocallee->payload.fun.return_types)); const Type* jp_type = join_point_type(a, (JoinPointType) { .yield_types = nyield_types }); - const Node* join_point = var(a, qualified_type_helper(jp_type, true), format_string_arena(a->arena, "inlined_return_%s", get_abstraction_name(ocallee))); - insert_dict_and_get_result(const Node*, const Node*, ctx->inlined_return_sites, ocallee, join_point); + const Node* join_point = param(a, shd_as_qualified_type(jp_type, true), shd_format_string_arena(a->arena, "inlined_return_%s", shd_get_abstraction_name(ocallee))); - const Node* nbody = inline_call(ctx, ocallee, nargs, true); + Node* control_case = case_(a, shd_singleton(join_point)); + const Node* nbody = inline_call(ctx, ocallee, shd_get_abstraction_mem(control_case), nargs, join_point); + shd_set_abstraction_body(control_case, nbody); - remove_dict(const Node*, ctx->inlined_return_sites, ocallee); - - return control(a, (Control) { - .yield_types = nyield_types, - .inside = case_(a, singleton(join_point), nbody) - }); + BodyBuilder* bb = shd_bld_begin_pseudo_instr(a, shd_rewrite_node(r, payload.mem)); + return shd_bld_to_instr_yield_values(bb, shd_bld_control(bb, nyield_types, control_case)); } } break; } + /*case BasicBlock_TAG: { + Nodes nparams = recreate_params(r, get_abstraction_params(node)); + register_processed_list(r, get_abstraction_params(node), nparams); + Node* bb = basic_block(a, nparams, get_abstraction_name(node)); + register_processed(r, node, bb); + set_abstraction_body(bb, rewrite_node(r, get_abstraction_body(node))); + return bb; + }*/ case Return_TAG: { - const Node** p_ret_jp = find_value_dict(const Node*, const Node*, ctx->inlined_return_sites, node); - if (p_ret_jp) - return join(a, (Join) { .join_point = *p_ret_jp, .args = rewrite_nodes(&ctx->rewriter, node->payload.fn_ret.args )}); + Return payload = node->payload.fn_ret; + if (ctx->inlined_call) + return join(a, (Join) { .mem = shd_rewrite_node(r, payload.mem), .join_point = ctx->inlined_call->return_jp, .args = shd_rewrite_nodes(r, payload.args)}); break; } case TailCall_TAG: { if (!ctx->graph) break; - const Node* ocallee = node->payload.tail_call.target; + const Node* ocallee = node->payload.tail_call.callee; ocallee = ignore_immediate_fn_addr(ocallee); if (ocallee->tag == Function_TAG) { - CGNode* fn_node = *find_value_dict(const Node*, CGNode*, ctx->graph->fn2cgn, ocallee); - if (get_inlining_heuristic(fn_node).can_be_inlined) { - debugv_print("Inlining tail call to %s\n", get_abstraction_name(ocallee)); - Nodes nargs = rewrite_nodes(&ctx->rewriter, node->payload.tail_call.args); - - return inline_call(ctx, ocallee, nargs, true); + CGNode* fn_node = *shd_dict_find_value(const Node*, CGNode*, ctx->graph->fn2cgn, ocallee); + if (get_inlining_heuristic(ctx->config, fn_node).can_be_inlined) { + shd_debugv_print("Inlining tail call to %s\n", shd_get_abstraction_name(ocallee)); + Nodes nargs = shd_rewrite_nodes(&ctx->rewriter, node->payload.tail_call.args); + return inline_call(ctx, ocallee, shd_rewrite_node(r, node->payload.tail_call.mem), nargs, NULL); } } break; } - case BasicBlock_TAG: { - Nodes params = recreate_variables(&ctx->rewriter, node->payload.basic_block.params); - register_processed_list(&ctx->rewriter, node->payload.basic_block.params, params); - Node* bb = basic_block(a, (Node*) ctx->fun, params, node->payload.basic_block.name); - register_processed(&ctx->rewriter, node, bb); - bb->payload.basic_block.body = process(ctx, node->payload.basic_block.body); - return bb; - } default: break; } - const Node* new = recreate_node_identity(&ctx->rewriter, node); - if (node->tag == Case_TAG) - register_processed(&ctx->rewriter, node, new); - return new; + return shd_recreate_node(&ctx->rewriter, node); } -KeyHash hash_node(const Node**); -bool compare_node(const Node**, const Node**); +KeyHash shd_hash_node(const Node**); +bool shd_compare_node(const Node**, const Node**); -void opt_simplify_cf(SHADY_UNUSED const CompilerConfig* config, Module* src, Module* dst, bool allow_fn_inlining) { +static void simplify_cf(const CompilerConfig* config, Module* src, Module* dst) { Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config, .graph = NULL, - .scope = NULL, .fun = NULL, - .inlined_return_sites = new_dict(const Node*, CGNode*, (HashFn) hash_node, (CmpFn) compare_node), + .inlined_call = NULL, }; - if (allow_fn_inlining) - ctx.graph = new_callgraph(src); + ctx.graph = shd_new_callgraph(src); - rewrite_module(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); if (ctx.graph) - destroy_callgraph(ctx.graph); - - destroy_rewriter(&ctx.rewriter); - destroy_dict(ctx.inlined_return_sites); -} + shd_destroy_callgraph(ctx.graph); -Module* opt_inline_jumps(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - opt_simplify_cf(config, src, dst, false); - return dst; + shd_destroy_rewriter(&ctx.rewriter); } -Module* opt_inline(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - opt_simplify_cf(config, src, dst, true); +Module* shd_pass_inline(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + simplify_cf(config, src, dst); return dst; } diff --git a/src/shady/passes/opt_mem2reg.c b/src/shady/passes/opt_mem2reg.c index 8ba377177..e0b178cb8 100644 --- a/src/shady/passes/opt_mem2reg.c +++ b/src/shady/passes/opt_mem2reg.c @@ -1,561 +1,134 @@ -#include "passes.h" +#include "shady/pass.h" -#include "portability.h" -#include "dict.h" -#include "arena.h" -#include "log.h" - -#include "../analysis/scope.h" -#include "../analysis/uses.h" -#include "../analysis/leak.h" -#include "../analysis/verify.h" - -#include "../transform/ir_gen_helpers.h" - -#include "../rewrite.h" -#include "../visit.h" -#include "../type.h" - -typedef struct { - AddressSpace as; - bool leaks; - bool read_from; - const Type* type; -} PtrSourceKnowledge; - -typedef struct { - const Node* ptr_address; - const Node* ptr_value; - bool ptr_has_leaked; - PtrSourceKnowledge* source; -} PtrKnowledge; - -typedef struct KB KnowledgeBase; +#include "../ir_private.h" +#include "../analysis/cfg.h" -struct KB { - CFNode* cfnode; - // when the associated node has exactly one parent edge, we can safely assume what held true - // for it will hold true for this one too, unless we have conflicting information - const KnowledgeBase* dominator_kb; - struct Dict* map; - struct Dict* potential_additional_params; - Arena* a; -}; +#include "list.h" +#include "portability.h" typedef struct { Rewriter rewriter; - Scope* scope; - const UsesMap* scope_uses; - struct Dict* abs_to_kb; - const Node* abs; - Arena* a; - - struct Dict* bb_new_args; + CFG* cfg; + bool* todo; } Context; -static PtrKnowledge* get_last_valid_ptr_knowledge(const KnowledgeBase* kb, const Node* n) { - PtrKnowledge** found = find_value_dict(const Node*, PtrKnowledge*, kb->map, n); - if (found) - return *found; - PtrKnowledge* k = NULL; - if (kb->dominator_kb) - k = get_last_valid_ptr_knowledge(kb->dominator_kb, n); - return k; -} - -static PtrKnowledge* create_ptr_knowledge(KnowledgeBase* kb, const Node* instruction, const Node* address_value) { - PtrKnowledge* k = arena_alloc(kb->a, sizeof(PtrKnowledge)); - PtrSourceKnowledge* sk = arena_alloc(kb->a, sizeof(PtrSourceKnowledge)); - *k = (PtrKnowledge) { .source = sk, .ptr_address = address_value }; - *sk = (PtrSourceKnowledge) { 0 }; - bool fresh = insert_dict(const Node*, PtrKnowledge*, kb->map, instruction, k); - assert(fresh); - return k; -} - -static PtrKnowledge* update_ptr_knowledge(KnowledgeBase* kb, const Node* n, PtrKnowledge* existing) { - PtrKnowledge* k = arena_alloc(kb->a, sizeof(PtrKnowledge)); - *k = *existing; // copy the data - bool fresh = insert_dict(const Node*, PtrKnowledge*, kb->map, n, k); - assert(fresh); - return k; -} - -static void insert_ptr_knowledge(KnowledgeBase* kb, const Node* n, PtrKnowledge* k) { - PtrKnowledge** found = find_value_dict(const Node*, PtrKnowledge*, kb->map, n); - assert(!found); - insert_dict(const Node*, PtrKnowledge*, kb->map, n, k); -} - -static const Node* get_known_value(Rewriter* r, const PtrKnowledge* k) { - const Node* v = NULL; - if (k && !k->ptr_has_leaked && !k->source->leaks) { - if (k->ptr_value) { - v = k->ptr_value; - } - } - if (r && v && v->arena != r->dst_arena) - return rewrite_node(r, v); - return v; -} - -static const Node* get_known_address(Rewriter* r, const PtrKnowledge* k) { - const Node* v = NULL; - if (k) { - if (k->ptr_address) { - v = k->ptr_address; - } - } - if (r && v && v->arena != r->dst_arena) - return rewrite_node(r, v); - return v; -} +typedef struct { + const Node* src; + Nodes indices; +}; -static void visit_ptr_uses(const Node* ptr_value, PtrSourceKnowledge* k, const UsesMap* map) { - const Use* use = get_first_use(map, ptr_value); - for (;use; use = use->next_use) { - if (is_abstraction(use->user) && use->operand_class == NcVariable) - continue; - else if (use->user->tag == Let_TAG && use->operand_class == NcInstruction) { - Nodes vars = get_abstraction_params(get_let_tail(use->user)); - for (size_t i = 0; i < vars.count; i++) { - debugv_print("mem2reg leak analysis: following let-bound variable: "); - log_node(DEBUGV, vars.nodes[i]); - debugv_print(".\n"); - visit_ptr_uses(vars.nodes[i], k, map); +static const Node* get_ptr_source(const Node* ptr) { + IrArena* a = ptr->arena; + while (true) { + switch (ptr->tag) { + case PtrCompositeElement_TAG: { + PtrCompositeElement payload = ptr->payload.ptr_composite_element; + ptr = payload.ptr; + break; } - } else if (use->user->tag == PrimOp_TAG) { - PrimOp payload = use->user->payload.prim_op; - switch (payload.op) { - case load_op: { - k->read_from = true; - continue; // loads don't leak the address. - } - case store_op: { - // stores leak the value if it's stored - if (ptr_value == payload.operands.nodes[1]) - k->leaks = true; - continue; - } - case reinterpret_op: { - debugv_print("mem2reg leak analysis: following reinterpret instr: "); - log_node(DEBUGV, use->user); - debugv_print(".\n"); - visit_ptr_uses(use->user, k, map); - continue; - } - case lea_op: - case convert_op: { - //TODO: follow where those derived pointers are used and establish whether they leak themselves - k->leaks = true; - continue; - } default: break; + case PtrArrayElementOffset_TAG: { + PtrArrayElementOffset payload = ptr->payload.ptr_array_element_offset; + ptr = payload.ptr; + break; } - if (has_primop_got_side_effects(payload.op)) - k->leaks = true; - } else if (use->user->tag == Composite_TAG) { - // todo... - k->leaks = true; - } else { - k->leaks = true; - } - } -} - -static void visit_instruction(Context* ctx, KnowledgeBase* kb, const Node* instruction, Nodes results) { - IrArena* a = instruction->arena; - switch (is_instruction(instruction)) { - case NotAnInstruction: assert(is_instruction(instruction)); - case Instruction_Call_TAG: - break; - case Instruction_PrimOp_TAG: { - PrimOp payload = instruction->payload.prim_op; - switch (payload.op) { - case alloca_logical_op: - case alloca_op: { - const Node* optr = first(results); - PtrKnowledge* k = create_ptr_knowledge(kb, instruction, optr); - visit_ptr_uses(optr, k->source, ctx->scope_uses); - debugv_print("mem2reg: "); - log_node(DEBUGV, optr); - if (k->source->leaks) - debugv_print(" is leaking so it will not be eliminated.\n"); - else - debugv_print(" was found to not leak.\n"); - const Type* t = instruction->type; - bool u = deconstruct_qualified_type(&t); - assert(t->tag == PtrType_TAG); - k->source->as = t->payload.ptr_type.address_space; - deconstruct_pointer_type(&t); - k->source->type = qualified_type_helper(t, u); - - insert_ptr_knowledge(kb, optr, k); - k->ptr_value = undef(a, (Undef) { .type = first(payload.type_arguments) }); - break; - } - case load_op: { - const Node* ptr = first(payload.operands); - const PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, ptr); - if (!k || !k->ptr_value) { - const KnowledgeBase* phi_kb = kb; - while (phi_kb->dominator_kb) { - phi_kb = phi_kb->dominator_kb; + case PrimOp_TAG: { + PrimOp payload = ptr->payload.prim_op; + switch (payload.op) { + case reinterpret_op: + case convert_op: { + const Node* src = shd_first(payload.operands); + if (shd_get_unqualified_type(src->type)->tag == PtrType_TAG) { + ptr = src; + continue; } - debug_print("mem2reg: It'd sure be nice to know the value of "); - log_node(DEBUG, first(payload.operands)); - debug_print(" at phi-like node %s.\n", get_abstraction_name(phi_kb->cfnode->node)); - // log_node(DEBUG, phi_location->node); - insert_set_get_key(const Node*, phi_kb->potential_additional_params, ptr); - } - break; - } - case store_op: { - const Node* ptr = first(payload.operands); - PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, ptr); - if (k) { - k = update_ptr_knowledge(kb, ptr, k); - k->ptr_value = payload.operands.nodes[1]; - } - break; // let's take care of dead stores another time - } - case reinterpret_op: { - // if we have knowledge on a particular ptr, the same knowledge propagates if we bitcast it! - PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, first(payload.operands)); - if (k) { - k = update_ptr_knowledge(kb, instruction, k); - k->ptr_address = first(results); - insert_ptr_knowledge(kb, first(results), k); + break; } - break; + default: break; } - case convert_op: { - // if we convert a pointer to generic AS, we'd like to use the old address instead where possible - PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, first(payload.operands)); - if (k) { - debug_print("mem2reg: the converted ptr "); - log_node(DEBUG, first(results)); - debug_print(" is the same as "); - log_node(DEBUG, first(payload.operands)); - debug_print(".\n"); - k = update_ptr_knowledge(kb, instruction, k); - k->ptr_address = first(payload.operands); - insert_ptr_knowledge(kb, first(results), k); - } - break; - } - default: break; + break; } - break; - } - case Instruction_Control_TAG: - break; - case Instruction_Block_TAG: - break; - case Instruction_Comment_TAG: - break; - case Instruction_If_TAG: - case Instruction_Match_TAG: - case Instruction_Loop_TAG: - assert(false && "unsupported"); - break; - } -} - -static void visit_terminator(Context* ctx, KnowledgeBase* kb, const Node* old) { - if (!old) - return; - switch (is_terminator(old)) { - case Terminator_LetMut_TAG: - case NotATerminator: assert(false); - case Terminator_Let_TAG: { - const Node* otail = get_let_tail(old); - visit_instruction(ctx, kb, get_let_instruction(old), get_abstraction_params(otail)); - break; + default: break; } - default: - break; - } -} - -KeyHash hash_node(const Node**); -bool compare_node(const Node**, const Node**); - -static void destroy_kb(KnowledgeBase* kb) { - destroy_dict(kb->map); - destroy_dict(kb->potential_additional_params); -} - -static KnowledgeBase* get_kb(Context* ctx, const Node* abs) { - KnowledgeBase** found = find_value_dict(const Node*, KnowledgeBase*, ctx->abs_to_kb, abs); - assert(found); - return *found; -} - -static void visit_cfnode(Context* ctx, CFNode* node, CFNode* dominator) { - const Node* oabs = node->node; - KnowledgeBase* kb = arena_alloc(ctx->a, sizeof(KnowledgeBase)); - *kb = (KnowledgeBase) { - .cfnode = node, - .a = ctx->a, - .map = new_dict(const Node*, PtrKnowledge*, (HashFn) hash_node, (CmpFn) compare_node), - .potential_additional_params = new_set(const Node*, (HashFn) hash_node, (CmpFn) compare_node), - .dominator_kb = NULL, - }; - if (entries_count_list(node->pred_edges) == 1) { - assert(dominator); - CFEdge edge = read_list(CFEdge, node->pred_edges)[0]; - assert(edge.dst == node); - assert(edge.src == dominator); - const KnowledgeBase* parent_kb = get_kb(ctx, dominator->node); - assert(parent_kb->map); - kb->dominator_kb = parent_kb; - } - assert(kb->map); - insert_dict(const Node*, KnowledgeBase*, ctx->abs_to_kb, node->node, kb); - assert(is_abstraction(oabs)); - visit_terminator(ctx, kb, get_abstraction_body(oabs)); - - for (size_t i = 0; i < entries_count_list(node->dominates); i++) { - CFNode* dominated = read_list(CFNode*, node->dominates)[i]; - visit_cfnode(ctx, dominated, node); + return ptr; } } -static const Node* process(Context* ctx, const Node* old) { - assert(old); - Context fn_ctx = *ctx; - if (old->tag == Function_TAG && !lookup_annotation(old, "Internal")) { - ctx = &fn_ctx; - fn_ctx.scope = new_scope(old); - fn_ctx.scope_uses = create_uses_map(old, (NcDeclaration | NcType)); - fn_ctx.abs_to_kb = new_dict(const Node*, KnowledgeBase**, (HashFn) hash_node, (CmpFn) compare_node); - visit_cfnode(&fn_ctx, fn_ctx.scope->entry, NULL); - fn_ctx.abs = old; - const Node* new_fn = recreate_node_identity(&fn_ctx.rewriter, old); - destroy_scope(fn_ctx.scope); - destroy_uses_map(fn_ctx.scope_uses); - size_t i = 0; - KnowledgeBase* kb; - while (dict_iter(fn_ctx.abs_to_kb, &i, NULL, &kb)) { - destroy_kb(kb); - } - destroy_dict(fn_ctx.abs_to_kb); - return new_fn; - } else if (is_abstraction(old)) { - fn_ctx.abs = old; - ctx = &fn_ctx; - } - - KnowledgeBase* kb = NULL; - if (ctx->abs && ctx->abs_to_kb) { - kb = get_kb(ctx, ctx->abs); - assert(kb); - } - if (!kb) - return recreate_node_identity(&ctx->rewriter, old); - - IrArena* a = ctx->rewriter.dst_arena; - - switch (old->tag) { - case PrimOp_TAG: { - PrimOp payload = old->payload.prim_op; - switch (payload.op) { - case load_op: { - const Node* ptr = first(payload.operands); - PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, ptr); - const Node* known_value = get_known_value(&ctx->rewriter, k); - if (known_value) { - const Type* known_value_t = known_value->type; - bool kv_u = deconstruct_qualified_type(&known_value_t); - - const Type* load_result_t = ptr->type; - bool lrt_u = deconstruct_qualified_type(&load_result_t); - deconstruct_pointer_type(&load_result_t); - assert(!lrt_u || kv_u); - if (is_reinterpret_cast_legal(load_result_t, known_value_t)) - return prim_op_helper(a, reinterpret_op, singleton(rewrite_node(&ctx->rewriter, load_result_t)), singleton(known_value)); - } - const Node* other_ptr = get_known_address(&ctx->rewriter, k); - if (other_ptr && ptr != other_ptr) { - return prim_op_helper(a, load_op, empty(a), singleton(other_ptr)); - } - break; - } - case store_op: { - const Node* ptr = first(payload.operands); - const PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, ptr); - if (k && !k->source->leaks && !k->source->read_from) - return quote_helper(a, empty(a)); - const Node* other_ptr = get_known_address(&ctx->rewriter, k); - if (other_ptr && ptr != other_ptr) { - return prim_op_helper(a, store_op, empty(a), mk_nodes(a, other_ptr, rewrite_node(&ctx->rewriter, payload.operands.nodes[1]))); - } - break; - } - case alloca_op: { - const PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, old); - if (k && !k->source->leaks && !k->source->read_from) - return quote_helper(a, singleton(undef(a, (Undef) { .type = get_unqualified_type(rewrite_node(&ctx->rewriter, old->type)) }))); - break; - } - default: break; - } - break; - } - case BasicBlock_TAG: { - CFNode* cfnode = scope_lookup(ctx->scope, old); - size_t i = 0; - const Node* ptr; - Nodes params = recreate_variables(&ctx->rewriter, get_abstraction_params(old)); - register_processed_list(&ctx->rewriter, get_abstraction_params(old), params); - Nodes ptrs = empty(ctx->rewriter.src_arena); - while (dict_iter(kb->potential_additional_params, &i, &ptr, NULL)) { - PtrSourceKnowledge* source = NULL; - PtrKnowledge uk = { 0 }; - // check if all the edges have a value for this! - for (size_t j = 0; j < entries_count_list(cfnode->pred_edges); j++) { - CFEdge edge = read_list(CFEdge, cfnode->pred_edges)[j]; - if (edge.type == StructuredPseudoExitEdge) - continue; // these are not real edges... - KnowledgeBase* kb_at_src = get_kb(ctx, edge.src->node); - - if (get_known_value(NULL, get_last_valid_ptr_knowledge(kb_at_src, ptr))) { - log_node(DEBUG, ptr); - debug_print(" has a known value in %s ...\n", get_abstraction_name(edge.src->node)); - } else - goto next_potential_param; - - PtrKnowledge* k = get_last_valid_ptr_knowledge(kb_at_src, ptr); - if (!source) - source = k->source; - else - assert(source == k->source); - - const Type* kv_type = get_known_value(NULL, get_last_valid_ptr_knowledge(kb_at_src, ptr))->type; - deconstruct_qualified_type(&kv_type); - const Type* alloca_type_t = source->type; - deconstruct_qualified_type(&alloca_type_t); - if (kv_type != source->type && !is_reinterpret_cast_legal(kv_type, alloca_type_t)) { - log_node(DEBUG, ptr); - debug_print(" has a known value in %s, but it's type ", get_abstraction_name(edge.src->node)); - log_node(DEBUG, kv_type); - debug_print(" cannot be reinterpreted into the alloca type "); - log_node(DEBUG, source->type); - debug_print("\n."); - goto next_potential_param; - } - - uk.ptr_has_leaked |= k->ptr_has_leaked; +static const Node* get_last_stored_value(Context* ctx, const Node* ptr, const Node* mem, const Type* expected_type) { + const Node* ptr_source = get_ptr_source(ptr); + while (mem) { + switch (mem->tag) { + case AbsMem_TAG: { + const Node* abs = mem->payload.abs_mem.abs; + CFNode* n = shd_cfg_lookup(ctx->cfg, abs); + if (shd_list_count(n->pred_edges) == 1) { + CFEdge e = shd_read_list(CFEdge, n->pred_edges)[0]; + mem = get_terminator_mem(e.terminator); + continue; } - - log_node(DEBUG, ptr); - debug_print(" has a known value in all predecessors! Turning it into a new parameter.\n"); - - const Node* param = var(a, rewrite_node(&ctx->rewriter, source->type), unique_name(a, "ssa_phi")); - params = append_nodes(a, params, param); - ptrs = append_nodes(ctx->rewriter.src_arena, ptrs, ptr); - - PtrKnowledge* k = arena_alloc(ctx->a, sizeof(PtrKnowledge)); - *k = (PtrKnowledge) { - .ptr_value = param, - .source = source, - .ptr_has_leaked = uk.ptr_has_leaked - }; - insert_ptr_knowledge(kb, ptr, k); - - next_potential_param: continue; + break; } - - if (ptrs.count > 0) { - insert_dict(const Node*, Nodes, ctx->bb_new_args, old, ptrs); + case Store_TAG: { + Store payload = mem->payload.store; + if (payload.ptr == ptr) + return payload.value; + if (get_ptr_source(payload.ptr) == ptr_source) + return NULL; + break; } - - Node* fn = (Node*) rewrite_node(&ctx->rewriter, ctx->scope->entry->node); - Node* bb = basic_block(a, fn, params, get_abstraction_name(old)); - register_processed(&ctx->rewriter, old, bb); - bb->payload.basic_block.body = rewrite_node(&ctx->rewriter, get_abstraction_body(old)); - return bb; + default: break; } - case Jump_TAG: { - const Node* new_bb = rewrite_node(&ctx->rewriter, old->payload.jump.target); - Nodes args = rewrite_nodes(&ctx->rewriter, old->payload.jump.args); - - Nodes* additional_ssa_params = find_value_dict(const Node*, Nodes, ctx->bb_new_args, old->payload.jump.target); - if (additional_ssa_params) { - assert(additional_ssa_params->count > 0); - - LARRAY(const Type*, tr_params_arr, args.count); - for (size_t i = 0; i < args.count; i++) - tr_params_arr[i] = var(a, args.nodes[i]->type, args.nodes[i]->payload.var.name); - Nodes tr_params = nodes(a, args.count, tr_params_arr); - Node* fn = (Node*) rewrite_node(&ctx->rewriter, ctx->scope->entry->node); - Node* trampoline = basic_block(a, fn, tr_params, format_string_interned(a, "%s_trampoline", get_abstraction_name(new_bb))); - Nodes tr_args = args; - BodyBuilder* bb = begin_body(a); - - for (size_t i = 0; i < additional_ssa_params->count; i++) { - const Node* ptr = additional_ssa_params->nodes[i]; - PtrKnowledge* k = get_last_valid_ptr_knowledge(kb, ptr); - const Node* value = get_known_value(&ctx->rewriter, k); - - const Type* known_value_t = value->type; - deconstruct_qualified_type(&known_value_t); - - const Type* alloca_type_t = k->source->type; - deconstruct_qualified_type(&alloca_type_t); - - if (alloca_type_t != known_value_t && is_reinterpret_cast_legal(alloca_type_t, known_value_t)) - value = first(gen_primop(bb, reinterpret_op, singleton(rewrite_node(&ctx->rewriter, alloca_type_t)), singleton(value))); - - assert(value); - args = append_nodes(a, args, value); - } - - trampoline->payload.basic_block.body = finish_body(bb, jump_helper(a, new_bb, args)); + mem = shd_get_parent_mem(mem); + } + return NULL; +} - return jump_helper(a, trampoline, tr_args); +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (node->tag) { + case Function_TAG: { + Node* new = shd_recreate_node_head(r, node); + Context fun_ctx = *ctx; + fun_ctx.cfg = build_fn_cfg(node); + shd_recreate_node_body(&fun_ctx.rewriter, node, new); + shd_destroy_cfg(fun_ctx.cfg); + return new; + } + case Load_TAG: { + Load payload = node->payload.load; + const Node* src = get_ptr_source(payload.ptr); + if (src->tag != LocalAlloc_TAG) + break; + // for now, only simplify loads from non-leaking allocas + const Node* ovalue = get_last_stored_value(ctx, payload.ptr, payload.mem, shd_get_unqualified_type(node->type)); + if (ovalue) { + *ctx->todo = true; + const Node* value = shd_rewrite_node(r, ovalue); + if (shd_is_qualified_type_uniform(node->type)) + value = prim_op_helper(a, subgroup_assume_uniform_op, shd_empty(a), shd_singleton(value)); + return mem_and_value(a, (MemAndValue) { .mem = shd_rewrite_node(r, payload.mem), .value = value }); } - - return jump_helper(a, new_bb, args); } default: break; } - return recreate_node_identity(&ctx->rewriter, old); + return shd_recreate_node(r, node); } -Module* opt_mem2reg(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* initial_arena = get_module_arena(src); - IrArena* a = new_ir_arena(aconfig); - Module* dst = src; - - for (size_t round = 0; round < 2; round++) { - dst = new_module(a, get_module_name(src)); - - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .bb_new_args = new_dict(const Node*, Nodes, (HashFn) hash_node, (CmpFn) compare_node), - .a = new_arena(), - }; - - ctx.rewriter.config.fold_quote = false; +bool shd_opt_mem2reg(SHADY_UNUSED const CompilerConfig* config, Module** m) { + Module* src = *m; + IrArena* a = shd_module_get_arena(src); - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - destroy_dict(ctx.bb_new_args); - destroy_arena(ctx.a); - - verify_module(dst); - - if (get_module_arena(src) != initial_arena) - destroy_ir_arena(get_module_arena(src)); - - dst = cleanup(config, dst); - src = dst; - } - - destroy_ir_arena(a); - - return dst; + Module* dst = NULL; + bool todo = false; + dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .todo = &todo + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + assert(dst); + *m = dst; + return todo; } diff --git a/src/shady/passes/opt_restructure.c b/src/shady/passes/opt_restructure.c deleted file mode 100644 index a1f41fb54..000000000 --- a/src/shady/passes/opt_restructure.c +++ /dev/null @@ -1,421 +0,0 @@ -#include "passes.h" - -#include "dict.h" -#include "list.h" -#include "portability.h" -#include "log.h" - -#include "../rewrite.h" -#include "../type.h" - -#include -#include - -#pragma GCC diagnostic error "-Wswitch" - -typedef struct { - const Node* old; - Node* new; -} TodoEntry; - -typedef struct ControlEntry_ ControlEntry; -struct ControlEntry_ { - ControlEntry* parent; - const Node* old_token; - const Node** phis; - int depth; -}; - -typedef struct DFSStackEntry_ DFSStackEntry; -struct DFSStackEntry_ { - DFSStackEntry* parent; - const Node* old; - - ControlEntry* containing_control; - - bool loop_header; - bool in_loop; -}; - -typedef struct { - Rewriter rewriter; - struct List* tmp_alloc_stack; - - jmp_buf bail; - - bool lower; - Node* fn; - const Node* level_ptr; - DFSStackEntry* dfs_stack; - ControlEntry* control_stack; -} Context; - -static DFSStackEntry* encountered_before(Context* ctx, const Node* bb, size_t* path_len) { - DFSStackEntry* entry = ctx->dfs_stack; - if (path_len) *path_len = 0; - while (entry != NULL) { - if (entry->old == bb) - return entry; - entry = entry->parent; - if (path_len) (*path_len)++; - } - return entry; -} - -static const Node* structure(Context* ctx, const Node* abs, const Node* exit_ladder); - -static const Node* handle_bb_callsite(Context* ctx, BodyBuilder* bb, const Node* caller, const Node* j, const Node* exit_ladder) { - assert(j->tag == Jump_TAG); - IrArena* a = ctx->rewriter.dst_arena; - const Node* dst = j->payload.jump.target; - Nodes oargs = j->payload.jump.args; - - size_t path_len; - DFSStackEntry* prior_encounter = encountered_before(ctx, dst, &path_len); - if (prior_encounter) { - // Create path - LARRAY(const Node*, path, path_len); - DFSStackEntry* entry2 = ctx->dfs_stack->parent; - for (size_t i = 0; i < path_len; i++) { - assert(entry2); - path[path_len - 1 - i] = entry2->old; - if (entry2->in_loop) - longjmp(ctx->bail, 1); - if (entry2->containing_control != ctx->control_stack) - longjmp(ctx->bail, 1); - entry2->in_loop = true; - entry2 = entry2->parent; - } - prior_encounter->loop_header = true; - return finish_body(bb, merge_continue(a, (MergeContinue) { - .args = rewrite_nodes(&ctx->rewriter, oargs) - })); - } else { - Nodes oparams = get_abstraction_params(dst); - assert(oparams.count == oargs.count); - LARRAY(const Node*, nparams, oargs.count); - Context ctx2 = *ctx; - - // Record each step of the depth-first search on a stack so we can identify loops - DFSStackEntry dfs_entry = { .parent = ctx->dfs_stack, .old = dst, .containing_control = ctx->control_stack }; - ctx2.dfs_stack = &dfs_entry; - - struct Dict* tmp_processed = clone_dict(ctx->rewriter.map); - append_list(struct Dict*, ctx->tmp_alloc_stack, tmp_processed); - ctx2.rewriter.map = tmp_processed; - for (size_t i = 0; i < oargs.count; i++) { - nparams[i] = var(a, rewrite_node(&ctx->rewriter, oparams.nodes[i]->type), "arg"); - register_processed(&ctx2.rewriter, oparams.nodes[i], nparams[i]); - } - - // We use a basic block for the exit ladder because we don't know what the ladder needs to do ahead of time - // opt_simplify_cf will later inline this - Node* inner_exit_ladder_bb = basic_block(a, ctx->fn, empty(a), unique_name(a, "exit_ladder_inline_me")); - - // Just jumps to the actual ladder - const Node* exit_ladder_trampoline = case_(a, empty(a), jump(a, (Jump) {.target = inner_exit_ladder_bb, .args = empty(a)})); - - const Node* structured = structure(&ctx2, dst, let(a, quote_helper(a, empty(a)), exit_ladder_trampoline)); - assert(is_terminator(structured)); - // forget we rewrote all that - destroy_dict(tmp_processed); - pop_list_impl(ctx->tmp_alloc_stack); - - if (dfs_entry.loop_header) { - const Node* body = case_(a, nodes(a, oargs.count, nparams), structured); - bind_instruction(bb, loop_instr(a, (Loop) { - .body = body, - .initial_args = rewrite_nodes(&ctx->rewriter, oargs), - .yield_types = nodes(a, 0, NULL), - })); - // we decide 'late' what the exit ladder should be - inner_exit_ladder_bb->payload.basic_block.body = merge_break(a, (MergeBreak) { .args = empty(a) }); - return finish_body(bb, exit_ladder); - } else { - bind_variables(bb, nodes(a, oargs.count, nparams), rewrite_nodes(&ctx->rewriter, oargs)); - inner_exit_ladder_bb->payload.basic_block.body = exit_ladder; - return finish_body(bb, structured); - } - } -} - -static ControlEntry* search_containing_control(Context* ctx, const Node* old_token) { - ControlEntry* entry = ctx->control_stack; - assert(entry); - while (entry != NULL) { - if (entry->old_token == old_token) - return entry; - entry = entry->parent; - } - return entry; -} - -static const Node* rebuild_let(Context* ctx, const Node* old_let, const Node* new_instruction, const Node* exit_ladder) { - IrArena* a = ctx->rewriter.dst_arena; - const Node* old_tail = get_let_tail(old_let); - Nodes otail_params = get_abstraction_params(old_tail); - - Nodes rewritten_params = recreate_variables(&ctx->rewriter, otail_params); - register_processed_list(&ctx->rewriter, otail_params, rewritten_params); - const Node* structured_lam = case_(a, rewritten_params, structure(ctx, old_tail, exit_ladder)); - return let(a, new_instruction, structured_lam); -} - -static const Node* structure(Context* ctx, const Node* abs, const Node* exit_ladder) { - IrArena* a = ctx->rewriter.dst_arena; - - const Node* body = get_abstraction_body(abs); - assert(body); - switch (is_terminator(body)) { - case NotATerminator: - case LetMut_TAG: assert(false); - case Let_TAG: { - const Node* old_tail = get_let_tail(body); - Nodes otail_params = get_abstraction_params(old_tail); - - const Node* old_instr = get_let_instruction(body); - switch (is_instruction(old_instr)) { - case NotAnInstruction: assert(false); - case Instruction_If_TAG: - case Instruction_Loop_TAG: - case Instruction_Match_TAG: error("not supposed to exist in IR at this stage"); - case Instruction_Block_TAG: error("Should be eliminated by the compiler"); - case Instruction_Comment_TAG: - case Instruction_PrimOp_TAG: { - break; - } - case Instruction_Call_TAG: { - const Node* callee = old_instr->payload.call.callee; - if (callee->tag == FnAddr_TAG) { - const Node* fn = rewrite_node(&ctx->rewriter, callee->payload.fn_addr.fn); - // leave leaf calls alone - if (lookup_annotation(fn, "Leaf")) { - break; - } - } - // if we don't manage that, give up :( - assert(false); // actually that should not come up. - longjmp(ctx->bail, 1); - } - // let(control(body), tail) - // var phi = undef; level = N+1; structurize[body, if (level == N+1, _ => tail(load(phi))); structured_exit_terminator] - case Instruction_Control_TAG: { - const Node* old_control_body = old_instr->payload.control.inside; - assert(old_control_body->tag == Case_TAG); - Nodes old_control_params = get_abstraction_params(old_control_body); - assert(old_control_params.count == 1); - - // Create N temporary variables to hold the join point arguments - BodyBuilder* bb_outer = begin_body(a); - Nodes yield_types = rewrite_nodes(&ctx->rewriter, old_instr->payload.control.yield_types); - LARRAY(const Node*, phis, yield_types.count); - for (size_t i = 0; i < yield_types.count; i++) { - const Type* type = yield_types.nodes[i]; - assert(is_data_type(type)); - phis[i] = first(bind_instruction_named(bb_outer, prim_op(a, (PrimOp) { .op = alloca_logical_op, .type_arguments = singleton(type) }), (String []) {"ctrl_phi" })); - } - - // Create a new context to rewrite the body with - // TODO: Bail if we try to re-enter the same control construct - Context control_ctx = *ctx; - ControlEntry control_entry = { - .parent = ctx->control_stack, - .old_token = first(old_control_params), - .phis = phis, - .depth = ctx->control_stack ? ctx->control_stack->depth + 1 : 1, - }; - control_ctx.control_stack = &control_entry; - - // Set the depth for threads entering the control body - bind_instruction(bb_outer, prim_op(a, (PrimOp) { .op = store_op, .operands = mk_nodes(a, ctx->level_ptr, int32_literal(a, control_entry.depth)) })); - - // Start building out the tail, first it needs to dereference the phi variables to recover the arguments given to join() - BodyBuilder* bb2 = begin_body(a); - LARRAY(const Node*, phi_values, yield_types.count); - for (size_t i = 0; i < yield_types.count; i++) { - phi_values[i] = first(bind_instruction(bb2, prim_op(a, (PrimOp) { .op = load_op, .operands = singleton(phis[i]) }))); - register_processed(&ctx->rewriter, otail_params.nodes[i], phi_values[i]); - } - - // Wrap the tail in a guarded if, to handle 'far' joins - const Node* level_value = first(bind_instruction(bb2, prim_op(a, (PrimOp) { .op = load_op, .operands = singleton(ctx->level_ptr) }))); - const Node* guard = first(bind_instruction(bb2, prim_op(a, (PrimOp) { .op = eq_op, .operands = mk_nodes(a, level_value, int32_literal(a, ctx->control_stack ? ctx->control_stack->depth : 0)) }))); - const Node* true_body = structure(ctx, old_tail, yield(a, (Yield) { .args = empty(a) })); - const Node* if_true_lam = case_(a, empty(a), true_body); - bind_instruction(bb2, if_instr(a, (If) { - .condition = guard, - .yield_types = empty(a), - .if_true = if_true_lam, - .if_false = NULL - })); - - const Node* tail_lambda = case_(a, empty(a), finish_body(bb2, exit_ladder)); - return finish_body(bb_outer, structure(&control_ctx, old_control_body, let(a, quote_helper(a, empty(a)), tail_lambda))); - } - } - return rebuild_let(ctx, body, recreate_node_identity(&ctx->rewriter, old_instr), exit_ladder); - } - case Jump_TAG: { - BodyBuilder* bb = begin_body(a); - return handle_bb_callsite(ctx, bb, abs, body, exit_ladder); - } - // br(cond, true_bb, false_bb, args) - // becomes - // let(if(cond, _ => handle_bb_callsite[true_bb, args], _ => handle_bb_callsite[false_bb, args]), _ => unreachable) - case Branch_TAG: { - const Node* condition = rewrite_node(&ctx->rewriter, body->payload.branch.branch_condition); - - BodyBuilder* if_true_bb = begin_body(a); - const Node* true_body = handle_bb_callsite(ctx, if_true_bb, abs, body->payload.branch.true_jump, yield(a, (Yield) { .args = empty(a) })); - const Node* if_true_lam = case_(a, empty(a), true_body); - - BodyBuilder* if_false_bb = begin_body(a); - const Node* false_body = handle_bb_callsite(ctx, if_false_bb, abs, body->payload.branch.false_jump, yield(a, (Yield) { .args = empty(a) })); - const Node* if_false_lam = case_(a, empty(a), false_body); - - const Node* instr = if_instr(a, (If) { - .condition = condition, - .yield_types = empty(a), - .if_true = if_true_lam, - .if_false = if_false_lam, - }); - const Node* post_merge_lam = case_(a, empty(a), exit_ladder); - return let(a, instr, post_merge_lam); - } - case Switch_TAG: { - const Node* switch_value = rewrite_node(&ctx->rewriter, body->payload.br_switch.switch_value); - - BodyBuilder* default_bb = begin_body(a); - const Node* default_body = handle_bb_callsite(ctx, default_bb, abs, body->payload.br_switch.default_jump, yield(a, (Yield) { .args = empty(a) })); - const Node* default_case = case_(a, empty(a), default_body); - - LARRAY(const Node*, cases, body->payload.br_switch.case_jumps.count); - for (size_t i = 0; i < body->payload.br_switch.case_jumps.count; i++) { - BodyBuilder* bb = begin_body(a); - cases[i] = case_(a, empty(a), handle_bb_callsite(ctx, bb, abs, body->payload.br_switch.case_jumps.nodes[i], yield(a, (Yield) {.args = empty(a)}))); - } - - const Node* instr = match_instr(a, (Match) { - .inspect = switch_value, - .yield_types = empty(a), - .default_case = default_case, - .cases = nodes(a, body->payload.br_switch.case_jumps.count, cases), - .literals = rewrite_nodes(&ctx->rewriter, body->payload.br_switch.case_values), - }); - return let(a, instr, case_(a, empty(a), exit_ladder)); - } - case Join_TAG: { - ControlEntry* control = search_containing_control(ctx, body->payload.join.join_point); - if (!control) - longjmp(ctx->bail, 1); - - BodyBuilder* bb = begin_body(a); - bind_instruction(bb, prim_op(a, (PrimOp) { .op = store_op, .operands = mk_nodes(a, ctx->level_ptr, int32_literal(a, control->depth - 1)) })); - - Nodes args = rewrite_nodes(&ctx->rewriter, body->payload.join.args); - for (size_t i = 0; i < args.count; i++) { - bind_instruction(bb, prim_op(a, (PrimOp) { .op = store_op, .operands = mk_nodes(a, control->phis[i], args.nodes[i]) })); - } - - return finish_body(bb, exit_ladder); - } - - case Return_TAG: - case Unreachable_TAG: return recreate_node_identity(&ctx->rewriter, body); - - case TailCall_TAG: longjmp(ctx->bail, 1); - - case Terminator_MergeBreak_TAG: - case Terminator_MergeContinue_TAG: - case Yield_TAG: error("Only control nodes are tolerated here.") - } -} - -static const Node* process(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - if (!node) return NULL; - assert(a != node->arena); - assert(node->arena == ctx->rewriter.src_arena); - - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - if (is_declaration(node)) { - String name = get_decl_name(node); - Nodes decls = get_module_declarations(ctx->rewriter.dst_module); - for (size_t i = 0; i < decls.count; i++) { - if (strcmp(get_decl_name(decls.nodes[i]), name) == 0) - return decls.nodes[i]; - } - } - - if (node->tag == Function_TAG) { - Node* new = recreate_decl_header_identity(&ctx->rewriter, node); - - size_t alloc_stack_size_now = entries_count_list(ctx->tmp_alloc_stack); - - Context ctx2 = *ctx; - ctx2.dfs_stack = NULL; - ctx2.control_stack = NULL; - bool is_builtin = lookup_annotation(node, "Builtin"); - bool is_leaf = false; - if (is_builtin || !node->payload.fun.body || lookup_annotation(node, "Structured") || setjmp(ctx2.bail)) { - ctx2.lower = false; - ctx2.rewriter.map = ctx->rewriter.map; - if (node->payload.fun.body) - new->payload.fun.body = rewrite_node(&ctx2.rewriter, node->payload.fun.body); - // builtin functions are always considered leaf functions - is_leaf = is_builtin || !node->payload.fun.body; - } else { - ctx2.lower = true; - BodyBuilder* bb = begin_body(a); - const Node* ptr = first(bind_instruction_named(bb, prim_op(a, (PrimOp) { .op = alloca_logical_op, .type_arguments = singleton(int32_type(a)) }), (String []) {"cf_depth" })); - bind_instruction(bb, prim_op(a, (PrimOp) { .op = store_op, .operands = mk_nodes(a, ptr, int32_literal(a, 0)) })); - ctx2.level_ptr = ptr; - ctx2.fn = new; - struct Dict* tmp_processed = clone_dict(ctx->rewriter.map); - append_list(struct Dict*, ctx->tmp_alloc_stack, tmp_processed); - ctx2.rewriter.map = tmp_processed; - new->payload.fun.body = finish_body(bb, structure(&ctx2, node, unreachable(a))); - is_leaf = true; - } - - //if (is_leaf) - // new->payload.fun.annotations = append_nodes(arena, new->payload.fun.annotations, annotation(arena, (Annotation) { .name = "Leaf" })); - - // if we did a longjmp, we might have orphaned a few of those - while (alloc_stack_size_now < entries_count_list(ctx->tmp_alloc_stack)) { - struct Dict* orphan = pop_last_list(struct Dict*, ctx->tmp_alloc_stack); - destroy_dict(orphan); - } - - new->payload.fun.annotations = filter_out_annotation(a, new->payload.fun.annotations, "MaybeLeaf"); - - return new; - } - - if (!ctx->lower) - return recreate_node_identity(&ctx->rewriter, node); - - // These should all be manually visited by 'structure' - // assert(!is_terminator(node) && !is_instruction(node)); - - switch (node->tag) { - default: return recreate_node_identity(&ctx->rewriter, node); - } -} - -Module* opt_restructurize(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .tmp_alloc_stack = new_list(struct Dict*), - }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - destroy_list(ctx.tmp_alloc_stack); - return dst; -} diff --git a/src/shady/passes/opt_stack.c b/src/shady/passes/opt_stack.c deleted file mode 100644 index e2c9cd617..000000000 --- a/src/shady/passes/opt_stack.c +++ /dev/null @@ -1,140 +0,0 @@ -#include "passes.h" - -#include "../rewrite.h" -#include "portability.h" -#include "log.h" - -typedef struct StackState_ StackState; -struct StackState_ { - StackState* prev; - enum { VALUE, MERGE } type; - bool leaks; - const Node* value; - size_t count; - const Node** values; -}; - -typedef struct { - Rewriter rewriter; - StackState* state; -} Context; - -static void tag_leaks(Context* ctx) { - StackState* s = ctx->state; - while (s) { - s->leaks = true; - s = s->prev; - } -} - -static const Node* process(Context* ctx, const Node* node) { - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; - StackState entry; - Context child_ctx = *ctx; - - bool is_push = false; - bool is_pop = false; - - switch (is_terminator(node)) { - case Terminator_Unreachable_TAG: break; - case Let_TAG: { - const Node* old_instruction = node->payload.let.instruction; - const Node* ntail = NULL; - switch (is_instruction(old_instruction)) { - case PrimOp_TAG: { - switch (old_instruction->payload.prim_op.op) { - case push_stack_op: { - const Node* value = rewrite_node(&ctx->rewriter, first(old_instruction->payload.prim_op.operands)); - entry = (StackState) { - .prev = ctx->state, - .type = VALUE, - .value = value, - .leaks = false, - }; - child_ctx.state = &entry; - is_push = true; - break; - } - case pop_stack_op: { - if (ctx->state) { - child_ctx.state = ctx->state->prev; - is_pop = true; - } - break; - } - default: break; - } - break; - } - // We sadly don't handle those yet: - case Match_TAG: - case Control_TAG: - case Loop_TAG: - case If_TAG: - case Instruction_Block_TAG: - // Leaf calls and indirect calls are not analysed and so they are considered to leak the state - // we also need to forget our information about the current state - case Instruction_Call_TAG: { - tag_leaks(ctx); - child_ctx.state = NULL; - break; - } - case Instruction_Comment_TAG: break; - case NotAnInstruction: assert(false); - } - - ntail = rewrite_node(&child_ctx.rewriter, node->payload.let.tail); - - const Node* ninstruction = NULL; - if (is_push && !child_ctx.state->leaks) { - // replace stack pushes with no-ops - ninstruction = quote_helper(a, empty(a)); - } else if (is_pop) { - assert(ctx->state->type == VALUE); - const Node* value = ctx->state->value; - ninstruction = quote_helper(a, singleton(value)); - } else { - // if the stack state is observed, or this was an unrelated instruction, leave it alone - ninstruction = recreate_node_identity(&ctx->rewriter, old_instruction); - } - assert(ninstruction); - return let(a, ninstruction, ntail); - } - // Unreachable is assumed to never happen, so it doesn't observe the stack state - case NotATerminator: break; - default: { - // All other non-let terminators are considered to leak the stack state - tag_leaks(ctx); - break; - } - } - - // child_ctx.state = NULL; - switch (node->tag) { - case Function_TAG: { - Node* fun = recreate_decl_header_identity(&ctx->rewriter, node); - child_ctx.state = NULL; - recreate_decl_body_identity(&child_ctx.rewriter, node, fun); - return fun; - } - default: return recreate_node_identity(&child_ctx.rewriter, node); - } -} - -Module* opt_stack(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .state = NULL, - }; - - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - return dst; -} diff --git a/src/shady/passes/passes.h b/src/shady/passes/passes.h index e1aaf0c34..d2d3f8fb2 100644 --- a/src/shady/passes/passes.h +++ b/src/shady/passes/passes.h @@ -1,21 +1,13 @@ #ifndef SHADY_PASSES_H #include "shady/ir.h" - -typedef Module* (RewritePass)(const CompilerConfig* config, Module* src); +#include "shady/pass.h" /// @name Boring, regular compiler stuff /// @{ -RewritePass import; -RewritePass cleanup; - -/// Removes all Unresolved nodes and replaces them with the appropriate decl/value -RewritePass bind_program; -/// Enforces the grammar, notably by let-binding any intermediary result -RewritePass normalize; -/// Makes sure every node is well-typed -RewritePass infer_program; +RewritePass shd_import; +RewritePass shd_cleanup; /// @} @@ -23,7 +15,15 @@ RewritePass infer_program; /// @{ /// Gets rid of structured control flow constructs, and turns them into branches, joins and tailcalls -RewritePass lower_cf_instrs; +RewritePass shd_pass_lower_cf_instrs; +/// Uses shady.scope annotations to insert control blocks +RewritePass shd_pass_scope2control; +RewritePass shd_pass_lift_everything; +RewritePass shd_pass_remove_critical_edges; +RewritePass shd_pass_lcssa; +RewritePass shd_pass_scope_heuristic; +/// Try to identify reconvergence points throughout the program for unstructured control flow programs +RewritePass shd_pass_reconvergence_heuristics; /// @} @@ -31,13 +31,8 @@ RewritePass lower_cf_instrs; /// @{ /// Extracts unstructured basic blocks into separate functions (including spilling) -RewritePass lift_indirect_targets; -/// Emulates uniform jumps within functions using a loop -RewritePass lower_jumps_loop; -/// Emulates uniform jumps within functions by applying a structuring transformation -RewritePass lower_jumps_structure; -RewritePass lcssa; -RewritePass normalize_builtins; +RewritePass shd_pass_lift_indirect_targets; +RewritePass shd_pass_normalize_builtins; /// @} @@ -45,33 +40,34 @@ RewritePass normalize_builtins; /// @{ /// Lowers calls to stack saves and forks, lowers returns to stack pops and joins -RewritePass lower_callf; +RewritePass shd_pass_lower_callf; /// Emulates tailcalls, forks and joins using a god function -RewritePass lower_tailcalls; -/// Turns SIMT code back into SIMD (intended for debugging with the help of the C backend) -RewritePass simt2d; +RewritePass shd_pass_lower_tailcalls; /// @} /// @name Physical memory emulation /// @{ -/// Implements stack frames, saves the stack size on function entry and restores it upon exit -RewritePass setup_stack_frames; +/// Implements stack frames: saves the stack size on function entry and restores it upon exit +RewritePass shd_pass_setup_stack_frames; +/// Implements stack frames: collects allocas into a struct placed on the stack upon function entry +RewritePass shd_pass_lower_alloca; /// Turns stack pushes and pops into accesses into pointer load and stores -RewritePass lower_stack; +RewritePass shd_pass_lower_stack; /// Eliminates lea_op on all physical address spaces -RewritePass lower_lea; +RewritePass shd_pass_lower_lea; /// Emulates generic pointers by replacing them with tagged integers and special load/store routines that look at those tags -RewritePass lower_generic_ptrs; +RewritePass shd_pass_lower_generic_ptrs; /// Emulates physical pointers to certain address spaces by using integer indices into global arrays -RewritePass lower_physical_ptrs; +RewritePass shd_pass_lower_physical_ptrs; /// Replaces size_of, offset_of etc with their exact values -RewritePass lower_memory_layout; -RewritePass lower_memcpy; -/// Eliminates pointers to unsized arrays from the IR. Needs lower_lea to have ran first! -RewritePass lower_decay_ptrs; -RewritePass lower_generic_globals; +RewritePass shd_pass_lower_memory_layout; +RewritePass shd_pass_lower_memcpy; +/// Eliminates pointers to unsized arrays from the IR. Needs lower_lea to have ran shd_first! +RewritePass shd_pass_lower_decay_ptrs; +RewritePass shd_pass_lower_generic_globals; +RewritePass shd_pass_lower_logical_pointers; /// @} @@ -79,11 +75,11 @@ RewritePass lower_generic_globals; /// @{ /// Emulates unsupported subgroup operations using subgroup memory -RewritePass lower_subgroup_ops; +RewritePass shd_pass_lower_subgroup_ops; /// Lowers subgroup logical variables into something that actually exists (likely a carved out portion of shared memory) -RewritePass lower_subgroup_vars; +RewritePass shd_pass_lower_subgroup_vars; /// Lowers the abstract mask type to whatever the configured target mask representation is -RewritePass lower_mask; +RewritePass shd_pass_lower_mask; /// @} @@ -91,10 +87,12 @@ RewritePass lower_mask; /// @{ /// Emulates unsupported integer datatypes and operations -RewritePass lower_int; -RewritePass lower_vec_arr; -RewritePass lower_workgroups; -RewritePass lower_fill; +RewritePass shd_pass_lower_int; +RewritePass shd_pass_lower_vec_arr; +RewritePass shd_pass_lower_workgroups; +RewritePass shd_pass_lower_fill; +RewritePass shd_pass_lower_nullptr; +RewritePass shd_pass_lower_inclusive_scan; /// @} @@ -102,29 +100,22 @@ RewritePass lower_fill; /// @{ /// Eliminates all Constant decls -RewritePass eliminate_constants; +RewritePass shd_pass_eliminate_constants; +/// Ditto but for @Inline ones only +RewritePass shd_pass_eliminate_inlineable_constants; /// Tags all functions that don't need special handling -RewritePass mark_leaf_functions; -/// Inlines basic blocks used exactly once, necessary after opt_restructure -RewritePass opt_inline_jumps; +RewritePass shd_pass_mark_leaf_functions; /// In addition, also inlines function calls according to heuristics -RewritePass opt_inline; -RewritePass opt_mem2reg; - -/// Try to identify reconvergence points throughout the program for unstructured control flow programs -RewritePass reconvergence_heuristics; - -RewritePass opt_stack; -RewritePass opt_restructurize; -RewritePass lower_switch_btree; +RewritePass shd_pass_inline; +OptPass shd_opt_mem2reg; -RewritePass lower_entrypoint_args; +RewritePass shd_pass_restructurize; +RewritePass shd_pass_lower_switch_btree; -RewritePass spirv_map_entrypoint_args; -RewritePass spirv_lift_globals_ssbo; +RewritePass shd_pass_lower_entrypoint_args; -RewritePass specialize_entry_point; -RewritePass specialize_execution_model; +RewritePass shd_pass_specialize_entry_point; +RewritePass shd_pass_specialize_execution_model; /// @} diff --git a/src/shady/passes/reconvergence_heuristics.c b/src/shady/passes/reconvergence_heuristics.c index 9047fb27c..15f192ce4 100644 --- a/src/shady/passes/reconvergence_heuristics.c +++ b/src/shady/passes/reconvergence_heuristics.c @@ -1,4 +1,8 @@ -#include "shady/ir.h" +#include "shady/pass.h" + +#include "../ir_private.h" +#include "../analysis/cfg.h" +#include "../analysis/looptree.h" #include "list.h" #include "dict.h" @@ -6,36 +10,30 @@ #include "portability.h" #include "util.h" -#include "../type.h" -#include "../rewrite.h" -#include "../ir_private.h" -#include "../transform/ir_gen_helpers.h" - -#include "../analysis/scope.h" -#include "../analysis/looptree.h" - #include typedef struct Context_ { Rewriter rewriter; + const CompilerConfig* config; + Arena* arena; const Node* current_fn; const Node* current_abstraction; - Scope* fwd_scope; - Scope* back_scope; + CFG* fwd_cfg; + CFG* rev_cfg; LoopTree* current_looptree; } Context; static bool in_loop(LoopTree* lt, const Node* entry, const Node* block) { - LTNode* lt_node = looptree_lookup(lt, block); + LTNode* lt_node = shd_loop_tree_lookup(lt, block); assert(lt_node); LTNode* parent = lt_node->parent; assert(parent); while (parent) { - if (entries_count_list(parent->cf_nodes) != 1) + if (shd_list_count(parent->cf_nodes) != 1) return false; - if (read_list(CFNode*, parent->cf_nodes)[0]->node == entry) + if (shd_read_list(CFNode*, parent->cf_nodes)[0]->node == entry) return true; parent = parent->parent; @@ -47,26 +45,55 @@ static bool in_loop(LoopTree* lt, const Node* entry, const Node* block) { //TODO: This is massively inefficient. static void gather_exiting_nodes(LoopTree* lt, const CFNode* entry, const CFNode* block, struct List* exiting_nodes) { if (!in_loop(lt, entry->node, block->node)) { - append_list(CFNode*, exiting_nodes, block); + shd_list_append(CFNode*, exiting_nodes, block); return; } - for (size_t i = 0; i < entries_count_list(block->dominates); i++) { - const CFNode* target = read_list(CFNode*, block->dominates)[i]; + for (size_t i = 0; i < shd_list_count(block->dominates); i++) { + const CFNode* target = shd_read_list(CFNode*, block->dominates)[i]; gather_exiting_nodes(lt, entry, target, exiting_nodes); } } +static void find_unbound_vars(const Node* exiting_node, struct Dict* bound_set, struct Dict* free_set, struct List* leaking) { + const Node* v; + size_t i = 0; + while (shd_dict_iter(free_set, &i, &v, NULL)) { + if (shd_dict_find_key(const Node*, bound_set, v)) + continue; + + shd_log_fmt(DEBUGVV, "Found variable used outside it's control scope: "); + shd_log_node(DEBUGVV, v); + shd_log_fmt(DEBUGVV, " (exiting_node:"); + shd_log_node(DEBUGVV, exiting_node); + shd_log_fmt(DEBUGVV, " )\n"); + + shd_list_append(const Node*, leaking, v); + } +} + +typedef struct { + const Node* alloca; + bool uniform; +} ExitValue; + +typedef struct { + ExitValue* params; + size_t params_count; + + Node* wrapper; +} Exit; + static const Node* process_abstraction(Context* ctx, const Node* node) { - assert(is_abstraction(node)); + assert(node && is_abstraction(node)); Context new_context = *ctx; ctx = &new_context; ctx->current_abstraction = node; Rewriter* rewriter = &ctx->rewriter; IrArena* arena = rewriter->dst_arena; - CFNode* current_node = scope_lookup(ctx->fwd_scope, node); - LTNode* lt_node = looptree_lookup(ctx->current_looptree, node); + CFNode* current_node = shd_cfg_lookup(ctx->fwd_cfg, node); + LTNode* lt_node = shd_loop_tree_lookup(ctx->current_looptree, node); LTNode* loop_header = NULL; assert(current_node); @@ -74,11 +101,11 @@ static const Node* process_abstraction(Context* ctx, const Node* node) { bool is_loop_entry = false; if (lt_node->parent && lt_node->parent->type == LF_HEAD) { - if (entries_count_list(lt_node->parent->cf_nodes) == 1) - if (read_list(CFNode*, lt_node->parent->cf_nodes)[0]->node == node) { + if (shd_list_count(lt_node->parent->cf_nodes) == 1) + if (shd_read_list(CFNode*, lt_node->parent->cf_nodes)[0]->node == node) { loop_header = lt_node->parent; assert(loop_header->type == LF_HEAD); - assert(entries_count_list(loop_header->cf_nodes) == 1 && "only reducible loops are handled"); + assert(shd_list_count(loop_header->cf_nodes) == 1 && "only reducible loops are handled"); is_loop_entry = true; } } @@ -86,260 +113,245 @@ static const Node* process_abstraction(Context* ctx, const Node* node) { if (is_loop_entry) { assert(!is_function(node)); - struct List* exiting_nodes = new_list(CFNode*); + struct List* exiting_nodes = shd_new_list(CFNode*); gather_exiting_nodes(ctx->current_looptree, current_node, current_node, exiting_nodes); - for (size_t i = 0; i < entries_count_list(exiting_nodes); i++) { - debugv_print("Node %s exits the loop headed at %s\n", get_abstraction_name(read_list(CFNode*, exiting_nodes)[i]->node), get_abstraction_name(node)); + for (size_t i = 0; i < shd_list_count(exiting_nodes); i++) { + shd_debugv_print("Node %s exits the loop headed at %s\n", shd_get_abstraction_name_safe(shd_read_list(CFNode * , exiting_nodes)[i]->node), shd_get_abstraction_name_safe(node)); } - BodyBuilder* outer_bb = begin_body(arena); - - size_t exiting_nodes_count = entries_count_list(exiting_nodes); + size_t exiting_nodes_count = shd_list_count(exiting_nodes); if (exiting_nodes_count > 0) { - Nodes nparams = recreate_variables(rewriter, get_abstraction_params(node)); - Nodes inner_yield_types = strip_qualifiers(arena, get_variables_types(arena, nparams)); + Nodes nparams = shd_recreate_params(rewriter, get_abstraction_params(node)); + Node* loop_container = basic_block(arena, nparams, node->payload.basic_block.name); + BodyBuilder* outer_bb = shd_bld_begin(arena, shd_get_abstraction_mem(loop_container)); + Nodes inner_yield_types = shd_strip_qualifiers(arena, shd_get_param_types(arena, nparams)); - LARRAY(Nodes, exit_allocas, exiting_nodes_count); + LARRAY(Exit, exits, exiting_nodes_count); for (size_t i = 0; i < exiting_nodes_count; i++) { - CFNode* exiting_node = read_list(CFNode*, exiting_nodes)[i]; - Nodes exit_param_types = rewrite_nodes(rewriter, get_variables_types(ctx->rewriter.src_arena, get_abstraction_params(exiting_node->node))); - LARRAY(const Node*, allocas, exit_param_types.count); - for (size_t j = 0; j < exit_param_types.count; j++) - allocas[j] = gen_primop_e(outer_bb, alloca_op, singleton(get_unqualified_type(exit_param_types.nodes[j])), empty(arena)); - exit_allocas[i] = nodes(arena, exit_param_types.count, allocas); + CFNode* exiting_node = shd_read_list(CFNode*, exiting_nodes)[i]; + Nodes exit_param_types = shd_rewrite_nodes(rewriter, shd_get_param_types(ctx->rewriter.src_arena, get_abstraction_params(exiting_node->node))); + + ExitValue* exit_params = shd_arena_alloc(ctx->arena, sizeof(ExitValue) * exit_param_types.count); + for (size_t j = 0; j < exit_param_types.count; j++) { + exit_params[j].alloca = shd_bld_stack_alloc(outer_bb, shd_get_unqualified_type(exit_param_types.nodes[j])); + exit_params[j].uniform = shd_is_qualified_type_uniform(exit_param_types.nodes[j]); + } + exits[i] = (Exit) { + .params = exit_params, + .params_count = exit_param_types.count, + }; } const Node* exit_destination_alloca = NULL; if (exiting_nodes_count > 1) - exit_destination_alloca = gen_primop_e(outer_bb, alloca_op, singleton(int32_type(arena)), empty(arena)); - - Node* fn = (Node*) find_processed(rewriter, ctx->current_fn); + exit_destination_alloca = shd_bld_stack_alloc(outer_bb, shd_int32_type(arena)); - const Node* join_token_exit = var(arena, qualified_type_helper(join_point_type(arena, (JoinPointType) { - .yield_types = empty(arena) + const Node* join_token_exit = param(arena, shd_as_qualified_type(join_point_type(arena, (JoinPointType) { + .yield_types = shd_empty(arena) }), true), "jp_exit"); - const Node* join_token_continue = var(arena, qualified_type_helper(join_point_type(arena, (JoinPointType) { - .yield_types = inner_yield_types - }), true), "jp_continue"); + const Node* join_token_continue = param(arena, + shd_as_qualified_type(join_point_type(arena, (JoinPointType) { + .yield_types = inner_yield_types + }), true), "jp_continue"); - - LARRAY(const Node*, exit_wrappers, exiting_nodes_count); for (size_t i = 0; i < exiting_nodes_count; i++) { - CFNode* exiting_node = read_list(CFNode*, exiting_nodes)[i]; + CFNode* exiting_node = shd_read_list(CFNode*, exiting_nodes)[i]; assert(exiting_node->node && exiting_node->node->tag != Function_TAG); - Nodes exit_wrapper_params = recreate_variables(&ctx->rewriter, get_abstraction_params(exiting_node->node)); - BodyBuilder* exit_wrapper_bb = begin_body(arena); - - for (size_t j = 0; j < exit_allocas[i].count; j++) - gen_store(exit_wrapper_bb, exit_allocas[i].nodes[j], exit_wrapper_params.nodes[j]); + Nodes exit_wrapper_params = shd_recreate_params(&ctx->rewriter, get_abstraction_params(exiting_node->node)); - const Node* exit_wrapper_body = finish_body(exit_wrapper_bb, join(arena, (Join) { - .join_point = join_token_exit, - .args = empty(arena) - })); - - switch (exiting_node->node->tag) { - case BasicBlock_TAG: { - Node* pre_join_exit_bb = basic_block(arena, fn, exit_wrapper_params, format_string_arena(arena->arena, "exit_wrapper_%d", i)); - pre_join_exit_bb->payload.basic_block.body = exit_wrapper_body; - exit_wrappers[i] = pre_join_exit_bb; - break; - } - case Case_TAG: - exit_wrappers[i] = case_(arena, exit_wrapper_params, exit_wrapper_body); - break; - default: - assert(false); - } + Node* wrapper = basic_block(arena, exit_wrapper_params, shd_format_string_arena(arena->arena, "exit_wrapper_%d", i)); + exits[i].wrapper = wrapper; } - Nodes continue_wrapper_params = recreate_variables(rewriter, get_abstraction_params(node)); + Nodes continue_wrapper_params = shd_recreate_params(rewriter, get_abstraction_params(node)); + Node* continue_wrapper = basic_block(arena, continue_wrapper_params, "continue"); const Node* continue_wrapper_body = join(arena, (Join) { .join_point = join_token_continue, - .args = continue_wrapper_params + .args = continue_wrapper_params, + .mem = shd_get_abstraction_mem(continue_wrapper), }); - const Node* continue_wrapper; - switch (node->tag) { - case BasicBlock_TAG: { - Node* pre_join_continue_bb = basic_block(arena, fn, continue_wrapper_params, "continue"); - pre_join_continue_bb->payload.basic_block.body = continue_wrapper_body; - continue_wrapper = pre_join_continue_bb; - break; - } - case Case_TAG: - continue_wrapper = case_(arena, continue_wrapper_params, continue_wrapper_body); - break; - default: - assert(false); - } + shd_set_abstraction_body(continue_wrapper, continue_wrapper_body); // replace the exit nodes by the exit wrappers - LARRAY(const Node*, cached_exits, exiting_nodes_count); + LARRAY(const Node**, cached_exits, exiting_nodes_count); for (size_t i = 0; i < exiting_nodes_count; i++) { - CFNode* exiting_node = read_list(CFNode*, exiting_nodes)[i]; - cached_exits[i] = search_processed(rewriter, exiting_node->node); + CFNode* exiting_node = shd_read_list(CFNode*, exiting_nodes)[i]; + cached_exits[i] = shd_search_processed(rewriter, exiting_node->node); if (cached_exits[i]) - remove_dict(const Node*, rewriter->map, exiting_node->node); - register_processed(rewriter, exiting_node->node, exit_wrappers[i]); + shd_dict_remove(const Node*, rewriter->map, exiting_node->node); + shd_register_processed(rewriter, exiting_node->node, exits[i].wrapper); } // ditto for the loop entry and the continue wrapper - const Node* cached_entry = search_processed(rewriter, node); + const Node** cached_entry = shd_search_processed(rewriter, node); if (cached_entry) - remove_dict(const Node*, rewriter->map, node); - register_processed(rewriter, node, continue_wrapper); + shd_dict_remove(const Node*, rewriter->map, node); + shd_register_processed(rewriter, node, continue_wrapper); // make sure we haven't started rewriting this... // for (size_t i = 0; i < old_params.count; i++) { // assert(!search_processed(rewriter, old_params.nodes[i])); // } - Nodes inner_loop_params = recreate_variables(rewriter, get_abstraction_params(node)); - register_processed_list(rewriter, get_abstraction_params(node), inner_loop_params); - const Node* loop_body = recreate_node_identity(rewriter, get_abstraction_body(node)); + struct Dict* old_map = rewriter->map; + rewriter->map = shd_clone_dict(rewriter->map); + Nodes inner_loop_params = shd_recreate_params(rewriter, get_abstraction_params(node)); + shd_register_processed_list(rewriter, get_abstraction_params(node), inner_loop_params); + Node* inner_control_case = case_(arena, shd_singleton(join_token_continue)); + shd_register_processed(rewriter, shd_get_abstraction_mem(node), shd_get_abstraction_mem(inner_control_case)); + const Node* loop_body = shd_rewrite_node(rewriter, get_abstraction_body(node)); - // restore the old context + // save the context for (size_t i = 0; i < exiting_nodes_count; i++) { - remove_dict(const Node*, rewriter->map, read_list(CFNode*, exiting_nodes)[i]->node); - if (cached_exits[i]) - register_processed(rewriter, read_list(CFNode*, exiting_nodes)[i]->node, cached_exits[i]); - } - remove_dict(const Node*, rewriter->map, node); - if (cached_entry) - register_processed(rewriter, node, cached_entry); + CFNode* exiting_node = shd_read_list(CFNode*, exiting_nodes)[i]; + assert(exiting_node->node && exiting_node->node->tag != Function_TAG); + Nodes exit_wrapper_params = get_abstraction_params(exits[i].wrapper); + BodyBuilder* exit_wrapper_bb = shd_bld_begin(arena, shd_get_abstraction_mem(exits[i].wrapper)); - BodyBuilder* inner_bb = begin_body(arena); - const Node* inner_control = control (arena, (Control) { - .inside = case_(arena, singleton(join_token_continue), loop_body), - .yield_types = inner_yield_types - }); - Nodes inner_control_results = bind_instruction(inner_bb, inner_control); + for (size_t j = 0; j < exits[i].params_count; j++) + shd_bld_store(exit_wrapper_bb, exits[i].params[j].alloca, exit_wrapper_params.nodes[j]); + // Set the destination if there's more than one option + if (exiting_nodes_count > 1) + shd_bld_store(exit_wrapper_bb, exit_destination_alloca, shd_int32_literal(arena, i)); - Node* loop_outer = basic_block(arena, fn, inner_loop_params, "loop_outer"); + shd_set_abstraction_body(exits[i].wrapper, shd_bld_join(exit_wrapper_bb, join_token_exit, shd_empty(arena))); + } - loop_outer->payload.basic_block.body = finish_body(inner_bb, jump(arena, (Jump) { - .target = loop_outer, - .args = inner_control_results - })); - const Node* outer_control = control (arena, (Control) { - .inside = case_(arena, singleton(join_token_exit), jump(arena, (Jump) { - .target = loop_outer, - .args = nparams - })), - .yield_types = empty(arena) - }); + shd_set_abstraction_body(inner_control_case, loop_body); - bind_instruction(outer_bb, outer_control); + shd_destroy_dict(rewriter->map); + rewriter->map = old_map; + //register_processed_list(rewriter, get_abstraction_params(node), nparams); - const Node* outer_body; + // restore the old context + for (size_t i = 0; i < exiting_nodes_count; i++) { + shd_dict_remove(const Node*, rewriter->map, shd_read_list(CFNode *, exiting_nodes)[i]->node); + if (cached_exits[i]) + shd_register_processed(rewriter, shd_read_list(CFNode*, exiting_nodes)[i]->node, *cached_exits[i]); + } + shd_dict_remove(const Node*, rewriter->map, node); + if (cached_entry) + shd_register_processed(rewriter, node, *cached_entry); + + Node* loop_outer = basic_block(arena, inner_loop_params, "loop_outer"); + BodyBuilder* inner_bb = shd_bld_begin(arena, shd_get_abstraction_mem(loop_outer)); + Nodes inner_control_results = shd_bld_control(inner_bb, inner_yield_types, inner_control_case); + // make sure what was uniform still is + for (size_t j = 0; j < inner_control_results.count; j++) { + if (shd_is_qualified_type_uniform(nparams.nodes[j]->type)) + inner_control_results = shd_change_node_at_index(arena, inner_control_results, j, prim_op_helper(arena, subgroup_assume_uniform_op, shd_empty(arena), shd_singleton(inner_control_results.nodes[j]))); + } + shd_set_abstraction_body(loop_outer, shd_bld_jump(inner_bb, loop_outer, inner_control_results)); + Node* outer_control_case = case_(arena, shd_singleton(join_token_exit)); + shd_set_abstraction_body(outer_control_case, jump(arena, (Jump) { + .target = loop_outer, + .args = nparams, + .mem = shd_get_abstraction_mem(outer_control_case), + })); + shd_bld_control(outer_bb, shd_empty(arena), outer_control_case); LARRAY(const Node*, exit_numbers, exiting_nodes_count); LARRAY(const Node*, exit_jumps, exiting_nodes_count); for (size_t i = 0; i < exiting_nodes_count; i++) { - BodyBuilder* exit_recover_bb = begin_body(arena); - - CFNode* exiting_node = read_list(CFNode*, exiting_nodes)[i]; - const Node* recreated_exit = rewrite_node(rewriter, exiting_node->node); - - LARRAY(const Node*, recovered_args, exit_allocas[i].count); - for (size_t j = 0; j < exit_allocas[i].count; j++) - recovered_args[j] = gen_load(exit_recover_bb, exit_allocas[i].nodes[j]); - - exit_numbers[i] = int32_literal(arena, i); - Node* exit_bb = basic_block(arena, fn, empty(arena), format_string_arena(arena->arena, "exit_recover_values_%s", get_abstraction_name(exiting_node->node))); - if (recreated_exit->tag == BasicBlock_TAG) { - exit_bb->payload.basic_block.body = finish_body(exit_recover_bb, jump(arena, (Jump) { - .target = recreated_exit, - .args = nodes(arena, exit_allocas[i].count, recovered_args), - })); - } else { - assert(recreated_exit->tag == Case_TAG); - exit_bb->payload.basic_block.body = finish_body(exit_recover_bb, let(arena, quote_helper(arena, nodes(arena, exit_allocas[i].count, recovered_args)), recreated_exit)); + CFNode* exiting_node = shd_read_list(CFNode*, exiting_nodes)[i]; + + Node* exit_bb = basic_block(arena, shd_empty(arena), shd_format_string_arena(arena->arena, "exit_recover_values_%s", shd_get_abstraction_name_safe(exiting_node->node))); + BodyBuilder* exit_recover_bb = shd_bld_begin(arena, shd_get_abstraction_mem(exit_bb)); + + const Node* recreated_exit = shd_rewrite_node(rewriter, exiting_node->node); + + LARRAY(const Node*, recovered_args, exits[i].params_count); + for (size_t j = 0; j < exits[i].params_count; j++) { + recovered_args[j] = shd_bld_load(exit_recover_bb, exits[i].params[j].alloca); + if (exits[i].params[j].uniform) + recovered_args[j] = prim_op_helper(arena, subgroup_assume_uniform_op, shd_empty(arena), shd_singleton(recovered_args[j])); } - exit_jumps[i] = jump_helper(arena, exit_bb, empty(arena)); + + exit_numbers[i] = shd_int32_literal(arena, i); + shd_set_abstraction_body(exit_bb, shd_bld_jump(exit_recover_bb, recreated_exit, shd_nodes(arena, exits[i].params_count, recovered_args))); + exit_jumps[i] = jump_helper(arena, shd_bb_mem(outer_bb), exit_bb, shd_empty(arena)); } + const Node* outer_body; if (exiting_nodes_count == 1) - outer_body = finish_body(outer_bb, exit_jumps[0]->payload.jump.target->payload.basic_block.body); + outer_body = shd_bld_finish(outer_bb, exit_jumps[0]); else { - const Node* loaded_destination = gen_load(outer_bb, exit_destination_alloca); - outer_body = finish_body(outer_bb, br_switch(arena, (Switch) { + const Node* loaded_destination = shd_bld_load(outer_bb, exit_destination_alloca); + outer_body = shd_bld_finish(outer_bb, br_switch(arena, (Switch) { .switch_value = loaded_destination, .default_jump = exit_jumps[0], - .case_values = nodes(arena, exiting_nodes_count, exit_numbers), - .case_jumps = nodes(arena, exiting_nodes_count, exit_jumps), + .case_values = shd_nodes(arena, exiting_nodes_count, exit_numbers), + .case_jumps = shd_nodes(arena, exiting_nodes_count, exit_jumps), + .mem = shd_bb_mem(outer_bb) })); } - - const Node* loop_container; - switch (node->tag) { - case BasicBlock_TAG: { - Node* bb = basic_block(arena, fn, nparams, node->payload.basic_block.name); - bb->payload.basic_block.body = outer_body; - loop_container = bb; - break; - } - case Case_TAG: - loop_container = case_(arena, nparams, outer_body); - break; - default: - assert(false); - } - destroy_list(exiting_nodes); + shd_set_abstraction_body(loop_container, outer_body); + shd_destroy_list(exiting_nodes); return loop_container; } - destroy_list(exiting_nodes); + shd_destroy_list(exiting_nodes); } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } static const Node* process_node(Context* ctx, const Node* node) { assert(node); - Rewriter* rewriter = &ctx->rewriter; - IrArena* arena = rewriter->dst_arena; + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + + Context new_context = *ctx; switch (node->tag) { case Function_TAG: { - Context new_context = *ctx; ctx = &new_context; + ctx->current_fn = NULL; + if (!(shd_lookup_annotation(node, "Restructure") || ctx->config->input_cf.restructure_with_heuristics)) + break; + ctx->current_fn = node; - ctx->fwd_scope = new_scope(ctx->current_fn); - ctx->back_scope = new_scope_flipped(ctx->current_fn); - ctx->current_looptree = build_loop_tree(ctx->fwd_scope); + ctx->fwd_cfg = build_fn_cfg(ctx->current_fn); + ctx->rev_cfg = build_fn_cfg_flipped(ctx->current_fn); + ctx->current_looptree = shd_new_loop_tree(ctx->fwd_cfg); const Node* new = process_abstraction(ctx, node);; - destroy_scope(ctx->fwd_scope); - destroy_scope(ctx->back_scope); - destroy_loop_tree(ctx->current_looptree); + shd_destroy_cfg(ctx->fwd_cfg); + shd_destroy_cfg(ctx->rev_cfg); + shd_destroy_loop_tree(ctx->current_looptree); return new; } - case Case_TAG: + case Constant_TAG: { + ctx = &new_context; + ctx->current_fn = NULL; + r = &ctx->rewriter; + break; + } case BasicBlock_TAG: - if (!ctx->current_fn || !lookup_annotation(ctx->current_fn, "Restructure")) + if (!ctx->current_fn || !(shd_lookup_annotation(ctx->current_fn, "Restructure") || ctx->config->input_cf.restructure_with_heuristics)) break; return process_abstraction(ctx, node); case Branch_TAG: { - if (!ctx->current_fn || !lookup_annotation(ctx->current_fn, "Restructure")) + Branch payload = node->payload.branch; + if (!ctx->current_fn || !(shd_lookup_annotation(ctx->current_fn, "Restructure") || ctx->config->input_cf.restructure_with_heuristics)) break; - assert(ctx->fwd_scope); + assert(ctx->fwd_cfg); - CFNode* cfnode = scope_lookup(ctx->back_scope, ctx->current_abstraction); - const Node* idom = NULL; + CFNode* cfnode = shd_cfg_lookup(ctx->rev_cfg, ctx->current_abstraction); + const Node* post_dominator = NULL; - LTNode* current_loop = looptree_lookup(ctx->current_looptree, ctx->current_abstraction)->parent; + LTNode* current_loop = shd_loop_tree_lookup(ctx->current_looptree, ctx->current_abstraction)->parent; assert(current_loop); - if (entries_count_list(current_loop->cf_nodes)) { + if (shd_list_count(current_loop->cf_nodes)) { bool leaves_loop = false; - CFNode* current_node = scope_lookup(ctx->fwd_scope, ctx->current_abstraction); - for (size_t i = 0; i < entries_count_list(current_node->succ_edges); i++) { - CFEdge edge = read_list(CFEdge, current_node->succ_edges)[i]; - LTNode* lt_target = looptree_lookup(ctx->current_looptree, edge.dst->node); + CFNode* current_node = shd_cfg_lookup(ctx->fwd_cfg, ctx->current_abstraction); + for (size_t i = 0; i < shd_list_count(current_node->succ_edges); i++) { + CFEdge edge = shd_read_list(CFEdge, current_node->succ_edges)[i]; + LTNode* lt_target = shd_loop_tree_lookup(ctx->current_looptree, edge.dst->node); if (lt_target->parent != current_loop) { leaves_loop = true; @@ -348,141 +360,142 @@ static const Node* process_node(Context* ctx, const Node* node) { } if (!leaves_loop) { - const Node* current_loop_head = read_list(CFNode*, current_loop->cf_nodes)[0]->node; - Scope* loop_scope = new_scope_lt_flipped(current_loop_head, ctx->current_looptree); - CFNode* idom_cf = scope_lookup(loop_scope, ctx->current_abstraction)->idom; + const Node* current_loop_head = shd_read_list(CFNode*, current_loop->cf_nodes)[0]->node; + CFG* loop_cfg = shd_new_cfg(ctx->current_fn, current_loop_head, (CFGBuildConfig) { + .include_structured_tails = true, + .lt = ctx->current_looptree, + .flipped = true + }); + CFNode* idom_cf = shd_cfg_lookup(loop_cfg, ctx->current_abstraction)->idom; if (idom_cf) - idom = idom_cf->node; - destroy_scope(loop_scope); + post_dominator = idom_cf->node; + shd_destroy_cfg(loop_cfg); } } else { - idom = cfnode->idom->node; + post_dominator = cfnode->idom->node; } - if(!idom) { + if (!post_dominator) { break; } - if (scope_lookup(ctx->fwd_scope, idom)->idom->node!= ctx->current_abstraction) + if (shd_cfg_lookup(ctx->fwd_cfg, post_dominator)->idom->node != ctx->current_abstraction) break; - assert(is_abstraction(idom) && idom->tag != Function_TAG); + assert(is_abstraction(post_dominator) && post_dominator->tag != Function_TAG); - LTNode* lt_node = looptree_lookup(ctx->current_looptree, ctx->current_abstraction); - LTNode* idom_lt_node = looptree_lookup(ctx->current_looptree, idom); - CFNode* current_node = scope_lookup(ctx->fwd_scope, ctx->current_abstraction); + LTNode* lt_node = shd_loop_tree_lookup(ctx->current_looptree, ctx->current_abstraction); + LTNode* idom_lt_node = shd_loop_tree_lookup(ctx->current_looptree, post_dominator); + CFNode* current_node = shd_cfg_lookup(ctx->fwd_cfg, ctx->current_abstraction); assert(lt_node); assert(idom_lt_node); assert(current_node); - Node* fn = (Node*) find_processed(rewriter, ctx->current_fn); + Node* fn = (Node*) shd_find_processed(r, ctx->current_fn); //Regular if/then/else case. Control flow joins at the immediate post dominator. Nodes yield_types; Nodes exit_args; - Nodes lambda_args; - Nodes old_params = get_abstraction_params(idom); + Nodes old_params = get_abstraction_params(post_dominator); + LARRAY(bool, uniform_param, old_params.count); if (old_params.count == 0) { - yield_types = empty(arena); - exit_args = empty(arena); - lambda_args = empty(arena); + yield_types = shd_empty(a); + exit_args = shd_empty(a); } else { - LARRAY(const Node*, types,old_params.count); + LARRAY(const Node*, types, old_params.count); LARRAY(const Node*, inner_args,old_params.count); - LARRAY(const Node*, outer_args,old_params.count); for (size_t j = 0; j < old_params.count; j++) { //TODO: Is this correct? - assert(old_params.nodes[j]->tag == Variable_TAG); - const Node* qualified_type = rewrite_node(rewriter, old_params.nodes[j]->payload.var.type); + assert(old_params.nodes[j]->tag == Param_TAG); + const Node* qualified_type = shd_rewrite_node(r, old_params.nodes[j]->payload.param.type); //const Node* qualified_type = rewrite_node(rewriter, old_params.nodes[j]->type); //This should always contain a qualified type? //if (contains_qualified_type(types[j])) - types[j] = get_unqualified_type(qualified_type); - - inner_args[j] = var(arena, qualified_type, old_params.nodes[j]->payload.var.name); - outer_args[j] = var(arena, qualified_type, old_params.nodes[j]->payload.var.name); + types[j] = shd_get_unqualified_type(qualified_type); + uniform_param[j] = shd_is_qualified_type_uniform(qualified_type); + inner_args[j] = param(a, qualified_type, old_params.nodes[j]->payload.param.name); } - yield_types = nodes(arena, old_params.count, types); - exit_args = nodes(arena, old_params.count, inner_args); - lambda_args = nodes(arena, old_params.count, outer_args); + yield_types = shd_nodes(a, old_params.count, types); + exit_args = shd_nodes(a, old_params.count, inner_args); } - const Node* join_token = var(arena, qualified_type_helper(join_point_type(arena, (JoinPointType) { + const Node* join_token = param(a, shd_as_qualified_type(join_point_type(a, (JoinPointType) { .yield_types = yield_types }), true), "jp_postdom"); - Node* pre_join = basic_block(arena, fn, exit_args, format_string_arena(arena->arena, "merge_%s_%s", get_abstraction_name(ctx->current_abstraction) , get_abstraction_name(idom))); - pre_join->payload.basic_block.body = join(arena, (Join) { + Node* pre_join = basic_block(a, exit_args, shd_format_string_arena(a->arena, "merge_%s_%s", shd_get_abstraction_name_safe(ctx->current_abstraction), shd_get_abstraction_name_safe(post_dominator))); + shd_set_abstraction_body(pre_join, join(a, (Join) { .join_point = join_token, - .args = exit_args - }); + .args = exit_args, + .mem = shd_get_abstraction_mem(pre_join), + })); - const Node* cached = search_processed(rewriter, idom); + const Node** cached = shd_search_processed(r, post_dominator); if (cached) - remove_dict(const Node*, is_declaration(idom) ? rewriter->decls_map : rewriter->map, idom); + shd_dict_remove(const Node*, is_declaration(post_dominator) ? r->decls_map : r->map, post_dominator); for (size_t i = 0; i < old_params.count; i++) { - assert(!search_processed(rewriter, old_params.nodes[i])); + assert(!shd_search_processed(r, old_params.nodes[i])); } - register_processed(rewriter, idom, pre_join); - - const Node* inner_terminator = recreate_node_identity(rewriter, node); - - remove_dict(const Node*, is_declaration(idom) ? rewriter->decls_map : rewriter->map, idom); - if (cached) - register_processed(rewriter, idom, cached); - - const Node* control_inner = case_(arena, singleton(join_token), inner_terminator); - const Node* new_target = control(arena, (Control) { - .inside = control_inner, - .yield_types = yield_types + shd_register_processed(r, post_dominator, pre_join); + + Node* control_case = case_(a, shd_singleton(join_token)); + const Node* inner_terminator = branch(a, (Branch) { + .mem = shd_get_abstraction_mem(control_case), + .condition = shd_rewrite_node(r, payload.condition), + .true_jump = jump_helper(a, shd_get_abstraction_mem(control_case), + shd_rewrite_node(r, payload.true_jump->payload.jump.target), + shd_rewrite_nodes(r, payload.true_jump->payload.jump.args)), + .false_jump = jump_helper(a, shd_get_abstraction_mem(control_case), + shd_rewrite_node(r, payload.false_jump->payload.jump.target), + shd_rewrite_nodes(r, payload.false_jump->payload.jump.args)), }); + shd_set_abstraction_body(control_case, inner_terminator); - const Node* recreated_join = rewrite_node(rewriter, idom); - - switch (idom->tag) { - case BasicBlock_TAG: { - const Node* outer_terminator = jump(arena, (Jump) { - .target = recreated_join, - .args = lambda_args - }); + shd_dict_remove(const Node*, is_declaration(post_dominator) ? r->decls_map : r->map, post_dominator); + if (cached) + shd_register_processed(r, post_dominator, *cached); - const Node* c = case_(arena, lambda_args, outer_terminator); - const Node* empty_let = let(arena, new_target, c); + const Node* join_target = shd_rewrite_node(r, post_dominator); - return empty_let; - } - case Case_TAG: - return let(arena, new_target, recreated_join); - default: - assert(false); + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, node->payload.branch.mem)); + Nodes results = shd_bld_control(bb, yield_types, control_case); + // make sure what was uniform still is + for (size_t j = 0; j < old_params.count; j++) { + if (uniform_param[j]) + results = shd_change_node_at_index(a, results, j, prim_op_helper(a, subgroup_assume_uniform_op, shd_empty(a), shd_singleton(results.nodes[j]))); } + return shd_bld_jump(bb, join_target, results); } default: break; } - return recreate_node_identity(rewriter, node); + return shd_recreate_node(r, node); } -Module* reconvergence_heuristics(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_reconvergence_heuristics(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + aconfig.optimisations.inline_single_use_bbs = true; + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), + .config = config, .current_fn = NULL, - .fwd_scope = NULL, - .back_scope = NULL, + .fwd_cfg = NULL, + .rev_cfg = NULL, .current_looptree = NULL, + .arena = shd_new_arena(), }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_arena(ctx.arena); return dst; } diff --git a/src/shady/passes/remove_critical_edges.c b/src/shady/passes/remove_critical_edges.c new file mode 100644 index 000000000..199d970e8 --- /dev/null +++ b/src/shady/passes/remove_critical_edges.c @@ -0,0 +1,39 @@ +#include "shady/pass.h" +#include "shady/ir/function.h" + +#include "log.h" +#include "portability.h" + +typedef struct { + Rewriter rewriter; +} Context; + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (node->tag) { + case Jump_TAG: { + Jump payload = node->payload.jump; + Node* new_block = basic_block(a, shd_empty(a), NULL); + shd_set_abstraction_body(new_block, jump_helper(a, shd_get_abstraction_mem(new_block), + shd_rewrite_node(r, payload.target), + shd_rewrite_nodes(r, payload.args))); + return jump_helper(a, shd_rewrite_node(r, payload.mem), new_block, shd_empty(a)); + } + default: break; + } + + return shd_recreate_node(r, node); +} + +Module* shd_pass_remove_critical_edges(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/shady/passes/restructure.c b/src/shady/passes/restructure.c new file mode 100644 index 000000000..1d5fa2bf7 --- /dev/null +++ b/src/shady/passes/restructure.c @@ -0,0 +1,432 @@ +#include "shady/pass.h" +#include "shady/ir/function.h" +#include "shady/ir/annotation.h" +#include "shady/ir/mem.h" +#include "shady/ir/debug.h" + +#include +#include + +#include "dict.h" +#include "list.h" +#include "portability.h" +#include "log.h" + +#pragma GCC diagnostic error "-Wswitch" + +typedef struct { + const Node* old; + Node* new; +} TodoEntry; + +typedef struct ControlEntry_ ControlEntry; +struct ControlEntry_ { + ControlEntry* parent; + const Node* old_token; + const Node** phis; + int depth; +}; + +typedef struct DFSStackEntry_ DFSStackEntry; +struct DFSStackEntry_ { + DFSStackEntry* parent; + const Node* old; + + ControlEntry* containing_control; + + bool loop_header; + bool in_loop; +}; + +typedef void (*TmpAllocCleanupFn)(void*); +typedef struct { + TmpAllocCleanupFn fn; + void* payload; +} TmpAllocCleanupClosure; + +static TmpAllocCleanupClosure create_delete_dict_closure(struct Dict* d) { + return (TmpAllocCleanupClosure) { + .fn = (TmpAllocCleanupFn) shd_destroy_dict, + .payload = d, + }; +} + +static TmpAllocCleanupClosure create_cancel_body_closure(BodyBuilder* bb) { + return (TmpAllocCleanupClosure) { + .fn = (TmpAllocCleanupFn) shd_bld_cancel, + .payload = bb, + }; +} + +typedef struct { + Rewriter rewriter; + struct List* cleanup_stack; + + jmp_buf bail; + + bool lower; + Node* fn; + const Node* level_ptr; + DFSStackEntry* dfs_stack; + ControlEntry* control_stack; +} Context; + +static DFSStackEntry* encountered_before(Context* ctx, const Node* bb, size_t* path_len) { + DFSStackEntry* entry = ctx->dfs_stack; + if (path_len) *path_len = 0; + while (entry != NULL) { + if (entry->old == bb) + return entry; + entry = entry->parent; + if (path_len) (*path_len)++; + } + return entry; +} + +static const Node* make_unreachable_case(IrArena* a) { + Node* c = case_(a, shd_empty(a)); + shd_set_abstraction_body(c, unreachable(a, (Unreachable) { .mem = shd_get_abstraction_mem(c) })); + return c; +} + +static const Node* make_selection_merge_case(IrArena* a) { + Node* c = case_(a, shd_empty(a)); + shd_set_abstraction_body(c, merge_selection(a, (MergeSelection) { .args = shd_empty(a), .mem = shd_get_abstraction_mem(c) })); + return c; +} + +static const Node* structure(Context* ctx, const Node* abs, const Node* exit); + +static const Node* handle_bb_callsite(Context* ctx, Jump jump, const Node* mem, const Node* exit) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + + const Node* old_target = jump.target; + Nodes oargs = jump.args; + + size_t path_len; + DFSStackEntry* prior_encounter = encountered_before(ctx, old_target, &path_len); + if (prior_encounter) { + // Create path + LARRAY(const Node*, path, path_len); + DFSStackEntry* entry2 = ctx->dfs_stack->parent; + for (size_t i = 0; i < path_len; i++) { + assert(entry2); + path[path_len - 1 - i] = entry2->old; + if (entry2->in_loop) + longjmp(ctx->bail, 1); + if (entry2->containing_control != ctx->control_stack) + longjmp(ctx->bail, 1); + entry2->in_loop = true; + entry2 = entry2->parent; + } + prior_encounter->loop_header = true; + return merge_continue(a, (MergeContinue) { + .args = shd_rewrite_nodes(r, oargs), + .mem = mem, + }); + } else { + Nodes oparams = get_abstraction_params(old_target); + assert(oparams.count == oargs.count); + LARRAY(const Node*, nparams, oargs.count); + Context ctx2 = *ctx; + + // Record each step of the depth-first search on a stack so we can identify loops + DFSStackEntry dfs_entry = { .parent = ctx->dfs_stack, .old = old_target, .containing_control = ctx->control_stack }; + ctx2.dfs_stack = &dfs_entry; + + BodyBuilder* bb = shd_bld_begin(a, mem); + TmpAllocCleanupClosure cj1 = create_cancel_body_closure(bb); + shd_list_append(TmpAllocCleanupClosure, ctx->cleanup_stack, cj1); + struct Dict* tmp_processed = shd_clone_dict(ctx->rewriter.map); + TmpAllocCleanupClosure cj2 = create_delete_dict_closure(tmp_processed); + shd_list_append(TmpAllocCleanupClosure, ctx->cleanup_stack, cj2); + ctx2.rewriter.map = tmp_processed; + for (size_t i = 0; i < oargs.count; i++) { + nparams[i] = param(a, shd_rewrite_node(&ctx->rewriter, oparams.nodes[i]->type), "arg"); + shd_register_processed(&ctx2.rewriter, oparams.nodes[i], nparams[i]); + } + + // We use a basic block for the exit ladder because we don't know what the ladder needs to do ahead of time + Node* inner_exit_ladder_bb = basic_block(a, shd_empty(a), shd_make_unique_name(a, "exit_ladder_inline_me")); + + // Just jumps to the actual ladder + Node* structured_target = case_(a, shd_nodes(a, oargs.count, nparams)); + shd_register_processed(&ctx2.rewriter, shd_get_abstraction_mem(old_target), shd_get_abstraction_mem(structured_target)); + const Node* structured = structure(&ctx2, get_abstraction_body(old_target), inner_exit_ladder_bb); + assert(is_terminator(structured)); + shd_set_abstraction_body(structured_target, structured); + + // forget we rewrote all that + shd_destroy_dict(tmp_processed); + shd_list_pop_impl(ctx->cleanup_stack); + shd_list_pop_impl(ctx->cleanup_stack); + + if (dfs_entry.loop_header) { + // Use the structured target as the body of a loop + shd_bld_loop(bb, shd_empty(a), shd_rewrite_nodes(&ctx->rewriter, oargs), structured_target); + // The exit ladder must exit that new loop + shd_set_abstraction_body(inner_exit_ladder_bb, merge_break(a, (MergeBreak) { .args = shd_empty(a), .mem = shd_get_abstraction_mem(inner_exit_ladder_bb) })); + // After that we jump to the parent exit + return shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), exit, shd_empty(a))); + } else { + // Simply jmp to the exit once done + shd_set_abstraction_body(inner_exit_ladder_bb, jump_helper(a, shd_get_abstraction_mem(inner_exit_ladder_bb), exit, + shd_empty(a))); + // Jump into the new structured target + return shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), structured_target, shd_rewrite_nodes(&ctx->rewriter, oargs))); + } + } +} + +static ControlEntry* search_containing_control(Context* ctx, const Node* old_token) { + ControlEntry* entry = ctx->control_stack; + assert(entry); + while (entry != NULL) { + if (entry->old_token == old_token) + return entry; + entry = entry->parent; + } + return entry; +} + +static const Node* structure(Context* ctx, const Node* body, const Node* exit) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + + assert(body && is_terminator(body)); + switch (is_terminator(body)) { + case NotATerminator: + case Jump_TAG: { + Jump payload = body->payload.jump; + return handle_bb_callsite(ctx, payload, shd_rewrite_node(r, payload.mem), exit); + } + // br(cond, true_bb, false_bb, args) + // becomes + // let(if(cond, _ => handle_bb_callsite[true_bb, args], _ => handle_bb_callsite[false_bb, args]), _ => unreachable) + case Branch_TAG: { + Branch payload = body->payload.branch; + const Node* condition = shd_rewrite_node(&ctx->rewriter, payload.condition); + shd_rewrite_node(r, payload.mem); + + Node* true_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(true_case, handle_bb_callsite(ctx, payload.true_jump->payload.jump, shd_get_abstraction_mem(true_case), make_selection_merge_case(a))); + + Node* false_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(false_case, handle_bb_callsite(ctx, payload.false_jump->payload.jump, shd_get_abstraction_mem(false_case), make_selection_merge_case(a))); + + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + shd_bld_if(bb, shd_empty(a), condition, true_case, false_case); + return shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), exit, shd_empty(a))); + } + case Switch_TAG: { + Switch payload = body->payload.br_switch; + const Node* switch_value = shd_rewrite_node(r, payload.switch_value); + shd_rewrite_node(r, payload.mem); + + Node* default_case = case_(a, shd_empty(a)); + shd_set_abstraction_body(default_case, handle_bb_callsite(ctx, payload.default_jump->payload.jump, shd_get_abstraction_mem(default_case), make_selection_merge_case(a))); + + LARRAY(Node*, cases, body->payload.br_switch.case_jumps.count); + for (size_t i = 0; i < body->payload.br_switch.case_jumps.count; i++) { + cases[i] = case_(a, shd_empty(a)); + shd_set_abstraction_body(cases[i], handle_bb_callsite(ctx, payload.case_jumps.nodes[i]->payload.jump, shd_get_abstraction_mem(cases[i]), make_selection_merge_case(a))); + } + + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + shd_bld_match(bb, shd_empty(a), switch_value, shd_rewrite_nodes(&ctx->rewriter, body->payload.br_switch.case_values), shd_nodes(a, body->payload.br_switch.case_jumps.count, (const Node**) cases), default_case); + return shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), exit, shd_empty(a))); + } + // let(control(body), tail) + // var phi = undef; level = N+1; structurize[body, if (level == N+1, _ => tail(load(phi))); structured_exit_terminator] + case Control_TAG: { + Control payload = body->payload.control; + const Node* old_control_case = payload.inside; + Nodes old_control_params = get_abstraction_params(old_control_case); + assert(old_control_params.count == 1); + + // Create N temporary variables to hold the join point arguments + BodyBuilder* bb_prelude = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + Nodes yield_types = shd_rewrite_nodes(&ctx->rewriter, body->payload.control.yield_types); + LARRAY(const Node*, phis, yield_types.count); + for (size_t i = 0; i < yield_types.count; i++) { + const Type* type = yield_types.nodes[i]; + assert(shd_is_data_type(type)); + phis[i] = shd_bld_local_alloc(bb_prelude, type); + } + + // Create a new context to rewrite the body with + // TODO: Bail if we try to re-enter the same control construct + Context control_ctx = *ctx; + ControlEntry control_entry = { + .parent = ctx->control_stack, + .old_token = shd_first(old_control_params), + .phis = phis, + .depth = ctx->control_stack ? ctx->control_stack->depth + 1 : 1, + }; + control_ctx.control_stack = &control_entry; + + // Set the depth for threads entering the control body + shd_bld_store(bb_prelude, ctx->level_ptr, shd_int32_literal(a, control_entry.depth)); + + // Start building out the tail, first it needs to dereference the phi variables to recover the arguments given to join() + Node* tail = case_(a, shd_empty(a)); + BodyBuilder* bb_tail = shd_bld_begin(a, shd_get_abstraction_mem(tail)); + LARRAY(const Node*, phi_values, yield_types.count); + for (size_t i = 0; i < yield_types.count; i++) { + phi_values[i] = shd_bld_load(bb_tail, phis[i]); + shd_register_processed(&ctx->rewriter, get_abstraction_params(get_structured_construct_tail(body)).nodes[i], phi_values[i]); + } + + // Wrap the tail in a guarded if, to handle 'far' joins + const Node* level_value = shd_bld_load(bb_tail, ctx->level_ptr); + const Node* guard = prim_op(a, (PrimOp) { .op = eq_op, .operands = mk_nodes(a, level_value, shd_int32_literal(a, ctx->control_stack ? ctx->control_stack->depth : 0)) }); + Node* true_case = case_(a, shd_empty(a)); + shd_register_processed(r, shd_get_abstraction_mem(get_structured_construct_tail(body)), shd_get_abstraction_mem(true_case)); + shd_set_abstraction_body(true_case, structure(ctx, get_abstraction_body(get_structured_construct_tail(body)), make_selection_merge_case(a))); + shd_bld_if(bb_tail, shd_empty(a), guard, true_case, NULL); + shd_set_abstraction_body(tail, shd_bld_finish(bb_tail, jump_helper(a, shd_bb_mem(bb_tail), exit, shd_empty(a)))); + + shd_register_processed(r, shd_get_abstraction_mem(old_control_case), shd_bb_mem(bb_prelude)); + return shd_bld_finish(bb_prelude, structure(&control_ctx, get_abstraction_body(old_control_case), tail)); + } + case Join_TAG: { + Join payload = body->payload.join; + ControlEntry* control = search_containing_control(ctx, body->payload.join.join_point); + if (!control) + longjmp(ctx->bail, 1); + + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); + shd_bld_store(bb, ctx->level_ptr, shd_int32_literal(a, control->depth - 1)); + + Nodes args = shd_rewrite_nodes(&ctx->rewriter, body->payload.join.args); + for (size_t i = 0; i < args.count; i++) { + shd_bld_store(bb, control->phis[i], args.nodes[i]); + } + + return shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), exit, shd_empty(a))); + } + + case Return_TAG: + case Unreachable_TAG: return shd_recreate_node(&ctx->rewriter, body); + + case TailCall_TAG: longjmp(ctx->bail, 1); + + case If_TAG: + case Match_TAG: + case Loop_TAG: shd_error("not supposed to exist in IR at this stage"); + case Terminator_MergeBreak_TAG: + case Terminator_MergeContinue_TAG: + case Terminator_MergeSelection_TAG: shd_error("Only control nodes are tolerated here.") + } +} + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + assert(a != node->arena); + assert(node->arena == ctx->rewriter.src_arena); + + if (is_declaration(node)) { + String name = get_declaration_name(node); + Nodes decls = shd_module_get_declarations(ctx->rewriter.dst_module); + for (size_t i = 0; i < decls.count; i++) { + if (strcmp(get_declaration_name(decls.nodes[i]), name) == 0) + return decls.nodes[i]; + } + } + + if (node->tag == Function_TAG) { + Node* new = shd_recreate_node_head(&ctx->rewriter, node); + + size_t alloc_stack_size_now = shd_list_count(ctx->cleanup_stack); + + Context ctx2 = *ctx; + ctx2.dfs_stack = NULL; + ctx2.control_stack = NULL; + bool is_builtin = shd_lookup_annotation(node, "Builtin"); + bool is_leaf = false; + if (is_builtin || !node->payload.fun.body || shd_lookup_annotation(node, "Structured") || setjmp(ctx2.bail)) { + ctx2.lower = false; + ctx2.rewriter.map = ctx->rewriter.map; + if (node->payload.fun.body) + shd_set_abstraction_body(new, shd_rewrite_node(&ctx2.rewriter, node->payload.fun.body)); + // builtin functions are always considered leaf functions + is_leaf = is_builtin || !node->payload.fun.body; + } else { + ctx2.lower = true; + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new)); + TmpAllocCleanupClosure cj1 = create_cancel_body_closure(bb); + shd_list_append(TmpAllocCleanupClosure, ctx->cleanup_stack, cj1); + const Node* ptr = shd_bld_local_alloc(bb, shd_int32_type(a)); + shd_set_value_name(ptr, "cf_depth"); + shd_bld_store(bb, ptr, shd_int32_literal(a, 0)); + ctx2.level_ptr = ptr; + ctx2.fn = new; + struct Dict* tmp_processed = shd_clone_dict(ctx->rewriter.map); + TmpAllocCleanupClosure cj2 = create_delete_dict_closure(tmp_processed); + shd_list_append(TmpAllocCleanupClosure, ctx->cleanup_stack, cj2); + ctx2.rewriter.map = tmp_processed; + shd_register_processed(&ctx2.rewriter, shd_get_abstraction_mem(node), shd_bb_mem(bb)); + shd_set_abstraction_body(new, shd_bld_finish(bb, structure(&ctx2, get_abstraction_body(node), make_unreachable_case(a)))); + is_leaf = true; + // We made it! Pop off the pending cleanup stuff and do it ourselves. + shd_list_pop_impl(ctx->cleanup_stack); + shd_list_pop_impl(ctx->cleanup_stack); + shd_destroy_dict(tmp_processed); + } + + //if (is_leaf) + // new->payload.fun.annotations = append_nodes(arena, new->payload.fun.annotations, annotation(arena, (Annotation) { .name = "Leaf" })); + + // if we did a longjmp, we might have orphaned a few of those + while (alloc_stack_size_now < shd_list_count(ctx->cleanup_stack)) { + TmpAllocCleanupClosure cj = shd_list_pop(TmpAllocCleanupClosure, ctx->cleanup_stack); + cj.fn(cj.payload); + } + + new->payload.fun.annotations = shd_filter_out_annotation(a, new->payload.fun.annotations, "MaybeLeaf"); + + return new; + } + + if (!ctx->lower) + return shd_recreate_node(&ctx->rewriter, node); + + // These should all be manually visited by 'structure' + // assert(!is_terminator(node) && !is_instruction(node)); + + switch (node->tag) { + case Instruction_Call_TAG: { + const Node* callee = node->payload.call.callee; + if (callee->tag == FnAddr_TAG) { + const Node* fn = shd_rewrite_node(&ctx->rewriter, callee->payload.fn_addr.fn); + // leave leaf calls alone + if (shd_lookup_annotation(fn, "Leaf")) { + break; + } + } + // if we don't manage that, give up :( + assert(false); // actually that should not come up. + longjmp(ctx->bail, 1); + } + case BasicBlock_TAG: shd_error("All basic blocks should be processed explicitly") + default: break; + } + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* shd_pass_restructurize(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .cleanup_stack = shd_new_list(TmpAllocCleanupClosure), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + shd_destroy_list(ctx.cleanup_stack); + return dst; +} diff --git a/src/shady/passes/scope2control.c b/src/shady/passes/scope2control.c new file mode 100644 index 000000000..b6b9a010c --- /dev/null +++ b/src/shady/passes/scope2control.c @@ -0,0 +1,335 @@ +#include "shady/pass.h" + +#include "shady/rewrite.h" + +#include "../ir_private.h" +#include "../analysis/cfg.h" +#include "../analysis/scheduler.h" + +#include "portability.h" +#include "dict.h" +#include "list.h" +#include "log.h" +#include "arena.h" +#include "util.h" + +#include + +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +typedef struct { + Rewriter rewriter; + const CompilerConfig* config; + Arena* arena; + struct Dict* controls; + struct Dict* jump2wrapper; +} Context; + +typedef struct { + const Node* wrapper; +} Wrapped; + +typedef struct { + Node* wrapper; + const Node* token; + const Node* destination; +} AddControl; + +typedef struct { + struct Dict* control_destinations; +} Controls; + +static Nodes remake_params(Context* ctx, Nodes old) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + LARRAY(const Node*, nvars, old.count); + for (size_t i = 0; i < old.count; i++) { + const Node* node = old.nodes[i]; + const Type* t = NULL; + if (node->payload.param.type) { + if (node->payload.param.type->tag == QualifiedType_TAG) + t = shd_rewrite_node(r, node->payload.param.type); + else + t = shd_as_qualified_type(shd_rewrite_node(r, node->payload.param.type), false); + } + nvars[i] = param(a, t, node->payload.param.name); + assert(nvars[i]->tag == Param_TAG); + } + return shd_nodes(a, old.count, nvars); +} + +static Controls* get_or_create_controls(Context* ctx, const Node* fn_or_bb) { + Controls** found = shd_dict_find_value(const Node, Controls*, ctx->controls, fn_or_bb); + if (found) + return *found; + IrArena* a = ctx->rewriter.dst_arena; + Controls* controls = shd_arena_alloc(ctx->arena, sizeof(Controls)); + *controls = (Controls) { + .control_destinations = shd_new_dict(const Node*, AddControl, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + shd_dict_insert(const Node*, Controls*, ctx->controls, fn_or_bb, controls); + return controls; +} + +static void wrap_in_controls(Context* ctx, CFG* cfg, Node* nabs, const Node* oabs) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + const Node* obody = get_abstraction_body(oabs); + if (!obody) + return; + + CFNode* n = shd_cfg_lookup(cfg, oabs); + size_t num_dom = shd_list_count(n->dominates); + LARRAY(Node*, nbbs, num_dom); + for (size_t i = 0; i < num_dom; i++) { + CFNode* dominated = shd_read_list(CFNode*, n->dominates)[i]; + const Node* obb = dominated->node; + assert(obb->tag == BasicBlock_TAG); + Nodes nparams = remake_params(ctx, get_abstraction_params(obb)); + shd_register_processed_list(r, get_abstraction_params(obb), nparams); + nbbs[i] = basic_block(a, nparams, shd_get_abstraction_name_unsafe(obb)); + shd_register_processed(r, obb, nbbs[i]); + } + + // We introduce a dummy case now because we don't know yet whether the body of the abstraction will be wrapped + Node* c = case_(a, shd_empty(a)); + Node* oc = c; + shd_register_processed(r, shd_get_abstraction_mem(oabs), shd_get_abstraction_mem(c)); + + for (size_t k = 0; k < num_dom; k++) { + CFNode* dominated = shd_read_list(CFNode*, n->dominates)[k]; + const Node* obb = dominated->node; + wrap_in_controls(ctx, cfg, nbbs[k], obb); + } + + Controls* controls = get_or_create_controls(ctx, oabs); + + shd_set_abstraction_body(oc, shd_rewrite_node(r, obody)); + + size_t i = 0; + AddControl add_control; + while(shd_dict_iter(controls->control_destinations, &i, NULL, &add_control)) { + const Node* dst = add_control.destination; + Node* control_case = case_(a, shd_singleton(add_control.token)); + shd_set_abstraction_body(control_case, jump_helper(a, shd_get_abstraction_mem(control_case), c, shd_empty(a))); + + Node* c2 = case_(a, shd_empty(a)); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(c2)); + const Type* jp_type = add_control.token->type; + shd_deconstruct_qualified_type(&jp_type); + assert(jp_type->tag == JoinPointType_TAG); + Nodes results = shd_bld_control(bb, jp_type->payload.join_point_type.yield_types, control_case); + + Nodes original_params = get_abstraction_params(dst); + for (size_t j = 0; j < results.count; j++) { + if (shd_is_qualified_type_uniform(original_params.nodes[j]->type)) + results = shd_change_node_at_index(a, results, j, prim_op_helper(a, subgroup_assume_uniform_op, shd_empty(a), shd_singleton(results.nodes[j]))); + } + + c = c2; + shd_set_abstraction_body(c2, shd_bld_finish(bb, jump_helper(a, shd_bb_mem(bb), shd_find_processed(r, dst), results))); + } + + const Node* body = jump_helper(a, shd_get_abstraction_mem(nabs), c, shd_empty(a)); + shd_set_abstraction_body(nabs, body); +} + +static bool lexical_scope_is_nested(Nodes scope, Nodes parentMaybe) { + if (scope.count <= parentMaybe.count) + return false; + for (size_t i = 0; i < parentMaybe.count; i++) { + if (scope.nodes[i] != parentMaybe.nodes[i]) + return false; + } + return true; +} + +static const Nodes* find_scope_info(const Node* abs) { + assert(is_abstraction(abs)); + const Node* terminator = get_abstraction_body(abs); + const Node* mem = get_terminator_mem(terminator); + Nodes* info = NULL; + while (mem) { + if (mem->tag == ExtInstr_TAG && strcmp(mem->payload.ext_instr.set, "shady.scope") == 0) { + if (!info || info->count > mem->payload.ext_instr.operands.count) + info = &mem->payload.ext_instr.operands; + } + mem = shd_get_parent_mem(mem); + } + return info; +} + +bool shd_compare_nodes(Nodes* a, Nodes* b); + +static void process_edge(Context* ctx, CFG* cfg, Scheduler* scheduler, CFEdge edge) { + assert(edge.type == JumpEdge && edge.jump); + const Node* src = edge.src->node; + const Node* dst = edge.dst->node; + + Rewriter* r = &ctx->rewriter; + IrArena* a = ctx->rewriter.dst_arena; + // if (!ctx->config->hacks.recover_structure) + // break; + const Nodes* src_lexical_scope = find_scope_info(src); + const Nodes* dst_lexical_scope = find_scope_info(dst); + if (!src_lexical_scope) { + shd_warn_print("Failed to find jump source node "); + shd_log_node(WARN, src); + shd_warn_print(" in lexical_scopes map. Is debug information enabled ?\n"); + } else if (!dst_lexical_scope) { + shd_warn_print("Failed to find jump target node "); + shd_log_node(WARN, dst); + shd_warn_print(" in lexical_scopes map. Is debug information enabled ?\n"); + } else if (lexical_scope_is_nested(*src_lexical_scope, *dst_lexical_scope)) { + shd_debug_print("Jump from %s to %s exits one or more nested lexical scopes, it might reconverge.\n", shd_get_abstraction_name_safe(src), shd_get_abstraction_name_safe(dst)); + + CFNode* src_cfnode = shd_cfg_lookup(cfg, src); + assert(src_cfnode->node); + CFNode* dst_cfnode = shd_cfg_lookup(cfg, dst); + assert(src_cfnode && dst_cfnode); + + // if(!cfg_is_dominated(dst_cfnode, src_cfnode)) + // return; + + CFNode* dom = src_cfnode->idom; + while (dom) { + shd_debug_print("Considering %s as a location for control\n", shd_get_abstraction_name_safe(dom->node)); + Nodes* dom_lexical_scope = find_scope_info(dom->node); + if (!dom_lexical_scope) { + shd_warn_print("Basic block %s did not have an entry in the lexical_scopes map. Is debug information enabled ?\n", shd_get_abstraction_name_safe(dom->node)); + dom = dom->idom; + continue; + } else if (lexical_scope_is_nested(*dst_lexical_scope, *dom_lexical_scope)) { + shd_error_print("We went up too far: %s is a parent of the jump destination scope.\n", shd_get_abstraction_name_safe(dom->node)); + } else if (shd_compare_nodes(dom_lexical_scope, dst_lexical_scope)) { + // if (cfg_is_dominated(target_cfnode, dom)) { + if (!shd_cfg_is_dominated(dom, dst_cfnode) && dst_cfnode != dom) { + // assert(false); + } + + shd_debug_print("We need to introduce a control block at %s, pointing at %s\n.", shd_get_abstraction_name_safe(dom->node), shd_get_abstraction_name_safe(dst)); + + Controls* controls = get_or_create_controls(ctx, dom->node); + AddControl* found = shd_dict_find_value(const Node, AddControl, controls->control_destinations, dst); + Wrapped wrapped; + if (found) { + wrapped.wrapper = found->wrapper; + } else { + Nodes wrapper_params = remake_params(ctx, get_abstraction_params(dst)); + Nodes join_args = wrapper_params; + Nodes yield_types = shd_rewrite_nodes(r, shd_strip_qualifiers(a, shd_get_param_types(a, get_abstraction_params(dst)))); + + const Type* jp_type = join_point_type(a, (JoinPointType) { + .yield_types = yield_types + }); + const Node* join_token = param(a, shd_as_qualified_type(jp_type, false), shd_get_abstraction_name_unsafe(dst)); + + Node* wrapper = basic_block(a, wrapper_params, shd_format_string_arena(a->arena, "wrapper_to_%s", shd_get_abstraction_name_safe(dst))); + wrapper->payload.basic_block.body = join(a, (Join) { + .args = join_args, + .join_point = join_token, + .mem = shd_get_abstraction_mem(wrapper), + }); + + AddControl add_control = { + .destination = dst, + .token = join_token, + .wrapper = wrapper, + }; + wrapped.wrapper = wrapper; + shd_dict_insert(const Node*, AddControl, controls->control_destinations, dst, add_control); + } + + shd_dict_insert(const Node*, Wrapped, ctx->jump2wrapper, edge.jump, wrapped); + // return jump_helper(a, wrapper, rewrite_nodes(&ctx->rewriter, node->payload.jump.args), rewrite_node(r, node->payload.jump.mem)); + } else { + dom = dom->idom; + continue; + } + break; + } + } +} + +static void prepare_function(Context* ctx, CFG* cfg, const Node* old_fn) { + Scheduler* scheduler = shd_new_scheduler(cfg); + for (size_t i = 0; i < cfg->size; i++) { + CFNode* n = cfg->rpo[i]; + for (size_t j = 0; j < shd_list_count(n->succ_edges); j++) { + process_edge(ctx, cfg, scheduler, shd_read_list(CFEdge, n->succ_edges)[j]); + } + } + shd_destroy_scheduler(scheduler); +} + +static const Node* process_node(Context* ctx, const Node* node) { + IrArena* a = ctx->rewriter.dst_arena; + Rewriter* r = &ctx->rewriter; + switch (node->tag) { + case Function_TAG: { + CFG* cfg = build_fn_cfg(node); + prepare_function(ctx, cfg, node); + Node* decl = shd_recreate_node_head(r, node); + wrap_in_controls(ctx, cfg, decl, node); + shd_destroy_cfg(cfg); + return decl; + } + case BasicBlock_TAG: { + assert(false); + } + // Eliminate now-useless scope instructions + case ExtInstr_TAG: { + if (strcmp(node->payload.ext_instr.set, "shady.scope") == 0) { + return shd_rewrite_node(r, node->payload.ext_instr.mem); + } + break; + } + case Jump_TAG: { + Wrapped* found = shd_dict_find_value(const Node*, Wrapped, ctx->jump2wrapper, node); + if (found) + return jump_helper(a, shd_rewrite_node(r, node->payload.jump.mem), found->wrapper, + shd_rewrite_nodes(r, node->payload.jump.args)); + break; + } + default: break; + } + + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* shd_pass_scope2control(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + aconfig.optimisations.inline_single_use_bbs = true; + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process_node), + .config = config, + .arena = shd_new_arena(), + .controls = shd_new_dict(const Node*, Controls*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .jump2wrapper = shd_new_dict(const Node*, Wrapped, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + }; + + ctx.rewriter.rewrite_fn = (RewriteNodeFn) process_node; + + shd_rewrite_module(&ctx.rewriter); + + size_t i = 0; + Controls* controls; + while (shd_dict_iter(ctx.controls, &i, NULL, &controls)) { + //size_t j = 0; + //AddControl add_control; + // while (dict_iter(controls.control_destinations, &j, NULL, &add_control)) { + // destroy_list(add_control.lift); + // } + shd_destroy_dict(controls->control_destinations); + } + + shd_destroy_dict(ctx.controls); + shd_destroy_dict(ctx.jump2wrapper); + shd_destroy_arena(ctx.arena); + shd_destroy_rewriter(&ctx.rewriter); + + return dst; +} diff --git a/src/shady/passes/scope_heuristic.c b/src/shady/passes/scope_heuristic.c new file mode 100644 index 000000000..e833e1abf --- /dev/null +++ b/src/shady/passes/scope_heuristic.c @@ -0,0 +1,211 @@ +#include "shady/pass.h" +#include "shady/ir/ext.h" + +#include "../ir_private.h" +#include "../analysis/cfg.h" +#include "../analysis/looptree.h" + +#include "log.h" +#include "list.h" +#include "portability.h" + +typedef struct { + Rewriter rewriter; + CFG* cfg; + Nodes* depth_per_rpo; +} Context; + +static Nodes to_ids(IrArena* a, Nodes in) { + LARRAY(const Node*, arr, in.count); + for (size_t i = 0; i < in.count; i++) + arr[i] = shd_uint32_literal(a, in.nodes[i]->id); + return shd_nodes(a, in.count, arr); +} + +static void visit_looptree_prepend(IrArena* a, Nodes* arr, LTNode* node, Nodes prefix) { + if (node->type == LF_HEAD) { + for (size_t i = 0; i < shd_list_count(node->lf_children); i++) { + LTNode* n = shd_read_list(LTNode*, node->lf_children)[i]; + visit_looptree_prepend(a, arr, n, prefix); + } + } else { + for (size_t i = 0; i < shd_list_count(node->cf_nodes); i++) { + CFNode* n = shd_read_list(CFNode*, node->cf_nodes)[i]; + arr[n->rpo_index] = shd_concat_nodes(a, prefix, arr[n->rpo_index]); + } + assert(node->lf_children); + } +} + +static bool is_nested(LTNode* a, LTNode* in) { + assert(a->type == LF_HEAD && in->type == LF_HEAD); + while (a) { + if (a == in) + return true; + a = a->parent; + } + return false; +} + +static void paint_dominated_up_to_postdom(CFNode* n, IrArena* a, Nodes* arr, const Node* postdom, const Node* prefix) { + if (n->node == postdom) + return; + + for (size_t i = 0; i < shd_list_count(n->dominates); i++) { + CFNode* dominated = shd_read_list(CFNode*, n->dominates)[i]; + paint_dominated_up_to_postdom(dominated, a, arr, postdom, prefix); + } + + arr[n->rpo_index] = shd_nodes_prepend(a, arr[n->rpo_index], prefix); +} + +static void visit_acyclic_cfg_domtree(CFNode* n, IrArena* a, Nodes* arr, CFG* flipped, LTNode* loop, LoopTree* lt) { + LTNode* ltn = shd_loop_tree_lookup(lt, n->node); + if (ltn->parent != loop) + return; + + for (size_t i = 0; i < shd_list_count(n->dominates); i++) { + CFNode* dominated = shd_read_list(CFNode*, n->dominates)[i]; + visit_acyclic_cfg_domtree(dominated, a, arr, flipped, loop, lt); + } + + CFNode* src = n; + + if (shd_list_count(src->succ_edges) < 2) + return; // no divergence, no bother + + CFNode* f_src = shd_cfg_lookup(flipped, src->node); + CFNode* f_src_ipostdom = f_src->idom; + if (!f_src_ipostdom) + return; + + // your post-dominator can't be yourself... can it ? + assert(f_src_ipostdom->node != src->node); + + LTNode* src_lt = shd_loop_tree_lookup(lt, src->node); + LTNode* pst_lt = shd_loop_tree_lookup(lt, f_src_ipostdom->node); + assert(src_lt->type == LF_LEAF && pst_lt->type == LF_LEAF); + if (src_lt->parent == pst_lt->parent) { + shd_log_fmt(DEBUGVV, "We have a candidate for reconvergence: a branch starts at %d and ends at %d\n", src->node->id, f_src_ipostdom->node->id); + paint_dominated_up_to_postdom(n, a, arr, f_src_ipostdom->node, n->node); + } +} + +static void visit_looptree(IrArena* a, Nodes* arr, const Node* fn, CFG* flipped, LoopTree* lt, LTNode* node) { + if (node->type == LF_HEAD) { + Nodes surrounding = shd_empty(a); + bool is_loop = false; + for (size_t i = 0; i < shd_list_count(node->cf_nodes); i++) { + CFNode* n = shd_read_list(CFNode*, node->cf_nodes)[i]; + surrounding = shd_nodes_append(a, surrounding, n->node); + is_loop = true; + } + + for (size_t i = 0; i < shd_list_count(node->lf_children); i++) { + LTNode* n = shd_read_list(LTNode*, node->lf_children)[i]; + visit_looptree(a, arr, fn, flipped, lt, n); + } + + assert(shd_list_count(node->cf_nodes) < 2); + CFG* sub_cfg = shd_new_cfg(fn, is_loop ? shd_read_list(CFNode*, node->cf_nodes)[0]->node : fn, (CFGBuildConfig) { + .include_structured_tails = true, + .lt = lt + }); + + visit_acyclic_cfg_domtree(sub_cfg->entry, a, arr, flipped, node, lt); + + if (is_loop > 0) + surrounding = shd_nodes_prepend(a, surrounding, string_lit_helper(a, shd_make_unique_name(a, "loop_body"))); + + visit_looptree_prepend(a, arr, node, surrounding); + // Remove one level of scoping for the loop headers (forcing reconvergence) + for (size_t i = 0; i < shd_list_count(node->cf_nodes); i++) { + CFNode* n = shd_read_list(CFNode*, node->cf_nodes)[i]; + Nodes old = arr[n->rpo_index]; + assert(old.count > 1); + arr[n->rpo_index] = shd_nodes(a, old.count - 1, &old.nodes[0]); + } + + shd_destroy_cfg(sub_cfg); + } +} + +static bool loop_depth(LTNode* a) { + int i = 0; + while (a) { + if (shd_list_count(a->cf_nodes) > 0) + i++; + else { + assert(!a->parent); + } + a = a->parent; + } + return i; +} + +static Nodes* compute_scope_depth(IrArena* a, CFG* cfg) { + CFG* flipped = build_fn_cfg_flipped(cfg->entry->node); + LoopTree* lt = shd_new_loop_tree(cfg); + + Nodes* arr = calloc(sizeof(Nodes), cfg->size); + for (size_t i = 0; i < cfg->size; i++) + arr[i] = shd_empty(a); + + visit_looptree(a, arr, cfg->entry->node, flipped, lt, lt->root); + + // we don't want to cause problems by holding onto pointless references... + for (size_t i = 0; i < cfg->size; i++) + arr[i] = to_ids(a, arr[i]); + + shd_destroy_loop_tree(lt); + shd_destroy_cfg(flipped); + + return arr; +} + +static const Node* process(Context* ctx, const Node* node) { + Rewriter* r = &ctx->rewriter; + IrArena* a = r->dst_arena; + switch (node->tag) { + case Function_TAG: { + Context fn_ctx = *ctx; + fn_ctx.cfg = build_fn_cfg(node); + fn_ctx.depth_per_rpo = compute_scope_depth(a, fn_ctx.cfg); + Node* new_fn = shd_recreate_node_head(r, node); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new_fn)); + shd_bld_ext_instruction(bb, "shady.scope", 0, unit_type(a), shd_empty(a)); + shd_register_processed(r, shd_get_abstraction_mem(node), shd_bb_mem(bb)); + shd_set_abstraction_body(new_fn, shd_bld_finish(bb, shd_rewrite_node(&fn_ctx.rewriter, get_abstraction_body(node)))); + shd_destroy_cfg(fn_ctx.cfg); + free(fn_ctx.depth_per_rpo); + return new_fn; + } + case BasicBlock_TAG: { + Nodes nparams = shd_recreate_params(r, get_abstraction_params(node)); + shd_register_processed_list(r, get_abstraction_params(node), nparams); + Node* new_bb = basic_block(a, nparams, shd_get_abstraction_name_unsafe(node)); + shd_register_processed(r, node, new_bb); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(new_bb)); + CFNode* n = shd_cfg_lookup(ctx->cfg, node); + shd_bld_ext_instruction(bb, "shady.scope", 0, unit_type(a), ctx->depth_per_rpo[n->rpo_index]); + shd_register_processed(r, shd_get_abstraction_mem(node), shd_bb_mem(bb)); + shd_set_abstraction_body(new_bb, shd_bld_finish(bb, shd_rewrite_node(r, get_abstraction_body(node)))); + return new_bb; + } + default: break; + } + + return shd_recreate_node(&ctx->rewriter, node); +} + +Module* shd_pass_scope_heuristic(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + Context ctx = { + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + }; + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); + return dst; +} diff --git a/src/shady/passes/setup_stack_frames.c b/src/shady/passes/setup_stack_frames.c index a8363f92f..b2dae8914 100644 --- a/src/shady/passes/setup_stack_frames.c +++ b/src/shady/passes/setup_stack_frames.c @@ -1,15 +1,13 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/visit.h" +#include "shady/ir/stack.h" +#include "shady/ir/annotation.h" +#include "shady/ir/function.h" +#include "shady/ir/debug.h" #include "log.h" #include "portability.h" #include "list.h" -#include "util.h" - -#include "../rewrite.h" -#include "../visit.h" -#include "../type.h" -#include "../ir_private.h" -#include "../transform/ir_gen_helpers.h" #include @@ -18,122 +16,57 @@ typedef struct Context_ { bool disable_lowering; const CompilerConfig* config; - const Node* entry_base_stack_ptr; - const Node* entry_stack_offset; + const Node* stack_size_on_entry; } Context; -typedef struct { - Visitor visitor; - Context* context; - BodyBuilder* bb; - Node* nom_t; - struct List* members; -} VContext; - -static void search_operand_for_alloca(VContext* vctx, const Node* node) { - IrArena* a = vctx->context->rewriter.dst_arena; - AddressSpace as; - - if (node->tag == PrimOp_TAG) { - switch (node->payload.prim_op.op) { - case alloca_op: as = AsPrivatePhysical; break; - case alloca_subgroup_op: as = AsSubgroupPhysical; break; - default: goto not_alloca; - } - - const Type* element_type = rewrite_node(&vctx->context->rewriter, node->payload.prim_op.type_arguments.nodes[0]); - assert(is_data_type(element_type)); - const Node* slot_offset = gen_primop_e(vctx->bb, offset_of_op, singleton(type_decl_ref_helper(a, vctx->nom_t)), singleton(int32_literal(a, entries_count_list(vctx->members)))); - append_list(const Type*, vctx->members, element_type); - - const Node* slot = first(bind_instruction_named(vctx->bb, prim_op(a, (PrimOp) { - .op = lea_op, - .operands = mk_nodes(a, vctx->context->entry_base_stack_ptr, slot_offset) }), (String []) {format_string_arena(a->arena, "stack_slot_%d", entries_count_list(vctx->members)) })); - - const Node* ptr_t = ptr_type(a, (PtrType) { .pointed_type = element_type, .address_space = as }); - slot = gen_reinterpret_cast(vctx->bb, ptr_t, slot); - - register_processed(&vctx->context->rewriter, node, quote_helper(a, singleton(slot))); - return; - } - - not_alloca: - visit_node_operands(&vctx->visitor, IGNORE_ABSTRACTIONS_MASK, node); -} - static const Node* process(Context* ctx, const Node* node) { - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - IrArena* a = ctx->rewriter.dst_arena; - Module* m = ctx->rewriter.dst_module; + Rewriter* r = &ctx->rewriter; switch (node->tag) { case Function_TAG: { - Node* fun = recreate_decl_header_identity(&ctx->rewriter, node); + Node* fun = shd_recreate_node_head(r, node); Context ctx2 = *ctx; - ctx2.disable_lowering = lookup_annotation_with_string_payload(node, "DisablePass", "setup_stack_frames"); + ctx2.disable_lowering = shd_lookup_annotation_with_string_payload(node, "DisablePass", "setup_stack_frames") || ctx->config->per_thread_stack_size == 0; - BodyBuilder* bb = begin_body(a); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); if (!ctx2.disable_lowering) { - Node* nom_t = nominal_type(m, empty(a), format_string_arena(a->arena, "%s_stack_frame", get_abstraction_name(node))); - ctx2.entry_stack_offset = first(bind_instruction_named(bb, prim_op(a, (PrimOp) { .op = get_stack_pointer_op } ), (String []) {format_string_arena(a->arena, "saved_stack_ptr_entering_%s", get_abstraction_name(fun)) })); - ctx2.entry_base_stack_ptr = gen_primop_ce(bb, get_stack_base_op, 0, NULL); - VContext vctx = { - .visitor = { - .visit_node_fn = (VisitNodeFn) search_operand_for_alloca, - }, - .context = &ctx2, - .bb = bb, - .nom_t = nom_t, - .members = new_list(const Node*), - }; - if (node->payload.fun.body) { - search_operand_for_alloca(&vctx, node->payload.fun.body); - visit_function_rpo(&vctx.visitor, node); - } - vctx.nom_t->payload.nom_type.body = record_type(a, (RecordType) { - .members = nodes(a, entries_count_list(vctx.members), read_list(const Node*, vctx.members)), - .names = strings(a, 0, NULL), - .special = 0 - }); - destroy_list(vctx.members); - - const Node* frame_size = gen_primop_e(bb, size_of_op, singleton(type_decl_ref_helper(a, nom_t)), empty(a)); - frame_size = convert_int_extend_according_to_src_t(bb, get_unqualified_type(ctx2.entry_stack_offset->type), frame_size); - const Node* updated_stack_ptr = gen_primop_e(bb, add_op, empty(a), mk_nodes(a, ctx2.entry_stack_offset, frame_size)); - gen_primop(bb, set_stack_pointer_op, empty(a), singleton(updated_stack_ptr)); + ctx2.stack_size_on_entry = shd_bld_get_stack_size(bb); + shd_set_value_name((Node*) ctx2.stack_size_on_entry, shd_fmt_string_irarena(a, "saved_stack_ptr_entering_%s", shd_get_abstraction_name(fun))); } + shd_register_processed(&ctx2.rewriter, shd_get_abstraction_mem(node), shd_bb_mem(bb)); if (node->payload.fun.body) - fun->payload.fun.body = finish_body(bb, rewrite_node(&ctx2.rewriter, node->payload.fun.body)); + shd_set_abstraction_body(fun, shd_bld_finish(bb, shd_rewrite_node(&ctx2.rewriter, node->payload.fun.body))); else - cancel_body(bb); + shd_bld_cancel(bb); return fun; } case Return_TAG: { - BodyBuilder* bb = begin_body(a); + Return payload = node->payload.fn_ret; + BodyBuilder* bb = shd_bld_begin(a, shd_rewrite_node(r, payload.mem)); if (!ctx->disable_lowering) { - assert(ctx->entry_stack_offset); + assert(ctx->stack_size_on_entry); // Restore SP before calling exit - bind_instruction(bb, prim_op(a, (PrimOp) { - .op = set_stack_pointer_op, - .operands = nodes(a, 1, (const Node* []) {ctx->entry_stack_offset }) - })); + shd_bld_set_stack_size(bb, ctx->stack_size_on_entry); } - return finish_body(bb, recreate_node_identity(&ctx->rewriter, node)); + return shd_bld_finish(bb, fn_ret(a, (Return) { + .mem = shd_bb_mem(bb), + .args = shd_rewrite_nodes(r, payload.args), + })); } - default: return recreate_node_identity(&ctx->rewriter, node); + default: break; } + return shd_recreate_node(r, node); } -Module* setup_stack_frames(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); +Module* shd_pass_setup_stack_frames(SHADY_UNUSED const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/simt2d.c b/src/shady/passes/simt2d.c deleted file mode 100644 index bad9e7197..000000000 --- a/src/shady/passes/simt2d.c +++ /dev/null @@ -1,105 +0,0 @@ -#include "passes.h" - -#include "../type.h" -#include "../rewrite.h" - -#include "portability.h" -#include "log.h" - -typedef struct { - Rewriter rewriter; - size_t width; - const Node* mask; -} Context; - -static const Node* widen(Context* ctx, const Node* value) { - IrArena* a = ctx->rewriter.dst_arena; - LARRAY(const Node*, copies, ctx->width); - for (size_t j = 0; j < ctx->width; j++) - copies[j] = value; - const Type* type = pack_type(a, (PackType) { .width = ctx->width, .element_type = get_unqualified_type(value->type)}); - return composite_helper(a, type, nodes(a, ctx->width, copies)); -} - -static const Node* process(Context* ctx, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - IrArena* a = ctx->rewriter.dst_arena; - switch (node->tag) { - case QualifiedType_TAG: { - if (!node->payload.qualified_type.is_uniform) return qualified_type(a, (QualifiedType) { - .is_uniform = true, - .type = pack_type(a, (PackType) { .width = ctx->width, .element_type = rewrite_node(&ctx->rewriter, node->payload.qualified_type.type )}) - }); - goto rewrite; - } - case PrimOp_TAG: { - Op op = node->payload.prim_op.op; - switch (op) { - case quote_op: goto rewrite; - case alloca_logical_op: { - BodyBuilder* bb = begin_body(a); - const Node* type = rewrite_node(&ctx->rewriter, first(node->payload.prim_op.type_arguments)); - LARRAY(const Node*, allocated, ctx->width); - for (size_t i = 0; i < ctx->width; i++) { - allocated[i] = first(bind_instruction_named(bb, prim_op(a, (PrimOp) { - .op = op, - .type_arguments = singleton(type), - //.type_arguments = singleton(maybe_packed_type_helper(type, ctx->width)), - .operands = empty(a) - }), (String[]) {"allocated"})); - } - //return yield_values_and_wrap_in_control(bb, singleton(widen(ctx, allocated))); - const Node* result_type = maybe_packed_type_helper(ptr_type(a, (PtrType) { .address_space = AsFunctionLogical, .pointed_type = type }), ctx->width); - const Node* packed = composite_helper(a, result_type, nodes(a, ctx->width, allocated)); - return yield_values_and_wrap_in_block(bb, singleton(packed)); - } - default: break; - } - - bool was_uniform = true; - Nodes old_operands = node->payload.prim_op.operands; - for (size_t i = 0; i < old_operands.count; i++) - was_uniform &= is_qualified_type_uniform(old_operands.nodes[i]->type); - Nodes new_type_arguments = rewrite_nodes(&ctx->rewriter, node->payload.prim_op.type_arguments); - - LARRAY(const Node*, new_operands, old_operands.count); - // Nodes new_operands = rewrite_nodes(&ctx->rewriter, node->payload.prim_op.operands); - for (size_t i = 0; i < old_operands.count; i++) { - const Node* old_operand = old_operands.nodes[i]; - const Type* old_operand_type = old_operand->type; - bool op_was_uniform = deconstruct_qualified_type(&old_operand_type); - // assert(was_uniform || !op_was_uniform && "result was uniform implies=> operand was uniform"); - new_operands[i] = rewrite_node(&ctx->rewriter, old_operand); - if (op_was_uniform) - new_operands[i] = widen(ctx, new_operands[i]); - } - return prim_op(a, (PrimOp) { - .op = op, - .type_arguments = new_type_arguments, - .operands = nodes(a, old_operands.count, new_operands) - }); - } - rewrite: - default: return recreate_node_identity(&ctx->rewriter, node); - } -} - -Module* simt2d(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - aconfig.is_simt = false; - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .width = config->specialization.subgroup_size, - .mask = NULL, - }; - - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - return dst; -} diff --git a/src/shady/passes/specialize_entry_point.c b/src/shady/passes/specialize_entry_point.c index eabe889b6..ebf1ee11f 100644 --- a/src/shady/passes/specialize_entry_point.c +++ b/src/shady/passes/specialize_entry_point.c @@ -1,45 +1,39 @@ -#include "passes.h" +#include "shady/pass.h" +#include "shady/ir/builtin.h" + +#include "../ir_private.h" #include "portability.h" #include "log.h" -#include "../ir_private.h" -#include "../rewrite.h" -#include "../transform/ir_gen_helpers.h" - #include typedef struct { Rewriter rewriter; const Node* old_entry_point_decl; - const Node* old_wg_size_annotation; + const CompilerConfig* config; } Context; - static const Node* process(Context* ctx, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - IrArena* a = ctx->rewriter.dst_arena; switch (node->tag) { case PrimOp_TAG: { Builtin b; - if (is_builtin_load_op(node, &b) && b == BuiltinWorkgroupSize) { - const Type* t = pack_type(a, (PackType) { .element_type = uint32_type(a), .width = 3 }); + if (shd_is_builtin_load_op(node, &b) && b == BuiltinWorkgroupSize) { + const Type* t = pack_type(a, (PackType) { .element_type = shd_uint32_type(a), .width = 3 }); uint32_t wg_size[3]; wg_size[0] = a->config.specializations.workgroup_size[0]; wg_size[1] = a->config.specializations.workgroup_size[1]; wg_size[2] = a->config.specializations.workgroup_size[2]; - return quote_helper(a, singleton(composite_helper(a, t, mk_nodes(a, uint32_literal(a, wg_size[0]), uint32_literal(a, wg_size[1]), uint32_literal(a, wg_size[2]) )))); + return composite_helper(a, t, mk_nodes(a, shd_uint32_literal(a, wg_size[0]), shd_uint32_literal(a, wg_size[1]), shd_uint32_literal(a, wg_size[2]) )); } break; } case GlobalVariable_TAG: { - const Node* ba = lookup_annotation(node, "Builtin"); + const Node* ba = shd_lookup_annotation(node, "Builtin"); if (ba) { - Builtin b = get_builtin_by_name(get_annotation_string_payload(ba)); + Builtin b = shd_get_builtin_by_name(shd_get_annotation_string_payload(ba)); switch (b) { case BuiltinWorkgroupSize: return NULL; @@ -50,37 +44,33 @@ static const Node* process(Context* ctx, const Node* node) { break; } case Constant_TAG: { - Node* ncnst = (Node*) recreate_node_identity(&ctx->rewriter, node); - if (strcmp(get_decl_name(ncnst), "SUBGROUP_SIZE") == 0) { - ncnst->payload.constant.instruction = quote_helper(a, singleton(uint32_literal(a, a->config.specializations.subgroup_size))); - } else if (strcmp(get_decl_name(ncnst), "SUBGROUPS_PER_WG") == 0) { - if (ctx->old_wg_size_annotation) { - // SUBGROUPS_PER_WG = (NUMBER OF INVOCATIONS IN SUBGROUP / SUBGROUP SIZE) - // Note: this computations assumes only full subgroups are launched, if subgroups can launch partially filled then this relationship does not hold. - uint32_t wg_size[3]; - wg_size[0] = a->config.specializations.workgroup_size[0]; - wg_size[1] = a->config.specializations.workgroup_size[1]; - wg_size[2] = a->config.specializations.workgroup_size[2]; - uint32_t subgroups_per_wg = (wg_size[0] * wg_size[1] * wg_size[2]) / a->config.specializations.subgroup_size; - if (subgroups_per_wg == 0) - subgroups_per_wg = 1; // uh-oh - ncnst->payload.constant.instruction = quote_helper(a, singleton(uint32_literal(a, subgroups_per_wg))); - } + Node* ncnst = (Node*) shd_recreate_node(&ctx->rewriter, node); + if (strcmp(get_declaration_name(ncnst), "SUBGROUPS_PER_WG") == 0) { + // SUBGROUPS_PER_WG = (NUMBER OF INVOCATIONS IN SUBGROUP / SUBGROUP SIZE) + // Note: this computations assumes only full subgroups are launched, if subgroups can launch partially filled then this relationship does not hold. + uint32_t wg_size[3]; + wg_size[0] = a->config.specializations.workgroup_size[0]; + wg_size[1] = a->config.specializations.workgroup_size[1]; + wg_size[2] = a->config.specializations.workgroup_size[2]; + uint32_t subgroups_per_wg = (wg_size[0] * wg_size[1] * wg_size[2]) / ctx->config->specialization.subgroup_size; + if (subgroups_per_wg == 0) + subgroups_per_wg = 1; // uh-oh + ncnst->payload.constant.value = shd_uint32_literal(a, subgroups_per_wg); } return ncnst; } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } static const Node* find_entry_point(Module* m, const CompilerConfig* config) { if (!config->specialization.entry_point) return NULL; const Node* found = NULL; - Nodes old_decls = get_module_declarations(m); + Nodes old_decls = shd_module_get_declarations(m); for (size_t i = 0; i < old_decls.count; i++) { - if (strcmp(get_decl_name(old_decls.nodes[i]), config->specialization.entry_point) == 0) { + if (strcmp(get_declaration_name(old_decls.nodes[i]), config->specialization.entry_point) == 0) { assert(!found); found = old_decls.nodes[i]; } @@ -90,45 +80,51 @@ static const Node* find_entry_point(Module* m, const CompilerConfig* config) { } static void specialize_arena_config(const CompilerConfig* config, Module* src, ArenaConfig* target) { - size_t subgroup_size = config->specialization.subgroup_size; - assert(subgroup_size); - target->specializations.subgroup_size = subgroup_size; - const Node* old_entry_point_decl = find_entry_point(src, config); + if (!old_entry_point_decl) + shd_error("Entry point not found") if (old_entry_point_decl->tag != Function_TAG) - error("%s is not a function", config->specialization.entry_point); - const Node* ep = lookup_annotation(old_entry_point_decl, "EntryPoint"); + shd_error("%s is not a function", config->specialization.entry_point); + const Node* ep = shd_lookup_annotation(old_entry_point_decl, "EntryPoint"); if (!ep) - error("%s is not annotated with @EntryPoint", config->specialization.entry_point); - switch (execution_model_from_string(get_annotation_string_payload(ep))) { - case EmNone: error("Unknown entry point type: %s", get_annotation_string_payload(ep)) + shd_error("%s is not annotated with @EntryPoint", config->specialization.entry_point); + switch (shd_execution_model_from_string(shd_get_annotation_string_payload(ep))) { + case EmNone: shd_error("Unknown entry point type: %s", shd_get_annotation_string_payload(ep)) case EmCompute: { - const Node* old_wg_size_annotation = lookup_annotation(old_entry_point_decl, "WorkgroupSize"); - assert(old_wg_size_annotation && old_wg_size_annotation->tag == AnnotationValues_TAG && get_annotation_values(old_wg_size_annotation).count == 3); - Nodes wg_size_nodes = get_annotation_values(old_wg_size_annotation); - target->specializations.workgroup_size[0] = get_int_literal_value(*resolve_to_int_literal(wg_size_nodes.nodes[0]), false); - target->specializations.workgroup_size[1] = get_int_literal_value(*resolve_to_int_literal(wg_size_nodes.nodes[1]), false); - target->specializations.workgroup_size[2] = get_int_literal_value(*resolve_to_int_literal(wg_size_nodes.nodes[2]), false); - assert(target->specializations.workgroup_size[0] * target->specializations.workgroup_size[1] * target->specializations.workgroup_size[2]); + const Node* old_wg_size_annotation = shd_lookup_annotation(old_entry_point_decl, "WorkgroupSize"); + assert(old_wg_size_annotation && old_wg_size_annotation->tag == AnnotationValues_TAG && shd_get_annotation_values(old_wg_size_annotation).count == 3); + Nodes wg_size_nodes = shd_get_annotation_values(old_wg_size_annotation); + target->specializations.workgroup_size[0] = shd_get_int_literal_value(*shd_resolve_to_int_literal(wg_size_nodes.nodes[0]), false); + target->specializations.workgroup_size[1] = shd_get_int_literal_value(*shd_resolve_to_int_literal(wg_size_nodes.nodes[1]), false); + target->specializations.workgroup_size[2] = shd_get_int_literal_value(*shd_resolve_to_int_literal(wg_size_nodes.nodes[2]), false); + assert(target->specializations.workgroup_size[0] * target->specializations.workgroup_size[1] * target->specializations.workgroup_size[2] > 0); break; } default: break; } } -Module* specialize_entry_point(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); +Module* shd_pass_specialize_entry_point(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); specialize_arena_config(config, src, &aconfig); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), + .config = config, }; const Node* old_entry_point_decl = find_entry_point(src, config); - rewrite_node(&ctx.rewriter, old_entry_point_decl); + shd_rewrite_node(&ctx.rewriter, old_entry_point_decl); + + Nodes old_decls = shd_module_get_declarations(src); + for (size_t i = 0; i < old_decls.count; i++) { + const Node* old_decl = old_decls.nodes[i]; + if (shd_lookup_annotation(old_decl, "RetainAfterSpecialization")) + shd_rewrite_node(&ctx.rewriter, old_decl); + } - destroy_rewriter(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/specialize_execution_model.c b/src/shady/passes/specialize_execution_model.c index 416e9ef2e..55089d0f2 100644 --- a/src/shady/passes/specialize_execution_model.c +++ b/src/shady/passes/specialize_execution_model.c @@ -1,12 +1,6 @@ -#include "passes.h" +#include "shady/pass.h" #include "portability.h" -#include "log.h" - -#include "../ir_private.h" -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" #include @@ -16,40 +10,46 @@ typedef struct { } Context; static const Node* process(Context* ctx, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - IrArena* a = ctx->rewriter.dst_arena; switch (node->tag) { + case Constant_TAG: { + Node* ncnst = (Node*) shd_recreate_node(&ctx->rewriter, node); + if (strcmp(get_declaration_name(ncnst), "SUBGROUP_SIZE") == 0) { + ncnst->payload.constant.value = shd_uint32_literal(a, ctx->config->specialization.subgroup_size); + } + return ncnst; + } default: break; } - return recreate_node_identity(&ctx->rewriter, node); + return shd_recreate_node(&ctx->rewriter, node); } static void specialize_arena_config(const CompilerConfig* config, Module* m, ArenaConfig* target) { switch (config->specialization.execution_model) { case EmVertex: case EmFragment: { - target->allow_subgroup_memory = false; - target->allow_shared_memory = false; + target->address_spaces[AsShared].allowed = false; + target->address_spaces[AsSubgroup].allowed = false; } default: break; } } -Module* specialize_execution_model(const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); +Module* shd_pass_specialize_execution_model(const CompilerConfig* config, Module* src) { + ArenaConfig aconfig = *shd_get_arena_config(shd_module_get_arena(src)); specialize_arena_config(config, src, &aconfig); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); + IrArena* a = shd_new_ir_arena(&aconfig); + Module* dst = shd_new_module(a, shd_module_get_name(src)); + + size_t subgroup_size = config->specialization.subgroup_size; + assert(subgroup_size); Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), + .rewriter = shd_create_node_rewriter(src, dst, (RewriteNodeFn) process), .config = config, }; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); + shd_rewrite_module(&ctx.rewriter); + shd_destroy_rewriter(&ctx.rewriter); return dst; } diff --git a/src/shady/passes/spirv_lift_globals_ssbo.c b/src/shady/passes/spirv_lift_globals_ssbo.c deleted file mode 100644 index 485b5e7b9..000000000 --- a/src/shady/passes/spirv_lift_globals_ssbo.c +++ /dev/null @@ -1,108 +0,0 @@ -#include "passes.h" - -#include "portability.h" -#include "log.h" - -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" -#include "../transform/memory_layout.h" - -typedef struct { - Rewriter rewriter; - const CompilerConfig* config; - BodyBuilder* bb; - Node* lifted_globals_decl; -} Context; - -static const Node* process(Context* ctx, const Node* node) { - IrArena* a = ctx->rewriter.dst_arena; - - BodyBuilder* abs_bb = NULL; - Context c = *ctx; - ctx = &c; - if (is_abstraction(node)) { - c.bb = abs_bb = begin_body(a); - } - - switch (node->tag) { - case RefDecl_TAG: { - const Node* odecl = node->payload.ref_decl.decl; - if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsGlobalPhysical) - break; - assert(ctx->bb && "this RefDecl node isn't appearing in an abstraction - we cannot replace it with a load!"); - const Node* ptr_addr = gen_lea(ctx->bb, ref_decl_helper(a, ctx->lifted_globals_decl), int32_literal(a, 0), singleton(rewrite_node(&ctx->rewriter, odecl))); - const Node* ptr = gen_load(ctx->bb, ptr_addr); - return ptr; - } - case GlobalVariable_TAG: - if (node->payload.global_variable.address_space != AsGlobalPhysical) - break; - assert(false); - default: break; - } - - Node* new = (Node*) recreate_node_identity(&ctx->rewriter, node); - if (abs_bb) { - assert(is_abstraction(new)); - if (get_abstraction_body(new)) - set_abstraction_body(new, finish_body(abs_bb, get_abstraction_body(new))); - else - cancel_body(abs_bb); - } - return new; -} - -Module* spirv_lift_globals_ssbo(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .config = config - }; - - Nodes old_decls = get_module_declarations(src); - LARRAY(const Type*, member_tys, old_decls.count); - LARRAY(String, member_names, old_decls.count); - - Nodes annotations = mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" })); - annotations = empty(a); - - annotations = append_nodes(a, annotations, annotation_value(a, (AnnotationValue) { .name = "DescriptorSet", .value = int32_literal(a, 0) })); - annotations = append_nodes(a, annotations, annotation_value(a, (AnnotationValue) { .name = "DescriptorBinding", .value = int32_literal(a, 0) })); - annotations = append_nodes(a, annotations, annotation(a, (Annotation) { .name = "Constants" })); - - size_t lifted_globals_count = 0; - for (size_t i = 0; i < old_decls.count; i++) { - const Node* odecl = old_decls.nodes[i]; - if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsGlobalPhysical) - continue; - - member_tys[lifted_globals_count] = rewrite_node(&ctx.rewriter, odecl->type); - member_names[lifted_globals_count] = get_decl_name(odecl); - - if (odecl->payload.global_variable.init) - annotations = append_nodes(a, annotations, annotation_values(a, (AnnotationValues) { - .name = "InitialValue", - .values = mk_nodes(a, int32_literal(a, lifted_globals_count), rewrite_node(&ctx.rewriter, odecl->payload.global_variable.init)) - })); - - register_processed(&ctx.rewriter, odecl, int32_literal(a, lifted_globals_count)); - lifted_globals_count++; - } - - if (lifted_globals_count > 0) { - const Type* lifted_globals_struct_t = record_type(a, (RecordType) { - .members = nodes(a, lifted_globals_count, member_tys), - .names = strings(a, lifted_globals_count, member_names), - .special = DecorateBlock - }); - ctx.lifted_globals_decl = global_var(dst, annotations, lifted_globals_struct_t, "lifted_globals", AsShaderStorageBufferObject); - } - - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - return dst; -} diff --git a/src/shady/passes/spirv_map_entrypoint_args.c b/src/shady/passes/spirv_map_entrypoint_args.c deleted file mode 100644 index 8bb23aa64..000000000 --- a/src/shady/passes/spirv_map_entrypoint_args.c +++ /dev/null @@ -1,77 +0,0 @@ -#include "passes.h" - -#include "portability.h" -#include "log.h" - -#include "../rewrite.h" -#include "../type.h" -#include "../transform/ir_gen_helpers.h" -#include "../transform/memory_layout.h" - -typedef struct { - Rewriter rewriter; - const CompilerConfig* config; -} Context; - -static const Node* rewrite_args_type(Rewriter* rewriter, const Node* old_type) { - IrArena* a = rewriter->dst_arena; - - if (old_type->tag != RecordType_TAG || old_type->payload.record_type.special != NotSpecial) - error("EntryPointArgs type must be a plain record type"); - - const Node* new_type = record_type(a, (RecordType) { - .members = rewrite_nodes(rewriter, old_type->payload.record_type.members), - .names = old_type->payload.record_type.names, - .special = DecorateBlock - }); - - register_processed(rewriter, old_type, new_type); - - return new_type; -} - -static const Node* process(Context* ctx, const Node* node) { - if (!node) return NULL; - const Node* found = search_processed(&ctx->rewriter, node); - if (found) return found; - - switch (node->tag) { - case GlobalVariable_TAG: - if (lookup_annotation(node, "EntryPointArgs")) { - if (node->payload.global_variable.address_space != AsExternal) - error("EntryPointArgs address space must be extern"); - - Nodes annotations = rewrite_nodes(&ctx->rewriter, node->payload.global_variable.annotations); - const Node* type = rewrite_args_type(&ctx->rewriter, node->payload.global_variable.type); - - const Node* new_var = global_var(ctx->rewriter.dst_module, - annotations, - type, - node->payload.global_variable.name, - AsPushConstant - ); - - register_processed(&ctx->rewriter, node, new_var); - - return new_var; - } - break; - default: break; - } - - return recreate_node_identity(&ctx->rewriter, node); -} - -Module* spirv_map_entrypoint_args(SHADY_UNUSED const CompilerConfig* config, Module* src) { - ArenaConfig aconfig = get_arena_config(get_module_arena(src)); - IrArena* a = new_ir_arena(aconfig); - Module* dst = new_module(a, get_module_name(src)); - Context ctx = { - .rewriter = create_rewriter(src, dst, (RewriteNodeFn) process), - .config = config - }; - ctx.rewriter.config.rebind_let = true; - rewrite_module(&ctx.rewriter); - destroy_rewriter(&ctx.rewriter); - return dst; -} diff --git a/src/shady/primops.c b/src/shady/primops.c index 46b96f14c..a8a860317 100644 --- a/src/shady/primops.c +++ b/src/shady/primops.c @@ -6,10 +6,10 @@ #include "primops_generated.c" -String get_primop_name(Op op) { +String shd_get_primop_name(Op op) { return primop_names[op]; } -bool has_primop_got_side_effects(Op op) { +bool shd_has_primop_got_side_effects(Op op) { return primop_side_effects[op]; } diff --git a/src/shady/print.c b/src/shady/print.c index 3b9a8566b..c1eb77089 100644 --- a/src/shady/print.c +++ b/src/shady/print.c @@ -1,5 +1,10 @@ +#include "shady/print.h" + +#include "shady/visit.h" + #include "ir_private.h" -#include "analysis/scope.h" +#include "analysis/cfg.h" +#include "analysis/scheduler.h" #include "analysis/uses.h" #include "analysis/leak.h" @@ -9,33 +14,130 @@ #include "growy.h" #include "printer.h" -#include "type.h" - #include #include #include typedef struct PrinterCtx_ PrinterCtx; -typedef void (*PrintFn)(PrinterCtx* ctx, char* format, ...); - -typedef struct { - bool skip_builtin; - bool skip_internal; - bool skip_generated; - bool print_ptrs; - bool color; - bool reparseable; -} PrintConfig; struct PrinterCtx_ { Printer* printer; + NodePrintConfig config; const Node* fn; - Scope* scope; - const UsesMap* scope_uses; - long int min_rpo; - PrintConfig config; + CFG* cfg; + Scheduler* scheduler; + const UsesMap* uses; + + Growy* root_growy; + Printer* root_printer; + + Growy** bb_growies; + Printer** bb_printers; + struct Dict* emitted; }; +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); + +static bool print_node_impl(PrinterCtx* ctx, const Node* node); +static void print_terminator(PrinterCtx* ctx, const Node* node); +static void print_mod_impl(PrinterCtx* ctx, Module* mod); + +static String emit_node(PrinterCtx* ctx, const Node* node); +static void print_mem(PrinterCtx* ctx, const Node* node); + +static PrinterCtx make_printer_ctx(Printer* printer, NodePrintConfig config) { + PrinterCtx ctx = { + .printer = printer, + .config = config, + .emitted = shd_new_dict(const Node*, String, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .root_growy = shd_new_growy(), + }; + ctx.root_printer = shd_new_printer_from_growy(ctx.root_growy); + return ctx; +} + +static void destroy_printer_ctx(PrinterCtx ctx) { + shd_destroy_dict(ctx.emitted); +} + +void shd_print_module(Printer* printer, NodePrintConfig config, Module* mod) { + PrinterCtx ctx = make_printer_ctx(printer, config); + print_mod_impl(&ctx, mod); + String s = shd_printer_growy_unwrap(ctx.root_printer); + shd_print(ctx.printer, "%s", s); + free((void*)s); + shd_printer_flush(ctx.printer); + destroy_printer_ctx(ctx); +} + +void shd_print_node(Printer* printer, NodePrintConfig config, const Node* node) { + PrinterCtx ctx = make_printer_ctx(printer, config); + String emitted = emit_node(&ctx, node); + String s = shd_printer_growy_unwrap(ctx.root_printer); + shd_print(ctx.printer, "%s%s", s, emitted); + free((void*)s); + shd_printer_flush(ctx.printer); + destroy_printer_ctx(ctx); +} + +void shd_print_node_into_str(const Node* node, char** str_ptr, size_t* size) { + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + if (node) + shd_print(p, "%%%d ", node->id); + shd_print_node(p, (NodePrintConfig) {.reparseable = true}, node); + shd_destroy_printer(p); + *size = shd_growy_size(g); + *str_ptr = shd_growy_deconstruct(g); +} + +void shd_print_module_into_str(Module* mod, char** str_ptr, size_t* size) { + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + shd_print_module(p, (NodePrintConfig) {.reparseable = true,}, mod); + shd_destroy_printer(p); + *size = shd_growy_size(g); + *str_ptr = shd_growy_deconstruct(g); +} + +void shd_dump_node(const Node* node) { + Printer* p = shd_new_printer_from_file(stdout); + if (node) + shd_print(p, "%%%d ", node->id); + shd_print_node(p, (NodePrintConfig) {.color = true}, node); + printf("\n"); +} + +void shd_dump_module(Module* mod) { + Printer* p = shd_new_printer_from_file(stdout); + shd_print_module(p, (NodePrintConfig) {.color = true}, mod); + shd_destroy_printer(p); + printf("\n"); +} + +void shd_log_node(LogLevel level, const Node* node) { + if (level <= shd_log_get_level()) { + Printer* p = shd_new_printer_from_file(stderr); + shd_print_node(p, (NodePrintConfig) {.color = true}, node); + shd_destroy_printer(p); + } +} + +void shd_log_module(LogLevel level, const CompilerConfig* compiler_cfg, Module* mod) { + NodePrintConfig config = { .color = true }; + if (compiler_cfg) { + config.print_generated = compiler_cfg->logging.print_generated; + config.print_builtin = compiler_cfg->logging.print_builtin; + config.print_internal = compiler_cfg->logging.print_internal; + } + if (level <= shd_log_get_level()) { + Printer* p = shd_new_printer_from_file(stderr); + shd_print_module(p, config, mod); + shd_destroy_printer(p); + } +} + #define COLOR(x) (ctx->config.color ? (x) : "") #define RESET COLOR("\033[0m") @@ -43,7 +145,7 @@ struct PrinterCtx_ { #define GREEN COLOR("\033[0;32m") #define YELLOW COLOR("\033[0;33m") #define BLUE COLOR("\033[0;34m") -#define MANGENTA COLOR("\033[0;35m") +#define MAGENTA COLOR("\033[0;35m") #define CYAN COLOR("\033[0;36m") #define WHITE COLOR("\033[0;37m") @@ -56,10 +158,16 @@ struct PrinterCtx_ { #define BCYAN COLOR("\033[0;96m") #define BWHITE COLOR("\033[0;97m") -#define printf(...) print(ctx->printer, __VA_ARGS__) -#define print_node(n) print_node_impl(ctx, n) +#define printf(...) shd_print(ctx->printer, __VA_ARGS__) +#define print_node(n) print_operand_helper(ctx, 0, n) +#define print_operand(nc, n) print_operand_helper(ctx, nc, n) + +static void print_operand_helper(PrinterCtx* ctx, NodeClass nc, const Node* op); -static void print_node_impl(PrinterCtx* ctx, const Node* node); +void _shd_print_node_operand(PrinterCtx* ctx, const Node* node, String op_name, NodeClass op_class, const Node* op); +void _shd_print_node_operand_list(PrinterCtx* ctx, const Node* node, String op_name, NodeClass op_class, Nodes ops); + +void _shd_print_node_generated(PrinterCtx* ctx, const Node* node); #pragma GCC diagnostic error "-Wswitch" @@ -70,7 +178,7 @@ static void print_param_list(PrinterCtx* ctx, Nodes params, const Nodes* default for (size_t i = 0; i < params.count; i++) { const Node* param = params.nodes[i]; if (ctx->config.print_ptrs) printf("%zu::", (size_t)(void*) param); - print_node(param->payload.var.type); + print_node(param->payload.param.type); printf(" "); print_node(param); printf(RESET); @@ -129,114 +237,119 @@ static void print_yield_types(PrinterCtx* ctx, Nodes types) { } } -static void print_abs_body(PrinterCtx* ctx, const Node* block); +static String emit_abs_body(PrinterCtx* ctx, const Node* abs); -static void print_basic_block(PrinterCtx* ctx, const Node* bb) { +static void print_basic_block(PrinterCtx* ctx, const CFNode* node) { + const Node* bb = node->node; printf(GREEN); - printf("\n\ncont"); + printf("\ncont"); printf(BYELLOW); - printf(" %s", bb->payload.basic_block.name); + if (bb->payload.basic_block.name && strlen(bb->payload.basic_block.name) > 0) + printf(" %s", bb->payload.basic_block.name); + else + printf(" %%%d", bb->id); printf(RESET); if (ctx->config.print_ptrs) { printf(" %zu:: ", (size_t)(void*)bb); - printf(" fn=%zu:: ", (size_t)(void*)bb->payload.basic_block.fn); } print_param_list(ctx, bb->payload.basic_block.params, NULL); printf(" {"); - indent(ctx->printer); + shd_printer_indent(ctx->printer); printf("\n"); - print_abs_body(ctx, bb); - deindent(ctx->printer); + printf("%s", emit_abs_body(ctx, bb)); + shd_printer_deindent(ctx->printer); printf("\n}"); } -static void print_dominated_bbs(PrinterCtx* ctx, const CFNode* dominator) { - assert(dominator); - for (size_t i = 0; i < dominator->dominates->elements_count; i++) { - const CFNode* cfnode = read_list(const CFNode*, dominator->dominates)[i]; - // ignore cases that make up basic structural dominance - if (find_key_dict(const Node*, dominator->structurally_dominates, cfnode->node)) - continue; - assert(is_basic_block(cfnode->node)); - print_basic_block(ctx, cfnode->node); - } -} - -static void print_abs_body(PrinterCtx* ctx, const Node* block) { - assert(!ctx->fn || is_function(ctx->fn)); - assert(is_abstraction(block)); +static String emit_abs_body(PrinterCtx* ctx, const Node* abs) { + Growy* g = shd_new_growy(); + Printer* p = shd_new_printer_from_growy(g); + CFNode* cfnode = ctx->cfg ? shd_cfg_lookup(ctx->cfg, abs) : NULL; + if (cfnode) + ctx->bb_printers[cfnode->rpo_index] = p; - print_node(get_abstraction_body(block)); + emit_node(ctx, get_abstraction_body(abs)); - // TODO: it's likely cleaner to instead print things according to the dominator tree in the first place. - if (ctx->scope != NULL) { - const CFNode* dominator = scope_lookup(ctx->scope, block); - if (ctx->min_rpo < ((long int) dominator->rpo_index)) { - size_t save_rpo = ctx->min_rpo; - ctx->min_rpo = dominator->rpo_index; - print_dominated_bbs(ctx, dominator); - ctx->min_rpo = save_rpo; + if (cfnode) { + Growy* g2 = shd_new_growy(); + Printer* p2 = shd_new_printer_from_growy(g2); + size_t count = cfnode->dominates->elements_count; + for (size_t i = 0; i < count; i++) { + const CFNode* dominated = shd_read_list(const CFNode*, cfnode->dominates)[i]; + assert(is_basic_block(dominated->node)); + PrinterCtx bb_ctx = *ctx; + bb_ctx.printer = p2; + print_basic_block(&bb_ctx, dominated); + if (i + 1 < count) + shd_newline(bb_ctx.printer); } + + String bbs = shd_printer_growy_unwrap(p2); + shd_print(p, "%s", bbs); + free((void*) bbs); } -} -static void print_case_body(PrinterCtx* ctx, const Node* case_) { - assert(is_case(case_)); - printf(" {"); - indent(ctx->printer); - printf("\n"); - print_abs_body(ctx, case_); - deindent(ctx->printer); - printf("\n}"); + String s = shd_printer_growy_unwrap(p); + String s2 = shd_string(ctx->fn->arena, s); + if (cfnode) + ctx->bb_printers[cfnode->rpo_index] = NULL; + free((void*) s); + return s2; } static void print_function(PrinterCtx* ctx, const Node* node) { assert(is_function(node)); PrinterCtx sub_ctx = *ctx; - if (node->arena->config.check_op_classes) { - Scope* scope = new_scope(node); - sub_ctx.scope = scope; - sub_ctx.fn = node; + sub_ctx.fn = node; + if (node->arena->config.name_bound) { + CFGBuildConfig cfg_config = structured_scope_cfg_build(); + CFG* cfg = shd_new_cfg(node, node, cfg_config); + sub_ctx.cfg = cfg; + sub_ctx.scheduler = shd_new_scheduler(cfg); + sub_ctx.bb_growies = calloc(sizeof(size_t), cfg->size); + sub_ctx.bb_printers = calloc(sizeof(size_t), cfg->size); if (node->arena->config.check_types && node->arena->config.allow_fold) { - sub_ctx.scope_uses = create_uses_map(node, (NcDeclaration | NcType)); + sub_ctx.uses = shd_new_uses_map_fn(node, (NcDeclaration | NcType)); } } ctx = &sub_ctx; - ctx->min_rpo = -1; print_yield_types(ctx, node->payload.fun.return_types); print_param_list(ctx, node->payload.fun.params, NULL); - if (!node->payload.fun.body) { + if (!get_abstraction_body(node)) { printf(";"); - return; - } + } else { + printf(" {"); + shd_printer_indent(ctx->printer); + printf("\n"); - printf(" {"); - indent(ctx->printer); - printf("\n"); + printf("%s", emit_abs_body(ctx, node)); - print_abs_body(ctx, node); - - deindent(ctx->printer); - printf("\n}"); + shd_printer_deindent(ctx->printer); + printf("\n}"); + } - if (node->arena->config.check_op_classes) { - if (sub_ctx.scope_uses) - destroy_uses_map(sub_ctx.scope_uses); - destroy_scope(sub_ctx.scope); + if (sub_ctx.cfg) { + if (sub_ctx.uses) + shd_destroy_uses_map(sub_ctx.uses); + free(sub_ctx.bb_printers); + free(sub_ctx.bb_growies); + shd_destroy_cfg(sub_ctx.cfg); + shd_destroy_scheduler(sub_ctx.scheduler); } } -static void print_annotations(PrinterCtx* ctx, Nodes annotations) { - for (size_t i = 0; i < annotations.count; i++) { - print_node(annotations.nodes[i]); - printf(" "); +static void print_nodes(PrinterCtx* ctx, Nodes nodes) { + for (size_t i = 0; i < nodes.count; i++) { + print_node(nodes.nodes[i]); + if (i + 1 < nodes.count) + printf(" "); } } -static void print_type(PrinterCtx* ctx, const Node* node) { +static bool print_type(PrinterCtx* ctx, const Node* node) { printf(BCYAN); switch (is_type(node)) { case NotAType: assert(false); break; @@ -249,7 +362,7 @@ static void print_type(PrinterCtx* ctx, const Node* node) { case FloatTy16: printf("16"); break; case FloatTy32: printf("32"); break; case FloatTy64: printf("64"); break; - default: error("Not a known valid float width") + default: shd_error("Not a known valid float width") } break; case MaskType_TAG: printf("mask"); break; @@ -266,11 +379,18 @@ static void print_type(PrinterCtx* ctx, const Node* node) { case IntTy16: printf("16"); break; case IntTy32: printf("32"); break; case IntTy64: printf("64"); break; - default: error("Not a known valid int width") + default: shd_error("Not a known valid int width") } break; case RecordType_TAG: - printf("struct"); + if (node->payload.record_type.members.count == 0) { + printf("unit_t"); + break; + } else if (node->payload.record_type.special & DecorateBlock) { + printf("block"); + } else { + printf("struct"); + } printf(RESET); printf(" {"); const Nodes* members = &node->payload.record_type.members; @@ -306,11 +426,11 @@ static void print_type(PrinterCtx* ctx, const Node* node) { break; } case PtrType_TAG: { - printf("ptr"); + printf(node->payload.ptr_type.is_reference ? "ref" : "ptr"); printf(RESET); printf("("); printf(BLUE); - printf(get_address_space_name(node->payload.ptr_type.address_space)); + printf(shd_get_address_space_name(node->payload.ptr_type.address_space)); printf(RESET); printf(", "); print_node(node->payload.ptr_type.pointed_type); @@ -346,28 +466,31 @@ static void print_type(PrinterCtx* ctx, const Node* node) { printf("["); print_node(node->payload.image_type.sampled_type); printf(RESET); - printf(", %d, %d, %d, %d]", node->payload.image_type.dim, node->payload.image_type.depth, node->payload.image_type.onion, node->payload.image_type.multisample); + printf(", %d, %d, %d, %d]", node->payload.image_type.dim, node->payload.image_type.depth, node->payload.image_type.arrayed, node->payload.image_type.ms); break; } case Type_SamplerType_TAG: { printf("sampler_type"); break; } - case Type_CombinedImageSamplerType_TAG: { + case Type_SampledImageType_TAG: { printf("sampled"); printf(RESET); printf("["); - print_node(node->payload.combined_image_sampler_type.image_type); + print_node(node->payload.sampled_image_type.image_type); printf(RESET); printf("]"); break; } case TypeDeclRef_TAG: { - printf("%s", get_decl_name(node->payload.type_decl_ref.decl)); + printf("%s", get_declaration_name(node->payload.type_decl_ref.decl)); break; } + default:_shd_print_node_generated(ctx, node); + break; } printf(RESET); + return true; } static void print_string_lit(PrinterCtx* ctx, const char* string) { @@ -391,7 +514,7 @@ static void print_string_lit(PrinterCtx* ctx, const char* string) { printf("\""); } -static void print_value(PrinterCtx* ctx, const Node* node) { +static bool print_value(PrinterCtx* ctx, const Node* node) { switch (is_value(node)) { case NotAValue: assert(false); break; case ConstrainedValue_TAG: { @@ -400,21 +523,14 @@ static void print_value(PrinterCtx* ctx, const Node* node) { print_node(node->payload.constrained.value); break; } - case Variable_TAG: - if (ctx->scope_uses) { - // if ((*find_value_dict(const Node*, Uses*, ctx->uses->map, node))->escapes_defining_block) - // printf(MANGENTA); - // else - printf(YELLOW); - } else - printf(YELLOW); - String name = get_value_name(node); - if (name) - printf("%s_%d", name, node->payload.var.id); - else - printf("v%d", node->payload.var.id); + case Value_Param_TAG: + printf(YELLOW); + String name = shd_get_value_name_unsafe(node); + if (name && strlen(name) > 0) + printf("%s_", name); + printf("%%%d", node->id); printf(RESET); - break; + return true; case UntypedNumber_TAG: printf(BBLUE); printf("%s", node->payload.untyped_number.plaintext); @@ -422,16 +538,16 @@ static void print_value(PrinterCtx* ctx, const Node* node) { break; case IntLiteral_TAG: printf(BBLUE); - uint64_t v = get_int_literal_value(node->payload.int_literal, false); + uint64_t v = shd_get_int_literal_value(node->payload.int_literal, false); switch (node->payload.int_literal.width) { case IntTy8: printf("%" PRIu8, (uint8_t) v); break; case IntTy16: printf("%" PRIu16, (uint16_t) v); break; case IntTy32: printf("%" PRIu32, (uint32_t) v); break; case IntTy64: printf("%" PRIu64, v); break; - default: error("Not a known valid int width") + default: shd_error("Not a known valid int width") } printf(RESET); - break; + return true; case FloatLiteral_TAG: printf(BBLUE); switch (node->payload.float_literal.width) { @@ -447,25 +563,25 @@ static void print_value(PrinterCtx* ctx, const Node* node) { memcpy(&d, &node->payload.float_literal.value, sizeof(uint64_t)); printf("%.17g", d); break; } - default: error("Not a known valid float width") + default: shd_error("Not a known valid float width") } printf(RESET); - break; + return true; case True_TAG: printf(BBLUE); printf("true"); printf(RESET); - break; + return true; case False_TAG: printf(BBLUE); printf("false"); printf(RESET); - break; + return true; case StringLiteral_TAG: printf(BBLUE); print_string_lit(ctx, node->payload.string_lit.string); printf(RESET); - break; + return true; case Value_Undef_TAG: { const Type* type = node->payload.undef.type; printf(BBLUE); @@ -475,7 +591,7 @@ static void print_value(PrinterCtx* ctx, const Node* node) { print_node(type); printf(RESET); printf("]"); - break; + return true; } case Value_NullPtr_TAG: { const Type* type = node->payload.undef.type; @@ -486,7 +602,7 @@ static void print_value(PrinterCtx* ctx, const Node* node) { print_node(type); printf(RESET); printf("]"); - break; + return true; } case Value_Composite_TAG: { const Type* type = node->payload.composite.type; @@ -497,7 +613,7 @@ static void print_value(PrinterCtx* ctx, const Node* node) { print_node(type); printf("]"); print_args_list(ctx, node->payload.composite.contents); - break; + return true; } case Value_Fill_TAG: { const Type* type = node->payload.fill.type; @@ -511,49 +627,91 @@ static void print_value(PrinterCtx* ctx, const Node* node) { printf("("); print_node(node->payload.fill.value); printf(")"); - break; + return true; } case Value_RefDecl_TAG: { printf(BYELLOW); - printf((char*) get_decl_name(node->payload.ref_decl.decl)); + printf("%s", (char*) get_declaration_name(node->payload.ref_decl.decl)); printf(RESET); - break; + return true; } case FnAddr_TAG: printf(BYELLOW); - printf((char*) get_decl_name(node->payload.fn_addr.fn)); + printf("%s", (char*) get_declaration_name(node->payload.fn_addr.fn)); printf(RESET); + return true; + default:_shd_print_node_generated(ctx, node); break; } + return false; } static void print_instruction(PrinterCtx* ctx, const Node* node) { + //printf("%%%d = ", node->id); switch (is_instruction(node)) { case NotAnInstruction: assert(false); break; - case Instruction_Comment_TAG: { - printf(GREY); - printf("/* %s */", node->payload.comment.string); - printf(RESET); - break; - } case PrimOp_TAG: { - printf(GREEN); - printf("%s", get_primop_name(node->payload.prim_op.op)); - printf(RESET); - Nodes ty_args = node->payload.prim_op.type_arguments; - if (ty_args.count > 0) - print_ty_args_list(ctx, node->payload.prim_op.type_arguments); - print_args_list(ctx, node->payload.prim_op.operands); - break; - } case Call_TAG: { - printf(GREEN); - printf("call"); + // case Instruction_Comment_TAG: { + // printf(MAGENTA); + // printf("/* %s */", node->payload.comment.string); + // printf(RESET); + // break; + // } case PrimOp_TAG: { + // printf(GREEN); + // printf("%s", get_primop_name(node->payload.prim_op.op)); + // printf(RESET); + // Nodes ty_args = node->payload.prim_op.type_arguments; + // if (ty_args.count > 0) + // print_ty_args_list(ctx, node->payload.prim_op.type_arguments); + // print_args_list(ctx, node->payload.prim_op.operands); + // break; + // } case Call_TAG: { + // printf(GREEN); + // printf("call"); + // printf(RESET); + // printf(" ("); + // shd_print_node(node->payload.call.callee); + // printf(")"); + // print_args_list(ctx, node->payload.call.args); + // break; + // } + default: _shd_print_node_generated(ctx, node); + } + //printf("\n"); +} + +static void print_jump(PrinterCtx* ctx, const Node* node) { + assert(node->tag == Jump_TAG); + print_node(node->payload.jump.target); + print_args_list(ctx, node->payload.jump.args); +} + +static void print_structured_construct_results(PrinterCtx* ctx, const Node* tail_case) { + Nodes params = get_abstraction_params(tail_case); + if (params.count > 0) { + printf(GREEN); + printf("val"); + printf(RESET); + for (size_t i = 0; i < params.count; i++) { + // TODO: fix let mut + if (tail_case->arena->config.check_types) { + printf(" "); + print_node(params.nodes[i]->type); + } + printf(" "); + print_node(params.nodes[i]); printf(RESET); - printf(" ("); - print_node(node->payload.call.callee); - printf(")"); - print_args_list(ctx, node->payload.call.args); - break; - } case If_TAG: { + } + printf(" = "); + } +} + +static void print_terminator(PrinterCtx* ctx, const Node* node) { + TerminatorTag tag = is_terminator(node); + switch (tag) { + case NotATerminator: assert(false); + /* + case If_TAG: { + print_structured_construct_results(ctx, get_structured_construct_tail(node)); printf(GREEN); printf("if"); printf(RESET); @@ -561,6 +719,8 @@ static void print_instruction(PrinterCtx* ctx, const Node* node) { printf("("); print_node(node->payload.if_instr.condition); printf(") "); + if (ctx->config.in_cfg) + break; print_case_body(ctx, node->payload.if_instr.if_true); if (node->payload.if_instr.if_false) { printf(GREEN); @@ -568,18 +728,11 @@ static void print_instruction(PrinterCtx* ctx, const Node* node) { printf(RESET); print_case_body(ctx, node->payload.if_instr.if_false); } - break; - } case Loop_TAG: { - printf(GREEN); - printf("loop"); - printf(RESET); - print_yield_types(ctx, node->payload.loop_instr.yield_types); - const Node* body = node->payload.loop_instr.body; - assert(is_case(body)); - print_param_list(ctx, body->payload.case_.params, &node->payload.loop_instr.initial_args); - print_case_body(ctx, body); + printf("\n"); + print_abs_body(ctx, get_structured_construct_tail(node)); break; } case Match_TAG: { + print_structured_construct_results(ctx, get_structured_construct_tail(node)); printf(GREEN); printf("match"); printf(RESET); @@ -587,6 +740,8 @@ static void print_instruction(PrinterCtx* ctx, const Node* node) { printf("("); print_node(node->payload.match_instr.inspect); printf(")"); + if (ctx->config.in_cfg) + break; printf(" {"); indent(ctx->printer); for (size_t i = 0; i < node->payload.match_instr.literals.count; i++) { @@ -609,79 +764,39 @@ static void print_instruction(PrinterCtx* ctx, const Node* node) { deindent(ctx->printer); printf("\n}"); + printf("\n"); + print_abs_body(ctx, get_structured_construct_tail(node)); + break; + } case Loop_TAG: { + print_structured_construct_results(ctx, get_structured_construct_tail(node)); + printf(GREEN); + printf("loop"); + printf(RESET); + print_yield_types(ctx, node->payload.loop_instr.yield_types); + if (ctx->config.in_cfg) + break; + const Node* body = node->payload.loop_instr.body; + print_param_list(ctx, get_abstraction_params(body), &node->payload.loop_instr.initial_args); + print_case_body(ctx, body); + printf("\n"); + print_abs_body(ctx, get_structured_construct_tail(node)); break; } case Control_TAG: { + print_structured_construct_results(ctx, get_structured_construct_tail(node)); printf(BGREEN); - if (ctx->scope_uses) { - if (is_control_static(ctx->scope_uses, node)) + if (ctx->uses) { + if (is_control_static(ctx->uses, node)) printf("static "); } printf("control"); printf(RESET); print_yield_types(ctx, node->payload.control.yield_types); - print_param_list(ctx, node->payload.control.inside->payload.case_.params, NULL); + if (ctx->config.in_cfg) + break; + print_param_list(ctx, get_abstraction_params(node->payload.control.inside), NULL); print_case_body(ctx, node->payload.control.inside); - break; - } case Block_TAG: { - printf(BGREEN); - printf("block"); - printf(RESET); - print_case_body(ctx, node->payload.block.inside); - break; - } - } -} - -static void print_jump(PrinterCtx* ctx, const Node* node) { - assert(node->tag == Jump_TAG); - print_node(node->payload.jump.target); - print_args_list(ctx, node->payload.jump.args); -} - -static void print_terminator(PrinterCtx* ctx, const Node* node) { - TerminatorTag tag = is_terminator(node); - switch (tag) { - case NotATerminator: assert(false); - case Let_TAG: - case LetMut_TAG: { - const Node* instruction = get_let_instruction(node); - const Node* tail = get_let_tail(node); - if (!ctx->config.reparseable) { - // if the let tail is a case, we apply some syntactic sugar - if (tail->payload.case_.params.count > 0) { - printf(GREEN); - if (tag == LetMut_TAG) - printf("var"); - else - printf("val"); - printf(RESET); - Nodes params = tail->payload.case_.params; - for (size_t i = 0; i < params.count; i++) { - if (tag == LetMut_TAG || !ctx->config.reparseable) { - printf(" "); - print_node(params.nodes[i]->payload.var.type); - } - printf(" "); - print_node(params.nodes[i]); - printf(RESET); - } - printf(" = "); - } - print_node(instruction); - printf(";\n"); - print_abs_body(ctx, tail); - } else { - printf(GREEN); - printf("let"); - printf(RESET); - printf(" "); - print_node(instruction); - printf(GREEN); - printf(" in "); - printf(RESET); - print_node(tail); - printf(";"); - } + printf("\n"); + print_abs_body(ctx, get_structured_construct_tail(node)); break; } case Return_TAG: printf(BGREEN); @@ -711,7 +826,7 @@ static void print_terminator(PrinterCtx* ctx, const Node* node) { printf("branch "); printf(RESET); printf("("); - print_node(node->payload.branch.branch_condition); + print_node(node->payload.branch.condition); printf(", "); print_jump(ctx, node->payload.branch.true_jump); printf(", "); @@ -743,7 +858,7 @@ static void print_terminator(PrinterCtx* ctx, const Node* node) { printf("join"); printf(RESET); printf("("); - print_node(node->payload.join.join_point); + shd_print_node(node->payload.join.join_point); printf(")"); print_args_list(ctx, node->payload.join.args); printf(";"); @@ -756,36 +871,42 @@ static void print_terminator(PrinterCtx* ctx, const Node* node) { break; case MergeContinue_TAG: case MergeBreak_TAG: - case Terminator_Yield_TAG: + case Terminator_MergeSelection_TAG: printf(BGREEN); printf("%s", node_tags[node->tag]); printf(RESET); - print_args_list(ctx, node->payload.yield.args); + print_args_list(ctx, node->payload.merge_selection.args); printf(";"); - break; + break;*/ + default:_shd_print_node_generated(ctx, node); + return; } + emit_node(ctx, get_terminator_mem(node)); } static void print_decl(PrinterCtx* ctx, const Node* node) { assert(is_declaration(node)); - if (ctx->config.skip_generated && lookup_annotation(node, "Generated")) + if (!ctx->config.print_generated && shd_lookup_annotation(node, "Generated")) return; - if (ctx->config.skip_internal && lookup_annotation(node, "Internal")) + if (!ctx->config.print_internal && shd_lookup_annotation(node, "Internal")) return; - if (ctx->config.skip_builtin && lookup_annotation(node, "Builtin")) + if (!ctx->config.print_builtin && shd_lookup_annotation(node, "Builtin")) return; PrinterCtx sub_ctx = *ctx; - sub_ctx.scope = NULL; + sub_ctx.cfg = NULL; + sub_ctx.scheduler = NULL; ctx = &sub_ctx; switch (node->tag) { case GlobalVariable_TAG: { const GlobalVariable* gvar = &node->payload.global_variable; - print_annotations(ctx, gvar->annotations); + print_nodes(ctx, gvar->annotations); + printf("\n"); printf(BLUE); - printf(get_address_space_name(gvar->address_space)); - printf(RESET); + printf("var "); + printf(BLUE); + printf(shd_get_address_space_name(gvar->address_space)); printf(" "); print_node(gvar->type); printf(BYELLOW); @@ -800,7 +921,8 @@ static void print_decl(PrinterCtx* ctx, const Node* node) { } case Constant_TAG: { const Constant* cnst = &node->payload.constant; - print_annotations(ctx, cnst->annotations); + print_nodes(ctx, cnst->annotations); + printf("\n"); printf(BLUE); printf("const "); printf(RESET); @@ -808,17 +930,16 @@ static void print_decl(PrinterCtx* ctx, const Node* node) { printf(BYELLOW); printf(" %s", cnst->name); printf(RESET); - printf(" = "); - if (get_quoted_value(cnst->instruction)) - print_node(get_quoted_value(cnst->instruction)); - else - print_node(cnst->instruction); + if (cnst->value) { + printf(" = %s", emit_node(ctx, cnst->value)); + } printf(";\n"); break; } case Function_TAG: { const Function* fun = &node->payload.fun; - print_annotations(ctx, fun->annotations); + print_nodes(ctx, fun->annotations); + printf("\n"); printf(BLUE); printf("fn"); printf(RESET); @@ -831,7 +952,8 @@ static void print_decl(PrinterCtx* ctx, const Node* node) { } case NominalType_TAG: { const NominalType* nom = &node->payload.nom_type; - print_annotations(ctx, nom->annotations); + print_nodes(ctx, nom->annotations); + printf("\n"); printf(BLUE); printf("type"); printf(RESET); @@ -843,50 +965,12 @@ static void print_decl(PrinterCtx* ctx, const Node* node) { printf(";\n\n"); break; } - default: error("Not a decl"); + default: shd_error("Not a decl"); } } -static void print_node_impl(PrinterCtx* ctx, const Node* node) { - if (node == NULL) { - printf("?"); - return; - } - - if (ctx->config.print_ptrs) printf("%zu::", (size_t)(void*)node); - - if (is_type(node)) - print_type(ctx, node); - else if (is_value(node)) - print_value(ctx, node); - else if (is_instruction(node)) - print_instruction(ctx, node); - else if (is_terminator(node)) - print_terminator(ctx, node); - else if (node->tag == Case_TAG) { - printf(BYELLOW); - printf("case_ "); - printf(RESET); - print_param_list(ctx, node->payload.case_.params, NULL); - indent(ctx->printer); - printf(" {\n"); - print_abs_body(ctx, node); - // printf(";"); - deindent(ctx->printer); - printf("\n}"); - } else if (is_declaration(node)) { - printf(BYELLOW); - printf("%s", get_decl_name(node)); - printf(RESET); - } else if (node->tag == Unbound_TAG) { - printf(YELLOW); - printf("`%s`", node->payload.unbound.name); - printf(RESET); - } else if (node->tag == UnboundBBs_TAG) { - print_node(node->payload.unbound_bbs.body); - for (size_t i = 0; i < node->payload.unbound_bbs.children_blocks.count; i++) - print_basic_block(ctx, node->payload.unbound_bbs.children_blocks.nodes[i]); - } else switch (node->tag) { +static void print_annotation(PrinterCtx* ctx, const Node* node) { + switch (is_annotation(node)) { case Annotation_TAG: { const Annotation* annotation = &node->payload.annotation; printf(RED); @@ -920,18 +1004,95 @@ static void print_node_impl(PrinterCtx* ctx, const Node* node) { print_args_list(ctx, annotation->values); break; } - case BasicBlock_TAG: { - printf(BYELLOW); + case NotAnAnnotation: shd_error(""); + } +} + +static String emit_node(PrinterCtx* ctx, const Node* node) { + if (node == NULL) { + return "?"; + } + + String* found = shd_dict_find_value(const Node*, String, ctx->emitted, node); + if (found) + return *found; + + bool print_def = true; + String r; + if (is_declaration(node)) { + r = get_declaration_name(node); + print_def = false; + } else if (is_param(node) || is_basic_block(node) || node->tag == RefDecl_TAG || node->tag == FnAddr_TAG) { + print_def = false; + r = shd_fmt_string_irarena(node->arena, "%%%d", node->id); + } else { + r = shd_fmt_string_irarena(node->arena, "%%%d", node->id); + } + shd_dict_insert(const Node*, String, ctx->emitted, node, r); + + Growy* g = shd_new_growy(); + PrinterCtx ctx2 = *ctx; + ctx2.printer = shd_new_printer_from_growy(g); + bool print_inline = print_node_impl(&ctx2, node); + + String s = shd_printer_growy_unwrap(ctx2.printer); + Printer* p = ctx->root_printer; + if (ctx->scheduler) { + CFNode* dst = shd_schedule_instruction(ctx->scheduler, node); + if (dst) + p = ctx2.bb_printers[dst->rpo_index]; + } + + if (print_def) + shd_print(p, "%%%d = %s\n", node->id, s); + + if (print_inline) { + String is = shd_string(node->arena, s); + shd_dict_insert(const Node*, String, ctx->emitted, node, is); + free((void*) s); + return is; + } else { + free((void*) s); + return r; + } +} + +static bool print_node_impl(PrinterCtx* ctx, const Node* node) { + assert(node); + + if (ctx->config.print_ptrs) printf("%zu::", (size_t)(void*)node); + + if (is_type(node)) { + return print_type(ctx, node); + } else if (is_instruction(node)) + print_instruction(ctx, node); + else if (is_value(node)) + return print_value(ctx, node); + else if (is_terminator(node)) + print_terminator(ctx, node); + else if (is_declaration(node)) { + printf(BYELLOW); + printf("%s", get_declaration_name(node)); + printf(RESET); + } else if (is_annotation(node)) { + print_annotation(ctx, node); + return true; + } else if (is_basic_block(node)) { + printf(BYELLOW); + if (node->payload.basic_block.name && strlen(node->payload.basic_block.name) > 0) printf("%s", node->payload.basic_block.name); - printf(RESET); - break; - } - default: error("dunno how to print %s", node_tags[node->tag]); + else + printf("%%%d", node->id); + printf(RESET); + return true; + } else { + _shd_print_node_generated(ctx, node); } + return false; } static void print_mod_impl(PrinterCtx* ctx, Module* mod) { - Nodes decls = get_module_declarations(mod); + Nodes decls = shd_module_get_declarations(mod); for (size_t i = 0; i < decls.count; i++) { const Node* decl = decls.nodes[i]; print_decl(ctx, decl); @@ -941,55 +1102,137 @@ static void print_mod_impl(PrinterCtx* ctx, Module* mod) { #undef print_node #undef printf -static void print_helper(Printer* printer, const Node* node, Module* mod, PrintConfig config) { - PrinterCtx ctx = { - .printer = printer, - .config = config, +typedef struct { + Visitor v; + PrinterCtx* ctx; +} PrinterVisitor; + +static void print_mem_visitor(PrinterVisitor* ctx, NodeClass nc, String opname, const Node* op, size_t i) { + if (nc == NcMem) + print_mem(ctx->ctx, op); +} + +static void print_mem(PrinterCtx* ctx, const Node* mem) { + PrinterVisitor pv = { + .v = { + .visit_op_fn = (VisitOpFn) print_mem_visitor, + }, + .ctx = ctx, }; - if (node) - print_node_impl(&ctx, node); - if (mod) - print_mod_impl(&ctx, mod); - flush(ctx.printer); - destroy_printer(ctx.printer); + switch (is_mem(mem)) { + case Mem_AbsMem_TAG: return; + case Mem_MemAndValue_TAG: return print_mem(ctx, mem->payload.mem_and_value.mem); + default: { + assert(is_instruction(mem)); + shd_visit_node_operands((Visitor*) &pv, 0, mem); + return; + } + } } -void print_node_into_str(const Node* node, char** str_ptr, size_t* size) { - Growy* g = new_growy(); - print_helper(open_growy_as_printer(g), node, NULL, (PrintConfig) { .reparseable = true }); - *size = growy_size(g); - *str_ptr = growy_deconstruct(g); +static void print_operand_name_helper(PrinterCtx* ctx, String name) { + shd_print(ctx->printer, GREY); + shd_print(ctx->printer, "%s", name); + shd_print(ctx->printer, RESET); + shd_print(ctx->printer, ": ", name); } -void print_module_into_str(Module* mod, char** str_ptr, size_t* size) { - Growy* g = new_growy(); - print_helper(open_growy_as_printer(g), NULL, mod, (PrintConfig) { .reparseable = true, }); - *size = growy_size(g); - *str_ptr = growy_deconstruct(g); +static void print_operand_helper(PrinterCtx* ctx, NodeClass nc, const Node* op) { + if (getenv("SHADY_SUPER_VERBOSE_NODE_DEBUG")) { + if (op && (is_value(op) || is_instruction(op))) + shd_print(ctx->printer, "%%%d ", op->id); + shd_print(ctx->printer, "%s", emit_node(ctx, op)); + } else { + shd_print(ctx->printer, "%s", emit_node(ctx, op)); + } } -void dump_node(const Node* node) { - print_helper(open_file_as_printer(stdout), node, NULL, (PrintConfig) { .color = true }); - printf("\n"); +void _shd_print_node_operand(PrinterCtx* ctx, const Node* n, String name, NodeClass op_class, const Node* op) { + print_operand_name_helper(ctx, name); + if (op_class == NcBasic_block) + shd_print(ctx->printer, BYELLOW); + print_operand_helper(ctx, op_class, op); + shd_print(ctx->printer, RESET); } -void dump_module(Module* mod) { - print_helper(open_file_as_printer(stdout), NULL, mod, (PrintConfig) { .color = true }); - printf("\n"); +void _shd_print_node_operand_list(PrinterCtx* ctx, const Node* n, String name, NodeClass op_class, Nodes ops) { + print_operand_name_helper(ctx, name); + shd_print(ctx->printer, "["); + for (size_t i = 0; i < ops.count; i++) { + print_operand_helper(ctx, op_class, ops.nodes[i]); + if (i + 1 < ops.count) + shd_print(ctx->printer, ", "); + } + shd_print(ctx->printer, "]"); } -void log_node(LogLevel level, const Node* node) { - if (level >= get_log_level()) - print_helper(open_file_as_printer(stderr), node, NULL, (PrintConfig) { .color = true }); +void _shd_print_node_operand_AddressSpace(PrinterCtx* ctx, const Node* n, String name, AddressSpace as) { + print_operand_name_helper(ctx, name); + shd_print(ctx->printer, "%s", shd_get_address_space_name(as)); } -void log_module(LogLevel level, CompilerConfig* compiler_cfg, Module* mod) { - PrintConfig config = { .color = true }; - if (compiler_cfg) { - config.skip_generated = compiler_cfg->logging.skip_generated; - config.skip_builtin = compiler_cfg->logging.skip_builtin; - config.skip_internal = compiler_cfg->logging.skip_internal; +void _shd_print_node_operand_Op(PrinterCtx* ctx, const Node* n, String name, Op op) { + print_operand_name_helper(ctx, name); + shd_print(ctx->printer, "%s", shd_get_primop_name(op)); +} + +void _shd_print_node_operand_RecordSpecialFlag(PrinterCtx* ctx, const Node* n, String name, RecordSpecialFlag flags) { + print_operand_name_helper(ctx, name); + if (flags & DecorateBlock) + shd_print(ctx->printer, "DecorateBlock"); +} + +void _shd_print_node_operand_uint32_t(PrinterCtx* ctx, const Node* n, String name, uint32_t i) { + print_operand_name_helper(ctx, name); + shd_print(ctx->printer, "%u", i); +} + +void _shd_print_node_operand_uint64_t(PrinterCtx* ctx, const Node* n, String name, uint64_t i) { + print_operand_name_helper(ctx, name); + shd_print(ctx->printer, "%zu", i); +} + +void _shd_print_node_operand_IntSizes(PrinterCtx* ctx, const Node* n, String name, IntSizes s) { + print_operand_name_helper(ctx, name); + switch (s) { + case IntTy8: shd_print(ctx->printer, "8"); break; + case IntTy16: shd_print(ctx->printer, "16"); break; + case IntTy32: shd_print(ctx->printer, "32"); break; + case IntTy64: shd_print(ctx->printer, "64"); break; } - if (level >= get_log_level()) - print_helper(open_file_as_printer(stderr), NULL, mod, config); } + +void _shd_print_node_operand_FloatSizes(PrinterCtx* ctx, const Node* n, String name, FloatSizes s) { + print_operand_name_helper(ctx, name); + switch (s) { + case FloatTy16: shd_print(ctx->printer, "16"); break; + case FloatTy32: shd_print(ctx->printer, "32"); break; + case FloatTy64: shd_print(ctx->printer, "64"); break; + } +} + +void _shd_print_node_operand_String(PrinterCtx* ctx, const Node* n, String name, String s ){ + print_operand_name_helper(ctx, name); + shd_print(ctx->printer, "\"%s\"", s); +} + +void _shd_print_node_operand_Strings(PrinterCtx* ctx, const Node* n, String name, Strings ops) { + print_operand_name_helper(ctx, name); + shd_print(ctx->printer, "["); + for (size_t i = 0; i < ops.count; i++) { + shd_print(ctx->printer, "\"%s\"", (size_t) ops.strings[i]); + if (i + 1 < ops.count) + shd_print(ctx->printer, ", "); + } + shd_print(ctx->printer, "]"); +} + +void _shd_print_node_operand_bool(PrinterCtx* ctx, const Node* n, String name, bool b) { + print_operand_name_helper(ctx, name); + if (b) + shd_print(ctx->printer, "true"); + else + shd_print(ctx->printer, "false"); +} + +#include "print_generated.c" diff --git a/src/shady/rewrite.c b/src/shady/rewrite.c index c11c7024e..e61c03657 100644 --- a/src/shady/rewrite.c +++ b/src/shady/rewrite.c @@ -1,220 +1,267 @@ -#include "rewrite.h" +#include "shady/rewrite.h" -#include "log.h" #include "ir_private.h" -#include "portability.h" -#include "type.h" +#include "log.h" +#include "portability.h" #include "dict.h" #include +#include -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); -Rewriter create_rewriter(Module* src, Module* dst, RewriteNodeFn fn) { +Rewriter shd_create_rewriter_base(Module* src, Module* dst) { return (Rewriter) { .src_arena = src->arena, .dst_arena = dst->arena, .src_module = src, .dst_module = dst, - .rewrite_fn = fn, .config = { .search_map = true, - //.write_map = true, - .rebind_let = false, - .fold_quote = true, + .write_map = true, }, - .map = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node), - .decls_map = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node), + .map = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .own_decls = true, + .decls_map = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node), + .parent = NULL, }; } -void destroy_rewriter(Rewriter* r) { +Rewriter shd_create_node_rewriter(Module* src, Module* dst, RewriteNodeFn fn) { + Rewriter r = shd_create_rewriter_base(src, dst); + r.rewrite_fn = fn; + return r; +} + +Rewriter shd_create_op_rewriter(Module* src, Module* dst, RewriteOpFn fn) { + Rewriter r = shd_create_rewriter_base(src, dst); + r.config.write_map = false; + r.rewrite_op_fn = fn; + return r; +} + +void shd_destroy_rewriter(Rewriter* r) { assert(r->map); - destroy_dict(r->map); - destroy_dict(r->decls_map); + shd_destroy_dict(r->map); + if (r->own_decls) + shd_destroy_dict(r->decls_map); } -Rewriter create_importer(Module* src, Module* dst) { - return create_rewriter(src, dst, recreate_node_identity); +Rewriter shd_create_importer(Module* src, Module* dst) { + return shd_create_node_rewriter(src, dst, shd_recreate_node); } -Module* rebuild_module(Module* src) { - IrArena* a = get_module_arena(src); - Module* dst = new_module(a, get_module_name(src)); - Rewriter r = create_importer(src, dst); - rewrite_module(&r); - destroy_rewriter(&r); - return dst; +Rewriter shd_create_children_rewriter(Rewriter* parent) { + Rewriter r = *parent; + r.map = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); + r.parent = parent; + r.own_decls = false; + return r; } -const Node* rewrite_node_with_fn(Rewriter* rewriter, const Node* node, RewriteNodeFn fn) { +Rewriter shd_create_decl_rewriter(Rewriter* parent) { + Rewriter r = *parent; + r.map = shd_new_dict(const Node*, Node*, (HashFn) shd_hash_node, (CmpFn) shd_compare_node); + r.own_decls = false; + return r; +} + +static bool should_memoize(const Node* node) { + if (is_declaration(node)) + return false; + if (node->tag == BasicBlock_TAG) + return false; + return true; +} + +const Node* shd_rewrite_node_with_fn(Rewriter* rewriter, const Node* node, RewriteNodeFn fn) { assert(rewriter->rewrite_fn); if (!node) return NULL; - const Node* found = NULL; + const Node** found = NULL; if (rewriter->config.search_map) { - found = search_processed(rewriter, node); + found = shd_search_processed(rewriter, node); } if (found) - return found; + return *found; const Node* rewritten = fn(rewriter, node); + // assert(rewriter->dst_arena == rewritten->arena); if (is_declaration(node)) return rewritten; - if (rewriter->config.write_map) { - register_processed(rewriter, node, rewritten); + if (rewriter->config.write_map && should_memoize(node)) { + shd_register_processed(rewriter, node, rewritten); } return rewritten; } -Nodes rewrite_nodes_with_fn(Rewriter* rewriter, Nodes values, RewriteNodeFn fn) { +Nodes shd_rewrite_nodes_with_fn(Rewriter* rewriter, Nodes values, RewriteNodeFn fn) { LARRAY(const Node*, arr, values.count); for (size_t i = 0; i < values.count; i++) - arr[i] = rewrite_node_with_fn(rewriter, values.nodes[i], fn); - return nodes(rewriter->dst_arena, values.count, arr); + arr[i] = shd_rewrite_node_with_fn(rewriter, values.nodes[i], fn); + return shd_nodes(rewriter->dst_arena, values.count, arr); } -const Node* rewrite_node(Rewriter* rewriter, const Node* node) { +const Node* shd_rewrite_node(Rewriter* rewriter, const Node* node) { assert(rewriter->rewrite_fn); - return rewrite_node_with_fn(rewriter, node, rewriter->rewrite_fn); + return shd_rewrite_node_with_fn(rewriter, node, rewriter->rewrite_fn); } -Nodes rewrite_nodes(Rewriter* rewriter, Nodes old_nodes) { +Nodes shd_rewrite_nodes(Rewriter* rewriter, Nodes old_nodes) { assert(rewriter->rewrite_fn); - return rewrite_nodes_with_fn(rewriter, old_nodes, rewriter->rewrite_fn); + return shd_rewrite_nodes_with_fn(rewriter, old_nodes, rewriter->rewrite_fn); } -const Node* rewrite_op_with_fn(Rewriter* rewriter, NodeClass class, String op_name, const Node* node, RewriteOpFn fn) { +const Node* shd_rewrite_op_with_fn(Rewriter* rewriter, NodeClass class, String op_name, const Node* node, RewriteOpFn fn) { assert(rewriter->rewrite_op_fn); if (!node) return NULL; - const Node* found = NULL; + const Node** found = NULL; if (rewriter->config.search_map) { - found = search_processed(rewriter, node); + found = shd_search_processed(rewriter, node); } if (found) - return found; + return *found; const Node* rewritten = fn(rewriter, class, op_name, node); if (is_declaration(node)) return rewritten; - if (rewriter->config.write_map) { - register_processed(rewriter, node, rewritten); + if (rewriter->config.write_map && should_memoize(node)) { + shd_register_processed(rewriter, node, rewritten); } return rewritten; } -Nodes rewrite_ops_with_fn(Rewriter* rewriter, NodeClass class, String op_name, Nodes values, RewriteOpFn fn) { +Nodes shd_rewrite_ops_with_fn(Rewriter* rewriter, NodeClass class, String op_name, Nodes values, RewriteOpFn fn) { LARRAY(const Node*, arr, values.count); for (size_t i = 0; i < values.count; i++) - arr[i] = rewrite_op_with_fn(rewriter, class, op_name, values.nodes[i], fn); - return nodes(rewriter->dst_arena, values.count, arr); + arr[i] = shd_rewrite_op_with_fn(rewriter, class, op_name, values.nodes[i], fn); + return shd_nodes(rewriter->dst_arena, values.count, arr); } -const Node* rewrite_op(Rewriter* rewriter, NodeClass class, String op_name, const Node* node) { +const Node* shd_rewrite_op(Rewriter* rewriter, NodeClass class, String op_name, const Node* node) { assert(rewriter->rewrite_op_fn); - return rewrite_op_with_fn(rewriter, class, op_name, node, rewriter->rewrite_op_fn); + return shd_rewrite_op_with_fn(rewriter, class, op_name, node, rewriter->rewrite_op_fn); } -Nodes rewrite_ops(Rewriter* rewriter, NodeClass class, String op_name, Nodes old_nodes) { +Nodes shd_rewrite_ops(Rewriter* rewriter, NodeClass class, String op_name, Nodes old_nodes) { assert(rewriter->rewrite_op_fn); - return rewrite_ops_with_fn(rewriter, class, op_name, old_nodes, rewriter->rewrite_op_fn); + return shd_rewrite_ops_with_fn(rewriter, class, op_name, old_nodes, rewriter->rewrite_op_fn); } static const Node* rewrite_op_helper(Rewriter* rewriter, NodeClass class, String op_name, const Node* node) { if (rewriter->rewrite_op_fn) - return rewrite_op_with_fn(rewriter, class, op_name, node, rewriter->rewrite_op_fn); + return shd_rewrite_op_with_fn(rewriter, class, op_name, node, rewriter->rewrite_op_fn); assert(rewriter->rewrite_fn); - return rewrite_node_with_fn(rewriter, node, rewriter->rewrite_fn); + return shd_rewrite_node_with_fn(rewriter, node, rewriter->rewrite_fn); } static Nodes rewrite_ops_helper(Rewriter* rewriter, NodeClass class, String op_name, Nodes old_nodes) { if (rewriter->rewrite_op_fn) - return rewrite_ops_with_fn(rewriter, class, op_name, old_nodes, rewriter->rewrite_op_fn); + return shd_rewrite_ops_with_fn(rewriter, class, op_name, old_nodes, rewriter->rewrite_op_fn); assert(rewriter->rewrite_fn); - return rewrite_nodes_with_fn(rewriter, old_nodes, rewriter->rewrite_fn); + return shd_rewrite_nodes_with_fn(rewriter, old_nodes, rewriter->rewrite_fn); } -const Node* search_processed(const Rewriter* ctx, const Node* old) { - struct Dict* map = is_declaration(old) ? ctx->decls_map : ctx->map; - assert(map && "this rewriter has no processed cache"); - const Node** found = find_value_dict(const Node*, const Node*, map, old); - return found ? *found : NULL; +static const Node** search_processed_(const Rewriter* ctx, const Node* old, bool deep) { + if (is_declaration(old)) { + const Node** found = shd_dict_find_value(const Node*, const Node*, ctx->decls_map, old); + return found ? found : NULL; + } + + while (ctx) { + assert(ctx->map && "this rewriter has no processed cache"); + const Node** found = shd_dict_find_value(const Node*, const Node*, ctx->map, old); + if (found) + return found; + if (deep) + ctx = ctx->parent; + else + ctx = NULL; + } + return NULL; } -const Node* find_processed(const Rewriter* ctx, const Node* old) { - const Node* found = search_processed(ctx, old); +const Node** shd_search_processed(const Rewriter* ctx, const Node* old) { + return search_processed_(ctx, old, true); +} + +const Node* shd_find_processed(const Rewriter* ctx, const Node* old) { + const Node** found = shd_search_processed(ctx, old); assert(found && "this node was supposed to have been processed before"); - return found; + return *found; } -void register_processed(Rewriter* ctx, const Node* old, const Node* new) { +void shd_register_processed(Rewriter* ctx, const Node* old, const Node* new) { assert(old->arena == ctx->src_arena); - assert(new->arena == ctx->dst_arena); + assert(new ? new->arena == ctx->dst_arena : true); #ifndef NDEBUG - const Node* found = search_processed(ctx, old); + const Node** found = search_processed_(ctx, old, false); if (found) { - error_print("Trying to replace "); - log_node(ERROR, old); - error_print(" with "); - log_node(ERROR, new); - error_print(" but there was already "); - log_node(ERROR, found); - error_print("\n"); - error("The same node got processed twice !"); + // this can happen and is typically harmless + // ie: when rewriting a jump into a loop, the outer jump cannot be finished until the loop body is rebuilt + // and therefore the back-edge jump inside the loop will be rebuilt while the outer one isn't done. + // as long as there is no conflict, this is correct, but this might hide perf hazards if we fail to cache things + if (*found == new) + return; + shd_error_print("Trying to replace "); + shd_log_node(ERROR, old); + shd_error_print(" with "); + shd_log_node(ERROR, new); + shd_error_print(" but there was already "); + if (*found) + shd_log_node(ERROR, *found); + else + shd_log_fmt(ERROR, "NULL"); + shd_error_print("\n"); + shd_error("The same node got processed twice !"); } #endif struct Dict* map = is_declaration(old) ? ctx->decls_map : ctx->map; assert(map && "this rewriter has no processed cache"); - bool r = insert_dict_and_get_result(const Node*, const Node*, map, old, new); + bool r = shd_dict_insert_get_result(const Node*, const Node*, map, old, new); assert(r); } -void register_processed_list(Rewriter* rewriter, Nodes old, Nodes new) { +void shd_register_processed_list(Rewriter* rewriter, Nodes old, Nodes new) { assert(old.count == new.count); for (size_t i = 0; i < old.count; i++) - register_processed(rewriter, old.nodes[i], new.nodes[i]); + shd_register_processed(rewriter, old.nodes[i], new.nodes[i]); } -void clear_processed_non_decls(Rewriter* rewriter) { - clear_dict(rewriter->map); -} - -KeyHash hash_node(Node**); -bool compare_node(Node**, Node**); +KeyHash shd_hash_node(Node** pnode); +bool shd_compare_node(Node** pa, Node** pb); #pragma GCC diagnostic error "-Wswitch" #include "rewrite_generated.c" -void rewrite_module(Rewriter* rewriter) { - Nodes old_decls = get_module_declarations(rewriter->src_module); +void shd_rewrite_module(Rewriter* rewriter) { + assert(rewriter->dst_module != rewriter->src_module); + Nodes old_decls = shd_module_get_declarations(rewriter->src_module); for (size_t i = 0; i < old_decls.count; i++) { - if (old_decls.nodes[i]->tag == NominalType_TAG) continue; + if (!shd_lookup_annotation(old_decls.nodes[i], "Exported")) continue; rewrite_op_helper(rewriter, NcDeclaration, "decl", old_decls.nodes[i]); } } -const Node* recreate_variable(Rewriter* rewriter, const Node* old) { - assert(old->tag == Variable_TAG); - return var(rewriter->dst_arena, rewrite_op_helper(rewriter, NcType, "type", old->payload.var.type), old->payload.var.name); +const Node* shd_recreate_param(Rewriter* rewriter, const Node* old) { + assert(old->tag == Param_TAG); + return param(rewriter->dst_arena, rewrite_op_helper(rewriter, NcType, "type", old->payload.param.type), old->payload.param.name); } -Nodes recreate_variables(Rewriter* rewriter, Nodes old) { - LARRAY(const Node*, nvars, old.count); - for (size_t i = 0; i < old.count; i++) { - if (rewriter->config.process_variables) - nvars[i] = rewrite_node(rewriter, old.nodes[i]); - else - nvars[i] = recreate_variable(rewriter, old.nodes[i]); - assert(nvars[i]->tag == Variable_TAG); +Nodes shd_recreate_params(Rewriter* rewriter, Nodes oparams) { + LARRAY(const Node*, nparams, oparams.count); + for (size_t i = 0; i < oparams.count; i++) { + nparams[i] = shd_recreate_param(rewriter, oparams.nodes[i]); + assert(nparams[i]->tag == Param_TAG); } - return nodes(rewriter->dst_arena, old.count, nvars); + return shd_nodes(rewriter->dst_arena, oparams.count, nparams); } -Node* recreate_decl_header_identity(Rewriter* rewriter, const Node* old) { +Node* shd_recreate_node_head(Rewriter* rewriter, const Node* old) { Node* new = NULL; switch (is_declaration(old)) { case GlobalVariable_TAG: { @@ -238,11 +285,11 @@ Node* recreate_decl_header_identity(Rewriter* rewriter, const Node* old) { } case Function_TAG: { Nodes new_annotations = rewrite_ops_helper(rewriter, NcAnnotation, "annotations", old->payload.fun.annotations); - Nodes new_params = recreate_variables(rewriter, old->payload.fun.params); + Nodes new_params = shd_recreate_params(rewriter, old->payload.fun.params); Nodes nyield_types = rewrite_ops_helper(rewriter, NcType, "return_types", old->payload.fun.return_types); new = function(rewriter->dst_module, new_params, old->payload.fun.name, new_annotations, nyield_types); assert(new && new->tag == Function_TAG); - register_processed_list(rewriter, old->payload.fun.params, new->payload.fun.params); + shd_register_processed_list(rewriter, old->payload.fun.params, new->payload.fun.params); break; } case NominalType_TAG: { @@ -250,14 +297,14 @@ Node* recreate_decl_header_identity(Rewriter* rewriter, const Node* old) { new = nominal_type(rewriter->dst_module, new_annotations, old->payload.nom_type.name); break; } - case NotADeclaration: error("not a decl"); + case NotADeclaration: shd_error("not a decl"); } assert(new); - register_processed(rewriter, old, new); + shd_register_processed(rewriter, old, new); return new; } -void recreate_decl_body_identity(Rewriter* rewriter, const Node* old, Node* new) { +void shd_recreate_node_body(Rewriter* rewriter, const Node* old, Node* new) { assert(is_declaration(new)); switch (is_declaration(old)) { case GlobalVariable_TAG: { @@ -265,39 +312,24 @@ void recreate_decl_body_identity(Rewriter* rewriter, const Node* old, Node* new) break; } case Constant_TAG: { - new->payload.constant.instruction = rewrite_op_helper(rewriter, NcInstruction, "instruction", old->payload.constant.instruction); + new->payload.constant.value = rewrite_op_helper(rewriter, NcValue, "value", old->payload.constant.value); // TODO check type now ? break; } case Function_TAG: { assert(new->payload.fun.body == NULL); - new->payload.fun.body = rewrite_op_helper(rewriter, NcTerminator, "body", old->payload.fun.body); + shd_set_abstraction_body(new, rewrite_op_helper(rewriter, NcTerminator, "body", old->payload.fun.body)); break; } case NominalType_TAG: { new->payload.nom_type.body = rewrite_op_helper(rewriter, NcType, "body", old->payload.nom_type.body); break; } - case NotADeclaration: error("not a decl"); + case NotADeclaration: shd_error("not a decl"); } } -static const Node* rebind_results(Rewriter* rewriter, const Node* ninstruction, const Node* olam) { - assert(olam->tag == Case_TAG); - Nodes oparams = olam->payload.case_.params; - Nodes ntypes = unwrap_multiple_yield_types(rewriter->dst_arena, ninstruction->type); - assert(ntypes.count == oparams.count); - LARRAY(const Node*, new_params, oparams.count); - for (size_t i = 0; i < oparams.count; i++) { - new_params[i] = var(rewriter->dst_arena, ntypes.nodes[i], oparams.nodes[i]->payload.var.name); - register_processed(rewriter, oparams.nodes[i], new_params[i]); - } - const Node* nbody = rewrite_node(rewriter, olam->payload.case_.body); - const Node* tail = case_(rewriter->dst_arena, nodes(rewriter->dst_arena, oparams.count, new_params), nbody); - return tail; -} - -const Node* recreate_node_identity(Rewriter* rewriter, const Node* node) { +const Node* shd_recreate_node(Rewriter* rewriter, const Node* node) { if (node == NULL) return NULL; @@ -312,54 +344,35 @@ const Node* recreate_node_identity(Rewriter* rewriter, const Node* node) { case Constant_TAG: case GlobalVariable_TAG: case NominalType_TAG: { - Node* new = recreate_decl_header_identity(rewriter, node); - recreate_decl_body_identity(rewriter, node, new); + Node* new = shd_recreate_node_head(rewriter, node); + shd_recreate_node_body(rewriter, node, new); return new; } - case Variable_TAG: error("variables should be recreated as part of decl handling"); - case Let_TAG: { - const Node* instruction = rewrite_op_helper(rewriter, NcInstruction, "instruction", node->payload.let.instruction); - if (arena->config.allow_fold && rewriter->config.fold_quote && instruction->tag == PrimOp_TAG && instruction->payload.prim_op.op == quote_op) { - Nodes old_params = node->payload.let.tail->payload.case_.params; - Nodes new_args = instruction->payload.prim_op.operands; - assert(old_params.count == new_args.count); - register_processed_list(rewriter, old_params, new_args); - for (size_t i = 0; i < old_params.count; i++) { - String old_name = get_value_name(old_params.nodes[i]); - if (!old_name) continue; - const Node* new_arg = new_args.nodes[i]; - if (new_arg->tag == Variable_TAG && !get_value_name(new_arg)) { - set_variable_name((Node*) new_arg, old_name); - } - } - return rewrite_op_helper(rewriter, NcTerminator, "body", node->payload.let.tail->payload.case_.body); - } - const Node* tail; - if (rewriter->config.rebind_let) - tail = rebind_results(rewriter, instruction, node->payload.let.tail); - else - tail = rewrite_op_helper(rewriter, NcCase, "tail", node->payload.let.tail); - return let(arena, instruction, tail); - } - case LetMut_TAG: error("De-sugar this by hand") - case Case_TAG: { - Nodes params = recreate_variables(rewriter, node->payload.case_.params); - register_processed_list(rewriter, node->payload.case_.params, params); - const Node* nterminator = rewrite_op_helper(rewriter, NcTerminator, "body", node->payload.case_.body); - const Node* nlam = case_(rewriter->dst_arena, params, nterminator); - // register_processed(rewriter, node, nlam); - return nlam; - } + case Param_TAG: + shd_log_fmt(ERROR, "Can't rewrite: "); + shd_log_node(ERROR, node); + shd_log_fmt(ERROR, ", params should be rewritten by the abstraction rewrite logic"); + shd_error_die(); case BasicBlock_TAG: { - Nodes params = recreate_variables(rewriter, node->payload.basic_block.params); - register_processed_list(rewriter, node->payload.basic_block.params, params); - const Node* fn = rewrite_op_helper(rewriter, NcDeclaration, "fn", node->payload.basic_block.fn); - Node* bb = basic_block(arena, (Node*) fn, params, node->payload.basic_block.name); - register_processed(rewriter, node, bb); + Nodes params = shd_recreate_params(rewriter, node->payload.basic_block.params); + shd_register_processed_list(rewriter, node->payload.basic_block.params, params); + Node* bb = basic_block(arena, params, node->payload.basic_block.name); + shd_register_processed(rewriter, node, bb); const Node* nterminator = rewrite_op_helper(rewriter, NcTerminator, "body", node->payload.basic_block.body); - bb->payload.basic_block.body = nterminator; + shd_set_abstraction_body(bb, nterminator); return bb; } } assert(false); } + +void shd_dump_rewriter_map(Rewriter* r) { + size_t i = 0; + const Node* src, *dst; + while (shd_dict_iter(r->map, &i, &src, &dst)) { + shd_log_node(ERROR, src); + shd_log_fmt(ERROR, " -> "); + shd_log_node(ERROR, dst); + shd_log_fmt(ERROR, "\n"); + } +} \ No newline at end of file diff --git a/src/shady/rewrite.h b/src/shady/rewrite.h deleted file mode 100644 index 2ba43e86e..000000000 --- a/src/shady/rewrite.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef SHADY_REWRITE_H -#define SHADY_REWRITE_H - -#include "shady/ir.h" - -typedef struct Rewriter_ Rewriter; - -typedef const Node* (*RewriteNodeFn)(Rewriter*, const Node*); -typedef const Node* (*RewriteOpFn)(Rewriter*, NodeClass, String, const Node*); - -const Node* rewrite_node(Rewriter*, const Node*); -const Node* rewrite_node_with_fn(Rewriter*, const Node*, RewriteNodeFn); - -const Node* rewrite_op(Rewriter*, NodeClass, String, const Node*); -const Node* rewrite_op_with_fn(Rewriter*, NodeClass, String, const Node*, RewriteOpFn); - -/// Applies the rewriter to all nodes in the collection -Nodes rewrite_nodes(Rewriter*, Nodes); -Nodes rewrite_nodes_with_fn(Rewriter* rewriter, Nodes values, RewriteNodeFn fn); - -Nodes rewrite_ops(Rewriter*, NodeClass, String, Nodes); -Nodes rewrite_ops_with_fn(Rewriter* rewriter, NodeClass,String, Nodes values, RewriteOpFn fn); - -Strings import_strings(IrArena*, Strings); - -struct Rewriter_ { - RewriteNodeFn rewrite_fn; - RewriteOpFn rewrite_op_fn; - IrArena* src_arena; - IrArena* dst_arena; - Module* src_module; - Module* dst_module; - struct { - bool search_map; - bool write_map; - bool rebind_let; - bool fold_quote; - bool process_variables; - } config; - struct Dict* map; - struct Dict* decls_map; -}; - -Rewriter create_rewriter(Module* src, Module* dst, RewriteNodeFn fn); -Rewriter create_importer(Module* src, Module* dst); -Module* rebuild_module(Module*); -Rewriter create_substituter(Module* arena); -void destroy_rewriter(Rewriter*); - -void rewrite_module(Rewriter*); - -/// Rewrites a node using the rewriter to provide the node and type operands -const Node* recreate_node_identity(Rewriter*, const Node*); - -/// Rewrites a constant / function header -Node* recreate_decl_header_identity(Rewriter*, const Node*); -void recreate_decl_body_identity(Rewriter*, const Node*, Node*); - -/// Rewrites a variable under a new identity -const Node* recreate_variable(Rewriter* rewriter, const Node* old); -Nodes recreate_variables(Rewriter* rewriter, Nodes old); - -/// Looks up if the node was already processed -const Node* search_processed(const Rewriter*, const Node*); -/// Same as search_processed but asserts if it fails to find a mapping -const Node* find_processed(const Rewriter*, const Node*); -void register_processed(Rewriter*, const Node*, const Node*); -void register_processed_list(Rewriter*, Nodes, Nodes); -void clear_processed_non_decls(Rewriter*); - -#endif diff --git a/src/shady/transform/CMakeLists.txt b/src/shady/transform/CMakeLists.txt new file mode 100644 index 000000000..0666502db --- /dev/null +++ b/src/shady/transform/CMakeLists.txt @@ -0,0 +1,3 @@ +target_sources(shady PRIVATE + internal_constants.c +) diff --git a/src/shady/transform/internal_constants.c b/src/shady/transform/internal_constants.c index 0c61bcc39..a5250c287 100644 --- a/src/shady/transform/internal_constants.c +++ b/src/shady/transform/internal_constants.c @@ -1,14 +1,16 @@ #include "internal_constants.h" #include "portability.h" +#include "ir_private.h" #include -void generate_dummy_constants(SHADY_UNUSED CompilerConfig* config, Module* mod) { - IrArena* arena = get_module_arena(mod); -#define X(name, T, placeholder) \ - Node* name##_var = constant(mod, nodes(arena, 0, NULL), T, #name); \ - name##_var->payload.constant.instruction = quote_helper(arena, singleton(placeholder)); +void shd_generate_dummy_constants(SHADY_UNUSED const CompilerConfig* config, Module* mod) { + IrArena* arena = shd_module_get_arena(mod); + Nodes annotations = mk_nodes(arena, annotation(arena, (Annotation) { .name = "RetainAfterSpecialization" }), annotation(arena, (Annotation) { .name = "Exported" })); +#define X(constant_name, T, placeholder) \ + Node* constant_name##_var = constant(mod, annotations, T, #constant_name); \ + constant_name##_var->payload.constant.value = placeholder; INTERNAL_CONSTANTS(X) #undef X } diff --git a/src/shady/transform/internal_constants.h b/src/shady/transform/internal_constants.h index 1e4d3ebda..46fb99ee5 100644 --- a/src/shady/transform/internal_constants.h +++ b/src/shady/transform/internal_constants.h @@ -4,9 +4,10 @@ #include "shady/ir.h" #define INTERNAL_CONSTANTS(X) \ -X(SUBGROUP_SIZE, int32_type(arena), uint32_literal(arena, 8)) \ -X(SUBGROUPS_PER_WG, int32_type(arena), uint32_literal(arena, 1)) \ +X(SUBGROUP_SIZE, shd_uint32_type(arena), shd_uint32_literal(arena, 64)) \ +X(SUBGROUPS_PER_WG, shd_uint32_type(arena), shd_uint32_literal(arena, 1)) \ -void generate_dummy_constants(CompilerConfig* config, Module*); +typedef struct CompilerConfig_ CompilerConfig; +void shd_generate_dummy_constants(const CompilerConfig* config, Module* mod); #endif diff --git a/src/shady/transform/ir_gen_helpers.c b/src/shady/transform/ir_gen_helpers.c deleted file mode 100644 index 0c6fcfda3..000000000 --- a/src/shady/transform/ir_gen_helpers.c +++ /dev/null @@ -1,224 +0,0 @@ -#include "ir_gen_helpers.h" - -#include "list.h" -#include "portability.h" -#include "log.h" -#include "util.h" - -#include "../ir_private.h" -#include "../type.h" -#include "../rewrite.h" - -#include -#include - -Nodes gen_primop(BodyBuilder* bb, Op op, Nodes type_args, Nodes operands) { - assert(bb->arena->config.check_types); - const Node* instruction = prim_op(bb->arena, (PrimOp) { .op = op, .type_arguments = type_args, .operands = operands }); - return bind_instruction(bb, instruction); -} - -Nodes gen_primop_c(BodyBuilder* bb, Op op, size_t operands_count, const Node* operands[]) { - return gen_primop(bb, op, empty(bb->arena), nodes(bb->arena, operands_count, operands)); -} - -const Node* gen_primop_ce(BodyBuilder* bb, Op op, size_t operands_count, const Node* operands[]) { - Nodes result = gen_primop_c(bb, op, operands_count, operands); - assert(result.count == 1); - return result.nodes[0]; -} - -const Node* gen_primop_e(BodyBuilder* bb, Op op, Nodes ty, Nodes nodes) { - Nodes result = gen_primop(bb, op, ty, nodes); - return first(result); -} - -void gen_push_value_stack(BodyBuilder* bb, const Node* value) { - gen_primop(bb, push_stack_op, singleton(get_unqualified_type(value->type)), singleton(value)); -} - -void gen_push_values_stack(BodyBuilder* bb, Nodes values) { - for (size_t i = values.count - 1; i < values.count; i--) { - const Node* value = values.nodes[i]; - gen_push_value_stack(bb, value); - } -} - -const Node* gen_pop_value_stack(BodyBuilder* bb, const Type* type) { - const Node* instruction = prim_op(bb->arena, (PrimOp) { .op = pop_stack_op, .type_arguments = nodes(bb->arena, 1, (const Node*[]) { type }) }); - return first(bind_instruction(bb, instruction)); -} - -const Node* gen_reinterpret_cast(BodyBuilder* bb, const Type* dst, const Node* src) { - assert(is_type(dst)); - return first(bind_instruction(bb, prim_op(bb->arena, (PrimOp) { .op = reinterpret_op, .operands = singleton(src), .type_arguments = singleton(dst)}))); -} - -const Node* gen_conversion(BodyBuilder* bb, const Type* dst, const Node* src) { - assert(is_type(dst)); - return first(bind_instruction(bb, prim_op(bb->arena, (PrimOp) { .op = convert_op, .operands = singleton(src), .type_arguments = singleton(dst)}))); -} - -const Node* gen_merge_halves(BodyBuilder* bb, const Node* lo, const Node* hi) { - const Type* src_type = get_unqualified_type(lo->type); - assert(get_unqualified_type(hi->type) == src_type); - assert(src_type->tag == Int_TAG); - IntSizes size = src_type->payload.int_type.width; - assert(size != IntSizeMax); - const Type* dst_type = int_type(bb->arena, (Int) { .width = size + 1, .is_signed = src_type->payload.int_type.is_signed }); - // widen them - lo = gen_conversion(bb, dst_type, lo); - hi = gen_conversion(bb, dst_type, hi); - // shift hi - const Node* shift_by = int_literal(bb->arena, (IntLiteral) { .width = size + 1, .is_signed = src_type->payload.int_type.is_signed, .value = get_type_bitwidth(src_type) }); - hi = gen_primop_ce(bb, lshift_op, 2, (const Node* []) { hi, shift_by}); - // Merge the two - return gen_primop_ce(bb, or_op, 2, (const Node* []) { lo, hi }); -} - -const Node* gen_load(BodyBuilder* bb, const Node* ptr) { - return gen_primop_ce(bb, load_op, 1, (const Node* []) {ptr }); -} - -void gen_store(BodyBuilder* instructions, const Node* ptr, const Node* value) { - gen_primop_c(instructions, store_op, 2, (const Node* []) { ptr, value }); -} - -const Node* gen_lea(BodyBuilder* bb, const Node* base, const Node* offset, Nodes selectors) { - LARRAY(const Node*, ops, 2 + selectors.count); - ops[0] = base; - ops[1] = offset; - for (size_t i = 0; i < selectors.count; i++) - ops[2 + i] = selectors.nodes[i]; - return gen_primop_ce(bb, lea_op, 2 + selectors.count, ops); -} - -const Node* gen_extract(BodyBuilder* bb, const Node* base, Nodes selectors) { - LARRAY(const Node*, ops, 1 + selectors.count); - ops[0] = base; - for (size_t i = 0; i < selectors.count; i++) - ops[1 + i] = selectors.nodes[i]; - return gen_primop_ce(bb, extract_op, 1 + selectors.count, ops); -} - -void gen_comment(BodyBuilder* bb, String str) { - bind_instruction(bb, comment(bb->arena, (Comment) { .string = str })); -} - -const Node* get_builtin(Module* m, Builtin b, String n) { - Nodes decls = get_module_declarations(m); - for (size_t i = 0; i < decls.count; i++) { - const Node* decl = decls.nodes[i]; - if (decl->tag != GlobalVariable_TAG) - continue; - const Node* a = lookup_annotation(decl, "Builtin"); - if (!a) - continue; - String builtin_name = get_annotation_string_payload(a); - assert(builtin_name); - if (strcmp(builtin_name, get_builtin_name(b)) == 0) - return decl; - } - - AddressSpace as = get_builtin_as(b); - IrArena* a = get_module_arena(m); - Node* decl = global_var(m, singleton(annotation_value_helper(a, "Builtin", string_lit_helper(a, get_builtin_name(b)))), get_builtin_type(a, b), n ? n : format_string_arena(a->arena, "builtin_%s", get_builtin_name(b)), as); - return decl; -} - -const Node* gen_builtin_load(Module* m, BodyBuilder* bb, Builtin b) { - return gen_load(bb, ref_decl_helper(bb->arena, get_builtin(m, b, NULL))); -} - -bool is_builtin_load_op(const Node* n, Builtin* out) { - assert(is_instruction(n)); - if (n->tag == PrimOp_TAG && n->payload.prim_op.op == load_op) { - const Node* src = first(n->payload.prim_op.operands); - if (src->tag == RefDecl_TAG) - src = src->payload.ref_decl.decl; - if (src->tag == GlobalVariable_TAG) { - const Node* a = lookup_annotation(src, "Builtin"); - if (a) { - String bn = get_annotation_string_payload(a); - assert(bn); - Builtin b = get_builtin_by_name(bn); - if (b != BuiltinsCount) { - *out = b; - return true; - } - } - } - } - return false; -} - -const Node* find_or_process_decl(Rewriter* rewriter, const char* name) { - Nodes old_decls = get_module_declarations(rewriter->src_module); - for (size_t i = 0; i < old_decls.count; i++) { - const Node* decl = old_decls.nodes[i]; - if (strcmp(get_decl_name(decl), name) == 0) { - return rewrite_node(rewriter, decl); - } - } - assert(false); -} - -const Node* access_decl(Rewriter* rewriter, const char* name) { - const Node* decl = find_or_process_decl(rewriter, name); - if (decl->tag == Function_TAG) - return fn_addr_helper(rewriter->dst_arena, decl); - else - return ref_decl_helper(rewriter->dst_arena, decl); -} - -const Node* convert_int_extend_according_to_src_t(BodyBuilder* bb, const Type* dst_type, const Node* src) { - const Type* src_type = get_unqualified_type(src->type); - assert(src_type->tag == Int_TAG); - assert(dst_type->tag == Int_TAG); - - // first convert to final bitsize then bitcast - const Type* extended_src_t = int_type(bb->arena, (Int) { .width = dst_type->payload.int_type.width, .is_signed = src_type->payload.int_type.is_signed }); - const Node* val = src; - val = gen_primop_e(bb, convert_op, singleton(extended_src_t), singleton(val)); - val = gen_primop_e(bb, reinterpret_op, singleton(dst_type), singleton(val)); - return val; -} - -const Node* convert_int_extend_according_to_dst_t(BodyBuilder* bb, const Type* dst_type, const Node* src) { - const Type* src_type = get_unqualified_type(src->type); - assert(src_type->tag == Int_TAG); - assert(dst_type->tag == Int_TAG); - - // first bitcast then convert to final bitsize - const Type* reinterpreted_src_t = int_type(bb->arena, (Int) { .width = src_type->payload.int_type.width, .is_signed = dst_type->payload.int_type.is_signed }); - const Node* val = src; - val = gen_primop_e(bb, reinterpret_op, singleton(reinterpreted_src_t), singleton(val)); - val = gen_primop_e(bb, convert_op, singleton(dst_type), singleton(val)); - return val; -} - -const Node* get_default_zero_value(IrArena* a, const Type* t) { - switch (is_type(t)) { - case NotAType: error("") - case Type_Int_TAG: return int_literal(a, (IntLiteral) { .width = t->payload.int_type.width, .is_signed = t->payload.int_type.is_signed, .value = 0 }); - case Type_Float_TAG: return float_literal(a, (FloatLiteral) { .width = t->payload.float_type.width, .value = 0 }); - case Type_Bool_TAG: return false_lit(a); - case Type_PtrType_TAG: return null_ptr(a, (NullPtr) { .ptr_type = t }); - case Type_QualifiedType_TAG: return get_default_zero_value(a, t->payload.qualified_type.type); - case Type_RecordType_TAG: - case Type_ArrType_TAG: - case Type_PackType_TAG: - case Type_TypeDeclRef_TAG: { - Nodes elem_tys = get_composite_type_element_types(t); - if (elem_tys.count >= 1024) { - warn_print("Potential performance issue: creating a really composite full of zero/default values (size=%d)!\n", elem_tys.count); - } - LARRAY(const Node*, elems, elem_tys.count); - for (size_t i = 0; i < elem_tys.count; i++) - elems[i] = get_default_zero_value(a, elem_tys.nodes[i]); - return composite_helper(a, t, nodes(a, elem_tys.count, elems)); - } - default: break; - } - return NULL; -} diff --git a/src/shady/transform/ir_gen_helpers.h b/src/shady/transform/ir_gen_helpers.h deleted file mode 100644 index d93803798..000000000 --- a/src/shady/transform/ir_gen_helpers.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef SHADY_IR_GEN_HELPERS_H -#define SHADY_IR_GEN_HELPERS_H - -#include "shady/ir.h" -#include "shady/builtins.h" - -void gen_push_value_stack(BodyBuilder* bb, const Node* value); -void gen_push_values_stack(BodyBuilder* bb, Nodes values); -const Node* gen_pop_value_stack(BodyBuilder*, const Type* type); - -Nodes gen_primop(BodyBuilder*, Op, Nodes, Nodes); -Nodes gen_primop_c(BodyBuilder*, Op op, size_t operands_count, const Node* operands[]); -const Node* gen_primop_ce(BodyBuilder*, Op op, size_t operands_count, const Node* operands[]); -const Node* gen_primop_e(BodyBuilder*, Op op, Nodes, Nodes); - -const Node* gen_reinterpret_cast(BodyBuilder*, const Type* dst, const Node* src); -const Node* gen_conversion(BodyBuilder*, const Type* dst, const Node* src); -const Node* gen_merge_halves(BodyBuilder*, const Node* lo, const Node* hi); - -const Node* gen_load(BodyBuilder*, const Node* ptr); -void gen_store(BodyBuilder*, const Node* ptr, const Node* value); -const Node* gen_lea(BodyBuilder*, const Node* base, const Node* offset, Nodes selectors); -const Node* gen_extract(BodyBuilder*, const Node* base, Nodes selectors); -void gen_comment(BodyBuilder*, String str); -const Node* get_builtin(Module* m, Builtin b, String n); -const Node* gen_builtin_load(Module*, BodyBuilder*, Builtin); - -typedef struct Rewriter_ Rewriter; - -const Node* find_or_process_decl(Rewriter*, const char* name); -const Node* access_decl(Rewriter*, const char* name); - -const Node* convert_int_extend_according_to_src_t(BodyBuilder*, const Type* dst_type, const Node* src); -const Node* convert_int_extend_according_to_dst_t(BodyBuilder*, const Type* dst_type, const Node* src); - -const Node* get_default_zero_value(IrArena*, const Type*); - -bool is_builtin_load_op(const Node*, Builtin*); - -#endif diff --git a/src/shady/transform/memory_layout.c b/src/shady/transform/memory_layout.c deleted file mode 100644 index 6e5341fab..000000000 --- a/src/shady/transform/memory_layout.c +++ /dev/null @@ -1,137 +0,0 @@ -#include "memory_layout.h" -#include "ir_gen_helpers.h" - -#include "log.h" -#include "portability.h" - -#include "../type.h" - -#include - -inline static size_t round_up(size_t a, size_t b) { - if (b == 0) - return a; - size_t divided = (a + b - 1) / b; - return divided * b; -} - -static int maxof(int a, int b) { - if (a > b) - return a; - return b; -} - -TypeMemLayout get_record_layout(IrArena* a, const Node* record_type, FieldLayout* fields) { - assert(record_type->tag == RecordType_TAG); - - size_t offset = 0; - size_t max_align = 0; - - Nodes member_types = record_type->payload.record_type.members; - for (size_t i = 0; i < member_types.count; i++) { - TypeMemLayout member_layout = get_mem_layout(a, member_types.nodes[i]); - offset = round_up(offset, member_layout.alignment_in_bytes); - if (fields) { - fields[i].mem_layout = member_layout; - fields[i].offset_in_bytes = offset; - } - offset += member_layout.size_in_bytes; - if (member_layout.alignment_in_bytes > max_align) - max_align = member_layout.alignment_in_bytes; - } - - return (TypeMemLayout) { - .type = record_type, - .size_in_bytes = round_up(offset, max_align), - .alignment_in_bytes = max_align, - }; -} - -size_t get_record_field_offset_in_bytes(IrArena* a, const Type* t, size_t i) { - assert(t->tag == RecordType_TAG); - Nodes member_types = t->payload.record_type.members; - assert(i < member_types.count); - LARRAY(FieldLayout, fields, member_types.count); - get_record_layout(a, t, fields); - return fields[i].offset_in_bytes; -} - -TypeMemLayout get_mem_layout(IrArena* a, const Type* type) { - size_t base_word_size = int_size_in_bytes(a->config.memory.word_size); - assert(is_type(type)); - switch (type->tag) { - case FnType_TAG: error("Functions have an opaque memory representation"); - case PtrType_TAG: switch (type->payload.ptr_type.address_space) { - case AsPrivatePhysical: - case AsSubgroupPhysical: - case AsSharedPhysical: - case AsGlobalPhysical: - case AsGeneric: return get_mem_layout(a, int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false })); - default: error_print("as: %d", type->payload.ptr_type.address_space); error("unhandled address space") - } - case Int_TAG: return (TypeMemLayout) { - .type = type, - .size_in_bytes = int_size_in_bytes(type->payload.int_type.width), - .alignment_in_bytes = maxof(int_size_in_bytes(type->payload.int_type.width), base_word_size), - }; - case Float_TAG: return (TypeMemLayout) { - .type = type, - .size_in_bytes = float_size_in_bytes(type->payload.float_type.width), - .alignment_in_bytes = maxof(float_size_in_bytes(type->payload.float_type.width), base_word_size), - }; - case Bool_TAG: return (TypeMemLayout) { - .type = type, - .size_in_bytes = base_word_size, - .alignment_in_bytes = base_word_size, - }; - case ArrType_TAG: { - const Node* size = type->payload.arr_type.size; - assert(size && "We can't know the full layout of arrays of unknown size !"); - size_t actual_size = get_int_literal_value(*resolve_to_int_literal(size), false); - TypeMemLayout element_layout = get_mem_layout(a, type->payload.arr_type.element_type); - return (TypeMemLayout) { - .type = type, - .size_in_bytes = actual_size * element_layout.size_in_bytes, - .alignment_in_bytes = element_layout.alignment_in_bytes - }; - } - case PackType_TAG: { - size_t width = type->payload.pack_type.width; - TypeMemLayout element_layout = get_mem_layout(a, type->payload.pack_type.element_type); - return (TypeMemLayout) { - .type = type, - .size_in_bytes = width * element_layout.size_in_bytes /* TODO Vulkan vec3 -> vec4 alignment rules ? */, - .alignment_in_bytes = element_layout.alignment_in_bytes - }; - } - case QualifiedType_TAG: return get_mem_layout(a, type->payload.qualified_type.type); - case TypeDeclRef_TAG: return get_mem_layout(a, type->payload.type_decl_ref.decl->payload.nom_type.body); - case RecordType_TAG: return get_record_layout(a, type, NULL); - default: error("not a known type"); - } -} - -const Node* size_t_literal(IrArena* a, uint64_t value) { - return int_literal(a, (IntLiteral) { .width = a->config.memory.ptr_size, .is_signed = false, .value = value }); -} - -const Node* bytes_to_words(BodyBuilder* bb, const Node* bytes) { - IrArena* a = bytes->arena; - const Type* word_type = int_type(a, (Int) { .width = a->config.memory.word_size, .is_signed = false }); - size_t word_width = get_type_bitwidth(word_type); - const Node* bytes_per_word = size_t_literal(a, word_width / 8); - return gen_primop_e(bb, div_op, empty(a), mk_nodes(a, bytes, bytes_per_word)); -} - -uint64_t bytes_to_words_static(const IrArena* a, uint64_t bytes) { - uint64_t word_width = int_size_in_bytes(a->config.memory.word_size); - return bytes / word_width; -} - -IntSizes float_to_int_width(FloatSizes width) { - switch (width) { - case FloatTy16: return IntTy16; - case FloatTy32: return IntTy32; - case FloatTy64: return IntTy64; - } -} diff --git a/src/shady/transform/memory_layout.h b/src/shady/transform/memory_layout.h deleted file mode 100644 index 925041e32..000000000 --- a/src/shady/transform/memory_layout.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef SHADY_MEMORY_LAYOUT_H -#define SHADY_MEMORY_LAYOUT_H - -#include "../ir_private.h" - -typedef struct { - const Type* type; - size_t size_in_bytes; - size_t alignment_in_bytes; -} TypeMemLayout; - -typedef struct { - TypeMemLayout mem_layout; - size_t offset_in_bytes; -} FieldLayout; - -TypeMemLayout get_mem_layout(IrArena*, const Type*); - -TypeMemLayout get_record_layout(IrArena* a, const Node* record_type, FieldLayout* fields); -size_t get_record_field_offset_in_bytes(IrArena*, const Type*, size_t); - -const Node* size_t_literal(IrArena* a, uint64_t value); -const Node* bytes_to_words(BodyBuilder* bb, const Node* bytes); -uint64_t bytes_to_words_static(const IrArena*, uint64_t bytes); -IntSizes float_to_int_width(FloatSizes width); - -#endif diff --git a/src/shady/type.c b/src/shady/type.c deleted file mode 100644 index e7f2c9580..000000000 --- a/src/shady/type.c +++ /dev/null @@ -1,1332 +0,0 @@ -#include "type.h" - -#include "log.h" -#include "ir_private.h" -#include "portability.h" -#include "dict.h" -#include "util.h" - -#include "shady/builtins.h" - -#include -#include - -#pragma GCC diagnostic error "-Wswitch" - -static bool are_types_identical(size_t num_types, const Type* types[]) { - for (size_t i = 0; i < num_types; i++) { - assert(types[i]); - if (types[0] != types[i]) - return false; - } - return true; -} - -bool is_subtype(const Type* supertype, const Type* type) { - assert(supertype && type); - if (supertype->tag != type->tag) - return false; - switch (is_type(supertype)) { - case NotAType: error("supplied not a type to is_subtype"); - case QualifiedType_TAG: { - // uniform T <: varying T - if (supertype->payload.qualified_type.is_uniform && !type->payload.qualified_type.is_uniform) - return false; - return is_subtype(supertype->payload.qualified_type.type, type->payload.qualified_type.type); - } - case RecordType_TAG: { - const Nodes* supermembers = &supertype->payload.record_type.members; - const Nodes* members = &type->payload.record_type.members; - for (size_t i = 0; i < members->count; i++) { - if (!is_subtype(supermembers->nodes[i], members->nodes[i])) - return false; - } - return true; - } - case JoinPointType_TAG: { - const Nodes* superparams = &supertype->payload.join_point_type.yield_types; - const Nodes* params = &type->payload.join_point_type.yield_types; - if (params->count != superparams->count) return false; - for (size_t i = 0; i < params->count; i++) { - if (!is_subtype(params->nodes[i], superparams->nodes[i])) - return false; - } - return true; - } - case FnType_TAG: { - // check returns - if (supertype->payload.fn_type.return_types.count != type->payload.fn_type.return_types.count) - return false; - for (size_t i = 0; i < type->payload.fn_type.return_types.count; i++) - if (!is_subtype(supertype->payload.fn_type.return_types.nodes[i], type->payload.fn_type.return_types.nodes[i])) - return false; - // check params - const Nodes* superparams = &supertype->payload.fn_type.param_types; - const Nodes* params = &type->payload.fn_type.param_types; - if (params->count != superparams->count) return false; - for (size_t i = 0; i < params->count; i++) { - if (!is_subtype(params->nodes[i], superparams->nodes[i])) - return false; - } - return true; - } case BBType_TAG: { - // check params - const Nodes* superparams = &supertype->payload.bb_type.param_types; - const Nodes* params = &type->payload.bb_type.param_types; - if (params->count != superparams->count) return false; - for (size_t i = 0; i < params->count; i++) { - if (!is_subtype(params->nodes[i], superparams->nodes[i])) - return false; - } - return true; - } case LamType_TAG: { - // check params - const Nodes* superparams = &supertype->payload.lam_type.param_types; - const Nodes* params = &type->payload.lam_type.param_types; - if (params->count != superparams->count) return false; - for (size_t i = 0; i < params->count; i++) { - if (!is_subtype(params->nodes[i], superparams->nodes[i])) - return false; - } - return true; - } case PtrType_TAG: { - if (supertype->payload.ptr_type.address_space != type->payload.ptr_type.address_space) - return false; - // if either pointer type is untyped, both need to be - if (supertype->arena->config.untyped_ptrs && (!supertype->payload.ptr_type.pointed_type || !type->payload.ptr_type.pointed_type)) - return !supertype->payload.ptr_type.pointed_type && !type->payload.ptr_type.pointed_type; - return is_subtype(supertype->payload.ptr_type.pointed_type, type->payload.ptr_type.pointed_type); - } - case Int_TAG: return supertype->payload.int_type.width == type->payload.int_type.width && supertype->payload.int_type.is_signed == type->payload.int_type.is_signed; - case ArrType_TAG: { - if (!is_subtype(supertype->payload.arr_type.element_type, type->payload.arr_type.element_type)) - return false; - // unsized arrays are supertypes of sized arrays (even though they're not datatypes...) - // TODO: maybe change this so it's only valid when talking about to pointer-to-arrays - const IntLiteral* size_literal = resolve_to_int_literal(supertype->payload.arr_type.size); - if (size_literal && size_literal->value == 0) - return true; - return supertype->payload.arr_type.size == type->payload.arr_type.size || !supertype->payload.arr_type.size; - } - case PackType_TAG: { - if (!is_subtype(supertype->payload.pack_type.element_type, type->payload.pack_type.element_type)) - return false; - return supertype->payload.pack_type.width == type->payload.pack_type.width; - } - case Type_TypeDeclRef_TAG: { - return supertype->payload.type_decl_ref.decl == type->payload.type_decl_ref.decl; - } - case Type_ImageType_TAG: { - if (!is_subtype(supertype->payload.image_type.sampled_type, type->payload.image_type.sampled_type)) - return false; - if (supertype->payload.image_type.depth != type->payload.image_type.depth) - return false; - if (supertype->payload.image_type.dim != type->payload.image_type.dim) - return false; - if (supertype->payload.image_type.onion != type->payload.image_type.onion) - return false; - if (supertype->payload.image_type.multisample != type->payload.image_type.multisample) - return false; - if (supertype->payload.image_type.sampled != type->payload.image_type.sampled) - return false; - return true; - } - case Type_CombinedImageSamplerType_TAG: - return is_subtype(supertype->payload.combined_image_sampler_type.image_type, type->payload.combined_image_sampler_type.image_type); - case SamplerType_TAG: - case NoRet_TAG: - case Bool_TAG: - case MaskType_TAG: - return true; - case Float_TAG: - return supertype->payload.float_type.width == type->payload.float_type.width; - } - SHADY_UNREACHABLE; -} - -void check_subtype(const Type* supertype, const Type* type) { - if (!is_subtype(supertype, type)) { - log_node(ERROR, type); - error_print(" isn't a subtype of "); - log_node(ERROR, supertype); - error_print("\n"); - error("failed check_subtype") - } -} - -size_t get_type_bitwidth(const Type* t) { - switch (t->tag) { - case Int_TAG: return int_size_in_bytes(t->payload.int_type.width) * 8; - case Float_TAG: return float_size_in_bytes(t->payload.float_type.width) * 8; - case PtrType_TAG: { - if (is_physical_as(t->payload.ptr_type.address_space)) - return int_size_in_bytes(t->arena->config.memory.ptr_size) * 8; - break; - } - default: break; - } - return SIZE_MAX; -} - -bool is_addr_space_uniform(IrArena* arena, AddressSpace as) { - switch (as) { - case AsFunctionLogical: - case AsPrivateLogical: - case AsPrivatePhysical: - case AsInput: - return !arena->config.is_simt; - default: - return true; - } -} - -const Type* get_actual_mask_type(IrArena* arena) { - switch (arena->config.specializations.subgroup_mask_representation) { - case SubgroupMaskAbstract: return mask_type(arena); - case SubgroupMaskInt64: return uint64_type(arena); - default: assert(false); - } -} - -String name_type_safe(IrArena* arena, const Type* t) { - switch (is_type(t)) { - case NotAType: assert(false); - case Type_MaskType_TAG: return "mask_t"; - case Type_JoinPointType_TAG: return "join_type_t"; - case Type_NoRet_TAG: return "no_ret"; - case Type_Int_TAG: - if (t->payload.int_type.is_signed) - return format_string_arena(arena->arena, "i%s", ((String[]) {"8", "16", "32", "64" })[t->payload.int_type.width]); - else - return format_string_arena(arena->arena, "u%s", ((String[]) {"8", "16", "32", "64" })[t->payload.int_type.width]); - case Type_Float_TAG: - return format_string_arena(arena->arena, "f%s", ((String[]) {"16", "32", "64" })[t->payload.float_type.width]); - case Type_Bool_TAG: return "bool"; - case Type_RecordType_TAG: - case Type_FnType_TAG: - case Type_BBType_TAG: - case Type_LamType_TAG: - case Type_PtrType_TAG: - case Type_QualifiedType_TAG: - case Type_ArrType_TAG: - case Type_PackType_TAG: - case Type_ImageType_TAG: - case Type_SamplerType_TAG: - case Type_CombinedImageSamplerType_TAG: - break; - case Type_TypeDeclRef_TAG: return t->payload.type_decl_ref.decl->payload.nom_type.name; - } - return unique_name(arena, node_tags[t->tag]); -} - -/// Is this a type that a value in the language can have ? -bool is_value_type(const Type* type) { - //if (type->tag == RecordType_TAG && type->payload.record_type.special == MultipleReturn) - // return true; - if (type->tag != QualifiedType_TAG) - return false; - return is_data_type(get_unqualified_type(type)); -} - -/// Is this a valid data type (for usage in other types and as type arguments) ? -bool is_data_type(const Type* type) { - switch (is_type(type)) { - case Type_MaskType_TAG: - case Type_JoinPointType_TAG: - case Type_Int_TAG: - case Type_Float_TAG: - case Type_Bool_TAG: - return true; - case Type_PtrType_TAG: - return true; - case Type_ArrType_TAG: - // array types _must_ be sized to be real data types - return type->payload.arr_type.size != NULL; - case Type_PackType_TAG: - return is_data_type(type->payload.pack_type.element_type); - case Type_RecordType_TAG: { - if (type->payload.record_type.members.count == 0) - return false; - for (size_t i = 0; i < type->payload.record_type.members.count; i++) - if (!is_data_type(type->payload.record_type.members.nodes[i])) - return false; - // multi-return record types are the results of instructions, but are not values themselves - return type->payload.record_type.special == NotSpecial; - } - case Type_TypeDeclRef_TAG: - return !get_nominal_type_body(type) || is_data_type(get_nominal_type_body(type)); - // qualified types are not data types because that information is only meant for values - case Type_QualifiedType_TAG: return false; - // values cannot contain abstractions - case Type_FnType_TAG: - case Type_BBType_TAG: - case Type_LamType_TAG: - return false; - // this type has no values to begin with - case Type_NoRet_TAG: - return false; - case NotAType: - return false; - // Image stuff is data (albeit opaque) - case Type_CombinedImageSamplerType_TAG: - case Type_SamplerType_TAG: - case Type_ImageType_TAG: - return true; - } -} - -bool is_arithm_type(const Type* t) { - return t->tag == Int_TAG || t->tag == Float_TAG; -} - -bool is_shiftable_type(const Type* t) { - return t->tag == Int_TAG || t->tag == MaskType_TAG; -} - -bool has_boolean_ops(const Type* t) { - return t->tag == Int_TAG || t->tag == Bool_TAG || t->tag == MaskType_TAG; -} - -bool is_comparable_type(const Type* t) { - return true; // TODO this is fine to allow, but we'll need to lower it for composite and native ptr types ! -} - -bool is_ordered_type(const Type* t) { - return is_arithm_type(t); -} - -bool is_physical_ptr_type(const Type* t) { - if (t->tag != PtrType_TAG) - return false; - AddressSpace as = t->payload.ptr_type.address_space; - return is_physical_as(as); -} - -bool is_generic_ptr_type(const Type* t) { - if (t->tag != PtrType_TAG) - return false; - AddressSpace as = t->payload.ptr_type.address_space; - return as == AsGeneric; -} - -bool is_reinterpret_cast_legal(const Type* src_type, const Type* dst_type) { - assert(is_data_type(src_type) && is_data_type(dst_type)); - if (src_type == dst_type) - return true; // folding will eliminate those, but we need to pass type-checking first :) - if (!(is_arithm_type(src_type) || src_type->tag == MaskType_TAG || is_physical_ptr_type(src_type))) - return false; - if (!(is_arithm_type(dst_type) || dst_type->tag == MaskType_TAG || is_physical_ptr_type(dst_type))) - return false; - assert(get_type_bitwidth(src_type) == get_type_bitwidth(dst_type)); - // either both pointers need to be in the generic address space, and we're only casting the element type, OR neither can be - if ((is_physical_ptr_type(src_type) && is_physical_ptr_type(dst_type)) && (is_generic_ptr_type(src_type) != is_generic_ptr_type(dst_type))) - return false; - return true; -} - -bool is_conversion_legal(const Type* src_type, const Type* dst_type) { - assert(is_data_type(src_type) && is_data_type(dst_type)); - if (!(is_arithm_type(src_type) || (is_physical_ptr_type(src_type) && get_type_bitwidth(src_type) == get_type_bitwidth(dst_type)))) - return false; - if (!(is_arithm_type(dst_type) || (is_physical_ptr_type(dst_type) && get_type_bitwidth(src_type) == get_type_bitwidth(dst_type)))) - return false; - // we only allow ptr-ptr conversions, use reinterpret otherwise - if (is_physical_ptr_type(src_type) != is_physical_ptr_type(dst_type)) - return false; - // exactly one of the pointers needs to be in the generic address space - if (is_generic_ptr_type(src_type) && is_generic_ptr_type(dst_type)) - return false; - if (src_type->tag == Int_TAG && dst_type->tag == Int_TAG) { - bool changes_sign = src_type->payload.int_type.is_signed != dst_type->payload.int_type.is_signed; - bool changes_width = src_type->payload.int_type.width != dst_type->payload.int_type.width; - if (changes_sign && changes_width) - return false; - } - // element types have to match (use reinterpret_cast for changing it) - if (is_physical_ptr_type(src_type) && is_physical_ptr_type(dst_type)) { - AddressSpace src_as = src_type->payload.ptr_type.address_space; - AddressSpace dst_as = dst_type->payload.ptr_type.address_space; - if (src_type->payload.ptr_type.pointed_type != dst_type->payload.ptr_type.pointed_type) - return false; - } - return true; -} - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" - -const Type* check_type_join_point_type(IrArena* arena, JoinPointType type) { - for (size_t i = 0; i < type.yield_types.count; i++) { - assert(is_data_type(type.yield_types.nodes[i])); - } - return NULL; -} - -const Type* check_type_record_type(IrArena* arena, RecordType type) { - assert(type.names.count == 0 || type.names.count == type.members.count); - for (size_t i = 0; i < type.members.count; i++) { - // member types are value types iff this is a return tuple - if (type.special == MultipleReturn) - assert(is_value_type(type.members.nodes[i])); - else - assert(is_data_type(type.members.nodes[i])); - } - return NULL; -} - -const Type* check_type_qualified_type(IrArena* arena, QualifiedType qualified_type) { - assert(is_data_type(qualified_type.type)); - assert(arena->config.is_simt || qualified_type.is_uniform); - return NULL; -} - -const Type* check_type_arr_type(IrArena* arena, ArrType type) { - assert(is_data_type(type.element_type)); - return NULL; -} - -const Type* check_type_pack_type(IrArena* arena, PackType pack_type) { - assert(is_data_type(pack_type.element_type)); - return NULL; -} - -const Type* check_type_ptr_type(IrArena* arena, PtrType ptr_type) { - assert((arena->config.untyped_ptrs || ptr_type.pointed_type) && "Shady does not support untyped pointers, but can infer them, see infer.c"); - if (!arena->config.allow_subgroup_memory) { - assert(ptr_type.address_space != AsSubgroupPhysical); - assert(ptr_type.address_space != AsSubgroupLogical); - } - if (!arena->config.allow_shared_memory) { - assert(ptr_type.address_space != AsSharedPhysical); - assert(ptr_type.address_space != AsSharedLogical); - } - if (ptr_type.pointed_type) { - if (ptr_type.pointed_type->tag == ArrType_TAG) { - assert(is_data_type(ptr_type.pointed_type->payload.arr_type.element_type)); - return NULL; - } - if (ptr_type.pointed_type->tag == FnType_TAG || ptr_type.pointed_type == unit_type(arena)) { - // no diagnostic required, we just allow these - return NULL; - } - const Node* maybe_record_type = ptr_type.pointed_type; - if (maybe_record_type->tag == TypeDeclRef_TAG) - maybe_record_type = get_nominal_type_body(maybe_record_type); - if (maybe_record_type->tag == RecordType_TAG && maybe_record_type->payload.record_type.special == DecorateBlock) { - return NULL; - } - assert(is_data_type(ptr_type.pointed_type)); - } - return NULL; -} - -const Type* check_type_var(IrArena* arena, Variable variable) { - assert(is_value_type(variable.type)); - return variable.type; -} - -const Type* check_type_untyped_number(IrArena* arena, UntypedNumber untyped) { - error("should never happen"); -} - -const Type* check_type_int_literal(IrArena* arena, IntLiteral lit) { - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = int_type(arena, (Int) { .width = lit.width, .is_signed = lit.is_signed }) - }); -} - -const Type* check_type_float_literal(IrArena* arena, FloatLiteral lit) { - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = float_type(arena, (Float) { .width = lit.width }) - }); -} - -const Type* check_type_true_lit(IrArena* arena) { return qualified_type(arena, (QualifiedType) { .type = bool_type(arena), .is_uniform = true }); } -const Type* check_type_false_lit(IrArena* arena) { return qualified_type(arena, (QualifiedType) { .type = bool_type(arena), .is_uniform = true }); } - -const Type* check_type_string_lit(IrArena* arena, StringLiteral str_lit) { - const Type* t = arr_type(arena, (ArrType) { - .element_type = int8_type(arena), - .size = int32_literal(arena, strlen(str_lit.string)) - }); - return qualified_type(arena, (QualifiedType) { - .type = t, - .is_uniform = true, - }); -} - -const Type* check_type_null_ptr(IrArena* a, NullPtr payload) { - assert(is_data_type(payload.ptr_type) && payload.ptr_type->tag == PtrType_TAG); - return qualified_type_helper(payload.ptr_type, true); -} - -const Type* check_type_composite(IrArena* arena, Composite composite) { - assert(is_data_type(composite.type)); - Nodes expected_member_types = get_composite_type_element_types(composite.type); - bool is_uniform = true; - assert(composite.contents.count == expected_member_types.count); - for (size_t i = 0; i < composite.contents.count; i++) { - const Type* element_type = composite.contents.nodes[i]->type; - is_uniform &= deconstruct_qualified_type(&element_type); - assert(is_subtype(expected_member_types.nodes[i], element_type)); - } - return qualified_type(arena, (QualifiedType) { - .is_uniform = is_uniform, - .type = composite.type - }); -} - -const Type* check_type_fill(IrArena* arena, Fill payload) { - assert(is_data_type(payload.type)); - const Node* element_t = get_fill_type_element_type(payload.type); - const Node* value_t = payload.value->type; - bool u = deconstruct_qualified_type(&value_t); - assert(is_subtype(element_t, value_t)); - return qualified_type(arena, (QualifiedType) { - .is_uniform = u, - .type = payload.type - }); -} - -const Type* check_type_undef(IrArena* arena, Undef payload) { - assert(is_data_type(payload.type)); - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = payload.type - }); -} - -const Type* check_type_fn_addr(IrArena* arena, FnAddr fn_addr) { - assert(fn_addr.fn->type->tag == FnType_TAG); - assert(fn_addr.fn->tag == Function_TAG); - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = ptr_type(arena, (PtrType) { - .pointed_type = fn_addr.fn->type, - .address_space = AsGeneric /* the actual AS does not matter because these are opaque anyways */, - }) - }); -} - -const Type* check_type_ref_decl(IrArena* arena, RefDecl ref_decl) { - const Type* t = ref_decl.decl->type; - assert(t && "RefDecl needs to be applied on a decl with a non-null type. Did you forget to set 'type' on a constant ?"); - switch (ref_decl.decl->tag) { - case GlobalVariable_TAG: - case Constant_TAG: break; - default: error("You can only use RefDecl on a global or a constant. See FnAddr for taking addresses of functions.") - } - assert(t->tag != QualifiedType_TAG && "decl types may not be qualified"); - return qualified_type(arena, (QualifiedType) { - .type = t, - .is_uniform = true, - }); -} - -const Type* check_type_prim_op(IrArena* arena, PrimOp prim_op) { - for (size_t i = 0; i < prim_op.type_arguments.count; i++) { - const Node* ta = prim_op.type_arguments.nodes[i]; - assert(ta && is_type(ta)); - } - for (size_t i = 0; i < prim_op.operands.count; i++) { - const Node* operand = prim_op.operands.nodes[i]; - assert(operand && is_value(operand)); - } - - bool extended = false; - bool ordered = false; - AddressSpace as; - switch (prim_op.op) { - case deref_op: - case assign_op: - case addrof_op: - case subscript_op: error("These ops are only allowed in untyped IR before desugaring. They don't type to anything."); - case quote_op: { - assert(prim_op.type_arguments.count == 0); - return wrap_multiple_yield_types(arena, get_values_types(arena, prim_op.operands)); - } - case neg_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - - const Type* type = first(prim_op.operands)->type; - assert(is_arithm_type(get_maybe_packed_type_element(get_unqualified_type(type)))); - return type; - } - case rshift_arithm_op: - case rshift_logical_op: - case lshift_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - const Type* first_operand_type = first(prim_op.operands)->type; - const Type* second_operand_type = prim_op.operands.nodes[1]->type; - - bool uniform_result = deconstruct_qualified_type(&first_operand_type); - uniform_result &= deconstruct_qualified_type(&second_operand_type); - - size_t value_simd_width = deconstruct_maybe_packed_type(&first_operand_type); - size_t shift_simd_width = deconstruct_maybe_packed_type(&second_operand_type); - assert(value_simd_width == shift_simd_width); - - assert(first_operand_type->tag == Int_TAG); - assert(second_operand_type->tag == Int_TAG); - - return qualified_type_helper(maybe_packed_type_helper(first_operand_type, value_simd_width), uniform_result); - } - case add_carry_op: - case sub_borrow_op: - case mul_extended_op: extended = true; SHADY_FALLTHROUGH; - case min_op: - case max_op: - case add_op: - case sub_op: - case mul_op: - case div_op: - case mod_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - const Type* first_operand_type = get_unqualified_type(first(prim_op.operands)->type); - - bool result_uniform = true; - for (size_t i = 0; i < prim_op.operands.count; i++) { - const Node* arg = prim_op.operands.nodes[i]; - const Type* operand_type = arg->type; - bool operand_uniform = deconstruct_qualified_type(&operand_type); - - assert(is_arithm_type(get_maybe_packed_type_element(operand_type))); - assert(first_operand_type == operand_type && "operand type mismatch"); - - result_uniform &= operand_uniform; - } - - const Type* result_t = first_operand_type; - if (extended) { - // TODO: assert unsigned - result_t = record_type(arena, (RecordType) {.members = mk_nodes(arena, result_t, result_t)}); - } - return qualified_type_helper(result_t, result_uniform); - } - - case not_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - - const Type* type = first(prim_op.operands)->type; - assert(has_boolean_ops(get_maybe_packed_type_element(get_unqualified_type(type)))); - return type; - } - case or_op: - case xor_op: - case and_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - const Type* first_operand_type = get_unqualified_type(first(prim_op.operands)->type); - - bool result_uniform = true; - for (size_t i = 0; i < prim_op.operands.count; i++) { - const Node* arg = prim_op.operands.nodes[i]; - const Type* operand_type = arg->type; - bool operand_uniform = deconstruct_qualified_type(&operand_type); - - assert(has_boolean_ops(get_maybe_packed_type_element(operand_type))); - assert(first_operand_type == operand_type && "operand type mismatch"); - - result_uniform &= operand_uniform; - } - - return qualified_type_helper(first_operand_type, result_uniform); - } - case lt_op: - case lte_op: - case gt_op: - case gte_op: ordered = true; SHADY_FALLTHROUGH - case eq_op: - case neq_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - const Type* first_operand_type = get_unqualified_type(first(prim_op.operands)->type); - size_t first_operand_width = get_maybe_packed_type_width(first_operand_type); - - bool result_uniform = true; - for (size_t i = 0; i < prim_op.operands.count; i++) { - const Node* arg = prim_op.operands.nodes[i]; - const Type* operand_type = arg->type; - bool operand_uniform = deconstruct_qualified_type(&operand_type); - - assert((ordered ? is_ordered_type : is_comparable_type)(get_maybe_packed_type_element(operand_type))); - assert(first_operand_type == operand_type && "operand type mismatch"); - - result_uniform &= operand_uniform; - } - - return qualified_type_helper(maybe_packed_type_helper(bool_type(arena), first_operand_width), result_uniform); - } - case sqrt_op: - case inv_sqrt_op: - case floor_op: - case ceil_op: - case round_op: - case fract_op: - case sin_op: - case cos_op: - case exp_op: - { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - const Node* src_type = first(prim_op.operands)->type; - bool uniform = deconstruct_qualified_type(&src_type); - size_t width = deconstruct_maybe_packed_type(&src_type); - assert(src_type->tag == Float_TAG); - return qualified_type_helper(maybe_packed_type_helper(src_type, width), uniform); - } - case pow_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - const Type* first_operand_type = get_unqualified_type(first(prim_op.operands)->type); - - bool result_uniform = true; - for (size_t i = 0; i < prim_op.operands.count; i++) { - const Node* arg = prim_op.operands.nodes[i]; - const Type* operand_type = arg->type; - bool operand_uniform = deconstruct_qualified_type(&operand_type); - - assert(get_maybe_packed_type_element(operand_type)->tag == Float_TAG); - assert(first_operand_type == operand_type && "operand type mismatch"); - - result_uniform &= operand_uniform; - } - - return qualified_type_helper(first_operand_type, result_uniform); - } - case abs_op: - case sign_op: - { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - const Node* src_type = first(prim_op.operands)->type; - bool uniform = deconstruct_qualified_type(&src_type); - size_t width = deconstruct_maybe_packed_type(&src_type); - assert(src_type->tag == Float_TAG || src_type->tag == Int_TAG && src_type->payload.int_type.is_signed); - return qualified_type_helper(maybe_packed_type_helper(src_type, width), uniform); - } - case load_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - - const Node* ptr = first(prim_op.operands); - const Node* ptr_type = ptr->type; - bool ptr_uniform = deconstruct_qualified_type(&ptr_type); - size_t width = deconstruct_maybe_packed_type(&ptr_type); - - assert(ptr_type->tag == PtrType_TAG); - const PtrType* node_ptr_type_ = &ptr_type->payload.ptr_type; - const Type* elem_type = node_ptr_type_->pointed_type; - elem_type = maybe_packed_type_helper(elem_type, width); - return qualified_type_helper(elem_type, ptr_uniform && is_addr_space_uniform(arena, ptr_type->payload.ptr_type.address_space)); - } - case store_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - - const Node* ptr = first(prim_op.operands); - const Node* ptr_type = ptr->type; - bool ptr_uniform = deconstruct_qualified_type(&ptr_type); - size_t width = deconstruct_maybe_packed_type(&ptr_type); - assert(ptr_type->tag == PtrType_TAG); - const PtrType* ptr_type_payload = &ptr_type->payload.ptr_type; - const Type* elem_type = ptr_type_payload->pointed_type; - assert(elem_type); - elem_type = maybe_packed_type_helper(elem_type, width); - // we don't enforce uniform stores - but we care about storing the right thing :) - const Type* val_expected_type = qualified_type(arena, (QualifiedType) { - .is_uniform = !arena->config.is_simt, - .type = elem_type - }); - - const Node* val = prim_op.operands.nodes[1]; - assert(is_subtype(val_expected_type, val->type)); - return empty_multiple_return_type(arena); - } - case alloca_logical_op: as = AsFunctionLogical; goto alloca_case; - case alloca_subgroup_op: as = AsSubgroupPhysical; goto alloca_case; - case alloca_op: as = AsPrivatePhysical; goto alloca_case; - alloca_case: { - assert(prim_op.type_arguments.count == 1); - assert(prim_op.operands.count == 0); - const Type* elem_type = prim_op.type_arguments.nodes[0]; - assert(is_type(elem_type)); - return qualified_type(arena, (QualifiedType) { - .is_uniform = is_addr_space_uniform(arena, as), - .type = ptr_type(arena, (PtrType) { - .pointed_type = elem_type, - .address_space = as, - }) - }); - } - case lea_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count >= 2); - - const Node* base = prim_op.operands.nodes[0]; - bool uniform = is_qualified_type_uniform(base->type); - - const Type* base_ptr_type = get_unqualified_type(base->type); - assert(base_ptr_type->tag == PtrType_TAG && "lea expects a pointer as a base"); - - const Node* offset = prim_op.operands.nodes[1]; - assert(offset); - const Type* offset_type = offset->type; - bool offset_uniform = deconstruct_qualified_type(&offset_type); - assert(offset_type->tag == Int_TAG && "lea expects an integer offset"); - const Type* pointee_type = base_ptr_type->payload.ptr_type.pointed_type; - - const IntLiteral* lit = resolve_to_int_literal(offset); - bool offset_is_zero = lit && lit->value == 0; - assert(offset_is_zero || pointee_type->tag == ArrType_TAG && "if an offset is used, the base pointer must point to an array"); - uniform &= offset_uniform; - - Nodes indices = nodes(arena, prim_op.operands.count - 2, &prim_op.operands.nodes[2]); - enter_composite(&pointee_type, &uniform, indices, true); - - return qualified_type(arena, (QualifiedType) { - .is_uniform = uniform, - .type = ptr_type(arena, (PtrType) { .pointed_type = pointee_type, .address_space = base_ptr_type->payload.ptr_type.address_space }) - }); - } - case memcpy_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 3); - const Type* dst_t = prim_op.operands.nodes[0]->type; - deconstruct_qualified_type(&dst_t); - assert(dst_t->tag == PtrType_TAG); - const Type* src_t = prim_op.operands.nodes[1]->type; - deconstruct_qualified_type(&src_t); - assert(src_t->tag == PtrType_TAG); - const Type* cnt_t = prim_op.operands.nodes[2]->type; - deconstruct_qualified_type(&cnt_t); - assert(cnt_t->tag == Int_TAG); - return empty_multiple_return_type(arena); - } - case memset_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 3); - const Type* dst_t = prim_op.operands.nodes[0]->type; - deconstruct_qualified_type(&dst_t); - assert(dst_t->tag == PtrType_TAG); - const Type* src_t = prim_op.operands.nodes[1]->type; - deconstruct_qualified_type(&src_t); - assert(src_t); - const Type* cnt_t = prim_op.operands.nodes[2]->type; - deconstruct_qualified_type(&cnt_t); - assert(cnt_t->tag == Int_TAG); - return empty_multiple_return_type(arena); - } - case align_of_op: - case size_of_op: { - assert(prim_op.type_arguments.count == 1); - assert(prim_op.operands.count == 0); - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = int_type(arena, (Int) { .width = arena->config.memory.ptr_size, .is_signed = false }) - }); - } - case offset_of_op: { - assert(prim_op.type_arguments.count == 1); - assert(prim_op.operands.count == 1); - const Type* optype = first(prim_op.operands)->type; - bool uniform = deconstruct_qualified_type(&optype); - assert(uniform && optype->tag == Int_TAG); - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = int_type(arena, (Int) { .width = arena->config.memory.ptr_size, .is_signed = false }) - }); - } - case select_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 3); - const Type* condition_type = prim_op.operands.nodes[0]->type; - bool condition_uniform = deconstruct_qualified_type(&condition_type); - size_t width = deconstruct_maybe_packed_type(&condition_type); - - const Type* alternatives_types[2]; - bool alternatives_all_uniform = true; - for (size_t i = 0; i < 2; i++) { - alternatives_types[i] = prim_op.operands.nodes[1 + i]->type; - alternatives_all_uniform &= deconstruct_qualified_type(&alternatives_types[i]); - size_t alternative_width = deconstruct_maybe_packed_type(&alternatives_types[i]); - assert(alternative_width == width); - } - - assert(is_subtype(bool_type(arena), condition_type)); - // todo find true supertype - assert(are_types_identical(2, alternatives_types)); - - return qualified_type_helper(maybe_packed_type_helper(alternatives_types[0], width), alternatives_all_uniform && condition_uniform); - } - case insert_op: - case extract_dynamic_op: - case extract_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count >= 2); - const Node* source = first(prim_op.operands); - - size_t indices_start = prim_op.op == insert_op ? 2 : 1; - Nodes indices = nodes(arena, prim_op.operands.count - indices_start, &prim_op.operands.nodes[indices_start]); - - const Type* t = source->type; - bool uniform = deconstruct_qualified_type(&t); - enter_composite(&t, &uniform, indices, true); - - if (prim_op.op == insert_op) { - const Node* inserted_data = prim_op.operands.nodes[1]; - const Type* inserted_data_type = inserted_data->type; - bool is_uniform = uniform & deconstruct_qualified_type(&inserted_data_type); - assert(is_subtype(t, inserted_data_type) && "inserting data into a composite, but it doesn't match the target and indices"); - return qualified_type(arena, (QualifiedType) { - .is_uniform = is_uniform, - .type = get_unqualified_type(source->type), - }); - } - - return qualified_type_helper(t, uniform); - } - case shuffle_op: { - assert(prim_op.operands.count >= 2); - assert(prim_op.type_arguments.count == 0); - const Node* lhs = prim_op.operands.nodes[0]; - const Node* rhs = prim_op.operands.nodes[1]; - const Type* lhs_t = lhs->type; - const Type* rhs_t = rhs->type; - bool lhs_u = deconstruct_qualified_type(&lhs_t); - bool rhs_u = deconstruct_qualified_type(&rhs_t); - assert(lhs_t->tag == PackType_TAG && rhs_t->tag == PackType_TAG); - size_t total_size = lhs_t->payload.pack_type.width + rhs_t->payload.pack_type.width; - const Type* element_t = lhs_t->payload.pack_type.element_type; - assert(element_t == rhs_t->payload.pack_type.element_type); - - size_t indices_count = prim_op.operands.count - 2; - const Node** indices = &prim_op.operands.nodes[2]; - bool u = lhs_u & rhs_u; - for (size_t i = 0; i < indices_count; i++) { - u &= is_qualified_type_uniform(indices[i]->type); - int64_t index = get_int_literal_value(*resolve_to_int_literal(indices[i]), true); - assert(index < 0 /* poison */ || (index >= 0 && index < total_size && "shuffle element out of range")); - } - return qualified_type_helper(pack_type(arena, (PackType) { .element_type = element_t, .width = indices_count }), u); - } - case reinterpret_op: { - assert(prim_op.type_arguments.count == 1); - assert(prim_op.operands.count == 1); - const Node* source = first(prim_op.operands); - const Type* src_type = source->type; - bool src_uniform = deconstruct_qualified_type(&src_type); - - const Type* dst_type = first(prim_op.type_arguments); - assert(is_data_type(dst_type)); - assert(is_reinterpret_cast_legal(src_type, dst_type)); - - return qualified_type(arena, (QualifiedType) { - .is_uniform = src_uniform, - .type = dst_type - }); - } - case convert_op: { - assert(prim_op.type_arguments.count == 1); - assert(prim_op.operands.count == 1); - const Node* source = first(prim_op.operands); - const Type* src_type = source->type; - bool src_uniform = deconstruct_qualified_type(&src_type); - - const Type* dst_type = first(prim_op.type_arguments); - assert(is_data_type(dst_type)); - assert(is_conversion_legal(src_type, dst_type)); - - // TODO check the conversion is legal - return qualified_type(arena, (QualifiedType) { - .is_uniform = src_uniform, - .type = dst_type - }); - } - // Mask management - case empty_mask_op: { - assert(prim_op.type_arguments.count == 0 && prim_op.operands.count == 0); - return qualified_type_helper(get_actual_mask_type(arena), true); - } - case mask_is_thread_active_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - return qualified_type(arena, (QualifiedType) { - .is_uniform = is_qualified_type_uniform(prim_op.operands.nodes[0]->type) && is_qualified_type_uniform(prim_op.operands.nodes[1]->type), - .type = bool_type(arena) - }); - } - // Subgroup ops - case subgroup_active_mask_op: { - assert(prim_op.type_arguments.count == 0 && prim_op.operands.count == 0); - return qualified_type_helper(get_actual_mask_type(arena), true); - } - case subgroup_ballot_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = get_actual_mask_type(arena) - }); - } - case subgroup_elect_first_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 0); - return qualified_type(arena, (QualifiedType) { - .is_uniform = false, - .type = bool_type(arena) - }); - } - case subgroup_assume_uniform_op: - case subgroup_broadcast_first_op: - case subgroup_reduce_sum_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - const Type* operand_type = get_unqualified_type(prim_op.operands.nodes[0]->type); - return qualified_type(arena, (QualifiedType) { - .is_uniform = true, - .type = operand_type - }); - } - // Intermediary ops - case create_joint_point_op: { - assert(prim_op.operands.count == 2); - const Node* join_point = first(prim_op.operands); - assert(is_qualified_type_uniform(join_point->type)); - return qualified_type(arena, (QualifiedType) { .type = join_point_type(arena, (JoinPointType) { .yield_types = prim_op.type_arguments }), .is_uniform = false }); - } - case default_join_point_op: { - assert(prim_op.operands.count == 0); - assert(prim_op.type_arguments.count == 0); - return qualified_type(arena, (QualifiedType) { .type = join_point_type(arena, (JoinPointType) { .yield_types = empty(arena) }), .is_uniform = true }); - } - // Stack stuff - case get_stack_pointer_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 0); - return qualified_type(arena, (QualifiedType) { .is_uniform = false, .type = uint32_type(arena) }); - } - case get_stack_base_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 0); - const Node* ptr = ptr_type(arena, (PtrType) { .pointed_type = arr_type(arena, (ArrType) { .element_type = uint8_type(arena), .size = NULL }), .address_space = prim_op.op == get_stack_base_op ? AsPrivatePhysical : AsSubgroupPhysical}); - return qualified_type(arena, (QualifiedType) { .is_uniform = false, .type = ptr }); - } - case set_stack_pointer_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 1); - assert(get_unqualified_type(prim_op.operands.nodes[0]->type) == uint32_type(arena)); - return empty_multiple_return_type(arena); - } - case push_stack_op: { - assert(prim_op.type_arguments.count == 1); - assert(prim_op.operands.count == 1); - const Type* element_type = first(prim_op.type_arguments); - assert(is_data_type(element_type)); - const Type* qual_element_type = qualified_type(arena, (QualifiedType) { - .is_uniform = false, - .type = element_type - }); - // the operand has to be a subtype of the annotated type - assert(is_subtype(qual_element_type, first(prim_op.operands)->type)); - return empty_multiple_return_type(arena); - } - case pop_stack_op:{ - assert(prim_op.operands.count == 0); - assert(prim_op.type_arguments.count == 1); - const Type* element_type = prim_op.type_arguments.nodes[0]; - assert(is_data_type(element_type)); - return qualified_type(arena, (QualifiedType) { .is_uniform = false, .type = element_type}); - } - // Debugging ops - case debug_printf_op: { - assert(prim_op.type_arguments.count == 0); - // TODO ? - return empty_multiple_return_type(arena); - } - case sample_texture_op: { - assert(prim_op.type_arguments.count == 0); - assert(prim_op.operands.count == 2); - const Type* sampled_image_t = first(prim_op.operands)->type; - bool uniform_src = deconstruct_qualified_type(&sampled_image_t); - const Type* coords_t = prim_op.operands.nodes[1]->type; - deconstruct_qualified_type(&coords_t); - assert(sampled_image_t->tag == CombinedImageSamplerType_TAG); - const Type* image_t = sampled_image_t->payload.combined_image_sampler_type.image_type; - assert(image_t->tag == ImageType_TAG); - size_t coords_dim = deconstruct_packed_type(&coords_t); - return qualified_type(arena, (QualifiedType) { .is_uniform = false, .type = maybe_packed_type_helper(image_t->payload.image_type.sampled_type, 4) }); - } - case PRIMOPS_COUNT: assert(false); - } -} - -static void check_arguments_types_against_parameters_helper(Nodes param_types, Nodes arg_types) { - if (param_types.count != arg_types.count) - error("Mismatched number of arguments/parameters"); - for (size_t i = 0; i < param_types.count; i++) - check_subtype(param_types.nodes[i], arg_types.nodes[i]); -} - -/// Shared logic between indirect calls and tailcalls -static Nodes check_value_call(const Node* callee, Nodes argument_types) { - assert(is_value(callee)); - - const Type* callee_type = callee->type; - SHADY_UNUSED bool callee_uniform = deconstruct_qualified_type(&callee_type); - AddressSpace as = deconstruct_pointer_type(&callee_type); - assert(as == AsGeneric); - - assert(callee_type->tag == FnType_TAG); - - const FnType* fn_type = &callee_type->payload.fn_type; - check_arguments_types_against_parameters_helper(fn_type->param_types, argument_types); - // TODO force the return types to be varying if the callee is not uniform - return fn_type->return_types; -} - -const Type* check_type_call(IrArena* arena, Call call) { - Nodes args = call.args; - for (size_t i = 0; i < args.count; i++) { - const Node* argument = args.nodes[i]; - assert(is_value(argument)); - } - Nodes argument_types = get_values_types(arena, args); - return wrap_multiple_yield_types(arena, check_value_call(call.callee, argument_types)); -} - -static void ensure_types_are_data_types(const Nodes* yield_types) { - for (size_t i = 0; i < yield_types->count; i++) { - assert(is_data_type(yield_types->nodes[i])); - } -} - -static void ensure_types_are_value_types(const Nodes* yield_types) { - for (size_t i = 0; i < yield_types->count; i++) { - assert(is_value_type(yield_types->nodes[i])); - } -} - -const Type* check_type_if_instr(IrArena* arena, If if_instr) { - ensure_types_are_data_types(&if_instr.yield_types); - if (get_unqualified_type(if_instr.condition->type) != bool_type(arena)) - error("condition of an if should be bool"); - // TODO check the contained Merge instrs - if (if_instr.yield_types.count > 0) - assert(if_instr.if_false); - - return wrap_multiple_yield_types(arena, add_qualifiers(arena, if_instr.yield_types, false)); -} - -const Type* check_type_loop_instr(IrArena* arena, Loop loop_instr) { - ensure_types_are_data_types(&loop_instr.yield_types); - // TODO check param against initial_args - // TODO check the contained Merge instrs - return wrap_multiple_yield_types(arena, add_qualifiers(arena, loop_instr.yield_types, false)); -} - -const Type* check_type_match_instr(IrArena* arena, Match match_instr) { - ensure_types_are_data_types(&match_instr.yield_types); - // TODO check param against initial_args - // TODO check the contained Merge instrs - return wrap_multiple_yield_types(arena, add_qualifiers(arena, match_instr.yield_types, false)); -} - -const Type* check_type_control(IrArena* arena, Control control) { - ensure_types_are_data_types(&control.yield_types); - // TODO check it then ! - assert(is_case(control.inside)); - const Node* join_point = first(control.inside->payload.case_.params); - - const Type* join_point_type = join_point->type; - deconstruct_qualified_type(&join_point_type); - assert(join_point_type->tag == JoinPointType_TAG); - - Nodes join_point_yield_types = join_point_type->payload.join_point_type.yield_types; - assert(join_point_yield_types.count == control.yield_types.count); - for (size_t i = 0; i < control.yield_types.count; i++) { - assert(is_subtype(control.yield_types.nodes[i], join_point_yield_types.nodes[i])); - } - - return wrap_multiple_yield_types(arena, add_qualifiers(arena, join_point_yield_types, false)); -} - -const Type* check_type_block(IrArena* arena, Block payload) { - ensure_types_are_value_types(&payload.yield_types); - assert(is_case(payload.inside)); - assert(payload.inside->payload.case_.params.count == 0); - - /*const Node* lam = payload.inside; - const Node* yield_instr = NULL; - while (true) { - assert(lam->tag == Case_TAG); - const Node* terminator = lam->payload.case_.body; - switch (terminator->tag) { - case Let_TAG: { - lam = terminator->payload.let.tail; - continue; - } - case Yield_TAG: - yield_instr = terminator; - break; - default: assert(false); - } - break; - } - - Nodes yield_values = yield_instr->payload.yield.args;*/ - return wrap_multiple_yield_types(arena, payload.yield_types); -} - -const Type* check_type_comment(IrArena* arena, SHADY_UNUSED Comment payload) { - return empty_multiple_return_type(arena); -} - -const Type* check_type_let(IrArena* arena, Let let) { - assert(is_instruction(let.instruction)); - assert(is_case(let.tail)); - Nodes produced_types = unwrap_multiple_yield_types(arena, let.instruction->type); - Nodes param_types = get_variables_types(arena, let.tail->payload.case_.params); - - check_arguments_types_against_parameters_helper(param_types, produced_types); - return noret_type(arena); -} - -const Type* check_type_tail_call(IrArena* arena, TailCall tail_call) { - Nodes args = tail_call.args; - for (size_t i = 0; i < args.count; i++) { - const Node* argument = args.nodes[i]; - assert(is_value(argument)); - } - assert(check_value_call(tail_call.target, get_values_types(arena, tail_call.args)).count == 0); - return noret_type(arena); -} - -static void check_basic_block_call(const Node* block, Nodes argument_types) { - assert(is_basic_block(block)); - assert(block->type->tag == BBType_TAG); - BBType bb_type = block->type->payload.bb_type; - check_arguments_types_against_parameters_helper(bb_type.param_types, argument_types); -} - -const Type* check_type_jump(IrArena* arena, Jump jump) { - for (size_t i = 0; i < jump.args.count; i++) { - const Node* argument = jump.args.nodes[i]; - assert(is_value(argument)); - } - - check_basic_block_call(jump.target, get_values_types(arena, jump.args)); - return noret_type(arena); -} - -const Type* check_type_branch(IrArena* arena, Branch payload) { - assert(payload.true_jump->tag == Jump_TAG); - assert(payload.false_jump->tag == Jump_TAG); - return noret_type(arena); -} - -const Type* check_type_br_switch(IrArena* arena, Switch payload) { - for (size_t i = 0; i < payload.case_jumps.count; i++) - assert(payload.case_jumps.nodes[i]->tag == Jump_TAG); - assert(payload.case_values.count == payload.case_jumps.count); - assert(payload.default_jump->tag == Jump_TAG); - return noret_type(arena); -} - -const Type* check_type_join(IrArena* arena, Join join) { - for (size_t i = 0; i < join.args.count; i++) { - const Node* argument = join.args.nodes[i]; - assert(is_value(argument)); - } - - const Type* join_target_type = join.join_point->type; - - deconstruct_qualified_type(&join_target_type); - assert(join_target_type->tag == JoinPointType_TAG); - - Nodes join_point_param_types = join_target_type->payload.join_point_type.yield_types; - join_point_param_types = add_qualifiers(arena, join_point_param_types, !arena->config.is_simt); - - check_arguments_types_against_parameters_helper(join_point_param_types, get_values_types(arena, join.args)); - - return noret_type(arena); -} - -const Type* check_type_unreachable(IrArena* arena) { - return noret_type(arena); -} - -const Type* check_type_merge_continue(IrArena* arena, MergeContinue mc) { - // TODO check it - return noret_type(arena); -} - -const Type* check_type_merge_break(IrArena* arena, MergeBreak mc) { - // TODO check it - return noret_type(arena); -} - -const Type* check_type_yield(IrArena* arena, SHADY_UNUSED Yield payload) { - // TODO check it - return noret_type(arena); -} - -const Type* check_type_fn_ret(IrArena* arena, Return ret) { - // assert(ret.fn); - // TODO check it then ! - return noret_type(arena); -} - -const Type* check_type_fun(IrArena* arena, Function fn) { - for (size_t i = 0; i < fn.return_types.count; i++) { - assert(is_value_type(fn.return_types.nodes[i])); - } - return fn_type(arena, (FnType) { .param_types = get_variables_types(arena, (&fn)->params), .return_types = (&fn)->return_types }); -} - -const Type* check_type_basic_block(IrArena* arena, BasicBlock bb) { - return bb_type(arena, (BBType) { .param_types = get_variables_types(arena, (&bb)->params) }); -} - -const Type* check_type_case_(IrArena* arena, Case lam) { - return lam_type(arena, (LamType) { .param_types = get_variables_types(arena, (&lam)->params) }); -} - -const Type* check_type_global_variable(IrArena* arena, GlobalVariable global_variable) { - assert(is_type(global_variable.type)); - - const Node* ba = lookup_annotation_list(global_variable.annotations, "Builtin"); - if (ba && arena->config.validate_builtin_types) { - Builtin b = get_builtin_by_name(get_annotation_string_payload(ba)); - assert(b != BuiltinsCount); - const Type* t = get_builtin_type(arena, b); - if (t != global_variable.type) { - error_print("Creating a @Builtin global variable '%s' with the incorrect type: ", global_variable.name); - log_node(ERROR, global_variable.type); - error_print(" instead of the expected "); - log_node(ERROR, t); - error_print(".\n"); - error_die(); - } - } - - assert(global_variable.address_space < NumAddressSpaces); - - return ptr_type(arena, (PtrType) { - .pointed_type = global_variable.type, - .address_space = global_variable.address_space - }); -} - -const Type* check_type_constant(IrArena* arena, Constant cnst) { - assert(is_data_type(cnst.type_hint)); - return cnst.type_hint; -} - -#pragma GCC diagnostic pop diff --git a/src/shady/type.h b/src/shady/type.h deleted file mode 100644 index 14ad8350d..000000000 --- a/src/shady/type.h +++ /dev/null @@ -1,88 +0,0 @@ -#ifndef SHADY_TYPE_H -#define SHADY_TYPE_H - -#include "shady/ir.h" - -bool is_subtype(const Type* supertype, const Type* type); -void check_subtype(const Type* supertype, const Type* type); - -/// Is this a type that a value in the language can have ? -bool is_value_type(const Type*); - -/// Is this a valid data type (for usage in other types and as type arguments) ? -bool is_data_type(const Type*); - -size_t get_type_bitwidth(const Type* t); - -bool is_arithm_type(const Type*); -bool is_shiftable_type(const Type*); -bool has_boolean_ops(const Type*); -bool is_comparable_type(const Type*); -bool is_ordered_type(const Type*); -bool is_physical_ptr_type(const Type* t); -bool is_generic_ptr_type(const Type* t); - -bool is_reinterpret_cast_legal(const Type* src_type, const Type* dst_type); -bool is_conversion_legal(const Type* src_type, const Type* dst_type); - -#include "type_generated.h" - -const Type* get_actual_mask_type(IrArena* arena); - -const Type* wrap_multiple_yield_types(IrArena* arena, Nodes types); -Nodes unwrap_multiple_yield_types(IrArena* arena, const Type* type); - -/// Returns the (possibly qualified) pointee type from a (possibly qualified) ptr type -const Type* get_pointee_type(IrArena*, const Type*); - -void enter_composite(const Type** datatype, bool* u, Nodes indices, bool allow_entering_pack); - -/// Collects the annotated types in the list of variables -/// NB: this is different from get_values_types, that function uses node.type, whereas this one uses node.payload.var.type -/// This means this function works in untyped modules where node.type is NULL. -Nodes get_variables_types(IrArena*, Nodes); - -Nodes get_values_types(IrArena*, Nodes); - -// Qualified type helpers -/// Ensures an operand has divergence-annotated type and extracts it -const Type* get_unqualified_type(const Type*); -bool is_qualified_type_uniform(const Type*); -bool deconstruct_qualified_type(const Type**); - -const Type* qualified_type_helper(const Type*, bool uniform); - -Nodes strip_qualifiers(IrArena*, Nodes); -Nodes add_qualifiers(IrArena*, Nodes, bool); - -// Pack (vector) type helpers -const Type* get_packed_type_element(const Type*); -size_t get_packed_type_width(const Type*); -size_t deconstruct_packed_type(const Type**); - -/// Helper for creating pack types, wraps type in a pack_type if width > 1 -const Type* maybe_packed_type_helper(const Type*, size_t width); - -/// 'Maybe' variants that work with any types, and assume width=1 for non-packed types -/// Useful for writing generic type checking code ! -const Type* get_maybe_packed_type_element(const Type*); -size_t get_maybe_packed_type_width(const Type*); -size_t deconstruct_maybe_packed_type(const Type**); - -// Pointer type helpers -const Type* get_pointer_type_element(const Type*); -AddressSpace get_pointer_type_address_space(const Type*); -AddressSpace deconstruct_pointer_type(const Type**); - -// Nominal type helpers -const Node* get_nominal_type_decl(const Type*); -const Type* get_nominal_type_body(const Type*); -const Node* get_maybe_nominal_type_decl(const Type*); -const Type* get_maybe_nominal_type_body(const Type*); - -// Composite type helpers -Nodes get_composite_type_element_types(const Type*); -const Node* get_fill_type_element_type(const Type*); -const Node* get_fill_type_size(const Type*); - -#endif diff --git a/src/shady/type_helpers.c b/src/shady/type_helpers.c deleted file mode 100644 index a4fb84720..000000000 --- a/src/shady/type_helpers.c +++ /dev/null @@ -1,295 +0,0 @@ -#include "ir_private.h" -#include "type.h" -#include "log.h" -#include "portability.h" - -#include - -const Type* wrap_multiple_yield_types(IrArena* arena, Nodes types) { - switch (types.count) { - case 0: return empty_multiple_return_type(arena); - case 1: return types.nodes[0]; - default: return record_type(arena, (RecordType) { - .members = types, - .names = strings(arena, 0, NULL), - .special = MultipleReturn, - }); - } - SHADY_UNREACHABLE; -} - -Nodes unwrap_multiple_yield_types(IrArena* arena, const Type* type) { - switch (type->tag) { - case RecordType_TAG: - if (type->payload.record_type.special == MultipleReturn) - return type->payload.record_type.members; - // fallthrough - default: return nodes(arena, 1, (const Node* []) { type }); - } -} - -bool is_arrow_type(const Node* node) { - NodeTag tag = node->tag; - return tag == FnType_TAG || tag == BBType_TAG || tag == LamType_TAG; -} - -const Type* get_pointee_type(IrArena* arena, const Type* type) { - bool qualified = false, uniform = false; - if (is_value_type(type)) { - qualified = true; - uniform = is_qualified_type_uniform(type); - type = get_unqualified_type(type); - } - assert(type->tag == PtrType_TAG); - uniform &= is_addr_space_uniform(arena, type->payload.ptr_type.address_space); - type = type->payload.ptr_type.pointed_type; - if (qualified) - type = qualified_type(arena, (QualifiedType) { - .type = type, - .is_uniform = uniform - }); - return type; -} - -void enter_composite(const Type** datatype, bool* uniform, Nodes indices, bool allow_entering_pack) { - const Type* current_type = *datatype; - - for(size_t i = 0; i < indices.count; i++) { - const Node* selector = indices.nodes[i]; - const Type* selector_type = selector->type; - bool selector_uniform = deconstruct_qualified_type(&selector_type); - - assert(selector_type->tag == Int_TAG && "selectors must be integers"); - *uniform &= selector_uniform; - - try_again: - switch (current_type->tag) { - case RecordType_TAG: { - size_t selector_value = get_int_literal_value(*resolve_to_int_literal(selector), false); - assert(selector_value < current_type->payload.record_type.members.count); - current_type = current_type->payload.record_type.members.nodes[selector_value]; - continue; - } - case ArrType_TAG: { - current_type = current_type->payload.arr_type.element_type; - continue; - } - case TypeDeclRef_TAG: { - const Node* nom_decl = current_type->payload.type_decl_ref.decl; - assert(nom_decl->tag == NominalType_TAG); - current_type = nom_decl->payload.nom_type.body; - goto try_again; - } - case PackType_TAG: { - assert(allow_entering_pack); - assert(selector->tag == IntLiteral_TAG && "selectors when indexing into a pack type need to be constant"); - size_t selector_value = get_int_literal_value(*resolve_to_int_literal(selector), false); - assert(selector_value < current_type->payload.pack_type.width); - current_type = current_type->payload.pack_type.element_type; - continue; - } - // also remember to assert literals for the selectors ! - default: { - log_string(ERROR, "Trying to enter non-composite type '"); - log_node(ERROR, current_type); - log_string(ERROR, "' with selector '"); - log_node(ERROR, selector); - log_string(ERROR, "'."); - error(""); - } - } - i++; - } - - *datatype = current_type; -} - -Nodes get_variables_types(IrArena* arena, Nodes variables) { - LARRAY(const Type*, arr, variables.count); - for (size_t i = 0; i < variables.count; i++) { - assert(variables.nodes[i]->tag == Variable_TAG); - arr[i] = variables.nodes[i]->payload.var.type; - } - return nodes(arena, variables.count, arr); -} - -Nodes get_values_types(IrArena* arena, Nodes values) { - assert(arena->config.check_types); - LARRAY(const Type*, arr, values.count); - for (size_t i = 0; i < values.count; i++) - arr[i] = values.nodes[i]->type; - return nodes(arena, values.count, arr); -} - -bool is_qualified_type_uniform(const Type* type) { - const Type* result_type = type; - bool is_uniform = deconstruct_qualified_type(&result_type); - return is_uniform; -} - -const Type* get_unqualified_type(const Type* type) { - assert(is_type(type)); - const Type* result_type = type; - deconstruct_qualified_type(&result_type); - return result_type; -} - -bool deconstruct_qualified_type(const Type** type_out) { - const Type* type = *type_out; - if (type->tag == QualifiedType_TAG) { - *type_out = type->payload.qualified_type.type; - return type->payload.qualified_type.is_uniform; - } else error("Expected a value type (annotated with qual_type)") -} - -const Type* qualified_type_helper(const Type* type, bool uniform) { - return qualified_type(type->arena, (QualifiedType) { .type = type, .is_uniform = uniform }); -} - -Nodes strip_qualifiers(IrArena* arena, Nodes tys) { - LARRAY(const Type*, arr, tys.count); - for (size_t i = 0; i < tys.count; i++) - arr[i] = get_unqualified_type(tys.nodes[i]); - return nodes(arena, tys.count, arr); -} - -Nodes add_qualifiers(IrArena* arena, Nodes tys, bool uniform) { - LARRAY(const Type*, arr, tys.count); - for (size_t i = 0; i < tys.count; i++) - arr[i] = qualified_type_helper(tys.nodes[i], uniform || !arena->config.is_simt /* SIMD arenas ban varying value types */); - return nodes(arena, tys.count, arr); -} - -const Type* get_packed_type_element(const Type* type) { - const Type* t = type; - deconstruct_packed_type(&t); - return t; -} - -size_t get_packed_type_width(const Type* type) { - const Type* t = type; - return deconstruct_packed_type(&t); -} - -size_t deconstruct_packed_type(const Type** type) { - assert((*type)->tag == PackType_TAG); - return deconstruct_maybe_packed_type(type); -} - -const Type* get_maybe_packed_type_element(const Type* type) { - const Type* t = type; - deconstruct_maybe_packed_type(&t); - return t; -} - -size_t get_maybe_packed_type_width(const Type* type) { - const Type* t = type; - return deconstruct_maybe_packed_type(&t); -} - -size_t deconstruct_maybe_packed_type(const Type** type) { - const Type* t = *type; - assert(is_data_type(t)); - if (t->tag == PackType_TAG) { - *type = t->payload.pack_type.element_type; - return t->payload.pack_type.width; - } - return 1; -} - -const Type* maybe_packed_type_helper(const Type* type, size_t width) { - assert(width > 0); - if (width == 1) - return type; - return pack_type(type->arena, (PackType) { - .width = width, - .element_type = type, - }); -} - -const Type* get_pointer_type_element(const Type* type) { - const Type* t = type; - deconstruct_pointer_type(&t); - return t; -} - -AddressSpace get_pointer_type_address_space(const Type* type) { - const Type* t = type; - return deconstruct_pointer_type(&t); -} - -AddressSpace deconstruct_pointer_type(const Type** type) { - const Type* t = *type; - assert(t->tag == PtrType_TAG); - *type = t->payload.ptr_type.pointed_type; - return t->payload.ptr_type.address_space; -} - -const Node* get_nominal_type_decl(const Type* type) { - assert(type->tag == TypeDeclRef_TAG); - return get_maybe_nominal_type_decl(type); -} - -const Type* get_nominal_type_body(const Type* type) { - assert(type->tag == TypeDeclRef_TAG); - return get_maybe_nominal_type_body(type); -} - -const Node* get_maybe_nominal_type_decl(const Type* type) { - if (type->tag == TypeDeclRef_TAG) { - const Node* decl = type->payload.type_decl_ref.decl; - assert(decl->tag == NominalType_TAG); - return decl; - } - return NULL; -} - -const Type* get_maybe_nominal_type_body(const Type* type) { - const Node* decl = get_maybe_nominal_type_decl(type); - if (decl) - return decl->payload.nom_type.body; - return type; -} - -Nodes get_composite_type_element_types(const Type* type) { - switch (is_type(type)) { - case Type_TypeDeclRef_TAG: { - type = get_nominal_type_body(type); - assert(type->tag == RecordType_TAG); - SHADY_FALLTHROUGH - } - case RecordType_TAG: { - return type->payload.record_type.members; - } - case Type_ArrType_TAG: - case Type_PackType_TAG: { - size_t size = get_int_literal_value(*resolve_to_int_literal(get_fill_type_size(type)), false); - if (size >= 1024) { - warn_print("Potential performance issue: creating a really big array of composites of types (size=%d)!\n", size); - } - const Type* element_type = get_fill_type_element_type(type); - LARRAY(const Type*, types, size); - for (size_t i = 0; i < size; i++) { - types[i] = element_type; - } - return nodes(type->arena, size, types); - } - default: error("Not a composite type !") - } -} - -const Node* get_fill_type_element_type(const Type* composite_t) { - switch (composite_t->tag) { - case ArrType_TAG: return composite_t->payload.arr_type.element_type; - case PackType_TAG: return composite_t->payload.pack_type.element_type; - default: error("fill values need to be either array or pack types") - } -} - -const Node* get_fill_type_size(const Type* composite_t) { - switch (composite_t->tag) { - case ArrType_TAG: return composite_t->payload.arr_type.size; - case PackType_TAG: return int32_literal(composite_t->arena, composite_t->payload.pack_type.width); - default: error("fill values need to be either array or pack types") - } -} diff --git a/src/shady/visit.c b/src/shady/visit.c index 3d0f4ac13..29cabb205 100644 --- a/src/shady/visit.c +++ b/src/shady/visit.c @@ -1,52 +1,65 @@ #include "shady/ir.h" #include "log.h" -#include "visit.h" -#include "analysis/scope.h" +#include "shady/visit.h" +#include "analysis/cfg.h" #include -void visit_node(Visitor* visitor, const Node* node) { +void shd_visit_node(Visitor* visitor, const Node* node) { assert(visitor->visit_node_fn); if (node) visitor->visit_node_fn(visitor, node); } -void visit_nodes(Visitor* visitor, Nodes nodes) { +void shd_visit_nodes(Visitor* visitor, Nodes nodes) { for (size_t i = 0; i < nodes.count; i++) { - visit_node(visitor, nodes.nodes[i]); + shd_visit_node(visitor, nodes.nodes[i]); } } -void visit_op(Visitor* visitor, NodeClass op_class, String op_name, const Node* op) { +void shd_visit_op(Visitor* visitor, NodeClass op_class, String op_name, const Node* op, size_t i) { if (!op) return; if (visitor->visit_op_fn) - visitor->visit_op_fn(visitor, op_class, op_name, op); + visitor->visit_op_fn(visitor, op_class, op_name, op, i); else visitor->visit_node_fn(visitor, op); } -void visit_ops(Visitor* visitor, NodeClass op_class, String op_name, Nodes ops) { +void shd_visit_ops(Visitor* visitor, NodeClass op_class, String op_name, Nodes ops) { for (size_t i = 0; i < ops.count; i++) - visit_op(visitor, op_class, op_name, ops.nodes[i]); + shd_visit_op(visitor, op_class, op_name, ops.nodes[i], i); } -void visit_function_rpo(Visitor* visitor, const Node* function) { +void shd_visit_function_rpo(Visitor* visitor, const Node* function) { assert(function->tag == Function_TAG); - Scope* scope = new_scope(function); - assert(scope->rpo[0]->node == function); - for (size_t i = 1; i < scope->size; i++) { - const Node* node = scope->rpo[i]->node; - visit_node(visitor, node); + CFG* cfg = build_fn_cfg(function); + assert(cfg->rpo[0]->node == function); + for (size_t i = 0; i < cfg->size; i++) { + const Node* node = cfg->rpo[i]->node; + shd_visit_node(visitor, node); } - destroy_scope(scope); + shd_destroy_cfg(cfg); +} + +void shd_visit_function_bodies_rpo(Visitor* visitor, const Node* function) { + assert(function->tag == Function_TAG); + CFG* cfg = build_fn_cfg(function); + assert(cfg->rpo[0]->node == function); + for (size_t i = 0; i < cfg->size; i++) { + const Node* node = cfg->rpo[i]->node; + assert(is_abstraction(node)); + if (get_abstraction_body(node)) + shd_visit_node(visitor, get_abstraction_body(node)); + } + shd_destroy_cfg(cfg); } #pragma GCC diagnostic error "-Wswitch" #include "visit_generated.c" -void visit_module(Visitor* visitor, Module* mod) { - Nodes decls = get_module_declarations(mod); - visit_nodes(visitor, decls); +void shd_visit_module(Visitor* visitor, Module* mod) { + Nodes decls = shd_module_get_declarations(mod); + shd_visit_nodes(visitor, decls); } diff --git a/src/shady/visit.h b/src/shady/visit.h deleted file mode 100644 index 789c94dc4..000000000 --- a/src/shady/visit.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef SHADY_VISIT_H -#define SHADY_VISIT_H - -#include "shady/ir.h" - -typedef struct Visitor_ Visitor; -typedef void (*VisitNodeFn)(Visitor*, const Node*); -typedef void (*VisitOpFn)(Visitor*, NodeClass, String, const Node*); - -struct Visitor_ { - VisitNodeFn visit_node_fn; - VisitOpFn visit_op_fn; -}; - -void visit_node_operands(Visitor*, NodeClass exclude, const Node*); -void visit_module(Visitor* visitor, Module*); - -void visit_node(Visitor* visitor, const Node*); -void visit_nodes(Visitor* visitor, Nodes nodes); - -void visit_op(Visitor* visitor, NodeClass, String, const Node*); -void visit_ops(Visitor* visitor, NodeClass, String, Nodes nodes); - -// visits the abstractions in the function, _except_ for the entry block (ie the function itself) -void visit_function_rpo(Visitor* visitor, const Node* function); -// use this in visit_node_operands to avoid visiting nodes in non-rpo order -#define IGNORE_ABSTRACTIONS_MASK NcBasic_block | NcCase | NcDeclaration - -#endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 061fead62..e3159cf2d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,53 +1,61 @@ -add_executable(test_math test_math.c) -target_link_libraries(test_math shady driver) -add_test(NAME test_math COMMAND test_math) - -list(APPEND BASIC_TESTS empty.slim) -list(APPEND BASIC_TESTS entrypoint_args1.slim) -list(APPEND BASIC_TESTS basic_blocks1.slim) -list(APPEND BASIC_TESTS control_flow1.slim) -list(APPEND BASIC_TESTS control_flow2.slim) -list(APPEND BASIC_TESTS functions1.slim) -list(APPEND BASIC_TESTS identity.slim) -list(APPEND BASIC_TESTS memory1.slim) -list(APPEND BASIC_TESTS memory2.slim) -list(APPEND BASIC_TESTS rec_pow.slim) -list(APPEND BASIC_TESTS rec_pow2.slim) -list(APPEND BASIC_TESTS restructure1.slim) -list(APPEND BASIC_TESTS restructure2.slim) -list(APPEND BASIC_TESTS simplify_control.slim) -list(APPEND BASIC_TESTS float.slim) -list(APPEND BASIC_TESTS constant_in_use.slim) -list(APPEND BASIC_TESTS arrays.slim) -list(APPEND BASIC_TESTS fn_decl.slim) -list(APPEND BASIC_TESTS math.slim) -list(APPEND BASIC_TESTS comments.slim) -list(APPEND BASIC_TESTS generic_ptrs1.slim) -list(APPEND BASIC_TESTS generic_ptrs2.slim) -list(APPEND BASIC_TESTS subgroup_var.slim) - -list(APPEND BASIC_TESTS reconvergence_heuristics/acyclic1.slim) -list(APPEND BASIC_TESTS reconvergence_heuristics/acyclic2.slim) -list(APPEND BASIC_TESTS reconvergence_heuristics/acyclic_evil.slim) -list(APPEND BASIC_TESTS reconvergence_heuristics/loops1.slim) -list(APPEND BASIC_TESTS reconvergence_heuristics/loops2.slim) -list(APPEND BASIC_TESTS reconvergence_heuristics/multi_exit_loop.slim) -list(APPEND BASIC_TESTS reconvergence_heuristics/nested_loops.slim) - -foreach(T IN LISTS BASIC_TESTS) - add_test(NAME "test/${T}" COMMAND slim ${PROJECT_SOURCE_DIR}/test/${T} -o test.spv) -endforeach() - -add_subdirectory(opt) - -function(spv_outputting_test) - cmake_parse_arguments(PARSE_ARGV 0 F "" "NAME;COMPILER" "EXTRA_ARGS" ) - add_test(NAME ${F_NAME} COMMAND ${CMAKE_COMMAND} -DCOMPILER=$ -DT=${F_NAME} "-DTARGS=${F_EXTRA_ARGS}" -DSRC=${PROJECT_SOURCE_DIR} -DDST=${PROJECT_BINARY_DIR} -P ${PROJECT_SOURCE_DIR}/test/test_with_val.cmake) -endfunction() - -spv_outputting_test(NAME samples/fib.slim COMPILER slim EXTRA_ARGS --entry-point main) -spv_outputting_test(NAME samples/hello_world.slim COMPILER slim EXTRA_ARGS --entry-point main) - -if (TARGET vcc) - add_subdirectory(vcc) -endif () +if (BUILD_TESTING) + add_executable(test_math test_math.c) + target_link_libraries(test_math driver) + add_test(NAME test_math COMMAND test_math) + + add_executable(test_builder test_builder.c) + target_link_libraries(test_builder driver) + add_test(NAME test_builder COMMAND test_builder) + + list(APPEND BASIC_TESTS empty.slim) + list(APPEND BASIC_TESTS entrypoint_args1.slim) + list(APPEND BASIC_TESTS basic_blocks1.slim) + list(APPEND BASIC_TESTS basic_blocks2.slim) + list(APPEND BASIC_TESTS control_flow1.slim) + list(APPEND BASIC_TESTS control_flow2.slim) + list(APPEND BASIC_TESTS functions1.slim) + list(APPEND BASIC_TESTS identity.slim) + list(APPEND BASIC_TESTS memory1.slim) + list(APPEND BASIC_TESTS memory2.slim) + list(APPEND BASIC_TESTS rec_pow.slim) + list(APPEND BASIC_TESTS rec_pow2.slim) + list(APPEND BASIC_TESTS restructure1.slim) + list(APPEND BASIC_TESTS restructure2.slim) + list(APPEND BASIC_TESTS simplify_control.slim) + list(APPEND BASIC_TESTS float.slim) + list(APPEND BASIC_TESTS constant_in_use.slim) + list(APPEND BASIC_TESTS arrays.slim) + list(APPEND BASIC_TESTS fn_decl.slim) + list(APPEND BASIC_TESTS math.slim) + list(APPEND BASIC_TESTS comments.slim) + list(APPEND BASIC_TESTS generic_ptrs1.slim) + list(APPEND BASIC_TESTS generic_ptrs2.slim) + list(APPEND BASIC_TESTS subgroup_var.slim) + + list(APPEND BASIC_TESTS reconvergence_heuristics/acyclic1.slim) + list(APPEND BASIC_TESTS reconvergence_heuristics/acyclic2.slim) + list(APPEND BASIC_TESTS reconvergence_heuristics/acyclic_evil.slim) + list(APPEND BASIC_TESTS reconvergence_heuristics/acyclic_simple_with_arg.slim) + list(APPEND BASIC_TESTS reconvergence_heuristics/loops1.slim) + list(APPEND BASIC_TESTS reconvergence_heuristics/loops2.slim) + list(APPEND BASIC_TESTS reconvergence_heuristics/multi_exit_loop.slim) + list(APPEND BASIC_TESTS reconvergence_heuristics/nested_loops.slim) + + foreach(T IN LISTS BASIC_TESTS) + add_test(NAME "test/${T}" COMMAND slim ${PROJECT_SOURCE_DIR}/test/${T} -o test.spv) + endforeach() + + add_subdirectory(opt) + + function(spv_outputting_test) + cmake_parse_arguments(PARSE_ARGV 0 F "" "NAME;COMPILER" "EXTRA_ARGS" ) + add_test(NAME ${F_NAME} COMMAND ${CMAKE_COMMAND} -DCOMPILER=$ -DT=${F_NAME} "-DTARGS=${F_EXTRA_ARGS}" -DSRC=${PROJECT_SOURCE_DIR} -DDST=${PROJECT_BINARY_DIR} -P ${PROJECT_SOURCE_DIR}/test/test_with_val.cmake) + endfunction() + + spv_outputting_test(NAME samples/fib.slim COMPILER slim EXTRA_ARGS --entry-point main) + spv_outputting_test(NAME samples/hello_world.slim COMPILER slim EXTRA_ARGS --entry-point main) + + if (TARGET vcc) + add_subdirectory(vcc) + endif() +endif() diff --git a/test/arrays.slim b/test/arrays.slim index 5c5cc054a..368cf6a47 100644 --- a/test/arrays.slim +++ b/test/arrays.slim @@ -1,4 +1,3 @@ - type T = struct { f32 x; f32 y; diff --git a/test/basic_blocks2.slim b/test/basic_blocks2.slim new file mode 100644 index 000000000..1dc6e00a0 --- /dev/null +++ b/test/basic_blocks2.slim @@ -0,0 +1,12 @@ +@EntryPoint("Compute") @WorkgroupSize(64, 1, 1) +fn main() { + jump bb1(7); + + cont bb1(varying i32 n) { + jump bb2(n); + } + + cont bb2(varying i32 n) { + return (); + } +} diff --git a/test/comments.slim b/test/comments.slim index 2c7ac5435..3a1587957 100644 --- a/test/comments.slim +++ b/test/comments.slim @@ -1,4 +1,3 @@ - // trivial function that does nothing @EntryPoint("Compute") @WorkgroupSize(SUBGROUP_SIZE, 1, 1) fn main() { diff --git a/test/constant_in_use.slim b/test/constant_in_use.slim index ae800e585..f8495cc58 100644 --- a/test/constant_in_use.slim +++ b/test/constant_in_use.slim @@ -1,4 +1,3 @@ - const i32 NINE = 9; const i32 TEN = 10; diff --git a/test/control_flow1.slim b/test/control_flow1.slim index 55f70e77e..9d8d3f03e 100644 --- a/test/control_flow1.slim +++ b/test/control_flow1.slim @@ -1,8 +1,9 @@ +@Exported fn extend varying i32(varying bool b) { val extended = if i32 (b) { - yield(1); + merge_selection(1); } else { - yield(0); + merge_selection(0); } return (extended); } diff --git a/test/control_flow2.slim b/test/control_flow2.slim index 4c833b3a9..52d6be049 100644 --- a/test/control_flow2.slim +++ b/test/control_flow2.slim @@ -1,3 +1,4 @@ +@Exported fn fac varying i32(varying i32 count) { val x = loop i32 (varying i32 i = 1, varying i32 a = 1) { val r = lt(i, count); // if i < count diff --git a/test/driver/test_elect_first.slim b/test/driver/test_elect_first.slim index b07873218..cd1ee0df9 100644 --- a/test/driver/test_elect_first.slim +++ b/test/driver/test_elect_first.slim @@ -1,4 +1,7 @@ -@EntryPoint("compute") @WorkgroupSize(64, 1, 1) fn main() { +@Internal @Builtin("SubgroupLocalInvocationId") +var input u32 subgroup_local_id; + +@EntryPoint("Compute") @WorkgroupSize(64, 1, 1) fn main() { val tid = subgroup_local_id; val x = tid / u32 4; debug_printf("tid = %d x = %d\n", tid, x); diff --git a/test/driver/test_scalarisation_loop.slim b/test/driver/test_scalarisation_loop.slim index ef9afc4af..999e7ac28 100644 --- a/test/driver/test_scalarisation_loop.slim +++ b/test/driver/test_scalarisation_loop.slim @@ -1,4 +1,7 @@ -@EntryPoint("compute") @WorkgroupSize(64, 1, 1) fn main() { +@Internal @Builtin("SubgroupLocalInvocationId") +var input u32 subgroup_local_id; + +@EntryPoint("Compute") @WorkgroupSize(32, 1, 1) fn main() { val tid = subgroup_local_id; val x = tid / u32 4; debug_printf("tid = %d x = %d\n", tid, x); @@ -13,6 +16,6 @@ break(); } } - debug_printf("Done SP=%d.\n", get_stack_pointer()); + debug_printf("Done SP=%d.\n", get_stack_size()); return (); } diff --git a/test/functions1.slim b/test/functions1.slim index 36241b2f1..01ea3d5e0 100644 --- a/test/functions1.slim +++ b/test/functions1.slim @@ -1,7 +1,9 @@ +@Exported fn identity varying i32(varying i32 i) { return(i); } +@Exported fn f varying i32 (varying i32 i) { val j = identity(i); val k = add(j, 1); diff --git a/test/generic_ptrs1.slim b/test/generic_ptrs1.slim index 26db2aafe..bd627ffd5 100644 --- a/test/generic_ptrs1.slim +++ b/test/generic_ptrs1.slim @@ -1,15 +1,19 @@ +@Exported fn foo1 ptr generic i32(varying ptr global i32 x) { return (convert[ptr generic i32](x)); } +@Exported fn foo2 ptr generic i32(varying ptr shared i32 x) { return (convert[ptr generic i32](x)); } +@Exported fn foo3 ptr generic i32(varying ptr subgroup i32 x) { return (convert[ptr generic i32](x)); } +@Exported fn foo4 ptr generic i32(varying ptr private i32 x) { return (convert[ptr generic i32](x)); } diff --git a/test/generic_ptrs2.slim b/test/generic_ptrs2.slim index f1745a6f7..cb01c8beb 100644 --- a/test/generic_ptrs2.slim +++ b/test/generic_ptrs2.slim @@ -1,3 +1,4 @@ +@Exported fn foo i32(varying ptr generic i32 x) { - return (load(x)); + return (*x); } \ No newline at end of file diff --git a/test/identity.slim b/test/identity.slim index df8ac2f39..b152635b8 100644 --- a/test/identity.slim +++ b/test/identity.slim @@ -1,4 +1,5 @@ // trivial function that returns its argument +@Exported fn identity i32(varying i32 i) { return (i); } diff --git a/test/math.slim b/test/math.slim index d2a6a946f..e55854f89 100644 --- a/test/math.slim +++ b/test/math.slim @@ -1,3 +1,4 @@ +@Exported fn foo f32(varying f32 x) { return (sqrt(x)); } diff --git a/test/memory1.slim b/test/memory1.slim index 23c6d6479..ea5cfb49c 100644 --- a/test/memory1.slim +++ b/test/memory1.slim @@ -1,12 +1,14 @@ @DescriptorSet(0) @DescriptorBinding(0) -global i32 extern_int; +var global i32 extern_int; +@Exported fn read_from_extern i32() { return (extern_int); } +@Exported fn read_from_global_ptr i32(uniform ptr global i32 global_ptr) { - val loaded = load(global_ptr); + val loaded = *global_ptr; return(loaded); } diff --git a/test/memory2.slim b/test/memory2.slim index ccb8dfcf8..323b45925 100644 --- a/test/memory2.slim +++ b/test/memory2.slim @@ -1,6 +1,7 @@ +@Exported fn alloca_load_store i32() { val a = alloca[i32](); - store(a, 9); - val i = load(a); + *a = 9; + val i = *a; return(i); } diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index 6029ae600..f6537f26d 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -3,8 +3,8 @@ target_link_libraries(opt_oracle PRIVATE driver) add_test(NAME "mem2reg1" COMMAND opt_oracle ${CMAKE_CURRENT_SOURCE_DIR}/mem2reg1.slim --no-dynamic-scheduling) set_property(TEST "mem2reg1" PROPERTY ENVIRONMENT "ASAN_OPTIONS=detect_leaks=0") -add_test(NAME "mem2reg2" COMMAND opt_oracle ${CMAKE_CURRENT_SOURCE_DIR}/mem2reg2.slim --no-dynamic-scheduling) -set_property(TEST "mem2reg2" PROPERTY ENVIRONMENT "ASAN_OPTIONS=detect_leaks=0") +# add_test(NAME "mem2reg2" COMMAND opt_oracle ${CMAKE_CURRENT_SOURCE_DIR}/mem2reg2.slim --no-dynamic-scheduling) +# set_property(TEST "mem2reg2" PROPERTY ENVIRONMENT "ASAN_OPTIONS=detect_leaks=0") add_test(NAME "mem2reg3" COMMAND opt_oracle ${CMAKE_CURRENT_SOURCE_DIR}/mem2reg3.slim --no-dynamic-scheduling) set_property(TEST "mem2reg3" PROPERTY ENVIRONMENT "ASAN_OPTIONS=detect_leaks=0") diff --git a/test/opt/mem2reg1.slim b/test/opt/mem2reg1.slim index 992833fa8..d0dcd1d87 100644 --- a/test/opt/mem2reg1.slim +++ b/test/opt/mem2reg1.slim @@ -1,3 +1,4 @@ +@Exported fn f varying i32() { var i32 i = 0; i = 1; diff --git a/test/opt/mem2reg2.slim b/test/opt/mem2reg2.slim index f87f56fc0..7b288fb60 100644 --- a/test/opt/mem2reg2.slim +++ b/test/opt/mem2reg2.slim @@ -1,6 +1,23 @@ +@Exported fn f varying i32(varying i32 x) { var i32 i = 0; - branch ((x > 0), t, j); + branch ((x > 0), t(), j()); + + cont t() { + i = x; + jump j(); + } + + cont j() { + i = 1; + return (i); + } +} + +@Exported +fn g varying i32(varying i32 x) { + var i32 i = 0; + branch ((x > 0), t(), j()); cont t() { i = x; diff --git a/test/opt/mem2reg3.slim b/test/opt/mem2reg3.slim index 2ace39f60..2bf5addbd 100644 --- a/test/opt/mem2reg3.slim +++ b/test/opt/mem2reg3.slim @@ -1,4 +1,20 @@ -fn f varying f32(varying i32 x) { - var i32 i = 0; +@Exported +fn f varying f32() { + var i32 i = 1116340224; // 69.0f return (*(reinterpret[ptr private f32](&i))); } + +@Exported +fn g varying i32() { + var i32 i = 420; + return (*(convert[ptr generic i32](&i))); +} + +@Exported +fn h varying f32() { + var i32 i = 1116340224; // 69.0f + val p = &i; + val p1 = (reinterpret[ptr private f32](&i)); + val p2 = convert[ptr generic f32](p1); + return (*p2); +} diff --git a/test/opt/mem2reg_should_fail.slim b/test/opt/mem2reg_should_fail.slim index 32fd526f0..b41d9e4c3 100644 --- a/test/opt/mem2reg_should_fail.slim +++ b/test/opt/mem2reg_should_fail.slim @@ -1,5 +1,6 @@ fn leak(varying ptr private i32 p); +@Exported fn f varying f32(varying i32 x) { var i32 i = 42; leak(&i); diff --git a/test/opt/opt_oracle.c b/test/opt/opt_oracle.c index b66aad4a0..425162dd4 100644 --- a/test/opt/opt_oracle.c +++ b/test/opt/opt_oracle.c @@ -1,9 +1,12 @@ #include "shady/ir.h" #include "shady/driver.h" +#include "shady/print.h" +#include "shady/visit.h" -#include "log.h" +#include "../shady/passes/passes.h" -#include "../src/shady/visit.h" +#include "log.h" +#include "portability.h" #include #include @@ -13,39 +16,35 @@ static bool expect_memstuff = false; static bool found_memstuff = false; static void search_for_memstuff(Visitor* v, const Node* n) { - if (n->tag == PrimOp_TAG) { - PrimOp payload = n->payload.prim_op; - switch (payload.op) { - case alloca_op: - case alloca_logical_op: - case load_op: - case store_op: - case memcpy_op: { - found_memstuff = true; - break; - } - default: break; + switch (n->tag) { + case Load_TAG: + case Store_TAG: + case CopyBytes_TAG: + case FillBytes_TAG: + case StackAlloc_TAG: + case LocalAlloc_TAG: { + found_memstuff = true; + break; } + default: break; } - visit_node_operands(v, NcDeclaration, n); + shd_visit_node_operands(v, ~(NcMem | NcDeclaration | NcTerminator), n); } -static void after_pass(void* uptr, String pass_name, Module* mod) { - if (strcmp(pass_name, "opt_mem2reg") == 0) { - Visitor v = {.visit_node_fn = search_for_memstuff}; - visit_module(&v, mod); - if (expect_memstuff != found_memstuff) { - error_print("Expected "); - if (!expect_memstuff) - error_print("no more "); - error_print("memory primops in the output.\n"); - dump_module(mod); - exit(-1); - } - dump_module(mod); - exit(0); +static void check_module(Module* mod) { + Visitor v = { .visit_node_fn = search_for_memstuff }; + shd_visit_module(&v, mod); + if (expect_memstuff != found_memstuff) { + shd_error_print("Expected "); + if (!expect_memstuff) + shd_error_print("no more "); + shd_error_print("memory primops in the output.\n"); + shd_dump_module(mod); + exit(-1); } + shd_dump_module(mod); + exit(0); } static void cli_parse_oracle_args(int* pargc, char** argv) { @@ -61,14 +60,45 @@ static void cli_parse_oracle_args(int* pargc, char** argv) { } } - cli_pack_remaining_args(pargc, argv); + shd_pack_remaining_args(pargc, argv); } -static void hook(DriverConfig* args, int* pargc, char** argv) { - args->config.hooks.after_pass.fn = after_pass; - cli_parse_oracle_args(pargc, argv); +static Module* oracle_passes(const CompilerConfig* config, Module* initial_mod) { + IrArena* initial_arena = shd_module_get_arena(initial_mod); + Module** pmod = &initial_mod; + + RUN_PASS(shd_cleanup) + check_module(*pmod); + + return *pmod; } -#define HOOK_STUFF hook(&args, &argc, argv); +int main(int argc, char** argv) { + shd_platform_specific_terminal_init_extras(); + + DriverConfig args = shd_default_driver_config(); + shd_parse_driver_args(&args, &argc, argv); + shd_parse_common_args(&argc, argv); + shd_parse_compiler_config_args(&args.config, &argc, argv); + cli_parse_oracle_args(&argc, argv); + shd_driver_parse_input_files(args.input_filenames, &argc, argv); + + ArenaConfig aconfig = shd_default_arena_config(&args.config.target); + aconfig.optimisations.weaken_non_leaking_allocas = true; + IrArena* arena = shd_new_ir_arena(&aconfig); + Module* mod = shd_new_module(arena, "my_module"); // TODO name module after first filename, or perhaps the last one + + ShadyErrorCodes err = shd_driver_load_source_files(&args, mod); + if (err) + exit(err); + + Module* mod2 = oracle_passes(&args.config, mod); + shd_destroy_ir_arena(shd_module_get_arena(mod2)); + + if (err) + exit(err); + shd_info_print("Compilation successful\n"); -#include "../../src/driver/slim.c" \ No newline at end of file + shd_destroy_ir_arena(arena); + shd_destroy_driver_config(&args); +} \ No newline at end of file diff --git a/test/rec_pow.slim b/test/rec_pow.slim index a1f72803e..42ed9ff0d 100644 --- a/test/rec_pow.slim +++ b/test/rec_pow.slim @@ -1,3 +1,4 @@ +@Exported fn rec_pow i32(varying i32 x, varying i32 y) { if (y > 1) { return (x * rec_pow(x, y - 1)); diff --git a/test/rec_pow2.slim b/test/rec_pow2.slim index 91af62059..101505bc1 100644 --- a/test/rec_pow2.slim +++ b/test/rec_pow2.slim @@ -1,7 +1,9 @@ +@Exported fn rec_pow_chain_helper i32(varying i32 x, varying i32 y) { return (rec_pow_chain(x, y)); } +@Exported fn rec_pow_chain i32(varying i32 x, varying i32 y) { if (y > 1) { return (x * rec_pow_chain_helper(x, y - 1)); diff --git a/test/reconvergence_heuristics/acyclic1.slim b/test/reconvergence_heuristics/acyclic1.slim index 7625c6d2f..e3967d59a 100644 --- a/test/reconvergence_heuristics/acyclic1.slim +++ b/test/reconvergence_heuristics/acyclic1.slim @@ -1,4 +1,4 @@ -@Restructure +@Exported @Restructure fn f1 i32(varying bool b) { jump bb1(); @@ -19,7 +19,7 @@ fn f1 i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn f2 i32(varying bool b) { jump bb1(); @@ -40,9 +40,9 @@ fn f2 i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn g i32(varying bool b) { - branch(b, bb1, bb2)(); + branch(b, bb1(), bb2()); cont bb1() { jump bb3(); @@ -61,7 +61,7 @@ fn g i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn hfixed i32(varying bool b) { branch(b, bb1(), bb2()); @@ -90,7 +90,7 @@ fn hfixed i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn hnot1 i32(varying bool b) { val not_b = ! b; @@ -109,7 +109,7 @@ fn hnot1 i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn hnot2 i32(varying bool b) { branch(b, bb1(), bb2()); @@ -140,9 +140,9 @@ fn hnot2 i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn i i32(varying bool b) { - branch(b, bb1, bb2)(); + branch(b, bb1(), bb2()); cont bb1() { jump bb2(); diff --git a/test/reconvergence_heuristics/acyclic2.slim b/test/reconvergence_heuristics/acyclic2.slim index d3c7c8c8a..187508c7d 100644 --- a/test/reconvergence_heuristics/acyclic2.slim +++ b/test/reconvergence_heuristics/acyclic2.slim @@ -1,4 +1,4 @@ -@Restructure +@Exported @Restructure fn h1 i32(varying bool b) { branch(b, bb1(), bb2()); diff --git a/test/reconvergence_heuristics/acyclic_evil.slim b/test/reconvergence_heuristics/acyclic_evil.slim index 0cb16b328..db34620e5 100644 --- a/test/reconvergence_heuristics/acyclic_evil.slim +++ b/test/reconvergence_heuristics/acyclic_evil.slim @@ -1,4 +1,4 @@ -@Restructure +@Exported @Restructure fn f i32(varying bool b) { jump A(); diff --git a/test/reconvergence_heuristics/acyclic_simple_with_arg.slim b/test/reconvergence_heuristics/acyclic_simple_with_arg.slim new file mode 100644 index 000000000..8c66ec657 --- /dev/null +++ b/test/reconvergence_heuristics/acyclic_simple_with_arg.slim @@ -0,0 +1,20 @@ +@Exported @Restructure +fn f1 i32(varying bool b, varying i32 i) { + jump bb1(); + + cont bb1() { + branch(b, bb2(i), bb3(6, 6)); + } + + cont bb2(varying i32 j) { + jump bb4(j); + } + + cont bb3(uniform i32 x, uniform i32 y) { + jump bb4(x + y); + } + + cont bb4(varying i32 l) { + return (l + i); + } +} \ No newline at end of file diff --git a/test/reconvergence_heuristics/loops1.slim b/test/reconvergence_heuristics/loops1.slim index baeac8680..cb5d435e3 100644 --- a/test/reconvergence_heuristics/loops1.slim +++ b/test/reconvergence_heuristics/loops1.slim @@ -1,4 +1,4 @@ -@Restructure +@Exported @Restructure fn exitingloop_controlblock varying i32(varying bool b) { jump pre_entry(); @@ -27,7 +27,7 @@ fn exitingloop_controlblock varying i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn exitingloop_function varying i32(varying bool b) { jump entry(); @@ -52,7 +52,7 @@ fn exitingloop_function varying i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn exitingloop_values varying i32(varying i32 n) { var i32 r = n + 0; var i32 k = 0; @@ -75,3 +75,27 @@ fn exitingloop_values varying i32(varying i32 n) { return (k); } } + +@Exported @Restructure +fn exitingloop_values_from_loop varying i32(varying i32 n) { + var i32 r = n + 0; + var i32 k = 0; + + jump entry(); + + cont entry() { + val r1 = r + 1; + val loop_cond = (r > 0); + branch(loop_cond, loop_body(), loop_exit()); + + cont loop_exit() { + return (r1); + } + } + + cont loop_body() { + k = k + r; + r = r - 1; + jump entry(); + } +} diff --git a/test/reconvergence_heuristics/loops2.slim b/test/reconvergence_heuristics/loops2.slim index e82fd979c..15113ac67 100644 --- a/test/reconvergence_heuristics/loops2.slim +++ b/test/reconvergence_heuristics/loops2.slim @@ -1,4 +1,4 @@ -@Restructure +@Exported @Restructure fn minimal_loop varying i32(varying bool b) { jump entry(); @@ -15,7 +15,7 @@ fn minimal_loop varying i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn reconverge_inside_loop varying i32(varying bool b) { jump entry(); diff --git a/test/reconvergence_heuristics/multi_exit_loop.slim b/test/reconvergence_heuristics/multi_exit_loop.slim index 348d802ca..2d9f37b10 100644 --- a/test/reconvergence_heuristics/multi_exit_loop.slim +++ b/test/reconvergence_heuristics/multi_exit_loop.slim @@ -1,4 +1,4 @@ -@Restructure +@Exported @Restructure fn loop_with_two_exits varying i32(varying bool b) { jump entry(); @@ -23,7 +23,7 @@ fn loop_with_two_exits varying i32(varying bool b) { } } -@Restructure +@Exported @Restructure fn loop_with_two_exits_and_values varying i32(varying bool b) { jump entry(); diff --git a/test/reconvergence_heuristics/nested_loops.slim b/test/reconvergence_heuristics/nested_loops.slim index d2bbf3ff4..2b9727092 100644 --- a/test/reconvergence_heuristics/nested_loops.slim +++ b/test/reconvergence_heuristics/nested_loops.slim @@ -1,4 +1,4 @@ -@Restructure +@Exported @Restructure fn f i32(varying bool b, varying i32 i) { jump A(); diff --git a/test/restructure1.slim b/test/restructure1.slim index b51eba4e0..ca028114b 100644 --- a/test/restructure1.slim +++ b/test/restructure1.slim @@ -4,6 +4,7 @@ // /!\ there is no reconvergence happening in this function (except for the return itself) // to reconverge, one needs to use the control() construct, such as demonstrated in restructure3 // to be absolutely clear: two distinct sets of threads will execute two different dynamic instances bb3 +@Exported fn f i32(varying bool b) { branch (b, bb1(), bb2()); @@ -23,6 +24,7 @@ fn f i32(varying bool b) { // f is equivalent (and should be turned back into) this function: // (modulo some dataflow jank that opt passes may or may not cleanup) +@Exported fn g i32(varying bool b) { val r = if i32(b) { // blah diff --git a/test/restructure2.slim b/test/restructure2.slim index aa89e2e03..9bf6238ce 100644 --- a/test/restructure2.slim +++ b/test/restructure2.slim @@ -2,6 +2,7 @@ // such loops are only allowed if there is only one path through them that involves a back-edge // in other words: no implicit synchronisation at the loop header +@Exported fn f i32(varying bool b) { jump bb1(); @@ -15,6 +16,7 @@ fn f i32(varying bool b) { } // the behaviour should be equivalent to this structured code: +@Exported fn g i32(varying bool b) { loop() { if (b) { diff --git a/test/subgroup_var.slim b/test/subgroup_var.slim index 594ab32ee..e4a24eeba 100644 --- a/test/subgroup_var.slim +++ b/test/subgroup_var.slim @@ -1,5 +1,6 @@ -subgroup i32 x; +var subgroup i32 x; +@Exported fn foo uniform i32() { return (x); } diff --git a/test/test_builder.c b/test/test_builder.c new file mode 100644 index 000000000..51efa033b --- /dev/null +++ b/test/test_builder.c @@ -0,0 +1,167 @@ +#include "shady/ir.h" +#include "shady/driver.h" +#include "shady/be/dump.h" + +#include "../shady/analysis/cfg.h" + +#include "log.h" + +#include +#include +#include + +#define CHECK(x, failure_handler) { if (!(x)) { shd_error_print(#x " failed\n"); failure_handler; } } + +static void test_body_builder_constants(IrArena* a) { + BodyBuilder* bb = shd_bld_begin_pure(a); + const Node* sum = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, shd_int32_literal(a, 4), shd_int32_literal(a, 38))); + const Node* result = shd_bld_to_instr_yield_value(bb, sum); + CHECK(sum == result, exit(-1)); + CHECK(result->tag == IntLiteral_TAG, exit(-1)); + CHECK(shd_get_int_literal_value(result->payload.int_literal, false) == 42, exit(-1)); +} + +static void test_body_builder_fun_body(IrArena* a) { + Module* m = shd_new_module(a, "test_module"); + const Node* p1 = param(a, shd_as_qualified_type(ptr_type(a, (PtrType) { + .address_space = AsGeneric, + .pointed_type = shd_uint32_type(a), + }), false), NULL); + const Node* p2 = param(a, shd_as_qualified_type(ptr_type(a, (PtrType) { + .address_space = AsGeneric, + .pointed_type = shd_uint32_type(a), + }), false), NULL); + // const Node* p3 = param(a, shd_as_qualified_type(bool_type(a), false), NULL); + // const Node* p4 = param(a, shd_as_qualified_type(uint32_type(a), false), NULL); + Node* fun = function(m, mk_nodes(a, p1, p2), "fun", shd_empty(a), shd_empty(a)); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); + + const Node* p1_value = shd_bld_load(bb, p1); + CHECK(p1_value->tag == Load_TAG, exit(-1)); + Node* true_case = case_(a, shd_empty(a)); + BodyBuilder* tc_builder = shd_bld_begin(a, shd_get_abstraction_mem(true_case)); + shd_bld_store(tc_builder, p1, shd_uint32_literal(a, 0)); + shd_set_abstraction_body(true_case, shd_bld_selection_merge(tc_builder, shd_empty(a))); + shd_bld_if(bb, shd_empty(a), prim_op_helper(a, gt_op, shd_empty(a), mk_nodes(a, p1_value, shd_uint32_literal(a, 0))), true_case, NULL); + + const Node* p2_value = shd_bld_load(bb, p2); + + const Node* sum = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, p1_value, p2_value)); + const Node* return_terminator = fn_ret(a, (Return) { + .mem = shd_bb_mem(bb), + .args = shd_singleton(sum) + }); + shd_set_abstraction_body(fun, shd_bld_finish(bb, return_terminator)); + // set_abstraction_body(fun, finish_body_with_return(bb, singleton(sum))); + + shd_dump_module(m); + + // Follow the CFG and the mems to make sure we arrive back at the initial start ! + CFG* cfg = build_fn_cfg(fun); + const Node* mem = get_terminator_mem(return_terminator); + do { + const Node* omem = shd_get_original_mem(mem); + if (!omem) + break; + mem = omem; + CHECK(mem->tag == AbsMem_TAG, exit(-1)); + CFNode* n = shd_cfg_lookup(cfg, mem->payload.abs_mem.abs); + if (n->idom) { + mem = get_terminator_mem(get_abstraction_body(n->idom->node)); + continue; + } + if (n->structured_idom) { + mem = get_terminator_mem(get_abstraction_body(n->structured_idom->node)); + continue; + } + break; + } while (1); + mem = shd_get_original_mem(mem); + CHECK(mem == shd_get_abstraction_mem(fun), exit(-1)); + shd_destroy_cfg(cfg); +} + +/// There is some "magic" code in body_builder and shd_set_abstraction_body to enable inserting control-flow +/// where there is only a mem dependency. This is useful when writing some complex polyfills. +static void test_body_builder_impure_block(IrArena* a) { + Module* m = shd_new_module(a, "test_module"); + const Node* p1 = param(a, shd_as_qualified_type(ptr_type(a, (PtrType) { + .address_space = AsGeneric, + .pointed_type = shd_uint32_type(a), + }), false), NULL); + Node* fun = function(m, mk_nodes(a, p1), "fun", shd_empty(a), shd_empty(a)); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); + + const Node* first_load = shd_bld_load(bb, p1); + + BodyBuilder* block_builder = shd_bld_begin_pseudo_instr(a, shd_bb_mem(bb)); + shd_bld_store(block_builder, p1, shd_uint32_literal(a, 0)); + shd_bld_add_instruction_extract(bb, shd_bld_to_instr_yield_values(block_builder, shd_empty(a))); + + const Node* second_load = shd_bld_load(bb, p1); + + const Node* sum = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, first_load, second_load)); + const Node* return_terminator = fn_ret(a, (Return) { + .mem = shd_bb_mem(bb), + .args = shd_singleton(sum) + }); + shd_set_abstraction_body(fun, shd_bld_finish(bb, return_terminator)); + + shd_dump_module(m); + + bool found_store = false; + const Node* mem = get_terminator_mem(return_terminator); + while (mem) { + if (mem->tag == Store_TAG) + found_store = true; + mem = shd_get_parent_mem(mem); + } + + CHECK(found_store, exit(-1)); +} + +/// There is some "magic" code in body_builder and shd_set_abstraction_body to enable inserting control-flow +/// where there is only a mem dependency. This is useful when writing some complex polyfills. +static void test_body_builder_impure_block_with_control_flow(IrArena* a) { + Module* m = shd_new_module(a, "test_module"); + const Node* p1 = param(a, shd_as_qualified_type(ptr_type(a, (PtrType) { + .address_space = AsGeneric, + .pointed_type = shd_uint32_type(a), + }), false), NULL); + Node* fun = function(m, mk_nodes(a, p1), "fun", shd_empty(a), shd_empty(a)); + BodyBuilder* bb = shd_bld_begin(a, shd_get_abstraction_mem(fun)); + + const Node* first_load = shd_bld_load(bb, p1); + + BodyBuilder* block_builder = shd_bld_begin_pseudo_instr(a, shd_bb_mem(bb)); + Node* if_true_case = case_(a, shd_empty(a)); + BodyBuilder* if_true_builder = shd_bld_begin(a, shd_get_abstraction_mem(if_true_case)); + shd_bld_store(if_true_builder, p1, shd_uint32_literal(a, 0)); + shd_set_abstraction_body(if_true_case, shd_bld_selection_merge(if_true_builder, shd_empty(a))); + shd_bld_if(block_builder, shd_empty(a), prim_op_helper(a, neq_op, shd_empty(a), mk_nodes(a, first_load, shd_uint32_literal(a, 0))), if_true_case, NULL); + shd_bld_add_instruction_extract(bb, shd_bld_to_instr_yield_values(block_builder, shd_empty(a))); + + const Node* second_load = shd_bld_load(bb, p1); + + const Node* sum = prim_op_helper(a, add_op, shd_empty(a), mk_nodes(a, first_load, second_load)); + const Node* return_terminator = fn_ret(a, (Return) { + .mem = shd_bb_mem(bb), + .args = shd_singleton(sum) + }); + shd_set_abstraction_body(fun, shd_bld_finish(bb, return_terminator)); + + shd_dump_module(m); +} + +int main(int argc, char** argv) { + shd_parse_common_args(&argc, argv); + + TargetConfig target_config = shd_default_target_config(); + ArenaConfig aconfig = shd_default_arena_config(&target_config); + IrArena* a = shd_new_ir_arena(&aconfig); + test_body_builder_constants(a); + test_body_builder_fun_body(a); + test_body_builder_impure_block(a); + test_body_builder_impure_block_with_control_flow(a); + shd_destroy_ir_arena(a); +} diff --git a/test/test_math.c b/test/test_math.c index de7335761..6c89a61bd 100644 --- a/test/test_math.c +++ b/test/test_math.c @@ -1,13 +1,13 @@ -#include -#include -#include - #include "shady/ir.h" #include "shady/driver.h" #include "log.h" -#define CHECK(x, failure_handler) { if (!(x)) { error_print(#x " failed\n"); failure_handler; } } +#include +#include +#include + +#define CHECK(x, failure_handler) { if (!(x)) { shd_error_print(#x " failed\n"); failure_handler; } } static bool check_same_bytes(char* a, char* b, size_t size) { if (memcmp(a, b, size) == 0) @@ -32,7 +32,7 @@ static bool check_same_bytes(char* a, char* b, size_t size) { } static void check_int_literal_against_reference(IrArena* a, const Node* lit, IntLiteral reference) { - const IntLiteral* ptr = resolve_to_int_literal(lit); + const IntLiteral* ptr = shd_resolve_to_int_literal(lit); CHECK(ptr, exit(-1)); IntLiteral got = *ptr; CHECK(got.is_signed == reference.is_signed, exit(-1)); @@ -48,25 +48,25 @@ static void test_int_literals(IrArena* a) { .is_signed = false, .value = 0 }; - check_int_literal_against_reference(a, uint8_literal(a, 0), ref_zero_u8); + check_int_literal_against_reference(a, shd_uint8_literal(a, 0), ref_zero_u8); IntLiteral ref_one_u8 = { .width = IntTy8, .is_signed = false, .value = 1 }; - check_int_literal_against_reference(a, uint8_literal(a, 1), ref_one_u8); + check_int_literal_against_reference(a, shd_uint8_literal(a, 1), ref_one_u8); IntLiteral ref_one_i8 = { .width = IntTy8, .is_signed = true, .value = 1 }; - check_int_literal_against_reference(a, int8_literal(a, 1), ref_one_i8); + check_int_literal_against_reference(a, shd_int8_literal(a, 1), ref_one_i8); IntLiteral ref_minus_one_i8 = { .width = IntTy8, .is_signed = true, .value = 255 }; - check_int_literal_against_reference(a, int8_literal(a, -1), ref_minus_one_i8); + check_int_literal_against_reference(a, shd_int8_literal(a, -1), ref_minus_one_i8); // Check sign extension works right int64_t i64_test_values[] = { 0, 1, 255, 256, -1, 65536, 65535, INT64_MAX, INT64_MIN }; for (size_t i = 0; i < sizeof(i64_test_values) / sizeof(i64_test_values[0]); i++) { @@ -76,25 +76,26 @@ static void test_int_literals(IrArena* a) { .width = IntTy64, .is_signed = true }; - uint64_t extracted_literal_value = get_int_literal_value(reference_literal, true); + uint64_t extracted_literal_value = shd_get_int_literal_value(reference_literal, true); int16_t reference_minus_one_i16 = test_value; CHECK(check_same_bytes((char*) &extracted_literal_value, (char*) &reference_minus_one_i16, sizeof(uint16_t)), exit(-1)); - uint64_t minus_one_u32 = get_int_literal_value(reference_literal, true); + uint64_t minus_one_u32 = shd_get_int_literal_value(reference_literal, true); int32_t reference_minus_one_i32 = test_value; CHECK(check_same_bytes((char*) &minus_one_u32, (char*) &reference_minus_one_i32, sizeof(uint32_t)), exit(-1)); - uint64_t minus_one_u64 = get_int_literal_value(reference_literal, true); + uint64_t minus_one_u64 = shd_get_int_literal_value(reference_literal, true); int64_t reference_minus_one_i64 = test_value; CHECK(check_same_bytes((char*) &minus_one_u64, (char*) &reference_minus_one_i64, sizeof(uint64_t)), exit(-1)); } } int main(int argc, char** argv) { - cli_parse_common_args(&argc, argv); + shd_parse_common_args(&argc, argv); - ArenaConfig acfg = default_arena_config(); - acfg.check_types = true; - acfg.allow_fold = true; - IrArena* a = new_ir_arena(acfg); + TargetConfig target_config = shd_default_target_config(); + ArenaConfig aconfig = shd_default_arena_config(&target_config); + aconfig.check_types = true; + aconfig.allow_fold = true; + IrArena* a = shd_new_ir_arena(&aconfig); test_int_literals(a); - destroy_ir_arena(a); -} \ No newline at end of file + shd_destroy_ir_arena(a); +} diff --git a/test/vcc/CMakeLists.txt b/test/vcc/CMakeLists.txt index a0af7c595..b0fca7fce 100644 --- a/test/vcc/CMakeLists.txt +++ b/test/vcc/CMakeLists.txt @@ -1,18 +1,25 @@ list(APPEND VCC_SIMPLE_TESTS empty.c) list(APPEND VCC_SIMPLE_TESTS address_spaces.c) +set(VCC_TEST_ARGS --vcc-include-path "${PROJECT_BINARY_DIR}/share/vcc/include/") + foreach(T IN LISTS VCC_SIMPLE_TESTS) - add_test(NAME "test/vcc/${T}" COMMAND vcc ${PROJECT_SOURCE_DIR}/test/vcc/${T}) + add_test(NAME "test/vcc/${T}" COMMAND vcc ${PROJECT_SOURCE_DIR}/test/vcc/${T} ${VCC_TEST_ARGS}) endforeach() -spv_outputting_test(NAME test/vcc/branch.c COMPILER vcc EXTRA_ARGS) -spv_outputting_test(NAME test/vcc/loop.c COMPILER vcc EXTRA_ARGS) -spv_outputting_test(NAME test/vcc/goto.c COMPILER vcc EXTRA_ARGS) +spv_outputting_test(NAME test/vcc/branch.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS}) +spv_outputting_test(NAME test/vcc/loop.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS}) +# spv_outputting_test(NAME test/vcc/loop_closed.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS}) +spv_outputting_test(NAME test/vcc/goto.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS}) +spv_outputting_test(NAME test/vcc/ternary.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS}) +spv_outputting_test(NAME test/vcc/string.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS}) + +spv_outputting_test(NAME test/vcc/vec_swizzle.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS} --entry-point test --execution-model Fragment) -spv_outputting_test(NAME test/vcc/vec_swizzle.c COMPILER vcc EXTRA_ARGS --entry-point test --no-dynamic-scheduling --execution-model Fragment) +spv_outputting_test(NAME test/vcc/empty.comp.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS} --entry-point main) -spv_outputting_test(NAME test/vcc/empty.comp.c COMPILER vcc EXTRA_ARGS --entry-point main) +spv_outputting_test(NAME test/vcc/simple.frag.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS} --entry-point main --execution-model Fragment) +spv_outputting_test(NAME test/vcc/checkerboard.frag.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS} --entry-point main --execution-model Fragment) +spv_outputting_test(NAME test/vcc/textured.frag.c COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS} --entry-point main --execution-model Fragment) -spv_outputting_test(NAME test/vcc/simple.frag.c COMPILER vcc EXTRA_ARGS --entry-point main --no-dynamic-scheduling --execution-model Fragment) -spv_outputting_test(NAME test/vcc/checkerboard.frag.c COMPILER vcc EXTRA_ARGS --entry-point main --no-dynamic-scheduling --execution-model Fragment) -spv_outputting_test(NAME test/vcc/textured.frag.c COMPILER vcc EXTRA_ARGS --entry-point main --no-dynamic-scheduling --execution-model Fragment) +add_subdirectory(cpp) diff --git a/test/vcc/address_spaces.c b/test/vcc/address_spaces.c index 0e4dab529..f0948d141 100644 --- a/test/vcc/address_spaces.c +++ b/test/vcc/address_spaces.c @@ -1,3 +1,3 @@ -#define global __attribute__((address_space(1))) +#include global int buffer[256]; diff --git a/test/vcc/cpp/CMakeLists.txt b/test/vcc/cpp/CMakeLists.txt new file mode 100644 index 000000000..515d0b6d0 --- /dev/null +++ b/test/vcc/cpp/CMakeLists.txt @@ -0,0 +1,2 @@ +spv_outputting_test(NAME test/vcc/cpp/vec_swizzle.cpp COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS} --entry-point test --execution-model Fragment --std=c++20) +spv_outputting_test(NAME test/vcc/cpp/textured.frag.cpp COMPILER vcc EXTRA_ARGS ${VCC_TEST_ARGS} --entry-point main --execution-model Fragment --std=c++20) diff --git a/test/vcc/cpp/textured.frag.cpp b/test/vcc/cpp/textured.frag.cpp new file mode 100644 index 000000000..d8fc4a90a --- /dev/null +++ b/test/vcc/cpp/textured.frag.cpp @@ -0,0 +1,19 @@ +#include +#include + +using namespace vcc; + +descriptor_set(0) descriptor_binding(1) uniform_constant sampler2D texSampler; + +location(0) input native_vec3 fragColor; +location(1) input native_vec2 fragTexCoord; + +location(0) output native_vec4 outColor; + +extern "C" { + +fragment_shader void main() { + outColor = texture2D(texSampler, fragTexCoord) * (vec4) { fragColor.x * 2.5f, fragColor.y * 2.5f, fragColor.z * 2.5f, 1.0f }; +} + +} diff --git a/test/vcc/cpp/vec_swizzle.cpp b/test/vcc/cpp/vec_swizzle.cpp new file mode 100644 index 000000000..6635367e3 --- /dev/null +++ b/test/vcc/cpp/vec_swizzle.cpp @@ -0,0 +1,17 @@ +#include + +using namespace vcc; + +extern "C" { + +location(0) vec3 vertexColor; +location(0) vec4 outColor; + +fragment_shader void test() { + vec4 a; + a.xyz = vertexColor; + a.w = 1.0f; + outColor = a; +} + +} \ No newline at end of file diff --git a/test/vcc/loop_closed.c b/test/vcc/loop_closed.c new file mode 100644 index 000000000..a2f897231 --- /dev/null +++ b/test/vcc/loop_closed.c @@ -0,0 +1,11 @@ +int square(int i) { + while (1) { + if (i == 9) + return i; + // if (i % 3 == 0) { + // i-=2; + // continue; + // } + i--; + } +} diff --git a/test/vcc/loop_for_simple.c b/test/vcc/loop_for_simple.c new file mode 100644 index 000000000..d5f6250a6 --- /dev/null +++ b/test/vcc/loop_for_simple.c @@ -0,0 +1,10 @@ +int square(int num) { + for (int i = 0; i < num; i++) { + // if (i == 9) + // break; + // if (i % 2 == 0) + // continue; + num--; + } + return 0; +} diff --git a/test/vcc/loop_two_backedges.c b/test/vcc/loop_two_backedges.c new file mode 100644 index 000000000..1bb67537a --- /dev/null +++ b/test/vcc/loop_two_backedges.c @@ -0,0 +1,10 @@ +int square(int i) { + while (1) { + i--; + if (i % 2 == 0) { + i -= 3; + continue; + } + } + return 0; +} diff --git a/test/vcc/string.c b/test/vcc/string.c new file mode 100644 index 000000000..c2ed032f3 --- /dev/null +++ b/test/vcc/string.c @@ -0,0 +1,3 @@ +void f(void) { + const char* s = "hi"; +} diff --git a/test/vcc/ternary.c b/test/vcc/ternary.c new file mode 100644 index 000000000..ce927cc7d --- /dev/null +++ b/test/vcc/ternary.c @@ -0,0 +1,5 @@ +#include + +int pick(int a, int b, bool c) { + return c ? a : b; +} \ No newline at end of file diff --git a/test/vcc/textured.frag.c b/test/vcc/textured.frag.c index e575f957a..a4b043009 100644 --- a/test/vcc/textured.frag.c +++ b/test/vcc/textured.frag.c @@ -1,7 +1,7 @@ #include #include -descriptor_set(0) descriptor_binding(1) uniform sampler2D texSampler; +descriptor_set(0) descriptor_binding(1) uniform_constant sampler2D texSampler; location(0) input vec3 fragColor; location(1) input vec2 fragTexCoord; diff --git a/vcc-std/CMakeLists.txt b/vcc-std/CMakeLists.txt index 882824b1f..a08742e2a 100644 --- a/vcc-std/CMakeLists.txt +++ b/vcc-std/CMakeLists.txt @@ -1,4 +1,12 @@ add_custom_target(copy-vcc-files ALL COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}/../share/vcc/include/) add_dependencies(shady copy-vcc-files) +if (NOT MSVC) + enable_language(CXX) + add_executable(test_vcc_vec src/test_vec.cpp) + target_include_directories(test_vcc_vec PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/) + set_property(TARGET test_vcc_vec PROPERTY CXX_STANDARD 20) + add_test(NAME test_vcc_vec COMMAND test_vcc_vec) +endif () + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/../share/vcc/ DESTINATION share/vcc) diff --git a/vcc-std/include/shady.h b/vcc-std/include/shady.h index 0e81f9788..8af86cc0f 100644 --- a/vcc-std/include/shady.h +++ b/vcc-std/include/shady.h @@ -18,47 +18,76 @@ namespace vcc { #define descriptor_binding(i) __attribute__((annotate("shady::descriptor_binding::"#i))) #define local_size(x, y, z) __attribute__((annotate("shady::workgroup_size::"#x"::"#y"::"#z))) -#define input __attribute__((address_space(389))) -#define output __attribute__((address_space(390))) -#define uniform __attribute__((annotate("shady::uniform"))) -#define push_constant __attribute__((address_space(392))) -#define private __attribute__((address_space(5))) -#define private_logical __attribute__((address_space(385))) +#define input __attribute__((annotate("shady::extern::389"))) +#define output __attribute__((annotate("shady::extern::390"))) +// maybe deprecate it ? +#define uniform_constant __attribute__((annotate("shady::extern::398"))) +#define uniform_block __attribute__((annotate("shady::extern::395"))) +#define push_constant __attribute__((annotate("shady::extern::392"))) +#define global __attribute__((annotate("shady::extern::1"))) +#define shared __attribute__((annotate("shady::extern::3"))) +#define private __attribute__((annotate("shady::extern::5"))) + +float sqrtf(float f) __asm__("shady::prim_op::sqrt"); -typedef float vec4 __attribute__((ext_vector_type(4))); -typedef float vec3 __attribute__((ext_vector_type(3))); -typedef float vec2 __attribute__((ext_vector_type(2))); - -typedef int ivec4 __attribute__((ext_vector_type(4))); -typedef int ivec3 __attribute__((ext_vector_type(3))); -typedef int ivec2 __attribute__((ext_vector_type(2))); +#if defined(__cplusplus) & !defined(SHADY_CPP_NO_NAMESPACE) +} +#endif -typedef unsigned uvec4 __attribute__((ext_vector_type(4))); -typedef unsigned uvec3 __attribute__((ext_vector_type(3))); -typedef unsigned uvec2 __attribute__((ext_vector_type(2))); +#include "shady_vec.h" +#include "shady_mat.h" -typedef struct __shady_builtin_sampler2D {} sampler2D; +#if defined(__cplusplus) & !defined(SHADY_CPP_NO_NAMESPACE) +namespace vcc { +#endif -vec4 texture2D(const sampler2D, vec2) __asm__("shady::prim_op::sample_texture"); +typedef __attribute__((address_space(0x1000))) struct __shady_builtin_sampler1D* sampler1D; +typedef __attribute__((address_space(0x1001))) struct __shady_builtin_sampler2D* sampler2D; +typedef __attribute__((address_space(0x1002))) struct __shady_builtin_sampler3D* sampler3D; +typedef __attribute__((address_space(0x1003))) struct __shady_builtin_sampler3D* samplerCube; + +native_vec4 texture1D(const sampler1D, float) __asm__("shady::prim_op::sample_texture"); +native_vec4 texture2D(const sampler2D, native_vec2) __asm__("shady::prim_op::sample_texture"); +native_vec4 texture3D(const sampler3D, native_vec3) __asm__("shady::prim_op::sample_texture"); +native_vec4 textureCube(const samplerCube, native_vec3) __asm__("shady::prim_op::sample_texture"); + +#if defined(__cplusplus) +native_vec4 texture(const sampler1D, float) __asm__("shady::prim_op::sample_texture"); +native_vec4 texture(const sampler2D, native_vec2) __asm__("shady::prim_op::sample_texture"); +native_vec4 texture(const sampler3D, native_vec3) __asm__("shady::prim_op::sample_texture"); +native_vec4 texture(const samplerCube, native_vec3) __asm__("shady::prim_op::sample_texture"); +#endif // builtins __attribute__((annotate("shady::builtin::FragCoord"))) -input vec4 gl_FragCoord; +input native_vec4 gl_FragCoord; __attribute__((annotate("shady::builtin::Position"))) -output vec4 gl_Position; +output native_vec4 gl_Position; __attribute__((annotate("shady::builtin::WorkgroupId"))) __attribute__((address_space(389))) -uvec3 gl_WorkGroupID; +native_uvec3 gl_WorkGroupID; __attribute__((annotate("shady::builtin::VertexIndex"))) __attribute__((address_space(389))) -input int gl_VertexIndex; +unsigned gl_VertexIndex; + +__attribute__((annotate("shady::builtin::SubgroupId"))) +__attribute__((address_space(389))) +unsigned subgroup_id; + +__attribute__((annotate("shady::builtin::SubgroupLocalInvocationId"))) +__attribute__((address_space(389))) +unsigned subgroup_local_id; __attribute__((annotate("shady::builtin::WorkgroupSize"))) __attribute__((address_space(389))) -uvec3 gl_WorkGroupSize; +native_uvec3 gl_WorkGroupSize; + +__attribute__((annotate("shady::builtin::GlobalInvocationId"))) +__attribute__((address_space(389))) +native_uvec3 gl_GlobalInvocationID; #if defined(__cplusplus) & !defined(SHADY_CPP_NO_NAMESPACE) } diff --git a/vcc-std/include/shady_mat.h b/vcc-std/include/shady_mat.h new file mode 100644 index 000000000..4500fed96 --- /dev/null +++ b/vcc-std/include/shady_mat.h @@ -0,0 +1,195 @@ +#if defined(__cplusplus) & !defined(SHADY_CPP_NO_NAMESPACE) +namespace vcc { +#endif + +typedef union mat4_ mat4; + +static inline mat4 transpose_mat4(mat4 src); +static inline mat4 mul_mat4(mat4 l, mat4 r); +static inline vec4 mul_mat4_vec4f(mat4 l, vec4 r); + +union mat4_ { + struct { + // we use row-major ordering + float m00, m01, m02, m03, + m10, m11, m12, m13, + m20, m21, m22, m23, + m30, m31, m32, m33; + }; + //vec4 rows[4]; + float arr[16]; + + +#if defined(__cplusplus) + mat4 operator*(const mat4& other) { + return mul_mat4(*this, other); + } + + vec4 operator*(const vec4& other) { + return mul_mat4_vec4f(*this, other); + } +#endif +}; + +static const mat4 identity_mat4 = { + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, +}; + +static inline mat4 transpose_mat4(mat4 src) { + return (mat4) { + src.m00, src.m10, src.m20, src.m30, + src.m01, src.m11, src.m21, src.m31, + src.m02, src.m12, src.m22, src.m32, + src.m03, src.m13, src.m23, src.m33, + }; +} + +static inline mat4 invert_mat4(mat4 m) { + float a = m.m00 * m.m11 - m.m01 * m.m10; + float b = m.m00 * m.m12 - m.m02 * m.m10; + float c = m.m00 * m.m13 - m.m03 * m.m10; + float d = m.m01 * m.m12 - m.m02 * m.m11; + float e = m.m01 * m.m13 - m.m03 * m.m11; + float f = m.m02 * m.m13 - m.m03 * m.m12; + float g = m.m20 * m.m31 - m.m21 * m.m30; + float h = m.m20 * m.m32 - m.m22 * m.m30; + float i = m.m20 * m.m33 - m.m23 * m.m30; + float j = m.m21 * m.m32 - m.m22 * m.m31; + float k = m.m21 * m.m33 - m.m23 * m.m31; + float l = m.m22 * m.m33 - m.m23 * m.m32; + float det = a * l - b * k + c * j + d * i - e * h + f * g; + det = 1.0f / det; + mat4 r; + r.m00 = ( m.m11 * l - m.m12 * k + m.m13 * j) * det; + r.m01 = (-m.m01 * l + m.m02 * k - m.m03 * j) * det; + r.m02 = ( m.m31 * f - m.m32 * e + m.m33 * d) * det; + r.m03 = (-m.m21 * f + m.m22 * e - m.m23 * d) * det; + r.m10 = (-m.m10 * l + m.m12 * i - m.m13 * h) * det; + r.m11 = ( m.m00 * l - m.m02 * i + m.m03 * h) * det; + r.m12 = (-m.m30 * f + m.m32 * c - m.m33 * b) * det; + r.m13 = ( m.m20 * f - m.m22 * c + m.m23 * b) * det; + r.m20 = ( m.m10 * k - m.m11 * i + m.m13 * g) * det; + r.m21 = (-m.m00 * k + m.m01 * i - m.m03 * g) * det; + r.m22 = ( m.m30 * e - m.m31 * c + m.m33 * a) * det; + r.m23 = (-m.m20 * e + m.m21 * c - m.m23 * a) * det; + r.m30 = (-m.m10 * j + m.m11 * h - m.m12 * g) * det; + r.m31 = ( m.m00 * j - m.m01 * h + m.m02 * g) * det; + r.m32 = (-m.m30 * d + m.m31 * b - m.m32 * a) * det; + r.m33 = ( m.m20 * d - m.m21 * b + m.m22 * a) * det; + return r; +} + +/*mat4 perspective_mat4(float a, float fov, float n, float f) { + float pi = M_PI; + float s = 1.0f / tanf(fov * 0.5f * (pi / 180.0f)); + return (mat4) { + s / a, 0, 0, 0, + 0, s, 0, 0, + 0, 0, -f / (f - n), -1.f, + 0, 0, - (f * n) / (f - n), 0 + }; +}*/ + +static inline mat4 translate_mat4(vec3 offset) { + mat4 m = identity_mat4; + m.m30 = offset.x; + m.m31 = offset.y; + m.m32 = offset.z; + return m; +} + +/*mat4 rotate_axis_mat4(unsigned int axis, float f) { + mat4 m = { 0 }; + m.m33 = 1; + + unsigned int t = (axis + 2) % 3; + unsigned int s = (axis + 1) % 3; + + m.rows[t].arr[t] = cosf(f); + m.rows[t].arr[s] = -sinf(f); + m.rows[s].arr[t] = sinf(f); + m.rows[s].arr[s] = cosf(f); + + // leave that unchanged + m.rows[axis].arr[axis] = 1; + + return m; +}*/ + +static inline mat4 mul_mat4(mat4 l, mat4 r) { + mat4 dst = { 0 }; +#define a(i, j) m##i##j +#define t(bc, br, i) l.a(i, br) * r.a(bc, i) +#define e(bc, br) dst.a(bc, br) = t(bc, br, 0) + t(bc, br, 1) + t(bc, br, 2) + t(bc, br, 3); +#define row(c) e(c, 0) e(c, 1) e(c, 2) e(c, 3) +#define genmul() row(0) row(1) row(2) row(3) + genmul() + return dst; +#undef a +#undef t +#undef e +#undef row +#undef genmul +} + +static inline vec4 mul_mat4_vec4f(mat4 l, vec4 r) { + float src[4] = { r.x, r.y, r.z, r.w }; + float dst[4]; +#define a(i, j) m##i##j +#define t(bc, br, i) l.a(i, br) * src[i] +#define e(bc, br) dst[br] = t(bc, br, 0) + t(bc, br, 1) + t(bc, br, 2) + t(bc, br, 3); +#define row(c) e(c, 0) e(c, 1) e(c, 2) e(c, 3) +#define genmul() row(0) + genmul() + return (vec4) { dst[0], dst[1], dst[2], dst[3] }; +} + +typedef union { + struct { + // we use row-major ordering + float m00, m01, m02, + m10, m11, m12, + m20, m21, m22; + }; + //vec4 rows[4]; + float arr[9]; +} Mat3f; + +static const Mat3f identity_mat3f = { + 1, 0, 0, + 0, 1, 0, + 0, 0, 1, +}; + +static Mat3f transpose_mat3f(Mat3f src) { + return (Mat3f) { + src.m00, src.m10, src.m20, + src.m01, src.m11, src.m21, + src.m02, src.m12, src.m22, + }; +} + +static Mat3f mul_mat3f(Mat3f l, Mat3f r) { + Mat3f dst = { 0 }; +#define a(i, j) m##i##j +#define t(bc, br, i) l.a(i, br) * r.a(bc, i) +#define e(bc, br) dst.a(bc, br) = t(bc, br, 0) + t(bc, br, 1) + t(bc, br, 2); +#define row(c) e(c, 0) e(c, 1) e(c, 2) +#define genmul() row(0) row(1) row(2) + genmul() + return dst; +#undef a +#undef t +#undef e +#undef row +#undef genmul +} + +typedef Mat3f mat3; + +#if defined(__cplusplus) & !defined(SHADY_CPP_NO_NAMESPACE) +} +#endif diff --git a/vcc-std/include/shady_vec.h b/vcc-std/include/shady_vec.h new file mode 100644 index 000000000..5ecd33ad8 --- /dev/null +++ b/vcc-std/include/shady_vec.h @@ -0,0 +1,302 @@ +#if defined(__cplusplus) & !defined(SHADY_CPP_NO_NAMESPACE) +namespace vcc { +#endif + +#if defined(__cplusplus) & !defined(SHADY_CPP_NO_WRAPPER_CLASSES) +#define SHADY_ENABLE_WRAPPER_CLASSES +#endif + +#ifdef __clang__ +typedef float native_vec4 __attribute__((ext_vector_type(4))); +typedef float native_vec3 __attribute__((ext_vector_type(3))); +typedef float native_vec2 __attribute__((ext_vector_type(2))); + +typedef int native_ivec4 __attribute__((ext_vector_type(4))); +typedef int native_ivec3 __attribute__((ext_vector_type(3))); +typedef int native_ivec2 __attribute__((ext_vector_type(2))); + +typedef unsigned native_uvec4 __attribute__((ext_vector_type(4))); +typedef unsigned native_uvec3 __attribute__((ext_vector_type(3))); +typedef unsigned native_uvec2 __attribute__((ext_vector_type(2))); +#else +// gcc can't cope with this +typedef float native_vec4 __attribute__((vector_size(16))); +typedef float native_vec3 __attribute__((vector_size(16))); +typedef float native_vec2 __attribute__((vector_size(8))); + +typedef int native_ivec4 __attribute__((vector_size(16))); +typedef int native_ivec3 __attribute__((vector_size(16))); +typedef int native_ivec2 __attribute__((vector_size(8))); + +typedef unsigned native_uvec4 __attribute__((vector_size(16))); +typedef unsigned native_uvec3 __attribute__((vector_size(16))); +typedef unsigned native_uvec2 __attribute__((vector_size(8))); +#endif + +#ifdef SHADY_ENABLE_WRAPPER_CLASSES +static_assert(__cplusplus >= 202002L, "C++20 is required"); +template +struct vec_native_type {}; + +template<> struct vec_native_type { using Native = native_vec4; }; +template<> struct vec_native_type { using Native = native_ivec4; }; +template<> struct vec_native_type { using Native = native_uvec4; }; +template<> struct vec_native_type { using Native = native_vec3; }; +template<> struct vec_native_type { using Native = native_ivec3; }; +template<> struct vec_native_type { using Native = native_uvec3; }; +template<> struct vec_native_type { using Native = native_vec2; }; +template<> struct vec_native_type { using Native = native_ivec2; }; +template<> struct vec_native_type { using Native = native_uvec2; }; +template<> struct vec_native_type { using Native = float; }; +template<> struct vec_native_type { using Native = int; }; +template<> struct vec_native_type { using Native = unsigned; }; + +template +struct Mapping { + int data[len]; +}; + +template +static consteval bool fits(unsigned len, Mapping mapping) { + for (unsigned i = 0; i < dst_len; i++) { + if (mapping.data[i] >= len) + return false; + } + return true; +} + +template +constexpr void for_range(F f) +{ + if constexpr (B < E) + { + f.template operator()(); + for_range((f)); + } +} + +template +struct vec_impl { + using This = vec_impl; + using Native = typename vec_native_type::Native; + + vec_impl() = default; + vec_impl(T s) { + for_range<0, len>([&](){ + arr[i] = s; + }); + } + + vec_impl(T x, T y, T z, T w) requires (len >= 4) { + this->arr[0] = x; + this->arr[1] = y; + this->arr[2] = z; + this->arr[3] = w; + } + + vec_impl(vec_impl xy, T z, T w) requires (len >= 4) : vec_impl(xy.x, xy.y, z, w) {} + vec_impl(T x, vec_impl yz, T w) requires (len >= 4) : vec_impl(x, yz.x, yz.y, w) {} + vec_impl(T x, T y, vec_impl zw) requires (len >= 4) : vec_impl(x, y, zw.x, zw.y) {} + vec_impl(vec_impl xy, vec_impl zw) requires (len >= 4) : vec_impl(xy.x, xy.y, zw.x, zw.y) {} + + vec_impl(vec_impl xyz, T w) requires (len >= 4) : vec_impl(xyz.x, xyz.y, xyz.z, w) {} + vec_impl(T x, vec_impl yzw) requires (len >= 4) : vec_impl(x, yzw.x, yzw.y, yzw.z) {} + + vec_impl(T x, T y, T z) requires (len >= 3) { + this->arr[0] = x; + this->arr[1] = y; + this->arr[2] = z; + } + + vec_impl(vec_impl xy, T z) requires (len >= 3) : vec_impl(xy.x, xy.y, z) {} + vec_impl(T x, vec_impl yz) requires (len >= 3) : vec_impl(x, yz.x, yz.y) {} + + vec_impl(T x, T y) { + this->arr[0] = x; + this->arr[1] = y; + } + + vec_impl(Native n) { + for_range<0, len>([&](){ + arr[i] = n[i]; + }); + } + + operator Native() const { + Native n; + for_range<0, len>([&](){ + n[i] = arr[i]; + }); + return n; + } + + template mapping> // requires(fits(len, mapping)) + struct Swizzler { + using That = vec_impl; + using ThatNative = typename vec_native_type::Native; + + operator That() const requires(dst_len > 1 && fits(len, mapping)) { + auto src = reinterpret_cast(this); + That dst; + for_range<0, dst_len>([&](){ + dst.arr[i] = src->arr[mapping.data[i]]; + }); + return dst; + } + + operator ThatNative() const requires(dst_len > 1 && fits(len, mapping)) { + That that = *this; + return that; + } + + operator T() const requires(dst_len == 1 && fits(len, mapping)) { + auto src = reinterpret_cast(this); + return src->arr[mapping.data[0]]; + } + + void operator=(const T& t) requires(dst_len == 1 && fits(len, mapping)) { + auto src = reinterpret_cast(this); + src->arr[mapping.data[0]] = t; + } + + void operator=(const That& src) requires(dst_len > 1 && fits(len, mapping)) { + auto dst = reinterpret_cast(this); + for_range<0, dst_len>([&](){ + dst->arr[mapping.data[i]] = src.arr[i]; + }); + } + }; + + This operator +(This other) { + This result; + for_range<0, len>([&](){ + result.arr[i] = this->arr[i] + other.arr[i]; + }); + return result; + } + This operator -(This other) { + This result; + for_range<0, len>([&](){ + result.arr[i] = this->arr[i] - other.arr[i]; + }); + return result; + } + This operator *(This other) { + This result; + for_range<0, len>([&](){ + result.arr[i] = this->arr[i] * other.arr[i]; + }); + return result; + } + This operator /(This other) { + This result; + for_range<0, len>([&](){ + result.arr[i] = this->arr[i] / other.arr[i]; + }); + return result; + } + This operator *(T s) { + This result; + for_range<0, len>([&](){ + result.arr[i] = this->arr[i] * s; + }); + return result; + } + This operator /(T s) { + This result; + for_range<0, len>([&](){ + result.arr[i] = this->arr[i] / s; + }); + return result; + } + +#define COMPONENT_0 x +#define COMPONENT_1 y +#define COMPONENT_2 z +#define COMPONENT_3 w + +#define CONCAT_4_(a, b, c, d) a##b##c##d +#define CONCAT_4(a, b, c, d) CONCAT_4_(a, b, c, d) +#define SWIZZLER_4(a, b, c, d) Swizzler<4, Mapping<4> { a, b, c, d }> CONCAT_4(COMPONENT_##a, COMPONENT_##b, COMPONENT_##c, COMPONENT_##d); +#define GEN_SWIZZLERS_4_D(D, C, B, A) SWIZZLER_4(A, B, C, D) +#define GEN_SWIZZLERS_4_C(C, B, A) GEN_SWIZZLERS_4_D(0, C, B, A) GEN_SWIZZLERS_4_D(1, C, B, A) GEN_SWIZZLERS_4_D(2, C, B, A) GEN_SWIZZLERS_4_D(3, C, B, A) +#define GEN_SWIZZLERS_4_B(B, A) GEN_SWIZZLERS_4_C(0, B, A) GEN_SWIZZLERS_4_C(1, B, A) GEN_SWIZZLERS_4_C(2, B, A) GEN_SWIZZLERS_4_C(3, B, A) +#define GEN_SWIZZLERS_4_A(A) GEN_SWIZZLERS_4_B(0, A) GEN_SWIZZLERS_4_B(1, A) GEN_SWIZZLERS_4_B(2, A) GEN_SWIZZLERS_4_B(3, A) +#define GEN_SWIZZLERS_4() GEN_SWIZZLERS_4_A(0) GEN_SWIZZLERS_4_A(1) GEN_SWIZZLERS_4_A(2) GEN_SWIZZLERS_4_A(3) + +#define CONCAT_3_(a, b, c) a##b##c +#define CONCAT_3(a, b, c) CONCAT_3_(a, b, c) +#define SWIZZLER_3(a, b, c) Swizzler<3, Mapping<3> { a, b, c }> CONCAT_3(COMPONENT_##a, COMPONENT_##b, COMPONENT_##c); +#define GEN_SWIZZLERS_3_C(C, B, A) SWIZZLER_3(A, B, C) +#define GEN_SWIZZLERS_3_B(B, A) GEN_SWIZZLERS_3_C(0, B, A) GEN_SWIZZLERS_3_C(1, B, A) GEN_SWIZZLERS_3_C(2, B, A) GEN_SWIZZLERS_3_C(3, B, A) +#define GEN_SWIZZLERS_3_A(A) GEN_SWIZZLERS_3_B(0, A) GEN_SWIZZLERS_3_B(1, A) GEN_SWIZZLERS_3_B(2, A) GEN_SWIZZLERS_3_B(3, A) +#define GEN_SWIZZLERS_3() GEN_SWIZZLERS_3_A(0) GEN_SWIZZLERS_3_A(1) GEN_SWIZZLERS_3_A(2) GEN_SWIZZLERS_3_A(3) + +#define CONCAT_2_(a, b) a##b +#define CONCAT_2(a, b) CONCAT_2_(a, b) +#define SWIZZLER_2(a, b) Swizzler<2, Mapping<2> { a, b }> CONCAT_2(COMPONENT_##a, COMPONENT_##b); +#define GEN_SWIZZLERS_2_B(B, A) SWIZZLER_2(A, B) +#define GEN_SWIZZLERS_2_A(A) GEN_SWIZZLERS_2_B(0, A) GEN_SWIZZLERS_2_B(1, A) GEN_SWIZZLERS_2_B(2, A) GEN_SWIZZLERS_2_B(3, A) +#define GEN_SWIZZLERS_2() GEN_SWIZZLERS_2_A(0) GEN_SWIZZLERS_2_A(1) GEN_SWIZZLERS_2_A(2) GEN_SWIZZLERS_2_A(3) + +#define SWIZZLER_1(a) Swizzler<1, Mapping<1> { a }> COMPONENT_##a; +#define GEN_SWIZZLERS_1_A(A) SWIZZLER_1(A) +#define GEN_SWIZZLERS_1() GEN_SWIZZLERS_1_A(0) GEN_SWIZZLERS_1_A(1) GEN_SWIZZLERS_1_A(2) GEN_SWIZZLERS_1_A(3) + + union { + GEN_SWIZZLERS_1() + GEN_SWIZZLERS_2() + GEN_SWIZZLERS_3() + GEN_SWIZZLERS_4() + T arr[len]; + }; + + static_assert(sizeof(T) * len == sizeof(arr)); +}; + +typedef vec_impl vec4; +typedef vec_impl uvec4; +typedef vec_impl ivec4; + +typedef vec_impl vec3; +typedef vec_impl uvec3; +typedef vec_impl ivec3; + +typedef vec_impl vec2; +typedef vec_impl uvec2; +typedef vec_impl ivec2; + +template +float lengthSquared(vec_impl vec) { + float acc = 0.0f; + for_range<0, len>([&](){ + acc += vec.arr[i] * vec.arr[i]; + }); + return acc; +} + +template +float length(vec_impl vec) { + return sqrtf(lengthSquared(vec)); +} + +template +vec_impl normalize(vec_impl vec) { + return vec / length(vec); +} + +#else +typedef native_vec4 vec4; +typedef native_vec3 vec3; +typedef native_vec2 vec2; +typedef native_ivec4 ivec4; +typedef native_ivec3 ivec3; +typedef native_ivec2 ivec2; +typedef native_uvec4 uvec4; +typedef native_uvec3 uvec3; +typedef native_uvec2 uvec2; +#endif + +#if defined(__cplusplus) & !defined(SHADY_CPP_NO_NAMESPACE) +} +#endif diff --git a/vcc-std/src/test_vec.cpp b/vcc-std/src/test_vec.cpp new file mode 100644 index 000000000..9ba6ba556 --- /dev/null +++ b/vcc-std/src/test_vec.cpp @@ -0,0 +1,66 @@ +#define __SHADY__ + +#include +#include "shady_vec.h" + +using namespace vcc; + +void check_native_casts(const vec4& v4, const uvec4& u4, const ivec4& i4) { + native_vec4 nv4 = v4; + native_uvec4 nu4 = u4; + native_ivec4 ni4 = i4; + vec4 rv4 = nv4; + native_vec3 nv3; + nv3 = v4.xyz; +} + +void check_vector_scalar_ctors() { + vec4 x4 = vec4(0.5f); + vec4 y4 = { 0.5f }; + vec4 z4(0.5f); + vec4 w4 = 0.5f; + + vec3 x3 = vec3(0.5f); + vec3 y3 = { 0.5f }; + vec3 z3(0.5f); + vec3 w3 = 0.5f; + + vec2 x2 = vec2(0.5f); + vec2 y2 = { 0.5f }; + vec2 z2(0.5f); + vec2 w2 = 0.5f; +} + +void check_swizzle_const(const vec4& v4, const uvec4& u4, const ivec4& i4) { + v4.x; + v4.xy; + v4.xyz; + v4.xyzw; + + v4.xxxx; + v4.xyww; +} + +void check_ctor_weird() { + vec4(vec2(0.5f), vec2(0.5f)); + vec4(0.5f, vec2(0.5f), 0.5f); + vec4(0.5f, vec3(0.5f)); + vec4(vec3(0.5f), 0.5f); +} + +void check_swizzle_mut(vec4& v) { + v.x = 0.5f; + v.xy = vec2(0.5f, 0.9f); +} + +#include +#include +int main(int argc, char** argv) { + vec4 v(1.0f, 0.5f, 0.0f, -1.0f); + float f; + f = v.x; printf("f = %f;\n", f); assert(f == 1.0f); + f = v.y; printf("f = %f;\n", f); assert(f == 0.5f); + f = v.z; printf("f = %f;\n", f); assert(f == 0.0f); + f = v.w; printf("f = %f;\n", f); assert(f == -1.0f); + std::unique_ptr uptr; +} \ No newline at end of file diff --git a/vcc/CMakeLists.txt b/vcc/CMakeLists.txt new file mode 100644 index 000000000..6973e0806 --- /dev/null +++ b/vcc/CMakeLists.txt @@ -0,0 +1,25 @@ +if (NOT TARGET shady_fe_llvm) + message("LLVM front-end unavailable. Skipping Vcc.") +else() + option(SHADY_ENABLE_VCC "Allows compiling C and C++ code with Shady." ON) +endif() + +if (SHADY_ENABLE_VCC) + set (VCC_CLANG_EXECUTABLE_NAME "clang" CACHE STRING "What 'clang' executable Vcc should call into") + + add_library(vcc_lib STATIC vcc_lib.c) + target_link_libraries(vcc_lib PUBLIC driver api) + + add_executable(vcc vcc.c) + target_compile_definitions(vcc_lib PRIVATE "VCC_CLANG_EXECUTABLE_NAME=${VCC_CLANG_EXECUTABLE_NAME}") + target_link_libraries(vcc PRIVATE api vcc_lib) + install(TARGETS vcc_lib vcc EXPORT shady_export_set) + + if (WIN32) + add_custom_command(TARGET vcc POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy -t $ $ + COMMAND_EXPAND_LISTS + ) + endif () + #message("Vcc will be built together with shady") +endif () \ No newline at end of file diff --git a/vcc/vcc.c b/vcc/vcc.c new file mode 100644 index 000000000..fe2587a1f --- /dev/null +++ b/vcc/vcc.c @@ -0,0 +1,47 @@ +#include "vcc/driver.h" + +#include "shady/ir/arena.h" +#include "shady/ir/module.h" + +#include "log.h" +#include "list.h" +#include "util.h" +#include "portability.h" + +int main(int argc, char** argv) { + shd_platform_specific_terminal_init_extras(); + + DriverConfig args = shd_default_driver_config(); + VccConfig vcc_options = vcc_init_config(&args.config); + shd_parse_driver_args(&args, &argc, argv); + shd_parse_common_args(&argc, argv); + shd_parse_compiler_config_args(&args.config, &argc, argv); + cli_parse_vcc_args(&vcc_options, &argc, argv); + shd_driver_parse_input_files(args.input_filenames, &argc, argv); + + if (shd_list_count(args.input_filenames) == 0) { + shd_error_print("Missing input file. See --help for proper usage"); + exit(MissingInputArg); + } + + ArenaConfig aconfig = shd_default_arena_config(&args.config.target); + IrArena* arena = shd_new_ir_arena(&aconfig); + + vcc_check_clang(); + + if (vcc_options.only_run_clang) + vcc_options.tmp_filename = shd_format_string_new("%s", args.output_filename); + vcc_run_clang(&vcc_options, shd_list_count(args.input_filenames), shd_read_list(String, args.input_filenames)); + + if (!vcc_options.only_run_clang) { + Module* mod = vcc_parse_back_into_module(&args.config, &vcc_options, "my_module"); + shd_driver_compile(&args, mod); + shd_destroy_ir_arena(shd_module_get_arena(mod)); + } + + shd_info_print("Done\n"); + + destroy_vcc_options(vcc_options); + shd_destroy_ir_arena(arena); + shd_destroy_driver_config(&args); +} diff --git a/vcc/vcc_lib.c b/vcc/vcc_lib.c new file mode 100644 index 000000000..e15504fe8 --- /dev/null +++ b/vcc/vcc_lib.c @@ -0,0 +1,156 @@ +#include "vcc/driver.h" + +#include "log.h" +#include "list.h" +#include "util.h" +#include "growy.h" +#include "portability.h" + +#include +#include + +#define STRINGIFY2(x) #x +#define STRINGIFY(x) STRINGIFY2(x) +#define VCC_CLANG STRINGIFY(VCC_CLANG_EXECUTABLE_NAME) + +uint32_t shd_hash(const void* data, size_t size); + +void cli_parse_vcc_args(VccConfig* options, int* pargc, char** argv) { + int argc = *pargc; + + for (int i = 1; i < argc; i++) { + if (argv[i] == NULL) + continue; + else if (strcmp(argv[i], "--vcc-keep-tmp-file") == 0) { + argv[i] = NULL; + options->delete_tmp_file = false; + options->tmp_filename = shd_format_string_new("vcc_tmp.ll"); + continue; + } else if (strcmp(argv[i], "--vcc-include-path") == 0) { + argv[i] = NULL; + i++; + if (i == argc) + shd_error("Missing subgroup size name"); + if (options->include_path) + free((void*) options->include_path); + options->include_path = shd_format_string_new("%s", argv[i]); + continue; + } else if (strcmp(argv[i], "--only-run-clang") == 0) { + argv[i] = NULL; + options->only_run_clang = true; + continue; + } + } + + shd_pack_remaining_args(pargc, argv); +} + +void vcc_check_clang(void) { + int clang_retval = system(VCC_CLANG" --version"); + if (clang_retval != 0) + shd_error("clang not present in path or otherwise broken (retval=%d)", clang_retval); +} + +VccConfig vcc_init_config(CompilerConfig* compiler_config) { + VccConfig vcc_config = { + .only_run_clang = false, + }; + + // magic! + //compiler_config->hacks.recover_structure = true; + compiler_config->input_cf.add_scope_annotations = true; + compiler_config->input_cf.has_scope_annotations = true; + + String self_path = shd_get_executable_location(); + String working_dir = shd_strip_path(self_path); + if (!vcc_config.include_path) { + vcc_config.include_path = shd_format_string_new("%s/../share/vcc/include/", working_dir); + } + free((void*) working_dir); + free((void*) self_path); + return vcc_config; +} + +void destroy_vcc_options(VccConfig vcc_options) { + if (vcc_options.include_path) + free((void*) vcc_options.include_path); + if (vcc_options.tmp_filename) + free((void*) vcc_options.tmp_filename); +} + +void vcc_run_clang(VccConfig* vcc_options, size_t num_source_files, String* input_filenames) { + Growy* g = shd_new_growy(); + shd_growy_append_string(g, VCC_CLANG); + String self_path = shd_get_executable_location(); + String working_dir = shd_strip_path(self_path); + shd_growy_append_formatted(g, " -c -emit-llvm -S -g -O0 -ffreestanding -Wno-main-return-type -Xclang -fpreserve-vec3-type --target=spir64-unknown-unknown -isystem\"%s\" -D__SHADY__=1", vcc_options->include_path); + free((void*) working_dir); + free((void*) self_path); + + if (!vcc_options->tmp_filename) { + if (vcc_options->only_run_clang) { + shd_error_print("Please provide an output filename.\n"); + shd_error_die(); + } + char* tmp_alloc; + vcc_options->tmp_filename = tmp_alloc = malloc(33); + tmp_alloc[32] = '\0'; + uint32_t hash = 0; + for (size_t i = 0; i < num_source_files; i++) { + String filename = input_filenames[i]; + hash ^= shd_hash(filename, strlen(filename)); + } + srand(hash); + for (size_t i = 0; i < 32; i++) { + tmp_alloc[i] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"[rand() % (10 + 26 * 2)]; + } + } + shd_growy_append_formatted(g, " -o %s", vcc_options->tmp_filename); + + for (size_t i = 0; i < num_source_files; i++) { + String filename = input_filenames[i]; + + shd_growy_append_string(g, " \""); + shd_growy_append_bytes(g, strlen(filename), filename); + shd_growy_append_string(g, "\""); + } + + shd_growy_append_bytes(g, 1, "\0"); + char* arg_string = shd_growy_deconstruct(g); + + shd_info_print("built command: %s\n", arg_string); + + FILE* stream = popen(arg_string, "r"); + free(arg_string); + + Growy* json_bytes = shd_new_growy(); + while (true) { + char buf[4096]; + int read = fread(buf, 1, sizeof(buf), stream); + if (read == 0) + break; + shd_growy_append_bytes(json_bytes, read, buf); + } + shd_growy_append_string(json_bytes, "\0"); + char* llvm_result = shd_growy_deconstruct(json_bytes); + int clang_returned = pclose(stream); + shd_info_print("Clang returned %d and replied: \n%s", clang_returned, llvm_result); + free(llvm_result); + if (clang_returned) + exit(ClangInvocationFailed); +} + +Module* vcc_parse_back_into_module(CompilerConfig* config, VccConfig* vcc_options, String module_name) { + size_t len; + char* llvm_ir; + if (!shd_read_file(vcc_options->tmp_filename, &len, &llvm_ir)) + exit(InputFileIOError); + Module* mod; + shd_driver_load_source_file(config, SrcLLVM, len, llvm_ir, module_name, &mod); + free(llvm_ir); + + if (vcc_options->delete_tmp_file) + remove(vcc_options->tmp_filename); + + return mod; +} \ No newline at end of file diff --git a/zhady/CMakeLists.txt b/zhady/CMakeLists.txt new file mode 100644 index 000000000..6ad881b34 --- /dev/null +++ b/zhady/CMakeLists.txt @@ -0,0 +1,51 @@ +find_package(SWIG) +find_package(JNI) +find_package(Java) + +if (NOT SWIG_FOUND) + message("SWIG not found. Skipping Java bindings.") +elseif (NOT SWIG_FOUND) + message("JNI not found. Skipping Java bindings.") +elseif (NOT SWIG_FOUND) + message("Java not found. Skipping Java bindings.") +else() + option(SHADY_ENABLE_JAVA_BINDINGS "Allows using shady (and if enabled, Vcc) in JVM applications" ON) +endif() + +if (SHADY_ENABLE_JAVA_BINDINGS) + message("Enabling Java bindings") + + add_custom_target(zhady_dir COMMAND cmake -E ${CMAKE_CURRENT_BINARY_DIR}/java_sources/de/unisaarland/zhady) + + include(UseSWIG) + swig_add_library(zhady_shared_lib TYPE SHARED LANGUAGE java SOURCES shady.i OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/java_sources/de/unisaarland/zhady) + set_property(TARGET zhady_shared_lib PROPERTY SWIG_USE_TARGET_INCLUDE_DIRECTORIES TRUE) + set_property(TARGET zhady_shared_lib PROPERTY SWIG_COMPILE_OPTIONS -package de.unisaarland.zhady) + target_link_libraries(zhady_shared_lib PUBLIC JNI::JNI api driver runtime) + + if (TARGET vcc_lib) + swig_add_library(vcc_swig_c TYPE OBJECT LANGUAGE java SOURCES vcc.i OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/java_sources/de/unisaarland/zhady) + set_property(TARGET vcc_swig_c PROPERTY SWIG_USE_TARGET_INCLUDE_DIRECTORIES TRUE) + set_property(TARGET vcc_swig_c PROPERTY SWIG_COMPILE_OPTIONS -package de.unisaarland.zhady) + target_link_libraries(vcc_swig_c PUBLIC JNI::JNI api vcc_lib) + target_link_libraries(zhady_shared_lib PRIVATE vcc_swig_c vcc_lib) + endif () + + include(UseJava) + get_property(zhady_jar_sources TARGET zhady_shared_lib PROPERTY SWIG_SUPPORT_FILES) + get_property(zhady_jar_sources_dir TARGET zhady_shared_lib PROPERTY SWIG_SUPPORT_FILES_DIRECTORY) + #set(CMAKE_JAVA_COMPILE_FLAGS ${CMAKE_CURRENT_BINARY_DIR}/java_sources/de/unisaarland/zhady/*.java) + #set(CMAKE_JAVA_COMPILE_FLAGS "-sourcepath" "${CMAKE_CURRENT_BINARY_DIR}/java_sources/") + + if (TARGET vcc_lib) + get_property(vcc_java_sources TARGET vcc_swig_c PROPERTY SWIG_SUPPORT_FILES) + list(APPEND zhady_jar_sources ${vcc_java_sources}) + endif () + + # message("Zhady sources: ${zhady_jar_sources}") + # message("Zhady sources dir: ${zhady_jar_sources_dir}") + set(CMAKE_JAVA_INCLUDE_PATH ${CMAKE_CURRENT_BINARY_DIR}/java_sources/) + add_jar(zhady_jar SOURCES ${zhady_jar_sources}) + + install(TARGETS zhady_shared_lib EXPORT shady_export_set) +endif () diff --git a/zhady/common.i b/zhady/common.i new file mode 100644 index 000000000..3b5cef7ed --- /dev/null +++ b/zhady/common.i @@ -0,0 +1,3 @@ +SWIG_JAVABODY_METHODS(public, public, SWIGTYPE) +SWIG_JAVABODY_PROXY(public, public, SWIGTYPE) +SWIG_JAVABODY_TYPEWRAPPER(public, public, public, SWIGTYPE) diff --git a/zhady/shady.i b/zhady/shady.i new file mode 100644 index 000000000..58d910071 --- /dev/null +++ b/zhady/shady.i @@ -0,0 +1,38 @@ +%include "common.i" +%include +//%apply int { _Bool }; + +%module shady +%{ +#include "shady/ir.h" + +* bb + +* bb + +#include "shady/runtime.h" + +* d + +* r + +* runtime + +#include "shady/driver.h" + +* mod + +#include "shady/config.h" +#include "shady/be/c.h" +#include "shady/be/spirv.h" +#include "shady/be/dump.h" +%} + +%include "shady/ir.h" +%include "grammar_generated.h" +%include "shady/driver.h" +%include "shady/runtime.h" +%include "shady/config.h" +%include "shady/be/c.h" +%include "shady/be/spirv.h" +%include "shady/be/dump.h" \ No newline at end of file diff --git a/zhady/vcc.i b/zhady/vcc.i new file mode 100644 index 000000000..ad84606aa --- /dev/null +++ b/zhady/vcc.i @@ -0,0 +1,8 @@ +%include common.i + +%module vcc +%{ +#include "vcc/driver.h" +%} + +%include "vcc/driver.h"