diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java
index dda603995a16..9bd7b6467210 100644
--- a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java
+++ b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java
@@ -121,7 +121,7 @@ public boolean canRead(Class> clazz, @Nullable MediaType mediaType) {
@Override
public boolean canWrite(Class> clazz, @Nullable MediaType mediaType) {
- return (AnnotationUtils.findAnnotation(clazz, XmlRootElement.class) != null && canWrite(mediaType));
+ return ((JAXBElement.class.isAssignableFrom(clazz) || AnnotationUtils.findAnnotation(clazz, XmlRootElement.class) != null) && canWrite(mediaType));
}
@Override
@@ -192,7 +192,7 @@ protected Source processSource(Source source) {
@Override
protected void writeToResult(Object o, HttpHeaders headers, Result result) throws Exception {
try {
- Class> clazz = ClassUtils.getUserClass(o);
+ Class> clazz = getMarshallerType(o);
Marshaller marshaller = createMarshaller(clazz);
setCharset(headers.getContentType(), marshaller);
marshaller.marshal(o, result);
@@ -205,6 +205,15 @@ protected void writeToResult(Object o, HttpHeaders headers, Result result) throw
}
}
+ private static Class> getMarshallerType(Object o) {
+ if (o instanceof JAXBElement> jaxbElement) {
+ return jaxbElement.getDeclaredType();
+ }
+ else {
+ return ClassUtils.getUserClass(o);
+ }
+ }
+
private void setCharset(@Nullable MediaType contentType, Marshaller marshaller) throws PropertyException {
if (contentType != null && contentType.getCharset() != null) {
marshaller.setProperty(Marshaller.JAXB_ENCODING, contentType.getCharset().name());
diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java
index a54aeb779464..4c5c1e6cf870 100644
--- a/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java
+++ b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java
@@ -18,6 +18,9 @@
import java.nio.charset.StandardCharsets;
+import javax.xml.namespace.QName;
+
+import jakarta.xml.bind.JAXBElement;
import jakarta.xml.bind.Marshaller;
import jakarta.xml.bind.Unmarshaller;
import jakarta.xml.bind.annotation.XmlAttribute;
@@ -93,6 +96,8 @@ void canWrite() {
.as("Converter does not support writing @XmlRootElement subclass").isTrue();
assertThat(converter.canWrite(rootElementCglib.getClass(), null))
.as("Converter does not support writing @XmlRootElement subclass").isTrue();
+ assertThat(converter.canWrite(JAXBElement.class, null))
+ .as("Converter does not support writing JAXBElement").isTrue();
assertThat(converter.canWrite(Type.class, null))
.as("Converter supports writing @XmlType").isFalse();
}
@@ -186,6 +191,18 @@ void writeXmlRootElement() throws Exception {
.isSimilarTo("", ev);
}
+ @Test
+ void writeJaxbElementRootElement() throws Exception {
+ MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();
+ JAXBElement jaxbElement = new JAXBElement<>(new QName("custom"), MyCustomElement.class, new MyCustomElement("field1", "field2"));
+ converter.write(jaxbElement, null, outputMessage);
+ assertThat(outputMessage.getHeaders().getContentType())
+ .as("Invalid content-type").isEqualTo(MediaType.APPLICATION_XML);
+ DifferenceEvaluator ev = chain(Default, downgradeDifferencesToEqual(XML_STANDALONE));
+ assertThat(XmlContent.of(outputMessage.getBodyAsString(StandardCharsets.UTF_8)))
+ .isSimilarTo("field1field2", ev);
+ }
+
@Test
void writeXmlRootElementSubclass() throws Exception {
MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();