Skip to content

Commit b141071

Browse files
fixes session refresh for JWT backed sessions
1 parent 3ea434b commit b141071

File tree

3 files changed

+84
-16
lines changed

3 files changed

+84
-16
lines changed

edge-apis/oidc.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/rand"
66
"crypto/tls"
77
"fmt"
8+
"github.com/golang-jwt/jwt/v5"
89
"github.com/google/uuid"
910
"github.com/michaelquigley/pfxlog"
1011
"github.com/zitadel/oidc/v2/pkg/client/rp"
@@ -16,6 +17,16 @@ import (
1617
"time"
1718
)
1819

20+
const JwtTokenPrefix = "ey"
21+
22+
type ServiceAccessClaims struct {
23+
jwt.RegisteredClaims
24+
ApiSessionId string `json:"z_asid"`
25+
IdentityId string `json:"z_iid"`
26+
TokenType string `json:"z_t"`
27+
Type string `json:"z_st"`
28+
}
29+
1930
type localRpServer struct {
2031
Server *http.Server
2132
Port string

ziti/client.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"crypto/x509/pkix"
2626
"encoding/pem"
2727
"fmt"
28+
"github.com/golang-jwt/jwt/v5"
2829
"github.com/openziti/foundation/v2/genext"
2930
"github.com/openziti/transport/v2"
3031
"github.com/pkg/errors"
@@ -174,7 +175,7 @@ func (self *CtrlClient) GetCurrentIdentity() (*rest_model.IdentityDetail, error)
174175
return resp.Payload.Data, nil
175176
}
176177

177-
// GetSession returns the full rest_model.SessionDetail for a specific id
178+
// GetSession returns the full rest_model.SessionDetail for a specific id. Does not function with JWT backed sessions.
178179
func (self *CtrlClient) GetSession(id string) (*rest_model.SessionDetail, error) {
179180
params := session.NewDetailSessionParams()
180181
params.ID = id
@@ -188,6 +189,50 @@ func (self *CtrlClient) GetSession(id string) (*rest_model.SessionDetail, error)
188189
return resp.Payload.Data, nil
189190
}
190191

192+
func (self *CtrlClient) GetSessionFromJwt(sessionToken string) (*rest_model.SessionDetail, error) {
193+
parser := jwt.NewParser()
194+
serviceAccessClaims := &apis.ServiceAccessClaims{}
195+
196+
_, _, err := parser.ParseUnverified(sessionToken, serviceAccessClaims)
197+
198+
if err != nil {
199+
return nil, err
200+
}
201+
202+
params := service.NewListServiceEdgeRoutersParams()
203+
params.SessionToken = &sessionToken
204+
params.ID = serviceAccessClaims.Subject //service id
205+
206+
resp, err := self.API.Service.ListServiceEdgeRouters(params, self.ApiSession)
207+
208+
if err != nil {
209+
return nil, rest_util.WrapErr(err)
210+
}
211+
createdAt := strfmt.DateTime(serviceAccessClaims.IssuedAt.Time)
212+
sessionType := rest_model.DialBind(serviceAccessClaims.Type)
213+
214+
sessionDetail := &rest_model.SessionDetail{
215+
BaseEntity: rest_model.BaseEntity{
216+
Links: nil,
217+
CreatedAt: &createdAt,
218+
ID: &serviceAccessClaims.ID,
219+
},
220+
APISessionID: &serviceAccessClaims.ApiSessionId,
221+
IdentityID: &serviceAccessClaims.IdentityId,
222+
ServiceID: &serviceAccessClaims.Subject,
223+
Token: &sessionToken,
224+
Type: &sessionType,
225+
}
226+
227+
for _, er := range resp.Payload.Data.EdgeRouters {
228+
sessionDetail.EdgeRouters = append(sessionDetail.EdgeRouters, &rest_model.SessionEdgeRouter{
229+
CommonEdgeRouterProperties: *er,
230+
})
231+
}
232+
233+
return sessionDetail, nil
234+
}
235+
191236
// GetIdentity returns the identity.Identity used to facilitate authentication. Each identity.Identity instance
192237
// may provide authentication material in the form of x509 certificates and private keys and/or trusted CA pools.
193238
func (self *CtrlClient) GetIdentity() (identity.Identity, error) {

ziti/ziti.go

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"net"
3030
"reflect"
3131
"strconv"
32+
"strings"
3233
"sync"
3334
"sync/atomic"
3435
"time"
@@ -555,7 +556,7 @@ func (context *ContextImpl) refreshSessions() {
555556
session := entry.Val
556557
log.Debugf("refreshing session for %s", key)
557558

558-
if s, err := context.refreshSession(*session.ID); err != nil {
559+
if s, err := context.refreshSession(session); err != nil {
559560
log.WithError(err).Errorf("failed to refresh session for %s", key)
560561
toDelete = append(toDelete, *session.ID)
561562
} else {
@@ -881,7 +882,7 @@ func (context *ContextImpl) DialWithOptions(serviceName string, options *DialOpt
881882
}
882883

883884
var refreshErr error
884-
if _, refreshErr = context.refreshSession(*session.ID); refreshErr == nil {
885+
if _, refreshErr = context.refreshSession(session); refreshErr == nil {
885886
// if the session wasn't expired, no reason to try again, return the failure
886887
return nil, errors.Wrapf(err, "unable to dial service '%s'", serviceName)
887888
}
@@ -1034,7 +1035,7 @@ func (context *ContextImpl) listenSession(service *rest_model.ServiceDetail, opt
10341035
func (context *ContextImpl) getEdgeRouterConn(session *rest_model.SessionDetail, options edge.ConnOptions) (edge.RouterConn, error) {
10351036
logger := pfxlog.Logger().WithField("sessionId", *session.ID)
10361037

1037-
if refreshedSession, err := context.refreshSession(*session.ID); err != nil {
1038+
if refreshedSession, err := context.refreshSession(session); err != nil {
10381039
target := &rest_session.DetailSessionNotFound{}
10391040
if errors.As(err, &target) {
10401041
sessionKey := fmt.Sprintf("%s:%s", session.Service.ID, *session.Type)
@@ -1055,9 +1056,11 @@ func (context *ContextImpl) getEdgeRouterConn(session *rest_model.SessionDetail,
10551056
var bestER edge.RouterConn
10561057
var unconnected []*rest_model.SessionEdgeRouter
10571058
for _, edgeRouter := range session.EdgeRouters {
1058-
for _, routerUrl := range edgeRouter.Urls {
1059-
if er, found := context.routerConnections.Get(routerUrl); found {
1060-
h := context.metrics.Histogram("latency." + routerUrl).(metrics2.Histogram)
1059+
for proto, addr := range edgeRouter.SupportedProtocols {
1060+
addr = strings.Replace(addr, "://", ":", 1)
1061+
edgeRouter.SupportedProtocols[proto] = addr
1062+
if er, found := context.routerConnections.Get(addr); found {
1063+
h := context.metrics.Histogram("latency." + addr).(metrics2.Histogram)
10611064
if h.Mean() < float64(bestLatency) {
10621065
bestLatency = time.Duration(int64(h.Mean()))
10631066
bestER = er
@@ -1074,9 +1077,9 @@ func (context *ContextImpl) getEdgeRouterConn(session *rest_model.SessionDetail,
10741077
}
10751078

10761079
for _, edgeRouter := range unconnected {
1077-
for _, routerUrl := range edgeRouter.Urls {
1078-
if context.options.isEdgeRouterUrlAccepted(routerUrl) {
1079-
go context.connectEdgeRouter(*edgeRouter.Name, routerUrl, ch)
1080+
for _, addr := range edgeRouter.SupportedProtocols {
1081+
if context.options.isEdgeRouterUrlAccepted(addr) {
1082+
go context.connectEdgeRouter(*edgeRouter.Name, addr, ch)
10801083
}
10811084
}
10821085
}
@@ -1373,13 +1376,21 @@ func (context *ContextImpl) createSession(service *rest_model.ServiceDetail, ses
13731376
return session, nil
13741377
}
13751378

1376-
func (context *ContextImpl) refreshSession(id string) (*rest_model.SessionDetail, error) {
1377-
session, err := context.CtrlClt.GetSession(id)
1379+
func (context *ContextImpl) refreshSession(session *rest_model.SessionDetail) (*rest_model.SessionDetail, error) {
1380+
var refreshedSession *rest_model.SessionDetail
1381+
var err error
1382+
if strings.HasPrefix(*session.Token, apis.JwtTokenPrefix) {
1383+
refreshedSession, err = context.CtrlClt.GetSessionFromJwt(*session.Token)
1384+
} else {
1385+
refreshedSession, err = context.CtrlClt.GetSession(*session.ID)
1386+
}
1387+
13781388
if err != nil {
13791389
return nil, err
13801390
}
1381-
context.cacheSession("refresh", session)
1382-
return session, nil
1391+
1392+
context.cacheSession("refresh", refreshedSession)
1393+
return refreshedSession, nil
13831394
}
13841395

13851396
func (context *ContextImpl) cacheSession(op string, session *rest_model.SessionDetail) {
@@ -1609,7 +1620,8 @@ func (mgr *listenerManager) refreshSession() {
16091620
return
16101621
}
16111622

1612-
session, err := mgr.context.refreshSession(*mgr.session.ID)
1623+
session, err := mgr.context.refreshSession(mgr.session)
1624+
16131625
if err != nil {
16141626
var target error = &rest_session.DetailSessionNotFound{}
16151627
if errors.As(err, &target) {
@@ -1630,7 +1642,7 @@ func (mgr *listenerManager) refreshSession() {
16301642
}
16311643
}
16321644

1633-
session, err = mgr.context.refreshSession(*mgr.session.ID)
1645+
session, err = mgr.context.refreshSession(mgr.session)
16341646
if err != nil {
16351647
target = &rest_session.DetailSessionUnauthorized{}
16361648
if errors.As(err, &target) {

0 commit comments

Comments
 (0)