Skip to content

Commit 9eefa13

Browse files
committed
try fix data race in Connect(), Close() functions sacOO7#2
1 parent 3f2d5cd commit 9eefa13

File tree

1 file changed

+54
-34
lines changed

1 file changed

+54
-34
lines changed

gowebsocket.go

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ import (
1414
)
1515

1616
// Empty struct for logger initialization
17-
type Empty struct {
18-
}
17+
type Empty struct{}
1918

2019
// Initialize logger
2120
var logger = logging.GetLogger(reflect.TypeOf(Empty{}).PkgPath()).SetLevel(logging.OFF)
@@ -46,9 +45,11 @@ type Socket struct {
4645
OnPongReceived func(data string, socket Socket)
4746
IsConnected bool
4847
Timeout time.Duration
49-
sendMu *sync.Mutex // Mutex to prevent concurrent writes
50-
receiveMu *sync.Mutex // Mutex to prevent concurrent reads
51-
connStateMu sync.Mutex // Mutex to protect connection state
48+
sendMu sync.Mutex // Mutex to prevent concurrent writes
49+
receiveMu sync.Mutex // Mutex to prevent concurrent reads
50+
connStateMu sync.Mutex // Mutex to protect connection state
51+
closeChan chan struct{} // Channel to signal closing
52+
closeWg sync.WaitGroup // WaitGroup to wait for goroutines
5253
}
5354

5455
// Connection options structure
@@ -61,6 +62,7 @@ type ConnectionOptions struct {
6162

6263
// Reconnection options (to be implemented)
6364
type ReconnectionOptions struct {
65+
// Fields for reconnection options
6466
}
6567

6668
// Create a new Socket instance
@@ -74,8 +76,8 @@ func New(url string) Socket {
7476
},
7577
WebsocketDialer: &websocket.Dialer{},
7678
Timeout: 0,
77-
sendMu: &sync.Mutex{},
78-
receiveMu: &sync.Mutex{},
79+
closeChan: make(chan struct{}),
80+
// Other fields are zero-initialized
7981
}
8082
}
8183

@@ -93,6 +95,7 @@ func (socket *Socket) Connect() {
9395
var resp *http.Response
9496
socket.setConnectionOptions()
9597

98+
// Dial the websocket connection
9699
socket.Conn, resp, err = socket.WebsocketDialer.Dial(socket.Url, socket.RequestHeader)
97100

98101
if err != nil {
@@ -156,37 +159,48 @@ func (socket *Socket) Connect() {
156159
return result
157160
})
158161

162+
// Initialize close channel and WaitGroup
163+
socket.closeChan = make(chan struct{})
164+
socket.closeWg.Add(1)
165+
159166
// Start reading messages
160167
go func() {
168+
defer socket.closeWg.Done()
161169
for {
162-
socket.receiveMu.Lock()
163-
if socket.Timeout != 0 {
164-
socket.Conn.SetReadDeadline(time.Now().Add(socket.Timeout))
165-
}
166-
messageType, message, err := socket.Conn.ReadMessage()
167-
socket.receiveMu.Unlock()
168-
if err != nil {
169-
logger.Error.Println("read:", err)
170-
socket.connStateMu.Lock()
171-
socket.IsConnected = false
172-
onDisconnected := socket.OnDisconnected
173-
socket.connStateMu.Unlock()
174-
175-
if onDisconnected != nil {
176-
onDisconnected(err, *socket)
177-
}
170+
select {
171+
case <-socket.closeChan:
172+
// Received close signal, exiting goroutine
178173
return
179-
}
180-
logger.Info.Printf("recv: %s", message)
181-
182-
switch messageType {
183-
case websocket.TextMessage:
184-
if socket.OnTextMessage != nil {
185-
socket.OnTextMessage(string(message), *socket)
174+
default:
175+
socket.receiveMu.Lock()
176+
if socket.Timeout != 0 {
177+
socket.Conn.SetReadDeadline(time.Now().Add(socket.Timeout))
178+
}
179+
messageType, message, err := socket.Conn.ReadMessage()
180+
socket.receiveMu.Unlock()
181+
if err != nil {
182+
logger.Error.Println("read:", err)
183+
socket.connStateMu.Lock()
184+
socket.IsConnected = false
185+
onDisconnected := socket.OnDisconnected
186+
socket.connStateMu.Unlock()
187+
188+
if onDisconnected != nil {
189+
onDisconnected(err, *socket)
190+
}
191+
return
186192
}
187-
case websocket.BinaryMessage:
188-
if socket.OnBinaryMessage != nil {
189-
socket.OnBinaryMessage(message, *socket)
193+
logger.Info.Printf("recv: %s", message)
194+
195+
switch messageType {
196+
case websocket.TextMessage:
197+
if socket.OnTextMessage != nil {
198+
socket.OnTextMessage(string(message), *socket)
199+
}
200+
case websocket.BinaryMessage:
201+
if socket.OnBinaryMessage != nil {
202+
socket.OnBinaryMessage(message, *socket)
203+
}
190204
}
191205
}
192206
}
@@ -227,9 +241,15 @@ func (socket *Socket) Close() {
227241
logger.Error.Println("write close:", err)
228242
}
229243

230-
// Close the connection
244+
// Close the websocket connection
231245
socket.Conn.Close()
232246

247+
// Signal the goroutine to exit
248+
close(socket.closeChan)
249+
250+
// Wait for the goroutine to finish
251+
socket.closeWg.Wait()
252+
233253
// Protect access to IsConnected and OnDisconnected
234254
socket.connStateMu.Lock()
235255
socket.IsConnected = false

0 commit comments

Comments
 (0)