Skip to content

Commit

Permalink
Fix memory problems with fruit injection
Browse files Browse the repository at this point in the history
The bindInstance must outlive the Injector
  • Loading branch information
netheril96 committed Mar 12, 2024
1 parent 0afb992 commit dcb8da6
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 119 deletions.
206 changes: 104 additions & 102 deletions sources/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cryptopp/secblock.h>
#include <fruit/component.h>
#include <fruit/fruit.h>
#include <fruit/fruit_forward_decls.h>
#include <json/json.h>
#include <optional>
#include <tclap/CmdLine.h>
Expand Down Expand Up @@ -1061,6 +1062,7 @@ class MountCommand : public _SinglePasswordCommandBase
false};

FSConfig config{};
lite_format::NameNormalizationFlags name_norm_flags{};

private:
std::vector<const char*> to_c_style_args(const std::vector<std::string>& args)
Expand Down Expand Up @@ -1103,124 +1105,125 @@ class MountCommand : public _SinglePasswordCommandBase

fruit::Component<FuseHighLevelOpsBase> get_fuse_high_ops_component()
{
auto name_norm_flags = std::make_shared<lite_format::NameNormalizationFlags>();
if (plain_text_names.getValue())
{
name_norm_flags->no_op = true;
name_norm_flags.no_op = true;
}
if (normalization.getValue() == "nfc")
{
name_norm_flags->should_normalize_nfc = true;
name_norm_flags.should_normalize_nfc = true;
}
else if (normalization.getValue() == "casefold")
{
name_norm_flags->should_case_fold = true;
name_norm_flags.should_case_fold = true;
}
else if (normalization.getValue() == "casefold+nfc")
{
name_norm_flags->should_normalize_nfc = true;
name_norm_flags->should_case_fold = true;
name_norm_flags.should_normalize_nfc = true;
name_norm_flags.should_case_fold = true;
}
else if (normalization.getValue() != "none")
{
throw_runtime_error("Invalid flag of --normalization: " + normalization.getValue());
}
name_norm_flags->supports_long_name = config.long_name_component;

auto partial
= fruit::createComponent()
.bindInstance(*this)
.install(::securefs::lite_format::get_name_translator_component, name_norm_flags)
.install(full_format::get_table_io_component, config.version)
.registerProvider<fruit::Annotated<tSkipVerification, bool>(const MountCommand&)>(
[](const MountCommand& cmd) { return cmd.insecure.getValue(); })
.registerProvider<fruit::Annotated<tVerify, bool>(const MountCommand&)>(
[](const MountCommand& cmd) { return !cmd.insecure.getValue(); })
.registerProvider<fruit::Annotated<tStoreTimeWithinFs, bool>(
const MountCommand&)>([](const MountCommand& cmd)
{ return cmd.config.version == 3; })
.registerProvider<fruit::Annotated<tReadOnly, bool>(const MountCommand&)>(
[](const MountCommand& cmd)
{
// TODO: Support readonly mounts.
return false;
})
.registerProvider(
[]() { return new BS::thread_pool(std::thread::hardware_concurrency() * 2); })
.bind<Directory, BtreeDirectory>()
.bindInstance<fruit::Annotated<tMaxPaddingSize, unsigned>>(config.max_padding)
.bindInstance<fruit::Annotated<tIvSize, unsigned>>(config.iv_size)
.bindInstance<fruit::Annotated<tBlockSize, unsigned>>(config.block_size)
.bindInstance<fruit::Annotated<tMasterKey, CryptoPP::AlignedSecByteBlock>>(
config.master_key)
.registerProvider<fruit::Annotated<tMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() != key_type::size())
{
throw_runtime_error("Master key size mismatch");
}
return key_type(master_key.data(), key_type::size());
})
.registerProvider<fruit::Annotated<tNameMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() < key_type::size())
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data(), key_type::size());
})
.registerProvider<fruit::Annotated<tContentMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() < key_type::size() * 2)
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data() + key_type::size(), key_type::size());
})
.registerProvider<fruit::Annotated<tXattrMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() < 3 * key_type::size())
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data() + 2 * key_type::size(),
key_type::size());
})
.registerProvider<fruit::Annotated<tPaddingMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>,
fruit::Annotated<tMaxPaddingSize, const unsigned&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key,
const unsigned& max_padding)
{
if (max_padding <= 0)
{
return key_type{};
}
if (master_key.size() < 4 * key_type::size())
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data() + 3 * key_type::size(),
key_type::size());
})
.registerProvider([](const MountCommand& cmd)
{ return new OSService(cmd.data_dir.getValue()); });
if (config.version == 4)
{
return partial.bind<FuseHighLevelOpsBase, lite_format::FuseHighLevelOps>();
}
else
name_norm_flags.supports_long_name = config.long_name_component;

auto internal_binder = [](unsigned format_version)
-> fruit::Component<
fruit::Required<lite_format::FuseHighLevelOps, full_format::FuseHighLevelOps>,
FuseHighLevelOpsBase>
{
return partial.bind<FuseHighLevelOpsBase, full_format::FuseHighLevelOps>();
}
if (format_version < 4)
{
return fruit::createComponent()
.bind<FuseHighLevelOpsBase, full_format::FuseHighLevelOps>();
}
return fruit::createComponent()
.bind<FuseHighLevelOpsBase, lite_format::FuseHighLevelOps>();
};

return fruit::createComponent()
.bindInstance(*this)
.install(+internal_binder, config.version)
.install(::securefs::lite_format::get_name_translator_component, &name_norm_flags)
.install(full_format::get_table_io_component, config.version)
.registerProvider<fruit::Annotated<tSkipVerification, bool>(const MountCommand&)>(
[](const MountCommand& cmd) { return cmd.insecure.getValue(); })
.registerProvider<fruit::Annotated<tVerify, bool>(const MountCommand&)>(
[](const MountCommand& cmd) { return !cmd.insecure.getValue(); })
.registerProvider<fruit::Annotated<tStoreTimeWithinFs, bool>(const MountCommand&)>(
[](const MountCommand& cmd) { return cmd.config.version == 3; })
.registerProvider<fruit::Annotated<tReadOnly, bool>(const MountCommand&)>(
[](const MountCommand& cmd)
{
// TODO: Support readonly mounts.
return false;
})
.registerProvider(
[]() { return new BS::thread_pool(std::thread::hardware_concurrency() * 2); })
.bind<Directory, BtreeDirectory>()
.bindInstance<fruit::Annotated<tMaxPaddingSize, unsigned>>(config.max_padding)
.bindInstance<fruit::Annotated<tIvSize, unsigned>>(config.iv_size)
.bindInstance<fruit::Annotated<tBlockSize, unsigned>>(config.block_size)
.bindInstance<fruit::Annotated<tMasterKey, CryptoPP::AlignedSecByteBlock>>(
config.master_key)
.registerProvider<fruit::Annotated<tMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() != key_type::size())
{
throw_runtime_error("Master key size mismatch");
}
return key_type(master_key.data(), key_type::size());
})
.registerProvider<fruit::Annotated<tNameMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() < key_type::size())
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data(), key_type::size());
})
.registerProvider<fruit::Annotated<tContentMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() < key_type::size() * 2)
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data() + key_type::size(), key_type::size());
})
.registerProvider<fruit::Annotated<tXattrMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key)
{
if (master_key.size() < 3 * key_type::size())
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data() + 2 * key_type::size(), key_type::size());
})
.registerProvider<fruit::Annotated<tPaddingMasterKey, key_type>(
fruit::Annotated<tMasterKey, const CryptoPP::AlignedSecByteBlock&>,
fruit::Annotated<tMaxPaddingSize, const unsigned&>)>(
[](const CryptoPP::AlignedSecByteBlock& master_key, const unsigned& max_padding)
{
if (max_padding <= 0)
{
return key_type{};
}
if (master_key.size() < 4 * key_type::size())
{
throw_runtime_error("Master key too short");
}
return key_type(master_key.data() + 3 * key_type::size(), key_type::size());
})
.registerProvider([](const MountCommand& cmd)
{ return new OSService(cmd.data_dir.getValue()); });
}

public:
Expand Down Expand Up @@ -1475,7 +1478,6 @@ class MountCommand : public _SinglePasswordCommandBase
#endif
auto op = FuseHighLevelOpsBase::build_ops(native_xattr);
VERBOSE_LOG("Calling fuse_main with arguments: %s", escape_args(fuse_args));
recreate_logger();
return fuse_main(static_cast<int>(fuse_args.size()),
const_cast<char**>(to_c_style_args(fuse_args).data()),
&op,
Expand Down
4 changes: 3 additions & 1 deletion sources/lite_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <absl/utility/utility.h>
#include <cryptopp/blake2.h>
#include <cryptopp/sha.h>
#include <fruit/component.h>
#include <fruit/fruit.h>
#include <fruit/fruit_forward_decls.h>
#include <uni_algo/case.h>
#include <uni_algo/norm.h>

Expand Down Expand Up @@ -1117,7 +1119,7 @@ namespace
} // namespace

fruit::Component<fruit::Required<fruit::Annotated<tNameMasterKey, key_type>>, NameTranslator>
get_name_translator_component(std::shared_ptr<NameNormalizationFlags> flags)
get_name_translator_component(const NameNormalizationFlags* flags)
{
if (flags->no_op)
{
Expand Down
2 changes: 1 addition & 1 deletion sources/lite_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ struct NameNormalizationFlags
};

fruit::Component<fruit::Required<fruit::Annotated<tNameMasterKey, key_type>>, NameTranslator>
get_name_translator_component(std::shared_ptr<NameNormalizationFlags> args);
get_name_translator_component(const NameNormalizationFlags* args);

class FuseHighLevelOps : public ::securefs::FuseHighLevelOpsBase
{
Expand Down
30 changes: 15 additions & 15 deletions test/test_lite_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,33 +62,33 @@ namespace

TEST_CASE("case folding name translator")
{
auto flags = std::make_shared<NameNormalizationFlags>();
flags->should_case_fold = true;
NameNormalizationFlags flags{};
flags.should_case_fold = true;
fruit::Injector<NameTranslator> injector(
+[](std::shared_ptr<NameNormalizationFlags> flags) -> fruit::Component<NameTranslator>
+[](const NameNormalizationFlags* flags) -> fruit::Component<NameTranslator>
{
return fruit::createComponent()
.install(get_name_translator_component, flags)
.install(get_test_component);
},
std::move(flags));
&flags);
auto t = injector.get<NameTranslator*>();
CHECK(t->encrypt_full_path(u8"/abCDe/ß", nullptr)
== t->encrypt_full_path(u8"/ABCde/ss", nullptr));
}

TEST_CASE("Unicode normalizing name translator")
{
auto flags = std::make_shared<NameNormalizationFlags>();
flags->should_normalize_nfc = true;
NameNormalizationFlags flags{};
flags.should_normalize_nfc = true;
fruit::Injector<NameTranslator> injector(
+[](std::shared_ptr<NameNormalizationFlags> flags) -> fruit::Component<NameTranslator>
+[](const NameNormalizationFlags* flags) -> fruit::Component<NameTranslator>
{
return fruit::createComponent()
.install(get_name_translator_component, flags)
.install(get_test_component);
},
std::move(flags));
&flags);
auto t = injector.get<NameTranslator*>();

CHECK(t->encrypt_full_path(u8"/aaa/ÄÄÄ", nullptr)
Expand All @@ -101,24 +101,24 @@ namespace
TEST_CASE("Lite FuseHighLevelOps")
{
auto whole_component
= [](std::shared_ptr<OSService> os) -> fruit::Component<FuseHighLevelOps>
= [](OSService* os,
const NameNormalizationFlags* flags) -> fruit::Component<FuseHighLevelOps>
{
auto flags = std::make_shared<NameNormalizationFlags>();
flags->supports_long_name = true;
return fruit::createComponent()
.install(get_name_translator_component, flags)
.install(get_test_component)

.bindInstance(*os);
};

auto temp_dir_name = OSService::temp_name("tmp/lite", "dir");
OSService::get_default().ensure_directory(temp_dir_name, 0755);
auto root = std::make_shared<OSService>(temp_dir_name);
OSService root(temp_dir_name);
NameNormalizationFlags flags{};
flags.supports_long_name = true;

fruit::Injector<FuseHighLevelOps> injector(+whole_component, root);
fruit::Injector<FuseHighLevelOps> injector(+whole_component, &root, &flags);
auto& ops = injector.get<FuseHighLevelOps&>();
testing::test_fuse_ops(ops, *root);
testing::test_fuse_ops(ops, root);
}
} // namespace
} // namespace securefs::lite_format

0 comments on commit dcb8da6

Please sign in to comment.