diff --git a/lib/omniauth/strategies/oauth2.rb b/lib/omniauth/strategies/oauth2.rb index 3ffff1b..11f353f 100644 --- a/lib/omniauth/strategies/oauth2.rb +++ b/lib/omniauth/strategies/oauth2.rb @@ -55,7 +55,11 @@ def authorize_params @env ||= {} @env["rack.session"] ||= {} end - session["omniauth.state"] = params[:state] + session["omniauth.states"] ||= [] + session["omniauth.states"] << params[:state] + + session["omniauth.state_origins"] ||= {} + session["omniauth.state_origins"][params[:state]] = session['omniauth.origin'] params end @@ -67,11 +71,15 @@ def callback_phase # rubocop:disable AbcSize, CyclomaticComplexity, MethodLength error = request.params["error_reason"] || request.params["error"] if error fail!(error, CallbackError.new(request.params["error"], request.params["error_description"] || request.params["error_reason"], request.params["error_uri"])) - elsif !options.provider_ignores_state && (request.params["state"].to_s.empty? || request.params["state"] != session.delete("omniauth.state")) + elsif !options.provider_ignores_state && (request.params["state"].to_s.empty? || session["omniauth.states"].empty? || !session["omniauth.states"].delete(request.params["state"])) fail!(:csrf_detected, CallbackError.new(:csrf_detected, "CSRF detected")) else self.access_token = build_access_token self.access_token = access_token.refresh! if access_token.expired? + + if session['omniauth.state_origins'] && session['omniauth.state_origins'][request.params['state']] + env['omniauth.origin'] = session['omniauth.state_origins'].delete(request.params['state']) + end super end rescue ::OAuth2::Error, CallbackError => e