Skip to content

Commit 1b30aa8

Browse files
committed
Provide first-class support for @⁠ContextHierarchy with Bean Overrides
Closes spring-projectsgh-34597
1 parent f68fb97 commit 1b30aa8

22 files changed

+972
-37
lines changed

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideContextCustomizerFactory.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,21 @@ class BeanOverrideContextCustomizerFactory implements ContextCustomizerFactory {
4242
public BeanOverrideContextCustomizer createContextCustomizer(Class<?> testClass,
4343
List<ContextConfigurationAttributes> configAttributes) {
4444

45+
String contextName = configAttributes.get(0).getName();
4546
Set<BeanOverrideHandler> handlers = new LinkedHashSet<>();
46-
findBeanOverrideHandlers(testClass, handlers);
47+
findBeanOverrideHandlers(testClass, contextName, handlers);
4748
if (handlers.isEmpty()) {
4849
return null;
4950
}
5051
return new BeanOverrideContextCustomizer(handlers);
5152
}
5253

53-
private void findBeanOverrideHandlers(Class<?> testClass, Set<BeanOverrideHandler> handlers) {
54-
BeanOverrideHandler.findAllHandlers(testClass).forEach(handler ->
55-
Assert.state(handlers.add(handler), () ->
56-
"Duplicate BeanOverrideHandler discovered in test class %s: %s"
57-
.formatted(testClass.getName(), handler)));
54+
private void findBeanOverrideHandlers(Class<?> testClass, @Nullable String contextName, Set<BeanOverrideHandler> handlers) {
55+
BeanOverrideHandler.findAllHandlers(testClass).stream()
56+
.filter(handler -> handler.getContextName().isEmpty() || handler.getContextName().equals(contextName))
57+
.forEach(handler -> Assert.state(handlers.add(handler),
58+
() -> "Duplicate BeanOverrideHandler discovered in test class %s: %s"
59+
.formatted(testClass.getName(), handler)));
5860
}
5961

6062
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideHandler.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,26 @@ public abstract class BeanOverrideHandler {
8787
@Nullable
8888
private final String beanName;
8989

90+
private final String contextName;
91+
9092
private final BeanOverrideStrategy strategy;
9193

9294

9395
protected BeanOverrideHandler(@Nullable Field field, ResolvableType beanType, @Nullable String beanName,
9496
BeanOverrideStrategy strategy) {
9597

98+
this(field, beanType, beanName, "", strategy);
99+
}
100+
101+
protected BeanOverrideHandler(@Nullable Field field, ResolvableType beanType, @Nullable String beanName,
102+
String contextName, BeanOverrideStrategy strategy) {
103+
96104
this.field = field;
97105
this.qualifierAnnotations = getQualifierAnnotations(field);
98106
this.beanType = beanType;
99107
this.beanName = beanName;
100108
this.strategy = strategy;
109+
this.contextName = contextName;
101110
}
102111

103112
/**
@@ -238,6 +247,21 @@ public final String getBeanName() {
238247
return this.beanName;
239248
}
240249

250+
/**
251+
* Get the name of the context hierarchy level in which this handler should
252+
* be applied.
253+
* <p>An empty string indicates that this handler should be applied to all
254+
* application contexts within a context hierarchy.
255+
* <p>If a context name is configured for this handler, it must match a name
256+
* configured via {@code @ContextConfiguration(name=...)}.
257+
* @since 6.2.6
258+
* @see org.springframework.test.context.ContextHierarchy @ContextHierarchy
259+
* @see org.springframework.test.context.ContextConfiguration#name()
260+
*/
261+
public final String getContextName() {
262+
return this.contextName;
263+
}
264+
241265
/**
242266
* Get the {@link BeanOverrideStrategy} for this {@code BeanOverrideHandler},
243267
* which influences how and when the bean override instance should be created.
@@ -341,6 +365,7 @@ public String toString() {
341365
.append("field", this.field)
342366
.append("beanType", this.beanType)
343367
.append("beanName", this.beanName)
368+
.append("contextName", this.contextName)
344369
.append("strategy", this.strategy)
345370
.toString();
346371
}

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideRegistry.java

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
import java.util.LinkedHashMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.Map.Entry;
23+
import java.util.Objects;
2424

2525
import org.apache.commons.logging.Log;
2626
import org.apache.commons.logging.LogFactory;
2727

2828
import org.springframework.beans.factory.BeanCreationException;
29+
import org.springframework.beans.factory.BeanFactory;
2930
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
31+
import org.springframework.lang.Nullable;
3032
import org.springframework.util.Assert;
3133
import org.springframework.util.ReflectionUtils;
32-
import org.springframework.util.StringUtils;
34+
35+
import static org.springframework.test.context.bean.override.BeanOverrideContextCustomizer.REGISTRY_BEAN_NAME;
3336

3437
/**
3538
* An internal class used to track {@link BeanOverrideHandler}-related state after
@@ -45,16 +48,22 @@ class BeanOverrideRegistry {
4548
private static final Log logger = LogFactory.getLog(BeanOverrideRegistry.class);
4649

4750

48-
private final Map<BeanOverrideHandler, String> handlerToBeanNameMap = new LinkedHashMap<>();
51+
private final Map<HandlerCacheKey, String> handlerToBeanNameMap = new LinkedHashMap<>();
4952

5053
private final Map<String, BeanOverrideHandler> wrappingBeanOverrideHandlers = new LinkedHashMap<>();
5154

5255
private final ConfigurableBeanFactory beanFactory;
5356

57+
@Nullable
58+
private final BeanOverrideRegistry parent;
59+
5460

5561
BeanOverrideRegistry(ConfigurableBeanFactory beanFactory) {
5662
Assert.notNull(beanFactory, "ConfigurableBeanFactory must not be null");
5763
this.beanFactory = beanFactory;
64+
BeanFactory parentBeanFactory = beanFactory.getParentBeanFactory();
65+
this.parent = (parentBeanFactory != null && parentBeanFactory.containsBean(REGISTRY_BEAN_NAME) ?
66+
parentBeanFactory.getBean(REGISTRY_BEAN_NAME, BeanOverrideRegistry.class) : null);
5867
}
5968

6069
/**
@@ -65,20 +74,21 @@ class BeanOverrideRegistry {
6574
* bean via {@link #wrapBeanIfNecessary(Object, String)}.
6675
*/
6776
void registerBeanOverrideHandler(BeanOverrideHandler handler, String beanName) {
68-
Assert.state(!this.handlerToBeanNameMap.containsKey(handler), () ->
77+
HandlerCacheKey handlerKey = HandlerCacheKey.from(handler);
78+
Assert.state(!this.handlerToBeanNameMap.containsKey(handlerKey), () ->
6979
"Cannot register BeanOverrideHandler for bean with name '%s'; detected multiple registrations for %s"
7080
.formatted(beanName, handler));
7181

7282
// Check if beanName was already registered, before adding the new mapping.
7383
boolean beanNameAlreadyRegistered = this.handlerToBeanNameMap.containsValue(beanName);
7484
// Add new mapping before potentially logging a warning, to ensure that
7585
// the current handler is logged as well.
76-
this.handlerToBeanNameMap.put(handler, beanName);
86+
this.handlerToBeanNameMap.put(handlerKey, beanName);
7787

7888
if (beanNameAlreadyRegistered && logger.isWarnEnabled()) {
7989
List<BeanOverrideHandler> competingHandlers = this.handlerToBeanNameMap.entrySet().stream()
8090
.filter(entry -> entry.getValue().equals(beanName))
81-
.map(Entry::getKey)
91+
.map(entry -> entry.getKey().handler)
8292
.toList();
8393
logger.warn("Bean with name '%s' was overridden by multiple handlers: %s"
8494
.formatted(beanName, competingHandlers));
@@ -110,14 +120,13 @@ Object wrapBeanIfNecessary(Object bean, String beanName) {
110120
void inject(Object target, BeanOverrideHandler handler) {
111121
Field field = handler.getField();
112122
Assert.notNull(field, () -> "BeanOverrideHandler must have a non-null field: " + handler);
113-
String beanName = this.handlerToBeanNameMap.get(handler);
114-
Assert.state(StringUtils.hasLength(beanName), () -> "No bean found for BeanOverrideHandler: " + handler);
115-
inject(field, target, beanName);
123+
Object bean = getBeanForHandler(handler, field.getType());
124+
Assert.state(bean != null, () -> "No bean found for BeanOverrideHandler: " + handler);
125+
inject(field, target, bean);
116126
}
117127

118-
private void inject(Field field, Object target, String beanName) {
128+
private void inject(Field field, Object target, Object bean) {
119129
try {
120-
Object bean = this.beanFactory.getBean(beanName, field.getType());
121130
ReflectionUtils.makeAccessible(field);
122131
ReflectionUtils.setField(field, target, bean);
123132
}
@@ -126,4 +135,56 @@ private void inject(Field field, Object target, String beanName) {
126135
}
127136
}
128137

138+
@Nullable
139+
private Object getBeanForHandler(BeanOverrideHandler handler, Class<?> requiredType) {
140+
String beanName = this.handlerToBeanNameMap.get(HandlerCacheKey.from(handler));
141+
if (beanName != null) {
142+
return this.beanFactory.getBean(beanName, requiredType);
143+
}
144+
if (this.parent != null) {
145+
return this.parent.getBeanForHandler(handler, requiredType);
146+
}
147+
return null;
148+
}
149+
150+
151+
/**
152+
* Cache key for a {@link BeanOverrideHandler}, which also takes the
153+
* {@linkplain BeanOverrideHandler#getContextName() context name} into account.
154+
*
155+
* @since 6.2.6
156+
*/
157+
private static final class HandlerCacheKey {
158+
159+
static HandlerCacheKey from(BeanOverrideHandler handler) {
160+
return new HandlerCacheKey(handler);
161+
}
162+
163+
164+
private final BeanOverrideHandler handler;
165+
166+
167+
private HandlerCacheKey(BeanOverrideHandler handler) {
168+
this.handler = handler;
169+
}
170+
171+
172+
@Override
173+
public boolean equals(@Nullable Object other) {
174+
return (this == other || (other instanceof HandlerCacheKey that &&
175+
this.handler.equals(that.handler) &&
176+
Objects.equals(this.handler.getContextName(), that.handler.getContextName())));
177+
}
178+
179+
@Override
180+
public int hashCode() {
181+
return this.handler.hashCode() + (29 * this.handler.getContextName().hashCode());
182+
}
183+
184+
@Override
185+
public String toString() {
186+
return this.handler.toString();
187+
}
188+
}
189+
129190
}

spring-test/src/main/java/org/springframework/test/context/bean/override/convention/TestBean.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@
164164
*/
165165
String methodName() default "";
166166

167+
/**
168+
* The name of the context hierarchy level in which this {@code @TestBean}
169+
* should be applied.
170+
* <p>Defaults to an empty string which indicates that this {@code @TestBean}
171+
* should be applied to all application contexts within a context hierarchy.
172+
* <p>If a context name is configured, it must match a name configured via
173+
* {@code @ContextConfiguration(name=...)}.
174+
* @since 6.2.6
175+
* @see org.springframework.test.context.ContextHierarchy @ContextHierarchy
176+
* @see org.springframework.test.context.ContextConfiguration#name()
177+
*/
178+
String contextName() default "";
179+
167180
/**
168181
* Whether to require the existence of the bean being overridden.
169182
* <p>Defaults to {@code false} which means that a bean will be created if a

spring-test/src/main/java/org/springframework/test/context/bean/override/convention/TestBeanOverrideHandler.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ final class TestBeanOverrideHandler extends BeanOverrideHandler {
4343

4444

4545
TestBeanOverrideHandler(Field field, ResolvableType beanType, @Nullable String beanName,
46-
BeanOverrideStrategy strategy, Method factoryMethod) {
46+
String contextName, BeanOverrideStrategy strategy, Method factoryMethod) {
4747

48-
super(field, beanType, beanName, strategy);
48+
super(field, beanType, beanName, contextName, strategy);
4949
this.factoryMethod = factoryMethod;
5050
}
5151

@@ -90,6 +90,7 @@ public String toString() {
9090
.append("field", getField())
9191
.append("beanType", getBeanType())
9292
.append("beanName", getBeanName())
93+
.append("contextName", getContextName())
9394
.append("strategy", getStrategy())
9495
.append("factoryMethod", this.factoryMethod)
9596
.toString();

spring-test/src/main/java/org/springframework/test/context/bean/override/convention/TestBeanOverrideProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public TestBeanOverrideHandler createHandler(Annotation overrideAnnotation, Clas
8282
}
8383

8484
return new TestBeanOverrideHandler(
85-
field, ResolvableType.forField(field, testClass), beanName, strategy, factoryMethod);
85+
field, ResolvableType.forField(field, testClass), beanName, testBean.contextName(), strategy, factoryMethod);
8686
}
8787

8888
/**

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/AbstractMockitoBeanOverrideHandler.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ abstract class AbstractMockitoBeanOverrideHandler extends BeanOverrideHandler {
3939

4040

4141
protected AbstractMockitoBeanOverrideHandler(@Nullable Field field, ResolvableType beanType,
42-
@Nullable String beanName, BeanOverrideStrategy strategy, MockReset reset) {
42+
@Nullable String beanName, String contextName, BeanOverrideStrategy strategy,
43+
MockReset reset) {
4344

44-
super(field, beanType, beanName, strategy);
45+
super(field, beanType, beanName, contextName, strategy);
4546
this.reset = (reset != null ? reset : MockReset.AFTER);
4647
}
4748

@@ -92,6 +93,7 @@ public String toString() {
9293
.append("field", getField())
9394
.append("beanType", getBeanType())
9495
.append("beanName", getBeanName())
96+
.append("contextName", getContextName())
9597
.append("strategy", getStrategy())
9698
.append("reset", getReset())
9799
.toString();

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBean.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,19 @@
144144
*/
145145
Class<?>[] types() default {};
146146

147+
/**
148+
* The name of the context hierarchy level in which this {@code @MockitoBean}
149+
* should be applied.
150+
* <p>Defaults to an empty string which indicates that this {@code @MockitoBean}
151+
* should be applied to all application contexts within a context hierarchy.
152+
* <p>If a context name is configured, it must match a name configured via
153+
* {@code @ContextConfiguration(name=...)}.
154+
* @since 6.2.6
155+
* @see org.springframework.test.context.ContextHierarchy @ContextHierarchy
156+
* @see org.springframework.test.context.ContextConfiguration#name()
157+
*/
158+
String contextName() default "";
159+
147160
/**
148161
* Extra interfaces that should also be declared by the mock.
149162
* <p>Defaults to none.

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoBeanOverrideHandler.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ class MockitoBeanOverrideHandler extends AbstractMockitoBeanOverrideHandler {
6363

6464
MockitoBeanOverrideHandler(@Nullable Field field, ResolvableType typeToMock, MockitoBean mockitoBean) {
6565
this(field, typeToMock, (!mockitoBean.name().isBlank() ? mockitoBean.name() : null),
66-
(mockitoBean.enforceOverride() ? REPLACE : REPLACE_OR_CREATE),
67-
mockitoBean.reset(), mockitoBean.extraInterfaces(), mockitoBean.answers(), mockitoBean.serializable());
66+
mockitoBean.contextName(), (mockitoBean.enforceOverride() ? REPLACE : REPLACE_OR_CREATE),
67+
mockitoBean.reset(), mockitoBean.extraInterfaces(), mockitoBean.answers(), mockitoBean.serializable());
6868
}
6969

7070
private MockitoBeanOverrideHandler(@Nullable Field field, ResolvableType typeToMock, @Nullable String beanName,
71-
BeanOverrideStrategy strategy, MockReset reset, Class<?>[] extraInterfaces, Answers answers,
72-
boolean serializable) {
71+
String contextName, BeanOverrideStrategy strategy, MockReset reset, Class<?>[] extraInterfaces,
72+
Answers answers, boolean serializable) {
7373

74-
super(field, typeToMock, beanName, strategy, reset);
74+
super(field, typeToMock, beanName, contextName, strategy, reset);
7575
Assert.notNull(typeToMock, "'typeToMock' must not be null");
7676
this.extraInterfaces = asClassSet(extraInterfaces);
7777
this.answers = answers;
@@ -160,6 +160,7 @@ public String toString() {
160160
.append("field", getField())
161161
.append("beanType", getBeanType())
162162
.append("beanName", getBeanName())
163+
.append("contextName", getContextName())
163164
.append("strategy", getStrategy())
164165
.append("reset", getReset())
165166
.append("extraInterfaces", getExtraInterfaces())

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBean.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,19 @@
136136
*/
137137
Class<?>[] types() default {};
138138

139+
/**
140+
* The name of the context hierarchy level in which this {@code @MockitoSpyBean}
141+
* should be applied.
142+
* <p>Defaults to an empty string which indicates that this {@code @MockitoSpyBean}
143+
* should be applied to all application contexts within a context hierarchy.
144+
* <p>If a context name is configured, it must match a name configured via
145+
* {@code @ContextConfiguration(name=...)}.
146+
* @since 6.2.6
147+
* @see org.springframework.test.context.ContextHierarchy @ContextHierarchy
148+
* @see org.springframework.test.context.ContextConfiguration#name()
149+
*/
150+
String contextName() default "";
151+
139152
/**
140153
* The reset mode to apply to the spied bean.
141154
* <p>The default is {@link MockReset#AFTER} meaning that spies are automatically

spring-test/src/main/java/org/springframework/test/context/bean/override/mockito/MockitoSpyBeanOverrideHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class MockitoSpyBeanOverrideHandler extends AbstractMockitoBeanOverrideHandler {
5454

5555
MockitoSpyBeanOverrideHandler(@Nullable Field field, ResolvableType typeToSpy, MockitoSpyBean spyBean) {
5656
super(field, typeToSpy, (StringUtils.hasText(spyBean.name()) ? spyBean.name() : null),
57-
BeanOverrideStrategy.WRAP, spyBean.reset());
57+
spyBean.contextName(), BeanOverrideStrategy.WRAP, spyBean.reset());
5858
Assert.notNull(typeToSpy, "typeToSpy must not be null");
5959
}
6060

0 commit comments

Comments
 (0)