【原创】自己动手写一个服务网关

发布日期:2019-01-17

引言

什么是网关?为什么需要使用网关?

如图所示,在不使用网关的情况下,我们的服务是直接暴露给服务调用方。当调用方增多,势必需要添加定制化访问权限、校验等逻辑。当添加API网关后,再第三方调用端和服务提供方之间就创建了一面墙,这面墙直接与调用方通信进行权限控制。本文所实现的网关源码抄袭了---Oh不对,是借鉴。借鉴了Zuul网关的源码,提炼出其核心思路,实现了一套简单的网关源码,博主将其改名为Eatuul。

题外话

本文是业内能搜到的第一篇自己动手实现网关的文章。博主写的手把手系列的文章,目的是在以最简单的方式,揭露出中间件的核心原理,让读者能够迅速了解实现的核心。需要说明的是,这不是源码分析系列的文章,因此写出来的代码,省去了一些复杂的内容,毕竟大家能理解到该中间件的核心原理即可。如果想看源码分析系列的,请关注博主,后期会将spring、spring boot、dubbo、mybatis等开源框架一一揭示。

正文

设计思路

先大致说一下,就是定义一个Servlet接收请求。然后经过preFilter(封装请求参数)routeFilter(转发请求),postFilter(输出内容)。三个过滤器之间,共享request、response以及其他的一些全局变量。如下图所示和真正的Zuul的区别?主要区别有如下几点(1)Zuul中在异常处理模块,有一个ErrorFilter来处理,博主在实现的时候偷懒了,略去。(2)Zuul中PreFiltersRoutingFiltersPostFilters默认都实现了一组,具体如下表所示博主总不可能每一个都给你们实现一遍吧。所以偷懒了,每种只实现一个。但是调用顺序还是不变,按照PreFilters->RoutingFilters->PostFilters的顺序调用(3)在routeFilters确实有转发请求的Filter然而博主偷天换日了,改用RestTemplate实现.

代码结构

大家去spring官网上搭建一套springboot的项目,博主就不展示pom的代码了。直接将项目结构展示一下,如下图所示EatuulServlet.java这个是网关的入口,逻辑也十分简单,分为三步(1)将requestresponse放入threadlocal中(2)执行三组过滤器(3)清除threadlocal中的的环境变量源码如下

package com.rjzheng.eatuul.httpimport java.io.IOExceptionimport javax.servlet.ServletExceptionimport javax.servlet.annotation.WebServletimport javax.servlet.http.HttpServletimport javax.servlet.http.HttpServletRequestimport javax.servlet.http.HttpServletResponse@WebServlet(name = "eatuul" urlPatterns = "/*")public class EatuulServlet extends HttpServlet { private EatRunner eatRunner = new EatRunner() @Override public void service(HttpServletRequest req HttpServletResponse resp) throws ServletException IOException { //将request,和response放入上下文对象中 eatRunner.init(req resp) try { //执行前置过滤 eatRunner.preRoute() //执行过滤 eatRunner.route() //执行后置过滤 eatRunner.postRoute() } catch (Throwable e) { RequestContext.getCurrentContext().getResponse() .sendError(HttpServletResponse.SC_NOT_FOUND e.getMessage()) } finally { //清除变量 RequestContext.getCurrentContext().unset() } }}

EatuulRunner.java这个是具体的执行器。需要说明一下,在Zuul中,ZuulRunner在获取具体有哪些过滤器的时候,有一个FileLoader可以动态读取配置加载。博主在实现我们自己的EatuulRunner时候,略去动态读取的过程,直接静态写死。源码如下

package com.rjzheng.eatuul.httpimport java.util.ArrayListimport java.util.Listimport java.util.concurrent.ConcurrentHashMapimport javax.servlet.http.HttpServletRequestimport javax.servlet.http.HttpServletResponseimport com.rjzheng.eatuul.filter.EatuulFilterimport com.rjzheng.eatuul.filter.post.SendResponseFilterimport com.rjzheng.eatuul.filter.pre.RequestWrapperFilterimport com.rjzheng.eatuul.filter.route.RoutingFilterpublic class EatRunner { //静态写死过滤器 private ConcurrentHashMap<String List<EatuulFilter>> hashFiltersByType = new ConcurrentHashMap<String List<EatuulFilter>>(){{ put("pre"new ArrayList<EatuulFilter>(){{ add(new RequestWrapperFilter()) }}) put("route"new ArrayList<EatuulFilter>(){{ add(new RoutingFilter()) }}) put("post"new ArrayList<EatuulFilter>(){{ add(new SendResponseFilter()) }}) }} public void init(HttpServletRequest req HttpServletResponse resp) { RequestContext ctx = RequestContext.getCurrentContext() ctx.setRequest(req) ctx.setResponse(resp) } public void preRoute() throws Throwable { runFilters("pre") } public void route() throws Throwable{ runFilters("route") } public void postRoute() throws Throwable{ runFilters("post") } public void runFilters(String sType) throws Throwable { List<EatuulFilter> list = this.hashFiltersByType.get(sType) if (list != null) { for (int i = 0 i < list.size() i++) { EatuulFilter zuulFilter = list.get(i) zuulFilter.run() } } }}

EatuulFilter.java接下来就是一系列Filter的代码了,先上父类EatuulFilter的源码

package com.rjzheng.eatuul.filterpublic abstract class EatuulFilter { abstract public String filterType() abstract public int filterOrder() abstract public void run()}

RequestWrapperFilter.java这个是PreFilter前置执行过滤器,负责封装请求。步骤如下所示(1)封装请求头(2)封装请求体(3)构造出RestTemplate能识别的RequestEntity(4)将RequestEntity放入全局threadlocal之中代码如下所示

package com.rjzheng.eatuul.filter.preimport java.io.IOExceptionimport java.io.InputStreamimport java.net.URIimport java.net.URISyntaxExceptionimport java.util.Collectionsimport java.util.Listimport javax.servlet.http.HttpServletRequestimport org.springframework.http.HttpHeadersimport org.springframework.http.HttpMethodimport org.springframework.http.RequestEntityimport org.springframework.util.MultiValueMapimport org.springframework.util.StreamUtilsimport com.rjzheng.eatuul.filter.EatuulFilterimport com.rjzheng.eatuul.http.RequestContextpublic class RequestWrapperFilter extends EatuulFilter{ @Override public String filterType() { // TODO Auto-generated method stub return "pre" } @Override public int filterOrder() { // TODO Auto-generated method stub return -1 } @Override public void run() { String rootURL = "http://localhost:9090" RequestContext ctx =RequestContext.getCurrentContext() HttpServletRequest servletRequest = ctx.getRequest() String targetURL = rootURL + servletRequest.getRequestURI() RequestEntity<byte[]> requestEntity = null try { requestEntity = createRequestEntity(servletRequest targetURL) } catch (Exception e) { e.printStackTrace() } //4、将requestEntity放入全局threadlocal之中 ctx.setRequestEntity(requestEntity) } private RequestEntity createRequestEntity(HttpServletRequest requestString url) throws URISyntaxException IOException { String method = request.getMethod() HttpMethod httpMethod = HttpMethod.resolve(method) //1、封装请求头 MultiValueMap<String String> headers =createRequestHeaders(request) //2、封装请求体 byte[] body = createRequestBody(request) //3、构造出RestTemplate能识别的RequestEntity RequestEntity requestEntity = new RequestEntity<byte[]>(bodyheadershttpMethod new URI(url)) return requestEntity } private byte[] createRequestBody(HttpServletRequest request) throws IOException { InputStream inputStream = request.getInputStream() return StreamUtils.copyToByteArray(inputStream) } private MultiValueMap<String String> createRequestHeaders(HttpServletRequest request) { HttpHeaders headers = new HttpHeaders() List<String> headerNames = Collections.list(request.getHeaderNames()) for(String headerName:headerNames) { List<String> headerValues = Collections.list(request.getHeaders(headerName)) for(String headerValue:headerValues) { headers.add(headerName headerValue) } } return headers }}

RoutingFilter.java这个是routeFilter这里我偷懒了,直接做转发请求,并且将返回值ResponseEntity放入全局threadlocal中

package com.rjzheng.eatuul.filter.routeimport org.springframework.http.RequestEntityimport org.springframework.http.ResponseEntityimport org.springframework.web.client.RestTemplateimport com.rjzheng.eatuul.filter.EatuulFilterimport com.rjzheng.eatuul.http.RequestContextpublic class RoutingFilter extends EatuulFilter{ @Override public String filterType() { // TODO Auto-generated method stub return "route" } @Override public int filterOrder() { // TODO Auto-generated method stub return 0 } @Override public void run(){ RequestContext ctx = RequestContext.getCurrentContext() RequestEntity requestEntity = ctx.getRequestEntity() RestTemplate restTemplate = new RestTemplate() ResponseEntity responseEntity = restTemplate.exchange(requestEntitybyte[].class) ctx.setResponseEntity(responseEntity) } }

SendResponseFilter.java这个是postFilters将ResponseEntity输出即可

package com.rjzheng.eatuul.filter.postimport java.util.Listimport java.util.Mapimport javax.servlet.ServletOutputStreamimport javax.servlet.http.HttpServletResponseimport org.springframework.http.HttpHeadersimport org.springframework.http.ResponseEntityimport com.rjzheng.eatuul.filter.EatuulFilterimport com.rjzheng.eatuul.http.RequestContextpublic class SendResponseFilter extends EatuulFilter{ @Override public String filterType() { return "post" } @Override public int filterOrder() { return 1000 } @Override public void run() { try { addResponseHeaders() writeResponse() } catch (Exception e) { e.printStackTrace() } } private void addResponseHeaders() { RequestContext ctx = RequestContext.getCurrentContext() HttpServletResponse servletResponse = ctx.getResponse() ResponseEntity responseEntity = ctx.getResponseEntity() HttpHeaders httpHeaders = responseEntity.getHeaders() for(Map.Entry<String List<String>> entry:httpHeaders.entrySet()) { String headerName = entry.getKey() List<String> headerValues = entry.getValue() for(String headerValue:headerValues) { servletResponse.addHeader(headerName headerValue) } } } private void writeResponse()throws Exception { RequestContext ctx = RequestContext.getCurrentContext() HttpServletResponse servletResponse = ctx.getResponse() if (servletResponse.getCharacterEncoding() == null) { // only set if not set servletResponse.setCharacterEncoding("UTF-8") } ResponseEntity responseEntity = ctx.getResponseEntity() if(responseEntity.hasBody()) { byte[] body = (byte[]) responseEntity.getBody() ServletOutputStream outputStream = servletResponse.getOutputStream() outputStream.write(body) outputStream.flush() } }}

RequestContext.java最后是一直在说的全局threadlocal变量

package com.rjzheng.eatuul.httpimport java.util.HashMapimport java.util.Mapimport java.util.concurrent.ConcurrentHashMapimport javax.servlet.http.HttpServletRequestimport javax.servlet.http.HttpServletResponseimport org.springframework.http.RequestEntityimport org.springframework.http.ResponseEntitypublic class RequestContext extends ConcurrentHashMap<String Object> { protected static Class<? extends RequestContext> contextClass = RequestContext.class protected static final ThreadLocal<? extends RequestContext> threadLocal = new ThreadLocal<RequestContext>() { @Override protected RequestContext initialValue() { try { return contextClass.newInstance() } catch (Throwable e) { throw new RuntimeException(e) } } } public static RequestContext getCurrentContext() { RequestContext context = threadLocal.get() return context } public HttpServletRequest getRequest() { return (HttpServletRequest) get("request") } public void setRequest(HttpServletRequest request) { put("request" request) } public HttpServletResponse getResponse() { return (HttpServletResponse) get("response") } public void setResponse(HttpServletResponse response) { set("response" response) } public void setRequestEntity(RequestEntity requestEntity){ set("requestEntity"requestEntity) } public RequestEntity getRequestEntity() { return (RequestEntity) get("requestEntity") } public void setResponseEntity(ResponseEntity responseEntity){ set("responseEntity"responseEntity) } public ResponseEntity getResponseEntity() { return (ResponseEntity) get("responseEntity") } public void set(String key Object value) { if (value != null) put(key value) else remove(key) } public void unset() { threadLocal.remove() }}

如何测试?

自己另外起一个server端口为9090如下所示

package com.rjzheng.eatserviceimport org.springframework.boot.autoconfigure.SpringBootApplicationimport org.springframework.boot.builder.SpringApplicationBuilderimport org.springframework.boot.web.servlet.ServletComponentScanimport com.rjzheng.eatservice.controller.IndexController@SpringBootApplication@ServletComponentScan(basePackageClasses = IndexController.class)public class Application { public static void main(String[] args) { new SpringApplicationBuilder(Application.class).properties("server.port=9090").run(args) }}

再来一个controller

package com.rjzheng.eatservice.controllerimport org.springframework.web.bind.annotation.RequestMappingimport org.springframework.web.bind.annotation.RestController@RestControllerpublic class IndexController { @RequestMapping("/index") public String index() { return "hello!world" }}

然后,你就发现可以从localhost:8080/index进行跳转访问了

结论

本文模拟了一下zuul网关的源码,借鉴了一下其精髓的部分。希望大家能有所收获