diff --git a/magnolia-core/src/main/java/info/magnolia/cms/filters/ServletDispatchingFilter.java b/magnolia-core/src/main/java/info/magnolia/cms/filters/ServletDispatchingFilter.java index 5ca88ae..9764950 100644 --- a/magnolia-core/src/main/java/info/magnolia/cms/filters/ServletDispatchingFilter.java +++ b/magnolia-core/src/main/java/info/magnolia/cms/filters/ServletDispatchingFilter.java @@ -40,6 +40,7 @@ import java.io.IOException; import java.util.Map; import java.util.regex.Matcher; +import javax.inject.Inject; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.Servlet; @@ -48,6 +49,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; +import info.magnolia.objectfactory.ComponentProvider; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -65,6 +67,8 @@ public class ServletDispatchingFilter extends AbstractMgnlFilter { static final Logger log = LoggerFactory.getLogger(ServletDispatchingFilter.class); + private final ComponentProvider componentProvider; + private String servletName; private String servletClass; @@ -75,7 +79,9 @@ public class ServletDispatchingFilter extends AbstractMgnlFilter { private Servlet servlet; - public ServletDispatchingFilter() { + @Inject + public ServletDispatchingFilter(ComponentProvider componentProvider) { + this.componentProvider = componentProvider; } @Override @@ -92,7 +98,7 @@ public class ServletDispatchingFilter extends AbstractMgnlFilter { if (servletClass != null) { try { - servlet = Classes.newInstance(servletClass); + servlet = componentProvider.newInstance(Classes.getClassFactory().forName(servletClass)); servlet.init(new CustomServletConfig(servletName, filterConfig.getServletContext(), parameters)); } catch (Throwable e) { diff --git a/magnolia-core/src/test/java/info/magnolia/cms/filters/ServletDispatchingFilterTest.java b/magnolia-core/src/test/java/info/magnolia/cms/filters/ServletDispatchingFilterTest.java index b6c8b6a..54fb1e4 100644 --- a/magnolia-core/src/test/java/info/magnolia/cms/filters/ServletDispatchingFilterTest.java +++ b/magnolia-core/src/test/java/info/magnolia/cms/filters/ServletDispatchingFilterTest.java @@ -44,6 +44,9 @@ import static org.easymock.EasyMock.replay; import static org.easymock.EasyMock.same; import static org.easymock.EasyMock.verify; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import info.magnolia.cms.core.AggregationState; import info.magnolia.cms.util.CustomFilterConfig; import info.magnolia.cms.util.ServletUtils; @@ -57,16 +60,22 @@ import info.magnolia.jcr.node2bean.impl.TypeMappingImpl; import info.magnolia.test.ComponentsTestUtil; import info.magnolia.test.MgnlTestCase; import info.magnolia.test.mock.MockAggregationState; +import info.magnolia.test.mock.MockComponentProvider; import info.magnolia.test.mock.MockHierarchyManager; import info.magnolia.test.mock.MockUtil; import info.magnolia.voting.DefaultVoting; import info.magnolia.voting.Voting; import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; import java.lang.reflect.Field; +import java.util.Collections; +import javax.inject.Inject; import javax.jcr.RepositoryException; import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; import javax.servlet.Servlet; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; @@ -81,6 +90,7 @@ import org.junit.Before; import org.junit.Test; import com.mockrunner.mock.web.MockHttpServletRequest; +import org.mockito.Mockito; /** * @version $Id$ @@ -225,7 +235,7 @@ public class ServletDispatchingFilterTest extends MgnlTestCase { replay(chain, res, req, servlet, ctx); state.setCurrentURI(requestPath); - final AbstractMgnlFilter filter = new ServletDispatchingFilter(); + final AbstractMgnlFilter filter = new ServletDispatchingFilter(new MockComponentProvider()); final Field servletField = ServletDispatchingFilter.class.getDeclaredField("servlet"); servletField.setAccessible(true); servletField.set(filter, servlet); @@ -278,6 +288,34 @@ public class ServletDispatchingFilterTest extends MgnlTestCase { filter.destroy(); } + @Test + public void servletCanHaveInjectedComponents() throws Exception { + // given + final ServletDispatchingFilter filter = new ServletDispatchingFilter(new MockComponentProvider()); + filter.setServletClass(TestInjectedServlet.class.getName()); + filter.addMapping("/"); + filter.setParameters(Collections.emptyMap()); + + // when + final FilterConfig filterConfig = mock(FilterConfig.class); + final HttpServletRequest req = mock(HttpServletRequest.class); + final HttpServletResponse res = mock(HttpServletResponse.class); + final FilterChain chain = mock(FilterChain.class); + when(req.getContextPath()).thenReturn("/does"); + when(req.getRequestURI()).thenReturn("/does/not/matter"); + when(req.getParameter("thing")).thenReturn("check"); + when(req.getAttribute("thing")).thenReturn("use parameter, not attr"); + final StringWriter out = new StringWriter(); + when(res.getWriter()).thenReturn(new PrintWriter(out)); + + filter.init(filterConfig); + filter.doFilter(req, res, chain); + + // then + assertEquals("This is the thing: check", out.getBuffer().toString()); +// Mockito.verifyNoMoreInteractions(chain, res, req); + } + public static class TestServlet extends HttpServlet { @Override @@ -304,4 +342,25 @@ public class ServletDispatchingFilterTest extends MgnlTestCase { assertEquals(null, req.getQueryString()); } } + + public static class TestInjectedServlet extends HttpServlet { + private final Thing thing; + + @Inject + public TestInjectedServlet(Thing thing) { + this.thing = thing; + } + + @Override + protected void service(HttpServletRequest req, HttpServletResponse res) throws ServletException, IOException { + assertNotNull(thing); + res.getWriter().append(thing.getThing(req)); + } + } + + public static class Thing { + public String getThing(HttpServletRequest req) { + return "This is the thing: " + req.getParameter("thing"); + } + } }