diff --git a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java index dda603995a16..9bd7b6467210 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverter.java @@ -121,7 +121,7 @@ public boolean canRead(Class clazz, @Nullable MediaType mediaType) { @Override public boolean canWrite(Class clazz, @Nullable MediaType mediaType) { - return (AnnotationUtils.findAnnotation(clazz, XmlRootElement.class) != null && canWrite(mediaType)); + return ((JAXBElement.class.isAssignableFrom(clazz) || AnnotationUtils.findAnnotation(clazz, XmlRootElement.class) != null) && canWrite(mediaType)); } @Override @@ -192,7 +192,7 @@ protected Source processSource(Source source) { @Override protected void writeToResult(Object o, HttpHeaders headers, Result result) throws Exception { try { - Class clazz = ClassUtils.getUserClass(o); + Class clazz = getMarshallerType(o); Marshaller marshaller = createMarshaller(clazz); setCharset(headers.getContentType(), marshaller); marshaller.marshal(o, result); @@ -205,6 +205,15 @@ protected void writeToResult(Object o, HttpHeaders headers, Result result) throw } } + private static Class getMarshallerType(Object o) { + if (o instanceof JAXBElement jaxbElement) { + return jaxbElement.getDeclaredType(); + } + else { + return ClassUtils.getUserClass(o); + } + } + private void setCharset(@Nullable MediaType contentType, Marshaller marshaller) throws PropertyException { if (contentType != null && contentType.getCharset() != null) { marshaller.setProperty(Marshaller.JAXB_ENCODING, contentType.getCharset().name()); diff --git a/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java index a54aeb779464..4c5c1e6cf870 100644 --- a/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java +++ b/spring-web/src/test/java/org/springframework/http/converter/xml/Jaxb2RootElementHttpMessageConverterTests.java @@ -18,6 +18,9 @@ import java.nio.charset.StandardCharsets; +import javax.xml.namespace.QName; + +import jakarta.xml.bind.JAXBElement; import jakarta.xml.bind.Marshaller; import jakarta.xml.bind.Unmarshaller; import jakarta.xml.bind.annotation.XmlAttribute; @@ -93,6 +96,8 @@ void canWrite() { .as("Converter does not support writing @XmlRootElement subclass").isTrue(); assertThat(converter.canWrite(rootElementCglib.getClass(), null)) .as("Converter does not support writing @XmlRootElement subclass").isTrue(); + assertThat(converter.canWrite(JAXBElement.class, null)) + .as("Converter does not support writing JAXBElement").isTrue(); assertThat(converter.canWrite(Type.class, null)) .as("Converter supports writing @XmlType").isFalse(); } @@ -186,6 +191,18 @@ void writeXmlRootElement() throws Exception { .isSimilarTo("", ev); } + @Test + void writeJaxbElementRootElement() throws Exception { + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + JAXBElement jaxbElement = new JAXBElement<>(new QName("custom"), MyCustomElement.class, new MyCustomElement("field1", "field2")); + converter.write(jaxbElement, null, outputMessage); + assertThat(outputMessage.getHeaders().getContentType()) + .as("Invalid content-type").isEqualTo(MediaType.APPLICATION_XML); + DifferenceEvaluator ev = chain(Default, downgradeDifferencesToEqual(XML_STANDALONE)); + assertThat(XmlContent.of(outputMessage.getBodyAsString(StandardCharsets.UTF_8))) + .isSimilarTo("field1field2", ev); + } + @Test void writeXmlRootElementSubclass() throws Exception { MockHttpOutputMessage outputMessage = new MockHttpOutputMessage();