diff options
Diffstat (limited to 'libgo/go/net/http/httputil/reverseproxy_test.go')
-rw-r--r-- | libgo/go/net/http/httputil/reverseproxy_test.go | 149 |
1 files changed, 141 insertions, 8 deletions
diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 2232042d3ed..2f75b4e34ec 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -17,6 +17,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "reflect" "strconv" "strings" @@ -28,6 +29,7 @@ import ( const fakeHopHeader = "X-Fake-Hop-Header-For-Test" func init() { + inOurTests = true hopHeaders = append(hopHeaders, fakeHopHeader) } @@ -49,6 +51,9 @@ func TestReverseProxy(t *testing.T) { if c := r.Header.Get("Connection"); c != "" { t.Errorf("handler got Connection header value %q", c) } + if c := r.Header.Get("Te"); c != "trailers" { + t.Errorf("handler got Te header value %q; want 'trailers'", c) + } if c := r.Header.Get("Upgrade"); c != "" { t.Errorf("handler got Upgrade header value %q", c) } @@ -85,6 +90,7 @@ func TestReverseProxy(t *testing.T) { getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Host = "some-name" getReq.Header.Set("Connection", "close") + getReq.Header.Set("Te", "trailers") getReq.Header.Set("Proxy-Connection", "should be deleted") getReq.Header.Set("Upgrade", "foo") getReq.Close = true @@ -631,6 +637,93 @@ func TestReverseProxyModifyResponse(t *testing.T) { } } +type failingRoundTripper struct{} + +func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errors.New("some error") +} + +type staticResponseRoundTripper struct{ res *http.Response } + +func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return rt.res, nil +} + +func TestReverseProxyErrorHandler(t *testing.T) { + tests := []struct { + name string + wantCode int + errorHandler func(http.ResponseWriter, *http.Request, error) + transport http.RoundTripper // defaults to failingRoundTripper + modifyResponse func(*http.Response) error + }{ + { + name: "default", + wantCode: http.StatusBadGateway, + }, + { + name: "errorhandler", + wantCode: http.StatusTeapot, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + }, + { + name: "modifyresponse_noerr", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return nil + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: 346, + }, + { + name: "modifyresponse_err", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return errors.New("some error to trigger errorHandler") + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: http.StatusTeapot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target := &url.URL{ + Scheme: "http", + Host: "dummy.tld", + Path: "/", + } + rproxy := NewSingleHostReverseProxy(target) + rproxy.Transport = tt.transport + rproxy.ModifyResponse = tt.modifyResponse + if rproxy.Transport == nil { + rproxy.Transport = failingRoundTripper{} + } + rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + if tt.errorHandler != nil { + rproxy.ErrorHandler = tt.errorHandler + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL + "/test") + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + resp.Body.Close() + }) + } +} + // Issue 16659: log errors from short read func TestReverseProxy_CopyBuffer(t *testing.T) { backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -649,18 +742,22 @@ func TestReverseProxy_CopyBuffer(t *testing.T) { var proxyLog bytes.Buffer rproxy := NewSingleHostReverseProxy(rpURL) rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) - frontendProxy := httptest.NewServer(rproxy) + donec := make(chan bool, 1) + frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { donec <- true }() + rproxy.ServeHTTP(w, r) + })) defer frontendProxy.Close() - resp, err := http.Get(frontendProxy.URL) - if err != nil { - t.Fatalf("failed to reach proxy: %v", err) - } - defer resp.Body.Close() - - if _, err := ioutil.ReadAll(resp.Body); err == nil { + if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { t.Fatalf("want non-nil error") } + // The race detector complains about the proxyLog usage in logf in copyBuffer + // and our usage below with proxyLog.Bytes() so we're explicitly using a + // channel to ensure that the ReverseProxy's ServeHTTP is done before we + // continue after Get. + <-donec + expected := []string{ "EOF", "read", @@ -740,6 +837,8 @@ func TestServeHTTPDeepCopy(t *testing.T) { // Issue 18327: verify we always do a deep copy of the Request.Header map // before any mutations. func TestClonesRequestHeaders(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) req, _ := http.NewRequest("GET", "http://foo.tld/", nil) req.RemoteAddr = "1.2.3.4:56789" rp := &ReverseProxy{ @@ -813,3 +912,37 @@ func (cc *checkCloser) Close() error { func (cc *checkCloser) Read(b []byte) (int, error) { return len(b), nil } + +// Issue 23643: panic on body copy error +func TestReverseProxy_PanicBodyError(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + rproxy := NewSingleHostReverseProxy(rpURL) + + // Ensure that the handler panics when the body read encounters an + // io.ErrUnexpectedEOF + defer func() { + err := recover() + if err == nil { + t.Fatal("handler should have panicked") + } + if err != http.ErrAbortHandler { + t.Fatal("expected ErrAbortHandler, got", err) + } + }() + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + rproxy.ServeHTTP(httptest.NewRecorder(), req) +} |