diff --git a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java index 97f9a45012..1b3d621469 100644 --- a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java +++ b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/JsrWebSocketMessages.java @@ -86,7 +86,7 @@ public interface JsrWebSocketMessages { @Message(id = 3015, value = "No decoder accepted message %s") String noDecoderAcceptedMessage(List decoders); - @Message(id = 3016, value = "Cannot send in middle of fragmeneted message") + @Message(id = 3016, value = "Cannot send in middle of fragmented message") IllegalStateException cannotSendInMiddleOfFragmentedMessage(); @Message(id = 3017, value = "Cannot add endpoint after deployment") diff --git a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/WebSocketSessionRemoteEndpoint.java b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/WebSocketSessionRemoteEndpoint.java index ff628e93f6..5015958cba 100644 --- a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/WebSocketSessionRemoteEndpoint.java +++ b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/WebSocketSessionRemoteEndpoint.java @@ -17,24 +17,24 @@ */ package io.undertow.websockets.jsr; +import java.io.IOException; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.Future; + import io.undertow.websockets.core.BinaryOutputStream; import io.undertow.websockets.core.StreamSinkFrameChannel; import io.undertow.websockets.core.WebSocketCallback; import io.undertow.websockets.core.WebSocketFrameType; import io.undertow.websockets.core.WebSocketUtils; import io.undertow.websockets.core.WebSockets; -import org.xnio.channels.Channels; - import jakarta.websocket.EncodeException; import jakarta.websocket.RemoteEndpoint; import jakarta.websocket.SendHandler; -import java.io.IOException; -import java.io.OutputStream; -import java.io.OutputStreamWriter; -import java.io.Writer; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.concurrent.Future; +import org.xnio.channels.Channels; /** * {@link RemoteEndpoint} implementation which uses a WebSocketSession for all its operation. @@ -243,7 +243,7 @@ class BasicWebSocketSessionRemoteEndpoint implements Basic { private StreamSinkFrameChannel binaryFrameSender; private StreamSinkFrameChannel textFrameSender; - public void assertNotInFragment() { + public synchronized void assertNotInFragment() { if (textFrameSender != null || binaryFrameSender != null) { throw JsrWebSocketMessages.MESSAGES.cannotSendInMiddleOfFragmentedMessage(); } @@ -268,57 +268,78 @@ public void sendBinary(final ByteBuffer data) throws IOException { data.clear(); //for some reason the TCK expects this, might as well just match the RI behaviour } - @Override + @Override public void sendText(final String partialMessage, final boolean isLast) throws IOException { - if(partialMessage == null) { + if (partialMessage == null) { throw JsrWebSocketMessages.MESSAGES.messageInNull(); } - if (binaryFrameSender != null) { - throw JsrWebSocketMessages.MESSAGES.cannotSendInMiddleOfFragmentedMessage(); - } - if (textFrameSender == null) { - textFrameSender = undertowSession.getWebSocketChannel().send(WebSocketFrameType.TEXT); - } + + StreamSinkFrameChannel sender = getTextFrameSender(); + try { - Channels.writeBlocking(textFrameSender, WebSocketUtils.fromUtf8String(partialMessage)); - if(isLast) { - textFrameSender.shutdownWrites(); + Channels.writeBlocking(sender, WebSocketUtils.fromUtf8String(partialMessage)); + if (isLast) { + sender.shutdownWrites(); } - Channels.flushBlocking(textFrameSender); + Channels.flushBlocking(sender); } finally { if (isLast) { - textFrameSender = null; + clearTextFrameSender(); } } - } @Override public void sendBinary(final ByteBuffer partialByte, final boolean isLast) throws IOException { - if(partialByte == null) { throw JsrWebSocketMessages.MESSAGES.messageInNull(); } - if (textFrameSender != null) { - throw JsrWebSocketMessages.MESSAGES.cannotSendInMiddleOfFragmentedMessage(); - } - if (binaryFrameSender == null) { - binaryFrameSender = undertowSession.getWebSocketChannel().send(WebSocketFrameType.BINARY); - } + + StreamSinkFrameChannel sender = getBinaryFrameSender(); + try { - Channels.writeBlocking(binaryFrameSender, partialByte); - if(isLast) { - binaryFrameSender.shutdownWrites(); + Channels.writeBlocking(sender, partialByte); + if (isLast) { + sender.shutdownWrites(); } - Channels.flushBlocking(binaryFrameSender); - } finally { + Channels.flushBlocking(sender); + } + finally { if (isLast) { - binaryFrameSender = null; + clearBinaryFrameSender(); } } partialByte.clear(); } + private synchronized StreamSinkFrameChannel getTextFrameSender() throws IOException { + if (binaryFrameSender != null) { + throw JsrWebSocketMessages.MESSAGES.cannotSendInMiddleOfFragmentedMessage(); + } + if (textFrameSender == null) { + textFrameSender = undertowSession.getWebSocketChannel().send(WebSocketFrameType.TEXT); + } + return textFrameSender; + } + + private synchronized void clearTextFrameSender() { + textFrameSender = null; + } + + private synchronized StreamSinkFrameChannel getBinaryFrameSender() throws IOException { + if (textFrameSender != null) { + throw JsrWebSocketMessages.MESSAGES.cannotSendInMiddleOfFragmentedMessage(); + } + if (binaryFrameSender == null) { + binaryFrameSender = undertowSession.getWebSocketChannel().send(WebSocketFrameType.BINARY); + } + return binaryFrameSender; + } + + private synchronized void clearBinaryFrameSender() { + binaryFrameSender = null; + } + @Override public OutputStream getSendStream() throws IOException { assertNotInFragment();