diff --git a/http/httpProxy.go b/http/httpProxy.go index 779ace69..dc048722 100644 --- a/http/httpProxy.go +++ b/http/httpProxy.go @@ -33,6 +33,7 @@ const ( ProxySchemaKey = "proxySchema" MaxConnectionsKey = "maxConnections" EnableRewriteKey = "enableRewrite" + EnableHttpExceptionKey = "enableHttpException" ) const ( @@ -415,9 +416,10 @@ func (m *LocationMatcher) NeedURLQueryString() bool { // We use meta element HTTP_Method as http method, HTTP_QueryString as query string // Request method as request uri // Body will transform to a http body with following rules: -// if body is a map[string]string we transform it as a form data -// if body is a string or []byte just use it -// else is unsupported +// +// if body is a map[string]string we transform it as a form data +// if body is a string or []byte just use it +// else is unsupported func MotanRequestToFasthttpRequest(motanRequest core.Request, fasthttpRequest *fasthttp.Request, defaultHTTPMethod string) error { httpMethod := motanRequest.GetAttachment(Method) if httpMethod == "" { diff --git a/provider/httpProvider.go b/provider/httpProvider.go index 7f45f8f6..f19cfe69 100644 --- a/provider/httpProvider.go +++ b/provider/httpProvider.go @@ -32,14 +32,15 @@ type HTTPProvider struct { gctx *motan.Context mixVars []string // for transparent http proxy - fastClient *fasthttp.HostClient - proxyAddr string - proxySchema string - locationMatcher *mhttp.LocationMatcher - maxConnections int - domain string - defaultHTTPMethod string - enableRewrite bool + fastClient *fasthttp.HostClient + proxyAddr string + proxySchema string + locationMatcher *mhttp.LocationMatcher + maxConnections int + domain string + defaultHTTPMethod string + enableRewrite bool + enableHttpException bool } const ( @@ -95,6 +96,13 @@ func (h *HTTPProvider) Initialize() { } else { h.enableRewrite = enableRewrite } + h.enableHttpException = false + enableHttpExceptionStr := h.url.GetParam(mhttp.EnableHttpExceptionKey, "false") + if enableHttpException, err := strconv.ParseBool(enableHttpExceptionStr); err != nil { + vlog.Errorf("%s should be a bool value, but got: %s", mhttp.EnableHttpExceptionKey, enableHttpExceptionStr) + } else { + h.enableHttpException = enableHttpException + } h.fastClient = &fasthttp.HostClient{ Name: "motan", Addr: h.proxyAddr, @@ -286,6 +294,12 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { fillExceptionWithCode(resp, http.StatusServiceUnavailable, t, err) return resp } + if h.enableHttpException { + if httpRes.StatusCode() >= 400 { + fillHttpException(resp, httpRes.StatusCode(), t, httpRes.Body()) + return resp + } + } headerBuffer := &bytes.Buffer{} httpRes.Header.Del("Connection") httpRes.Header.WriteTo(headerBuffer) @@ -341,6 +355,12 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { fillExceptionWithCode(resp, http.StatusServiceUnavailable, t, err) return resp } + if h.enableHttpException { + if httpRes.StatusCode() >= 400 { + fillHttpException(resp, httpRes.StatusCode(), t, httpRes.Body()) + return resp + } + } mhttp.FasthttpResponseToMotanResponse(resp, httpRes) resp.ProcessTime = (time.Now().UnixNano() - t) / 1e6 updateUpstreamStatusCode(resp, httpRes.StatusCode()) @@ -407,7 +427,6 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { defer httpResp.Body.Close() headers := httpResp.Header statusCode := httpResp.StatusCode - body, err := ioutil.ReadAll(httpResp.Body) l := len(body) if l == 0 { @@ -420,6 +439,12 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { ErrMsg: fmt.Sprintf("%s", err), ErrType: http.StatusServiceUnavailable} return resp } + if h.enableHttpException { + if statusCode >= 400 { + fillHttpException(resp, statusCode, t, body) + return resp + } + } request.GetAttachments().Range(func(k, v string) bool { resp.SetAttachment(k, v) return true @@ -477,6 +502,11 @@ func fillExceptionWithCode(resp *motan.MotanResponse, code int, start int64, err resp.Exception = &motan.Exception{ErrCode: code, ErrMsg: fmt.Sprintf("%s", err), ErrType: code} } +func fillHttpException(resp *motan.MotanResponse, statusCode int, start int64, body []byte) { + resp.ProcessTime = int64((time.Now().UnixNano() - start) / 1e6) + resp.Exception = &motan.Exception{ErrCode: statusCode, ErrMsg: string(body), ErrType: motan.BizException} +} + func fillException(resp *motan.MotanResponse, start int64, err error) { fillExceptionWithCode(resp, http.StatusServiceUnavailable, start, err) } diff --git a/provider/httpProvider_test.go b/provider/httpProvider_test.go index 1e10d955..cf587864 100644 --- a/provider/httpProvider_test.go +++ b/provider/httpProvider_test.go @@ -70,6 +70,45 @@ func TestHTTPProvider_Call(t *testing.T) { assert.Equal(t, "/2/p1/test?a=b", string(provider.Call(req).GetValue().([]interface{})[1].([]byte))) } +func TestHTTPProvider_Http_Exception(t *testing.T) { + context := &core.Context{} + context.Config, _ = config.NewConfigFromReader(bytes.NewReader([]byte(httpProviderTestData))) + providerURL := &core.URL{Protocol: "http", Path: "test4"} + providerURL.PutParam(mhttp.DomainKey, "test.domain") + providerURL.PutParam("requestTimeout", "2000") + providerURL.PutParam("proxyAddress", "localhost:8091") + provider := &HTTPProvider{url: providerURL, gctx: context} + provider.Initialize() + req := &core.MotanRequest{} + req.ServiceName = "test4" + req.Method = "/p1/test" + req.SetAttachment("Host", "test.domain") + req.SetAttachment(mhttp.QueryString, "a=b") + assert.Nil(t, provider.Call(req).GetException()) +} + +func TestHTTPProvider_Http_EnableException(t *testing.T) { + context := &core.Context{} + context.Config, _ = config.NewConfigFromReader(bytes.NewReader([]byte(httpProviderTestData))) + providerURL := &core.URL{Protocol: "http", Path: "test4"} + providerURL.PutParam(mhttp.DomainKey, "test.domain") + providerURL.PutParam("requestTimeout", "2000") + providerURL.PutParam("proxyAddress", "localhost:8091") + providerURL.PutParam(mhttp.EnableHttpExceptionKey, "true") + provider := &HTTPProvider{url: providerURL, gctx: context} + provider.Initialize() + req := &core.MotanRequest{} + req.ServiceName = "test4" + req.Method = "/p1/test" + req.SetAttachment("Host", "test.domain") + req.SetAttachment(mhttp.QueryString, "a=b") + exception := provider.Call(req).GetException() + assert.NotNil(t, exception) + assert.Equal(t, exception.ErrCode, 500) + assert.Equal(t, exception.ErrType, core.BizException) + assert.Equal(t, exception.ErrMsg, "request failed") +} + func TestMain(m *testing.M) { go func() { var addr = ":8090" @@ -80,6 +119,17 @@ func TestMain(m *testing.M) { }) http.ListenAndServe(addr, handler) }() + go func() { + // 返回500的server + var addr = ":8091" + handler := &http.ServeMux{} + handler.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { + request.ParseForm() + writer.WriteHeader(500) + writer.Write([]byte("request failed")) + }) + http.ListenAndServe(addr, handler) + }() time.Sleep(time.Second * 3) os.Exit(m.Run()) }