#include "auth_bearer.hpp"
#include "user_info_cache.hpp"
#include <algorithm>
namespace samples::pg {
public:
AuthCheckerBearer(const AuthCache& auth_cache,
std::vector<server::auth::UserScope> required_scopes)
: auth_cache_(auth_cache), required_scopes_(std::move(required_scopes)) {}
[[nodiscard]] AuthCheckResult CheckAuth(
[[nodiscard]] bool SupportsUserAuth() const noexcept override { return true; }
private:
const AuthCache& auth_cache_;
const std::vector<server::auth::UserScope> required_scopes_;
};
AuthCheckerBearer::AuthCheckResult AuthCheckerBearer::CheckAuth(
const auto& auth_value = request.
GetHeader(http::headers::kAuthorization);
if (auth_value.empty()) {
return AuthCheckResult{AuthCheckResult::Status::kTokenNotFound,
{},
"Empty 'Authorization' header",
}
const auto bearer_sep_pos = auth_value.find(' ');
if (bearer_sep_pos == std::string::npos ||
std::string_view{auth_value.data(), bearer_sep_pos} != "Bearer") {
return AuthCheckResult{
AuthCheckResult::Status::kTokenNotFound,
{},
"'Authorization' header should have 'Bearer some-token' format",
}
bearer_sep_pos + 1};
const auto cache_snapshot = auth_cache_.Get();
auto it = cache_snapshot->find(token);
if (it == cache_snapshot->end()) {
return AuthCheckResult{AuthCheckResult::Status::kForbidden};
}
const UserDbInfo& info = it->second;
for (const auto& scope : required_scopes_) {
if (std::find(info.scopes.begin(), info.scopes.end(), scope.GetValue()) ==
info.scopes.end()) {
return AuthCheckResult{AuthCheckResult::Status::kForbidden,
{},
"No '" + scope.GetValue() + "' permission"};
}
}
request_context.
SetData(
"name", info.name);
return {};
}
server::handlers::auth::AuthCheckerBasePtr CheckerFactory::operator()(
const ::components::ComponentContext& context,
auto scopes = auth_config[
"scopes"].
As<server::auth::UserScopes>({});
const auto& auth_cache = context.FindComponent<AuthCache>();
return std::make_shared<AuthCheckerBearer>(auth_cache, std::move(scopes));
}
}