-
Notifications
You must be signed in to change notification settings - Fork 809
/
Copy pathazure_event_hubs_entra_test.go
executable file
·146 lines (109 loc) · 3.62 KB
/
azure_event_hubs_entra_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package azure_event_hubs_entra
import (
"context"
"errors"
"fmt"
"strings"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/segmentio/kafka-go/sasl"
)
type MockTokenCredential struct {
getTokenFunc func() (string, error)
}
func (c *MockTokenCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
if len(options.Scopes) != 1 {
return azcore.AccessToken{}, fmt.Errorf("Scopes must contain 1 element! Contains %d elements.", len(options.Scopes))
}
scope := options.Scopes[0]
if !strings.HasPrefix(scope, "https://") {
return azcore.AccessToken{}, fmt.Errorf("Scope must start with https, and it did not.")
}
if !strings.HasSuffix(scope, "/.default") {
return azcore.AccessToken{}, fmt.Errorf("Scope must end with /.default, and it did not.")
}
if options.EnableCAE {
return azcore.AccessToken{}, fmt.Errorf("CAE must be false. It was true.")
}
token, err := c.getTokenFunc()
if err != nil {
return azcore.AccessToken{}, err
}
return azcore.AccessToken{Token: token}, nil
}
func TestName(t *testing.T) {
mechanism := NewMechanism(&MockTokenCredential{
getTokenFunc: func() (string, error) { return "testtoken", nil },
})
expected := "OAUTHBEARER"
actual := mechanism.Name()
if expected != actual {
t.Fatalf("Expected: %s - Actual: %s", expected, actual)
}
}
func TestStart(t *testing.T) {
mechanism := NewMechanism(&MockTokenCredential{
getTokenFunc: func() (string, error) { return "testtoken", nil },
})
ctx := sasl.WithMetadata(context.Background(), &sasl.Metadata{
Host: "test.servicebus.windows.net",
Port: 9093,
})
stateMachine, saslBytes, err := mechanism.Start(ctx)
if stateMachine == nil {
t.Fatalf("Expected stateMachine to be non-nil")
}
expectedSaslData := "n,,\x01auth=Bearer testtoken\x01\x01"
if string(saslBytes) != expectedSaslData {
t.Fatalf("expected saslData to be %s. Received %s.", expectedSaslData, string(saslBytes))
}
if err != nil {
t.Fatalf("expected err to be nil")
}
}
func TestStartNoMetadata(t *testing.T) {
mechanism := NewMechanism(&MockTokenCredential{
getTokenFunc: func() (string, error) { return "testtoken", nil },
})
ctx := context.Background()
stateMachine, saslBytes, err := mechanism.Start(ctx)
assertStartError(stateMachine, t, saslBytes, err, "missing sasl metadata")
}
func TestStartTokenError(t *testing.T) {
mechanism := NewMechanism(&MockTokenCredential{
getTokenFunc: func() (string, error) { return "", errors.New("Failed to acquire token") },
})
ctx := sasl.WithMetadata(context.Background(), &sasl.Metadata{
Host: "test.servicebus.windows.net",
Port: 9093,
})
stateMachine, saslBytes, err := mechanism.Start(ctx)
assertStartError(stateMachine, t, saslBytes, err, "failed to request an Azure Entra Token: Failed to acquire token")
}
func assertStartError(stateMachine sasl.StateMachine, t *testing.T, saslBytes []byte, err error, expectedError string) {
if stateMachine != nil {
t.Fatalf("Expected stateMachine to be nil")
}
if saslBytes != nil {
t.Fatalf("Expected saslBytes to be nil")
}
if err.Error() != expectedError {
t.Fatalf("expected err to be %s. was %s", expectedError, err.Error())
}
}
func TestNext(t *testing.T) {
mechanism := NewMechanism(&MockTokenCredential{
getTokenFunc: func() (string, error) { return "testtoken", nil },
})
done, response, err := mechanism.Next(context.Background(), []byte("challenge"))
if !done {
t.Fatalf("Expected done to be true")
}
if response != nil {
t.Fatalf("Expected nil response")
}
if err != nil {
t.Fatalf("Expected nil error")
}
}