diff --git a/src/common/file_util.cpp b/src/common/file_util.cpp index 6dda955fe..2e3874de8 100644 --- a/src/common/file_util.cpp +++ b/src/common/file_util.cpp @@ -1516,6 +1516,46 @@ std::size_t IOFile::WriteImpl(const void* data, std::size_t length, std::size_t #endif } +bool IOFile::ReadLine(std::string& line) { + line.clear(); + + char ch; + bool read_anything = false; + + while (true) { + const std::size_t read = ReadImpl(&ch, sizeof(ch), 1); + + if (read != sizeof(ch)) { + return read_anything; + } + read_anything = true; + + if (ch == '\n') { + return true; + } + + // Always convert to UNIX style + if (ch != '\r') { + line.push_back(ch); + } + } +} + +size_t IOFile::WriteLine(const std::string_view line) { + const size_t written_line = WriteImpl(line.data(), line.size(), 1); + if (written_line != line.size()) { + return written_line; + } + + char nl = '\n'; + const size_t written_nl = WriteImpl(&nl, sizeof(nl), 1); + if (written_nl != sizeof(nl)) { + return written_nl; + } + + return written_line + written_nl; +} + bool IOFile::Resize(u64 size) { if (!IsOpen() || 0 != #if defined(HAVE_LIBRETRO_VFS) diff --git a/src/common/file_util.h b/src/common/file_util.h index 8484b02a1..f12cac5bc 100644 --- a/src/common/file_util.h +++ b/src/common/file_util.h @@ -428,6 +428,27 @@ public: return WriteImpl(data.data(), data.size(), sizeof(T)); } + /** + * Reads the file line by line, returning true if data + * was read and false when reaching the end of file. + * + * @param line The output string to write the read data to + * + * @returns Whether the line was read or not + */ + bool ReadLine(std::string& line); + + /** + * Writes the specified line to the file + * automatically appending a newline + * character to it. + * + * @param line The input string to write + * + * @returns Count of bytes written, including the newline. + */ + size_t WriteLine(const std::string_view line); + [[nodiscard]] virtual bool IsOpen() const { return nullptr != m_file; } diff --git a/src/core/hle/service/http/http_c.cpp b/src/core/hle/service/http/http_c.cpp index 59cc47b3a..5fc4ec2f5 100644 --- a/src/core/hle/service/http/http_c.cpp +++ b/src/core/hle/service/http/http_c.cpp @@ -11,6 +11,7 @@ #include #include "common/archives.h" #include "common/assert.h" +#include "common/file_util.h" #include "common/scope_exit.h" #include "common/string_util.h" #include "core/core.h" @@ -297,6 +298,9 @@ void Context::MakeRequest() { request.method = request_method_strings.at(method); request.path = url_info.path; + // Apply URL replacements if any + url_info.host = url_replacer->Apply(url_info.host); + request.progress = [this](u64 current, u64 total) -> bool { // TODO(B3N30): Is there a state that shows response header are available current_download_size_bytes = current; @@ -450,8 +454,6 @@ void Context::MakeRequestSSL(httplib::Request& request, const URLInfo& url_info, } bool Context::ContentProvider(size_t offset, size_t length, httplib::DataSink& sink) { - state = RequestState::SendingRequest; - if (!post_data_raw.empty()) { sink.write(post_data_raw.data() + offset, length); } @@ -462,8 +464,6 @@ bool Context::ContentProvider(size_t offset, size_t length, httplib::DataSink& s } bool Context::ChunkedContentProvider(size_t offset, httplib::DataSink& sink) { - state = RequestState::SendingRequest; - finish_post_data.Wait(); switch (post_data_type) { @@ -788,6 +788,7 @@ void HTTP_C::CreateContext(Kernel::HLERequestContext& ctx) { contexts[context_counter].socket_buffer_size = 0; contexts[context_counter].handle = context_counter; contexts[context_counter].session_id = session_data->session_id; + contexts[context_counter].url_replacer = &url_replacer; session_data->num_http_contexts++; @@ -858,8 +859,6 @@ void HTTP_C::GetRequestState(Kernel::HLERequestContext& ctx) { return; } - LOG_DEBUG(Service_HTTP, "called, context_handle={}", context_handle); - Context& http_context = GetContext(context_handle); RequestState state = http_context.state; @@ -1414,7 +1413,6 @@ void HTTP_C::NotifyFinishSendPostData(Kernel::HLERequestContext& ctx) { } http_context.finish_post_data.Set(); - http_context.post_pending_request = false; http_context.current_copied_data = 0; http_context.request_future = @@ -2017,6 +2015,61 @@ void HTTP_C::Finalize(Kernel::HLERequestContext& ctx) { LOG_WARNING(Service_HTTP, "(STUBBED) called"); } +void HTTP_C::RegisterURLReplacement(Kernel::HLERequestContext& ctx) { + IPC::RequestParser rp(ctx); + const u32 pattern_size = rp.Pop(); + const u32 replacement_size = rp.Pop(); + + const std::vector& pattern_buf = rp.PopStaticBuffer(); + const std::vector& replacement_buf = rp.PopStaticBuffer(); + + std::string pattern(reinterpret_cast(pattern_buf.data()), + std::min(static_cast(pattern_size), pattern_buf.size())); + std::string replacement( + reinterpret_cast(replacement_buf.data()), + std::min(static_cast(replacement_size), replacement_buf.size())); + + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + if (url_replacer.HasRule(pattern)) { + rb.Push(Result{ErrorDescription::AlreadyExists, ErrorModule::HTTP, + ErrorSummary::InvalidArgument, ErrorLevel::Status}); + return; + } + + Result res = url_replacer.AddRule(pattern, replacement) + ? ResultSuccess + : Result{ErrorDescription::InvalidCombination, ErrorModule::HTTP, + ErrorSummary::InvalidArgument, ErrorLevel::Status}; + if (res.IsSuccess()) { + res = url_replacer.Save() ? res + : Result{ErrorDescription::OutOfMemory, ErrorModule::HTTP, + ErrorSummary::Internal, ErrorLevel::Permanent}; + } + + rb.Push(res); +} + +void HTTP_C::UnregisterURLReplacement(Kernel::HLERequestContext& ctx) { + IPC::RequestParser rp(ctx); + const u32 pattern_size = rp.Pop(); + + const std::vector& pattern_buf = rp.PopStaticBuffer(); + + std::string pattern(reinterpret_cast(pattern_buf.data()), + std::min(static_cast(pattern_size), pattern_buf.size())); + + bool deleted = url_replacer.DeleteRule(pattern); + Result res = deleted ? ResultSuccess + : Result{ErrorDescription::NotFound, ErrorModule::HTTP, + ErrorSummary::NotFound, ErrorLevel::Info}; + if (deleted) { + url_replacer.Save(); + } + + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(res); +} + void HTTP_C::GetDownloadSizeState(Kernel::HLERequestContext& ctx) { IPC::RequestParser rp(ctx); const Context::Handle context_handle = rp.Pop(); @@ -2185,6 +2238,96 @@ void HTTP_C::DecryptClCertA() { ClCertA.init = true; } +URLReplacer::URLReplacer() { + const std::string path{fmt::format("{}/http_hle_replace_rules.txt", + FileUtil::GetUserPath(FileUtil::UserPath::SysDataDir))}; + + FileUtil::IOFile f(path, "rb"); + if (!f.IsOpen()) { + return; + } + + std::string pattern; + std::string replacement; + while (f.ReadLine(pattern) && f.ReadLine(replacement)) { + try { + rules.push_back(Rule{ + .regex = boost::regex(pattern), + .pattern = pattern, + .replacement = replacement, + }); + } catch (const boost::regex_error& e) { + LOG_ERROR(Service_HTTP, "Failed to load HTTP HLE replacement pattern \"{}\": {}", + pattern, e.what()); + } + } +} + +bool URLReplacer::HasRule(const std::string& pattern) { + for (const auto& rule : rules) { + if (rule.pattern == pattern) { + return true; + } + } + return false; +} + +bool URLReplacer::AddRule(const std::string& pattern, const std::string& replacement) { + try { + rules.push_back(Rule{ + .regex = boost::regex(pattern), + .pattern = pattern, + .replacement = replacement, + }); + } catch (const boost::regex_error& e) { + return false; + } + return true; +} + +bool URLReplacer::DeleteRule(const std::string& pattern) { + const auto old_size = rules.size(); + + std::erase_if(rules, [&](const Rule& rule) { return rule.pattern == pattern; }); + + return rules.size() != old_size; +} + +std::string URLReplacer::Apply(const std::string& url) const { + std::string result = url; + + for (const auto& rule : rules) { + if (boost::regex_search(result, rule.regex)) { + result = boost::regex_replace(result, rule.regex, rule.replacement, + boost::match_default | boost::format_all); + LOG_WARNING(Service_HTTP, "rule \"{}\" has replaced URL \"{}\" to \"{}\"", rule.pattern, + url, result); + break; + } + } + + return result; +} + +bool URLReplacer::Save() { + const std::string path{fmt::format("{}/http_hle_replace_rules.txt", + FileUtil::GetUserPath(FileUtil::UserPath::SysDataDir))}; + + FileUtil::IOFile f(path, "wb"); + + for (const auto& rule : rules) { + if ((f.WriteLine(rule.pattern) != rule.pattern.size() + 1) || + (f.WriteLine(rule.replacement) != rule.replacement.size() + 1)) { + LOG_ERROR(Service_HTTP, "failed to write URL replacement rules"); + f.Close(); + FileUtil::Delete(path); + return false; + } + } + + return true; +} + HTTP_C::HTTP_C() : ServiceFramework("http:C", 32) { static const FunctionInfo functions[] = { // clang-format off @@ -2245,6 +2388,9 @@ HTTP_C::HTTP_C() : ServiceFramework("http:C", 32) { {0x0037, &HTTP_C::SetKeepAlive, "SetKeepAlive"}, {0x0038, &HTTP_C::SetPostDataTypeSize, "SetPostDataTypeSize"}, {0x0039, &HTTP_C::Finalize, "Finalize"}, + // Custom + {0x0C00, &HTTP_C::RegisterURLReplacement, "RegisterURLReplacement"}, + {0x0C01, &HTTP_C::UnregisterURLReplacement, "UnregisterURLReplacement"}, // clang-format on }; RegisterHandlers(functions); diff --git a/src/core/hle/service/http/http_c.h b/src/core/hle/service/http/http_c.h index 972a1b22f..52eccb8be 100644 --- a/src/core/hle/service/http/http_c.h +++ b/src/core/hle/service/http/http_c.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -162,6 +163,28 @@ struct ClCertAData { bool init = false; }; +class URLReplacer { +private: + struct Rule { + boost::regex regex; + + std::string pattern; + std::string replacement; + }; + + std::vector rules; + +public: + URLReplacer(); + + bool HasRule(const std::string& pattern); + bool AddRule(const std::string& pattern, const std::string& replacement); + bool DeleteRule(const std::string& pattern); + std::string Apply(const std::string& url) const; + + bool Save(); +}; + /// Represents an HTTP context. class Context final { public: @@ -276,6 +299,7 @@ public: u32 socket_buffer_size; std::vector headers; const ClCertAData* clcert_data; + const URLReplacer* url_replacer; bool post_data_added = false; bool post_pending_request = false; Params post_data; @@ -866,6 +890,10 @@ private: */ void Finalize(Kernel::HLERequestContext& ctx); + void RegisterURLReplacement(Kernel::HLERequestContext& ctx); + + void UnregisterURLReplacement(Kernel::HLERequestContext& ctx); + [[nodiscard]] SessionData* EnsureSessionInitialized(Kernel::HLERequestContext& ctx, IPC::RequestParser rp); @@ -900,6 +928,8 @@ private: ClCertAData ClCertA; + URLReplacer url_replacer; + private: template void serialize(Archive& ar, const unsigned int) {