diff options
Diffstat (limited to 'libgo/go/net/http/h2_bundle.go')
-rw-r--r-- | libgo/go/net/http/h2_bundle.go | 2269 |
1 files changed, 1733 insertions, 536 deletions
diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 5826bb7d858..25fdf09d92b 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -1,5 +1,5 @@ // Code generated by golang.org/x/tools/cmd/bundle. -//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2 +//go:generate bundle -o h2_bundle.go -prefix http2 -underscore golang.org/x/net/http2 // Package http2 implements the HTTP/2 protocol. // @@ -21,6 +21,7 @@ import ( "bytes" "compress/gzip" "context" + "crypto/rand" "crypto/tls" "encoding/binary" "errors" @@ -43,6 +44,7 @@ import ( "time" "golang_org/x/net/http2/hpack" + "golang_org/x/net/idna" "golang_org/x/net/lex/httplex" ) @@ -853,10 +855,12 @@ type http2Framer struct { // If the limit is hit, MetaHeadersFrame.Truncated is set true. MaxHeaderListSize uint32 - logReads bool + logReads, logWrites bool - debugFramer *http2Framer // only use for logging written writes - debugFramerBuf *bytes.Buffer + debugFramer *http2Framer // only use for logging written writes + debugFramerBuf *bytes.Buffer + debugReadLoggerf func(string, ...interface{}) + debugWriteLoggerf func(string, ...interface{}) } func (fr *http2Framer) maxHeaderListSize() uint32 { @@ -890,7 +894,7 @@ func (f *http2Framer) endWrite() error { byte(length>>16), byte(length>>8), byte(length)) - if http2logFrameWrites { + if f.logWrites { f.logWrite() } @@ -912,10 +916,10 @@ func (f *http2Framer) logWrite() { f.debugFramerBuf.Write(f.wbuf) fr, err := f.debugFramer.ReadFrame() if err != nil { - log.Printf("http2: Framer %p: failed to decode just-written frame", f) + f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) return } - log.Printf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) + f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) } func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } @@ -936,9 +940,12 @@ const ( // NewFramer returns a Framer that writes frames to w and reads them from r. func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr := &http2Framer{ - w: w, - r: r, - logReads: http2logFrameReads, + w: w, + r: r, + logReads: http2logFrameReads, + logWrites: http2logFrameWrites, + debugReadLoggerf: log.Printf, + debugWriteLoggerf: log.Printf, } fr.getReadBuf = func(size uint32) []byte { if cap(fr.readBuf) >= int(size) { @@ -1020,7 +1027,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { return nil, err } if fr.logReads { - log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { return fr.readMetaFrame(f.(*http2HeadersFrame)) @@ -1254,7 +1261,7 @@ func (f *http2Framer) WriteSettings(settings ...http2Setting) error { return f.endWrite() } -// WriteSettings writes an empty SETTINGS frame with the ACK bit set. +// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. @@ -1920,8 +1927,8 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr hdec.SetEmitEnabled(true) hdec.SetMaxStringLength(fr.maxHeaderStringLen()) hdec.SetEmitFunc(func(hf hpack.HeaderField) { - if http2VerboseLogs && http2logFrameReads { - log.Printf("http2: decoded hpack field %+v", hf) + if http2VerboseLogs && fr.logReads { + fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) } if !httplex.ValidHeaderFieldValue(hf.Value) { invalid = http2headerFieldValueError(hf.Value) @@ -2091,6 +2098,13 @@ type http2clientTrace httptrace.ClientTrace func http2reqContext(r *Request) context.Context { return r.Context() } +func (t *http2Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout + } + return 0 +} + func http2setResponseUncompressed(res *Response) { res.Uncompressed = true } func http2traceGotConn(req *Request, cc *http2ClientConn) { @@ -2145,6 +2159,48 @@ func http2requestTrace(req *Request) *http2clientTrace { return (*http2clientTrace)(trace) } +// Ping sends a PING frame to the server and waits for the ack. +func (cc *http2ClientConn) Ping(ctx context.Context) error { + return cc.ping(ctx) +} + +func http2cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } + +var _ Pusher = (*http2responseWriter)(nil) + +// Push implements http.Pusher. +func (w *http2responseWriter) Push(target string, opts *PushOptions) error { + internalOpts := http2pushOptions{} + if opts != nil { + internalOpts.Method = opts.Method + internalOpts.Header = opts.Header + } + return w.push(target, internalOpts) +} + +func http2configureServer18(h1 *Server, h2 *http2Server) error { + if h2.IdleTimeout == 0 { + if h1.IdleTimeout != 0 { + h2.IdleTimeout = h1.IdleTimeout + } else { + h2.IdleTimeout = h1.ReadTimeout + } + } + return nil +} + +func http2shouldLogPanic(panicValue interface{}) bool { + return panicValue != nil && panicValue != ErrAbortHandler +} + +func http2reqGetBody(req *Request) func() (io.ReadCloser, error) { + return req.GetBody +} + +func http2reqBodyIsNoBody(body io.ReadCloser) bool { + return body == NoBody +} + var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" type http2goroutineLock uint64 @@ -2368,6 +2424,7 @@ var ( http2VerboseLogs bool http2logFrameWrites bool http2logFrameReads bool + http2inTests bool ) func init() { @@ -2409,13 +2466,23 @@ var ( type http2streamState int +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. const ( http2stateIdle http2streamState = iota http2stateOpen http2stateHalfClosedLocal http2stateHalfClosedRemote - http2stateResvLocal - http2stateResvRemote http2stateClosed ) @@ -2424,8 +2491,6 @@ var http2stateName = [...]string{ http2stateOpen: "Open", http2stateHalfClosedLocal: "HalfClosedLocal", http2stateHalfClosedRemote: "HalfClosedRemote", - http2stateResvLocal: "ResvLocal", - http2stateResvRemote: "ResvRemote", http2stateClosed: "Closed", } @@ -2586,13 +2651,27 @@ func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { return &http2bufferedWriter{w: w} } +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const http2bufWriterPoolBufferSize = 4 << 10 + var http2bufWriterPool = sync.Pool{ New: func() interface{} { - - return bufio.NewWriterSize(nil, 4<<10) + return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) }, } +func (w *http2bufferedWriter) Available() int { + if w.bw == nil { + return http2bufWriterPoolBufferSize + } + return w.bw.Available() +} + func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { if w.bw == nil { bw := http2bufWriterPool.Get().(*bufio.Writer) @@ -2686,6 +2765,19 @@ func (s *http2sorter) SortStrings(ss []string) { s.v = save } +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/', but not with with "//", +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +func http2validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" +} + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -2882,6 +2974,15 @@ type http2Server struct { // PermitProhibitedCipherSuites, if true, permits the use of // cipher suites prohibited by the HTTP/2 spec. PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() http2WriteScheduler } func (s *http2Server) maxReadFrameSize() uint32 { @@ -2904,9 +3005,15 @@ func (s *http2Server) maxConcurrentStreams() uint32 { // // ConfigureServer must be called before s begins serving. func http2ConfigureServer(s *Server, conf *http2Server) error { + if s == nil { + panic("nil *http.Server") + } if conf == nil { conf = new(http2Server) } + if err := http2configureServer18(s, conf); err != nil { + return err + } if s.TLSConfig == nil { s.TLSConfig = new(tls.Config) @@ -2945,8 +3052,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) } - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14") - if s.TLSNextProto == nil { s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){} } @@ -2960,7 +3065,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { }) } s.TLSNextProto[http2NextProtoTLS] = protoHandler - s.TLSNextProto["h2-14"] = protoHandler return nil } @@ -3014,29 +3118,39 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { defer cancel() sc := &http2serverConn{ - srv: s, - hs: opts.baseConfig(), - conn: c, - baseCtx: baseCtx, - remoteAddrStr: c.RemoteAddr().String(), - bw: http2newBufferedWriter(c), - handler: opts.handler(), - streams: make(map[uint32]*http2stream), - readFrameCh: make(chan http2readFrameResult), - wantWriteFrameCh: make(chan http2frameWriteMsg, 8), - wroteFrameCh: make(chan http2frameWriteResult, 1), - bodyReadCh: make(chan http2bodyReadMsg), - doneServing: make(chan struct{}), - advMaxStreams: s.maxConcurrentStreams(), - writeSched: http2writeScheduler{ - maxFrameSize: http2initialMaxFrameSize, - }, + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + wantStartPushCh: make(chan http2startPushRequest, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), + bodyReadCh: make(chan http2bodyReadMsg), + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, + advMaxStreams: s.maxConcurrentStreams(), initialWindowSize: http2initialWindowSize, + maxFrameSize: http2initialMaxFrameSize, headerTableSize: http2initialHeaderTableSize, serveG: http2newGoroutineLock(), pushEnabled: true, } + if sc.hs.WriteTimeout != 0 { + sc.conn.SetWriteDeadline(time.Time{}) + } + + if s.NewWriteScheduler != nil { + sc.writeSched = s.NewWriteScheduler() + } else { + sc.writeSched = http2NewRandomWriteScheduler() + } + sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -3090,16 +3204,18 @@ type http2serverConn struct { handler Handler baseCtx http2contextContext framer *http2Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan http2readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve - wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan http2bodyReadMsg // from handlers -> serve - testHookCh chan func(int) // code to run on the serve loop - flow http2flow // conn-wide (not stream-specific) outbound flow control - inflow http2flow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve + wantStartPushCh chan http2startPushRequest // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + testHookCh chan func(int) // code to run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string + writeSched http2WriteScheduler // Everything following is owned by the serve loop; use serveG.check(): serveG http2goroutineLock // used to verify funcs are on serve() @@ -3109,22 +3225,27 @@ type http2serverConn struct { unackedSettings int // how many SETTINGS have we sent without ACKs? clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curOpenStreams uint32 // client's number of open streams - maxStreamID uint32 // max ever seen + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes streams map[uint32]*http2stream initialWindowSize int32 + maxFrameSize int32 headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh needsFrameFlush bool // last frame write wasn't a flush - writeSched http2writeScheduler - inGoAway bool // we've started to or sent GOAWAY - needToSendGoAway bool // we need to schedule a GOAWAY frame write + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write goAwayCode http2ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used - freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf + idleTimer *time.Timer // nil if unused + idleTimerCh <-chan time.Time // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -3143,6 +3264,11 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { return uint32(n + typicalHeaders*perFieldOverhead) } +func (sc *http2serverConn) curOpenStreams() uint32 { + sc.serveG.check() + return sc.curClientStreams + sc.curPushedStreams +} + // stream represents a stream. This is the minimal metadata needed by // the serve goroutine. Most of the actual stream state is owned by // the http.Handler's goroutine in the responseWriter. Because the @@ -3168,11 +3294,10 @@ type http2stream struct { numTrailerValues int64 weight uint8 state http2streamState - sentReset bool // only true once detached from streams map - gotReset bool // only true once detacted from streams map - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - reqBuf []byte + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + reqBuf []byte // if non-nil, body pipe buffer to return later at EOF trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -3195,8 +3320,14 @@ func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2strea return st.state, st } - if streamID <= sc.maxStreamID { - return http2stateClosed, nil + if streamID%2 == 1 { + if streamID <= sc.maxClientStreamID { + return http2stateClosed, nil + } + } else { + if streamID <= sc.maxPushPromiseID { + return http2stateClosed, nil + } } return http2stateIdle, nil } @@ -3328,17 +3459,17 @@ func (sc *http2serverConn) readFrames() { // frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. type http2frameWriteResult struct { - wm http2frameWriteMsg // what was written (or attempted) - err error // result of the writeFrame call + wr http2FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call } // writeFrameAsync runs in its own goroutine and writes a single frame // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *http2serverConn) writeFrameAsync(wm http2frameWriteMsg) { - err := wm.write.writeFrame(sc) - sc.wroteFrameCh <- http2frameWriteResult{wm, err} +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wr, err} } func (sc *http2serverConn) closeAllStreamsOnConnClose() { @@ -3382,7 +3513,7 @@ func (sc *http2serverConn) serve() { sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeSettings{ {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, @@ -3399,6 +3530,17 @@ func (sc *http2serverConn) serve() { sc.setConnState(StateActive) sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout) + defer sc.idleTimer.Stop() + sc.idleTimerCh = sc.idleTimer.C + } + + var gracefulShutdownCh <-chan struct{} + if sc.hs != nil { + gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs) + } + go sc.readFrames() settingsTimer := time.NewTimer(http2firstSettingsTimeout) @@ -3406,8 +3548,10 @@ func (sc *http2serverConn) serve() { for { loopNum++ select { - case wm := <-sc.wantWriteFrameCh: - sc.writeFrame(wm) + case wr := <-sc.wantWriteFrameCh: + sc.writeFrame(wr) + case spr := <-sc.wantStartPushCh: + sc.startPush(spr) case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: @@ -3424,12 +3568,22 @@ func (sc *http2serverConn) serve() { case <-settingsTimer.C: sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) return + case <-gracefulShutdownCh: + gracefulShutdownCh = nil + sc.startGracefulShutdown() case <-sc.shutdownTimerCh: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return + case <-sc.idleTimerCh: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) case fn := <-sc.testHookCh: fn(loopNum) } + + if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame { + return + } } } @@ -3477,7 +3631,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte ch := http2errChanPool.Get().(chan error) writeArg := http2writeDataPool.Get().(*http2writeData) *writeArg = http2writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(http2frameWriteMsg{ + err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: writeArg, stream: stream, done: ch, @@ -3507,17 +3661,17 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte return err } -// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts // if the connection has gone away. // // This must not be run from the serve goroutine itself, else it might // deadlock writing to sc.wantWriteFrameCh (which is only mildly // buffered and is read by serve itself). If you're on the serve // goroutine, call writeFrame instead. -func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { +func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { sc.serveG.checkNotOn() select { - case sc.wantWriteFrameCh <- wm: + case sc.wantWriteFrameCh <- wr: return nil case <-sc.doneServing: @@ -3533,53 +3687,81 @@ func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { // make it onto the wire // // If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) { +func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { sc.serveG.check() + // If true, wr will not be written and wr.done will not be signaled. var ignoreWrite bool - switch wm.write.(type) { + if wr.StreamID() != 0 { + _, isReset := wr.write.(http2StreamError) + if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { + ignoreWrite = true + } + } + + switch wr.write.(type) { case *http2writeResHeaders: - wm.stream.wroteHeaders = true + wr.stream.wroteHeaders = true case http2write100ContinueHeadersFrame: - if wm.stream.wroteHeaders { + if wr.stream.wroteHeaders { + + if wr.done != nil { + panic("wr.done != nil for write100ContinueHeadersFrame") + } ignoreWrite = true } } if !ignoreWrite { - sc.writeSched.add(wm) + sc.writeSched.Push(wr) } sc.scheduleFrameWrite() } -// startFrameWrite starts a goroutine to write wm (in a separate +// startFrameWrite starts a goroutine to write wr (in a separate // goroutine since that might block on the network), and updates the -// serve goroutine's state about the world, updated from info in wm. -func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) { +// serve goroutine's state about the world, updated from info in wr. +func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { sc.serveG.check() if sc.writingFrame { panic("internal error: can only be writing one frame at a time") } - st := wm.stream + st := wr.stream if st != nil { switch st.state { case http2stateHalfClosedLocal: - panic("internal error: attempt to send frame on half-closed-local stream") - case http2stateClosed: - if st.sentReset || st.gotReset { + switch wr.write.(type) { + case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: - sc.scheduleFrameWrite() - return + default: + panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) } - panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm)) + case http2stateClosed: + panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) + } + } + if wpp, ok := wr.write.(*http2writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + wr.replyToWriter(err) + return } } sc.writingFrame = true sc.needsFrameFlush = true - go sc.writeFrameAsync(wm) + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(http2frameWriteResult{wr, err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } } // errHandlerPanicked is the error given to any callers blocked in a read from @@ -3595,26 +3777,12 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { panic("internal error: expected to be already writing a frame") } sc.writingFrame = false + sc.writingFrameAsync = false - wm := res.wm - st := wm.stream - - closeStream := http2endsStream(wm.write) - - if _, ok := wm.write.(http2handlerPanicRST); ok { - sc.closeStream(st, http2errHandlerPanicked) - } + wr := res.wr - if ch := wm.done; ch != nil { - select { - case ch <- res.err: - default: - panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write)) - } - } - wm.write = nil - - if closeStream { + if http2writeEndsStream(wr.write) { + st := wr.stream if st == nil { panic("internal error: expecting non-nil stream") } @@ -3622,13 +3790,24 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { case http2stateOpen: st.state = http2stateHalfClosedLocal - errCancel := http2streamError(st.id, http2ErrCodeCancel) - sc.resetStream(errCancel) + sc.resetStream(http2streamError(st.id, http2ErrCodeCancel)) case http2stateHalfClosedRemote: sc.closeStream(st, http2errHandlerComplete) } + } else { + switch v := wr.write.(type) { + case http2StreamError: + + if st, ok := sc.streams[v.StreamID]; ok { + sc.closeStream(st, v) + } + case http2handlerPanicRST: + sc.closeStream(wr.stream, http2errHandlerPanicked) + } } + wr.replyToWriter(res.err) + sc.scheduleFrameWrite() } @@ -3646,47 +3825,68 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { // flush the write buffer. func (sc *http2serverConn) scheduleFrameWrite() { sc.serveG.check() - if sc.writingFrame { - return - } - if sc.needToSendGoAway { - sc.needToSendGoAway = false - sc.startFrameWrite(http2frameWriteMsg{ - write: &http2writeGoAway{ - maxStreamID: sc.maxStreamID, - code: sc.goAwayCode, - }, - }) - return - } - if sc.needToSendSettingsAck { - sc.needToSendSettingsAck = false - sc.startFrameWrite(http2frameWriteMsg{write: http2writeSettingsAck{}}) + if sc.writingFrame || sc.inFrameScheduleLoop { return } - if !sc.inGoAway { - if wm, ok := sc.writeSched.take(); ok { - sc.startFrameWrite(wm) - return + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2FrameWriteRequest{ + write: &http2writeGoAway{ + maxStreamID: sc.maxClientStreamID, + code: sc.goAwayCode, + }, + }) + continue } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + continue + } + if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { + if wr, ok := sc.writeSched.Pop(); ok { + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false + continue + } + break } - if sc.needsFrameFlush { - sc.startFrameWrite(http2frameWriteMsg{write: http2flushFrameWriter{}}) - sc.needsFrameFlush = false - return - } + sc.inFrameScheduleLoop = false +} + +// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the +// client we're gracefully shutting down. The connection isn't closed +// until all current streams are done. +func (sc *http2serverConn) startGracefulShutdown() { + sc.goAwayIn(http2ErrCodeNo, 0) } func (sc *http2serverConn) goAway(code http2ErrCode) { sc.serveG.check() - if sc.inGoAway { - return - } + var forceCloseIn time.Duration if code != http2ErrCodeNo { - sc.shutDownIn(250 * time.Millisecond) + forceCloseIn = 250 * time.Millisecond } else { - sc.shutDownIn(1 * time.Second) + forceCloseIn = 1 * time.Second + } + sc.goAwayIn(code, forceCloseIn) +} + +func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) { + sc.serveG.check() + if sc.inGoAway { + return + } + if forceCloseIn != 0 { + sc.shutDownIn(forceCloseIn) } sc.inGoAway = true sc.needToSendGoAway = true @@ -3702,10 +3902,9 @@ func (sc *http2serverConn) shutDownIn(d time.Duration) { func (sc *http2serverConn) resetStream(se http2StreamError) { sc.serveG.check() - sc.writeFrame(http2frameWriteMsg{write: se}) + sc.writeFrame(http2FrameWriteRequest{write: se}) if st, ok := sc.streams[se.StreamID]; ok { - st.sentReset = true - sc.closeStream(st, se) + st.resetQueued = true } } @@ -3782,6 +3981,8 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { return sc.processResetStream(f) case *http2PriorityFrame: return sc.processPriority(f) + case *http2GoAwayFrame: + return sc.processGoAway(f) case *http2PushPromiseFrame: return http2ConnectionError(http2ErrCodeProtocol) @@ -3801,7 +4002,10 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - sc.writeFrame(http2frameWriteMsg{write: http2writePingAck{f}}) + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } + sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) return nil } @@ -3809,7 +4013,11 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error sc.serveG.check() switch { case f.StreamID != 0: - st := sc.streams[f.StreamID] + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } if st == nil { return nil @@ -3835,7 +4043,6 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } if st != nil { - st.gotReset = true st.cancelCtx() sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) } @@ -3848,11 +4055,21 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } st.state = http2stateClosed - sc.curOpenStreams-- - if sc.curOpenStreams == 0 { - sc.setConnState(StateIdle) + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- } delete(sc.streams, st.id) + if len(sc.streams) == 0 { + sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } + if http2h1ServerKeepAlivesDisabled(sc.hs) { + sc.startGracefulShutdown() + } + } if p := st.body; p != nil { sc.sendWindowUpdate(nil, p.Len()) @@ -3860,11 +4077,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { p.CloseWithError(err) } st.cw.Close() - sc.writeSched.forgetStream(st.id) - if st.reqBuf != nil { - - sc.freeRequestBodyBuf = st.reqBuf - } + sc.writeSched.CloseStream(st.id) } func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { @@ -3904,7 +4117,7 @@ func (sc *http2serverConn) processSetting(s http2Setting) error { case http2SettingInitialWindowSize: return sc.processSettingInitialWindowSize(s.Val) case http2SettingMaxFrameSize: - sc.writeSched.maxFrameSize = s.Val + sc.maxFrameSize = int32(s.Val) case http2SettingMaxHeaderListSize: sc.peerMaxHeaderListSize = s.Val default: @@ -3933,11 +4146,18 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.serveG.check() + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } data := f.Data() id := f.Header().StreamID - st, ok := sc.streams[id] - if !ok || st.state != http2stateOpen || st.gotTrailerHeader { + state, st := sc.state(id) + if id == 0 || state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } + if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { if sc.inflow.available() < int32(f.Length) { return http2streamError(id, http2ErrCodeFlowControl) @@ -3946,6 +4166,10 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) + if st != nil && st.resetQueued { + + return nil + } return http2streamError(id, http2ErrCodeStreamClosed) } if st.body == nil { @@ -3985,6 +4209,24 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { return nil } +func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { + sc.serveG.check() + if f.ErrCode != http2ErrCodeNo { + sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } else { + sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } + sc.startGracefulShutdown() + + sc.pushEnabled = false + return nil +} + +// isPushed reports whether the stream is server-initiated. +func (st *http2stream) isPushed() bool { + return st.id%2 == 0 +} + // endStream closes a Request.Body's pipe. It is called when a DATA // frame says a request body is over (or after trailers). func (st *http2stream) endStream() { @@ -4014,7 +4256,7 @@ func (st *http2stream) copyTrailersToHandlerRequest() { func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() - id := f.Header().StreamID + id := f.StreamID if sc.inGoAway { return nil @@ -4024,50 +4266,43 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - st := sc.streams[f.Header().StreamID] - if st != nil { + if st := sc.streams[f.StreamID]; st != nil { + if st.resetQueued { + + return nil + } return st.processTrailerHeaders(f) } - if id <= sc.maxStreamID { + if id <= sc.maxClientStreamID { return http2ConnectionError(http2ErrCodeProtocol) } - sc.maxStreamID = id + sc.maxClientStreamID = id - ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) - st = &http2stream{ - sc: sc, - id: id, - state: http2stateOpen, - ctx: ctx, - cancelCtx: cancelCtx, - } - if f.StreamEnded() { - st.state = http2stateHalfClosedRemote + if sc.idleTimer != nil { + sc.idleTimer.Stop() } - st.cw.Init() - st.flow.conn = &sc.flow - st.flow.add(sc.initialWindowSize) - st.inflow.conn = &sc.inflow - st.inflow.add(http2initialWindowSize) + if sc.curClientStreams+1 > sc.advMaxStreams { + if sc.unackedSettings == 0 { - sc.streams[id] = st - if f.HasPriority() { - http2adjustStreamPriority(sc.streams, st.id, f.Priority) - } - sc.curOpenStreams++ - if sc.curOpenStreams == 1 { - sc.setConnState(StateActive) + return http2streamError(id, http2ErrCodeProtocol) + } + + return http2streamError(id, http2ErrCodeRefusedStream) } - if sc.curOpenStreams > sc.advMaxStreams { - if sc.unackedSettings == 0 { + initialState := http2stateOpen + if f.StreamEnded() { + initialState = http2stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) - return http2streamError(st.id, http2ErrCodeProtocol) + if f.HasPriority() { + if err := http2checkPriority(f.StreamID, f.Priority); err != nil { + return err } - - return http2streamError(st.id, http2ErrCodeRefusedStream) + sc.writeSched.AdjustStream(st.id, f.Priority) } rw, req, err := sc.newWriterAndRequest(st, f) @@ -4085,10 +4320,14 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if f.Truncated { handler = http2handleHeaderListTooLong - } else if err := http2checkValidHTTP2Request(req); err != nil { + } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { handler = http2new400Handler(err) } + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + go sc.runHandler(rw, req, handler) return nil } @@ -4121,90 +4360,138 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { return nil } -func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { - http2adjustStreamPriority(sc.streams, f.StreamID, f.http2PriorityParam) +func http2checkPriority(streamID uint32, p http2PriorityParam) error { + if streamID == p.StreamDep { + + return http2streamError(streamID, http2ErrCodeProtocol) + } return nil } -func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, priority http2PriorityParam) { - st, ok := streams[streamID] - if !ok { +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + if sc.inGoAway { + return nil + } + if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + return nil +} - return +func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") } - st.weight = priority.Weight - parent := streams[priority.StreamDep] - if parent == st { - return + ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) + st := &http2stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, } + st.cw.Init() + st.flow.conn = &sc.flow + st.flow.add(sc.initialWindowSize) + st.inflow.conn = &sc.inflow + st.inflow.add(http2initialWindowSize) - for piter := parent; piter != nil; piter = piter.parent { - if piter == st { - parent.parent = st.parent - break - } + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ } - st.parent = parent - if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) { - for _, openStream := range streams { - if openStream != st && openStream.parent == st.parent { - openStream.parent = st - } - } + if sc.curOpenStreams() == 1 { + sc.setConnState(StateActive) } + + return st } func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) { sc.serveG.check() - method := f.PseudoValue("method") - path := f.PseudoValue("path") - scheme := f.PseudoValue("scheme") - authority := f.PseudoValue("authority") + rp := http2requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } - isConnect := method == "CONNECT" + isConnect := rp.method == "CONNECT" if isConnect { - if path != "" || scheme != "" || authority == "" { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - } else if method == "" || path == "" || - (scheme != "https" && scheme != "http") { + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } bodyOpen := !f.StreamEnded() - if method == "HEAD" && bodyOpen { + if rp.method == "HEAD" && bodyOpen { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - var tlsState *tls.ConnectionState // nil if not scheme https - if scheme == "https" { - tlsState = sc.tlsState + rp.header = make(Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") } - header := make(Header) - for _, hf := range f.RegularFields() { - header.Add(sc.canonicalHeader(hf.Name), hf.Value) + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err } + if bodyOpen { + st.reqBuf = http2getRequestBodyBuf() + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2fixedBuffer{buf: st.reqBuf}, + } - if authority == "" { - authority = header.Get("Host") + if vv, ok := rp.header["Content-Length"]; ok { + req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + } else { + req.ContentLength = -1 + } } - needsContinue := header.Get("Expect") == "100-continue" + return rw, req, nil +} + +type http2requestParam struct { + method string + scheme, authority, path string + header Header +} + +func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) { + sc.serveG.check() + + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState + } + + needsContinue := rp.header.Get("Expect") == "100-continue" if needsContinue { - header.Del("Expect") + rp.header.Del("Expect") } - if cookies := header["Cookie"]; len(cookies) > 1 { - header.Set("Cookie", strings.Join(cookies, "; ")) + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) } // Setup Trailers var trailer Header - for _, v := range header["Trailer"] { + for _, v := range rp.header["Trailer"] { for _, key := range strings.Split(v, ",") { key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { @@ -4218,55 +4505,42 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead } } } - delete(header, "Trailer") + delete(rp.header, "Trailer") - body := &http2requestBody{ - conn: sc, - stream: st, - needsContinue: needsContinue, - } var url_ *url.URL var requestURI string - if isConnect { - url_ = &url.URL{Host: authority} - requestURI = authority + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority } else { var err error - url_, err = url.ParseRequestURI(path) + url_, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, http2streamError(st.id, http2ErrCodeProtocol) } - requestURI = path + requestURI = rp.path + } + + body := &http2requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, } req := &Request{ - Method: method, + Method: rp.method, URL: url_, RemoteAddr: sc.remoteAddrStr, - Header: header, + Header: rp.header, RequestURI: requestURI, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, TLS: tlsState, - Host: authority, + Host: rp.authority, Body: body, Trailer: trailer, } req = http2requestWithContext(req, st.ctx) - if bodyOpen { - - buf := make([]byte, http2initialWindowSize) - - body.pipe = &http2pipe{ - b: &http2fixedBuffer{buf: buf}, - } - - if vv, ok := header["Content-Length"]; ok { - req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) - } else { - req.ContentLength = -1 - } - } rws := http2responseWriterStatePool.Get().(*http2responseWriterState) bwSave := rws.bw @@ -4282,13 +4556,22 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead return rw, req, nil } -func (sc *http2serverConn) getRequestBodyBuf() []byte { - sc.serveG.check() - if buf := sc.freeRequestBodyBuf; buf != nil { - sc.freeRequestBodyBuf = nil - return buf +var http2reqBodyCache = make(chan []byte, 8) + +func http2getRequestBodyBuf() []byte { + select { + case b := <-http2reqBodyCache: + return b + default: + return make([]byte, http2initialWindowSize) + } +} + +func http2putRequestBodyBuf(b []byte) { + select { + case http2reqBodyCache <- b: + default: } - return make([]byte, http2initialWindowSize) } // Run on its own goroutine. @@ -4298,15 +4581,17 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han rw.rws.stream.cancelCtx() if didPanic { e := recover() - // Same as net/http: - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2handlerPanicRST{rw.rws.stream.id}, stream: rw.rws.stream, }) - sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + + if http2shouldLogPanic(e) { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + } return } rw.handlerDone() @@ -4334,7 +4619,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR errc = http2errChanPool.Get().(chan error) } - if err := sc.writeFrameFromHandler(http2frameWriteMsg{ + if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: headerData, stream: st, done: errc, @@ -4357,7 +4642,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR // called from handler goroutines. func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2write100ContinueHeadersFrame{st.id}, stream: st, }) @@ -4373,11 +4658,19 @@ type http2bodyReadMsg struct { // called from handler goroutines. // Notes that the handler for the given stream ID read n bytes of its body // and schedules flow control tokens to be sent. -func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int) { +func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { sc.serveG.checkNotOn() - select { - case sc.bodyReadCh <- http2bodyReadMsg{st, n}: - case <-sc.doneServing: + if n > 0 { + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } + } + if err == io.EOF { + if buf := st.reqBuf; buf != nil { + st.reqBuf = nil + http2putRequestBodyBuf(buf) + } } } @@ -4419,7 +4712,7 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { if st != nil { streamID = st.id } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, stream: st, }) @@ -4434,16 +4727,19 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { } } +// requestBody is the Handler's Request.Body type. +// Read and Close may be called concurrently. type http2requestBody struct { stream *http2stream conn *http2serverConn - closed bool + closed bool // for use by Close only + sawEOF bool // for use by Read only pipe *http2pipe // non-nil if we have a HTTP entity message body needsContinue bool // need to send a 100-continue } func (b *http2requestBody) Close() error { - if b.pipe != nil { + if b.pipe != nil && !b.closed { b.pipe.BreakWithError(http2errClosedBody) } b.closed = true @@ -4455,13 +4751,17 @@ func (b *http2requestBody) Read(p []byte) (n int, err error) { b.needsContinue = false b.conn.write100ContinueHeaders(b.stream) } - if b.pipe == nil { + if b.pipe == nil || b.sawEOF { return 0, io.EOF } n, err = b.pipe.Read(p) - if n > 0 { - b.conn.noteBodyReadFromHandler(b.stream, n) + if err == io.EOF { + b.sawEOF = true } + if b.conn == nil && http2inTests { + return + } + b.conn.noteBodyReadFromHandler(b.stream, n, err) return } @@ -4696,8 +4996,9 @@ func (w *http2responseWriter) CloseNotify() <-chan bool { if ch == nil { ch = make(chan bool, 1) rws.closeNotifierCh = ch + cw := rws.stream.cw go func() { - rws.stream.cw.Wait() + cw.Wait() ch <- true }() } @@ -4793,6 +5094,172 @@ func (w *http2responseWriter) handlerDone() { http2responseWriterStatePool.Put(rws) } +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +// pushOptions is the internal version of http.PushOptions, which we +// cannot include here because it's only defined in Go 1.8 and later. +type http2pushOptions struct { + Method string + Header Header +} + +func (w *http2responseWriter) push(target string, opts http2pushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + if st.isPushed() { + return http2ErrRecursivePush + } + + if opts.Method == "" { + opts.Method = "GET" + } + if opts.Header == nil { + opts.Header = Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" + } + + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include pseudo header %q", k) + } + + switch strings.ToLower(k) { + case "content-length", "content-encoding", "trailer", "te", "expect", "host": + return fmt.Errorf("promised request headers cannot include %q", k) + } + } + if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err + } + + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + } + + msg := http2startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: http2cloneHeader(opts.Header), + done: http2errChanPool.Get().(chan error), + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case sc.wantStartPushCh <- msg: + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case err := <-msg.done: + http2errChanPool.Put(msg.done) + return err + } +} + +type http2startPushRequest struct { + parent *http2stream + method string + url *url.URL + header Header + done chan error +} + +func (sc *http2serverConn) startPush(msg http2startPushRequest) { + sc.serveG.check() + + if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + + msg.done <- http2errStreamClosed + return + } + + if !sc.pushEnabled { + msg.done <- ErrNotSupported + return + } + + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + if !sc.pushEnabled { + return 0, ErrNotSupported + } + + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, http2ErrPushLimitReached + } + + if sc.maxPushPromiseID+2 >= 1<<31 { + sc.startGracefulShutdown() + return 0, http2ErrPushLimitReached + } + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: http2cloneHeader(msg.header), + }) + if err != nil { + + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(http2FrameWriteRequest{ + write: &http2writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + // foreachHeaderElement splits v according to the "#rule" construction // in RFC 2616 section 2.1 and calls fn for each non-empty element. func http2foreachHeaderElement(v string, fn func(string)) { @@ -4820,16 +5287,16 @@ var http2connHeaders = []string{ "Upgrade", } -// checkValidHTTP2Request checks whether req is a valid HTTP/2 request, +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, // per RFC 7540 Section 8.1.2.2. // The returned error is reported to users. -func http2checkValidHTTP2Request(req *Request) error { - for _, h := range http2connHeaders { - if _, ok := req.Header[h]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", h) +func http2checkValidHTTP2RequestHeaders(h Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) } } - te := req.Header["Te"] + te := h["Te"] if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) } @@ -4877,6 +5344,45 @@ var http2badTrailer = map[string]bool{ "Www-Authenticate": true, } +// h1ServerShutdownChan returns a channel that will be closed when the +// provided *http.Server wants to shut down. +// +// This is a somewhat hacky way to get at http1 innards. It works +// when the http2 code is bundled into the net/http package in the +// standard library. The alternatives ended up making the cmd/go tool +// depend on http Servers. This is the lightest option for now. +// This is tested via the TestServeShutdown* tests in net/http. +func http2h1ServerShutdownChan(hs *Server) <-chan struct{} { + if fn := http2testh1ServerShutdownChan; fn != nil { + return fn(hs) + } + var x interface{} = hs + type I interface { + getDoneChan() <-chan struct{} + } + if hs, ok := x.(I); ok { + return hs.getDoneChan() + } + return nil +} + +// optional test hook for h1ServerShutdownChan. +var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{} + +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func http2h1ServerKeepAlivesDisabled(hs *Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool + } + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false +} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -4997,6 +5503,9 @@ type http2ClientConn struct { readerDone chan struct{} // closed on error readerErr error // set before readerDone is closed + idleTimeout time.Duration // or 0 for never + idleTimer *time.Timer + mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes flow http2flow // our conn-level flow control quota (cs.flow is per stream) @@ -5007,6 +5516,7 @@ type http2ClientConn struct { goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated nextStreamID uint32 + pings map[[8]byte]chan struct{} // in flight ping data to notification channel bw *bufio.Writer br *bufio.Reader fr *http2Framer @@ -5033,6 +5543,7 @@ type http2clientStream struct { ID uint32 resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload + startedWrite bool // started request body write; guarded by cc.mu requestedGzip bool on100 func() // optional code to run if get a 100 continue response @@ -5041,6 +5552,7 @@ type http2clientStream struct { bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu + didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu peerReset chan struct{} // closed on peer reset resetErr error // populated before peerReset is closed @@ -5068,15 +5580,26 @@ func (cs *http2clientStream) awaitRequestCancel(req *Request) { } select { case <-req.Cancel: + cs.cancelStream() cs.bufPipe.CloseWithError(http2errRequestCanceled) - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-ctx.Done(): + cs.cancelStream() cs.bufPipe.CloseWithError(ctx.Err()) - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-cs.done: } } +func (cs *http2clientStream) cancelStream() { + cs.cc.mu.Lock() + didReset := cs.didReset + cs.didReset = true + cs.cc.mu.Unlock() + + if !didReset { + cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } +} + // checkResetOrDone reports any error sent in a RST_STREAM frame by the // server, or errStreamClosed if the stream is complete. func (cs *http2clientStream) checkResetOrDone() error { @@ -5133,14 +5656,22 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. func http2authorityAddr(scheme string, authority string) (addr string) { - if _, _, err := net.SplitHostPort(authority); err == nil { - return authority + host, port, err := net.SplitHostPort(authority) + if err != nil { + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a } - port := "443" - if scheme == "http" { - port = "80" + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port } - return net.JoinHostPort(authority, port) + return net.JoinHostPort(host, port) } // RoundTripOpt is like RoundTrip, but takes options. @@ -5158,8 +5689,10 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res } http2traceGotConn(req, cc) res, err := cc.RoundTrip(req) - if http2shouldRetryRequest(req, err) { - continue + if err != nil { + if req, err = http2shouldRetryRequest(req, err); err == nil { + continue + } } if err != nil { t.vlogf("RoundTrip failure: %v", err) @@ -5181,11 +5714,39 @@ func (t *http2Transport) CloseIdleConnections() { var ( http2errClientConnClosed = errors.New("http2: client conn is closed") http2errClientConnUnusable = errors.New("http2: client conn not usable") + + http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + http2errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written") ) -func http2shouldRetryRequest(req *Request, err error) bool { +// shouldRetryRequest is called by RoundTrip when a request fails to get +// response headers. It is always called with a non-nil error. +// It returns either a request to retry (either the same request, or a +// modified clone), or an error if the request can't be replayed. +func http2shouldRetryRequest(req *Request, err error) (*Request, error) { + switch err { + default: + return nil, err + case http2errClientConnUnusable, http2errClientConnGotGoAway: + return req, nil + case http2errClientConnGotGoAwayAfterSomeReqBody: + + if req.Body == nil || http2reqBodyIsNoBody(req.Body) { + return req, nil + } - return err == http2errClientConnUnusable + getBody := http2reqGetBody(req) + if getBody == nil { + return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error") + } + body, err := getBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = body + return &newReq, nil + } } func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) { @@ -5203,7 +5764,7 @@ func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2Clie func (t *http2Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) if t.TLSClientConfig != nil { - *cfg = *t.TLSClientConfig + *cfg = *http2cloneTLSConfig(t.TLSClientConfig) } if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) @@ -5273,6 +5834,11 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client streams: make(map[uint32]*http2clientStream), singleUse: singleUse, wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), + } + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } if http2VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -5328,6 +5894,15 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { if old != nil && old.ErrCode != http2ErrCodeNo { cc.goAway.ErrCode = old.ErrCode } + last := f.LastStreamID + for streamID, cs := range cc.streams { + if streamID > last { + select { + case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}: + default: + } + } + } } func (cc *http2ClientConn) CanTakeNewRequest() bool { @@ -5345,6 +5920,16 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool { cc.nextStreamID < math.MaxInt32 } +// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// only be called when we're idle, but because we're coming from a new +// goroutine, there could be a new request coming in at the same time, +// so this simply calls the synchronized closeIfIdle to shut down this +// connection. The timer could just call closeIfIdle, but this is more +// clear. +func (cc *http2ClientConn) onIdleTimeout() { + cc.closeIfIdle() +} + func (cc *http2ClientConn) closeIfIdle() { cc.mu.Lock() if len(cc.streams) > 0 { @@ -5437,48 +6022,37 @@ func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { // Certain headers are special-cased as okay but not transmitted later. func http2checkConnHeaders(req *Request) error { if v := req.Header.Get("Upgrade"); v != "" { - return errors.New("http2: invalid Upgrade request header") + return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) } - if v := req.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(req.Header["Transfer-Encoding"]) > 1 { - return errors.New("http2: invalid Transfer-Encoding request header") + if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) } - if v := req.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(req.Header["Connection"]) > 1 { - return errors.New("http2: invalid Connection request header") + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") { + return fmt.Errorf("http2: invalid Connection request header: %q", vv) } return nil } -func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) { - body = req.Body - if body == nil { - return nil, 0 +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func http2actualContentLength(req *Request) int64 { + if req.Body == nil { + return 0 } if req.ContentLength != 0 { - return req.Body, req.ContentLength - } - - // We have a body but a zero content length. Test to see if - // it's actually zero or just unset. - var buf [1]byte - n, rerr := body.Read(buf[:]) - if rerr != nil && rerr != io.EOF { - return http2errorReader{rerr}, -1 - } - if n == 1 { - - if rerr == io.EOF { - return bytes.NewReader(buf[:]), 1 - } - return io.MultiReader(bytes.NewReader(buf[:]), body), -1 + return req.ContentLength } - - return nil, 0 + return -1 } func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { if err := http2checkConnHeaders(req); err != nil { return nil, err } + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } trailers, err := http2commaSeparatedTrailers(req) if err != nil { @@ -5486,9 +6060,6 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } hasTrailers := trailers != "" - body, contentLen := http2bodyAndLength(req) - hasBody := body != nil - cc.mu.Lock() cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { @@ -5496,6 +6067,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { return nil, http2errClientConnUnusable } + body := req.Body + hasBody := body != nil + contentLen := http2actualContentLength(req) + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? var requestedGzip bool if !cc.t.disableCompression() && @@ -5561,6 +6136,13 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cs.abortRequestBodyWrite(http2errStopReqBodyWrite) } if re.err != nil { + if re.err == http2errClientConnGotGoAway { + cc.mu.Lock() + if cs.startedWrite { + re.err = http2errClientConnGotGoAwayAfterSomeReqBody + } + cc.mu.Unlock() + } cc.forgetStreamID(cs.ID) return nil, re.err } @@ -5806,6 +6388,26 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail if host == "" { host = req.URL.Host } + host, err := httplex.PunycodeHostPort(host) + if err != nil { + return nil, err + } + + var path string + if req.Method != "CONNECT" { + path = req.URL.RequestURI() + if !http2validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !http2validPseudoPath(path) { + if req.URL.Opaque != "" { + return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return nil, fmt.Errorf("invalid request :path %q", orig) + } + } + } + } for k, vv := range req.Header { if !httplex.ValidHeaderFieldName(k) { @@ -5821,8 +6423,8 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail cc.writeHeader(":authority", host) cc.writeHeader(":method", req.Method) if req.Method != "CONNECT" { - cc.writeHeader(":path", req.URL.RequestURI()) - cc.writeHeader(":scheme", "https") + cc.writeHeader(":path", path) + cc.writeHeader(":scheme", req.URL.Scheme) } if trailers != "" { cc.writeHeader("trailer", trailers) @@ -5940,6 +6542,9 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr if andRemove && cs != nil && !cc.closed { cc.lastActive = time.Now() delete(cc.streams, id) + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + } close(cs.done) cc.cond.Broadcast() } @@ -5996,6 +6601,10 @@ func (rl *http2clientConnReadLoop) cleanup() { defer cc.t.connPool().MarkDead(cc) defer close(cc.readerDone) + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + err := cc.readerErr cc.mu.Lock() if cc.goAway != nil && http2isEOFOrNetReadError(err) { @@ -6398,9 +7007,10 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc.bw.Flush() cc.wmu.Unlock() } + didReset := cs.didReset cc.mu.Unlock() - if len(data) > 0 { + if len(data) > 0 && !didReset { if _, err := cs.bufPipe.Write(data); err != nil { rl.endStreamError(cs, err) return err @@ -6551,9 +7161,56 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er return nil } +// Ping sends a PING frame to the server and waits for the ack. +// Public implementation is in go17.go and not_go17.go +func (cc *http2ClientConn) ping(ctx http2contextContext) error { + c := make(chan struct{}) + // Generate a random payload + var p [8]byte + for { + if _, err := rand.Read(p[:]); err != nil { + return err + } + cc.mu.Lock() + + if _, found := cc.pings[p]; !found { + cc.pings[p] = c + cc.mu.Unlock() + break + } + cc.mu.Unlock() + } + cc.wmu.Lock() + if err := cc.fr.WritePing(false, p); err != nil { + cc.wmu.Unlock() + return err + } + if err := cc.bw.Flush(); err != nil { + cc.wmu.Unlock() + return err + } + cc.wmu.Unlock() + select { + case <-c: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cc.readerDone: + + return cc.readerErr + } +} + func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { if f.IsAck() { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + if c, ok := cc.pings[f.Data]; ok { + close(c) + delete(cc.pings, f.Data) + } return nil } cc := rl.cc @@ -6666,6 +7323,9 @@ func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reade resc := make(chan error, 1) s.resc = resc s.fn = func() { + cs.cc.mu.Lock() + cs.startedWrite = true + cs.cc.mu.Unlock() resc <- cs.writeRequestBody(body, cs.req.Body) } s.delay = t.expectContinueTimeout() @@ -6728,6 +7388,11 @@ func http2isConnectionCloseRequest(req *Request) bool { // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool } // writeContext is the interface needed by the various frame writer @@ -6749,9 +7414,10 @@ type http2writeContext interface { HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) } -// endsStream reports whether the given frame writer w will locally -// close the stream. -func http2endsStream(w http2writeFramer) bool { +// writeEndsStream reports whether w writes a frame that will transition +// the stream to a half-closed local state. This returns false for RST_STREAM, +// which closes the entire stream (not just the local half). +func http2writeEndsStream(w http2writeFramer) bool { switch v := w.(type) { case *http2writeData: return v.endStream @@ -6759,7 +7425,7 @@ func http2endsStream(w http2writeFramer) bool { return v.endStream case nil: - panic("endsStream called on nil writeFramer") + panic("writeEndsStream called on nil writeFramer") } return false } @@ -6770,8 +7436,16 @@ func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { return ctx.Flush() } +func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } + type http2writeSettings []http2Setting +func (s http2writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return http2frameHeaderLen+settingSize*len(s) <= max + +} + func (s http2writeSettings) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettings([]http2Setting(s)...) } @@ -6791,6 +7465,8 @@ func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { return err } +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } + type http2writeData struct { streamID uint32 p []byte @@ -6805,6 +7481,10 @@ func (w *http2writeData) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) } +func (w *http2writeData) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.p) <= max +} + // handlerPanicRST is the message sent from handler goroutines when // the handler panics. type http2handlerPanicRST struct { @@ -6815,22 +7495,59 @@ func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) } +func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (se http2StreamError) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) } +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + type http2writePingAck struct{ pf *http2PingFrame } func (w http2writePingAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WritePing(true, w.pf.Data) } +func (w http2writePingAck) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.pf.Data) <= max +} + type http2writeSettingsAck struct{} func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettingsAck() } +func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + // writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames // for HTTP response headers or trailers from a server handler. type http2writeResHeaders struct { @@ -6852,6 +7569,11 @@ func http2encKV(enc *hpack.Encoder, k, v string) { enc.WriteField(hpack.HeaderField{Name: k, Value: v}) } +func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { + + return false +} + func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() @@ -6877,39 +7599,69 @@ func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { panic("unexpected empty hpack") } - // For now we're lazy and just pick the minimum MAX_FRAME_SIZE - // that all peers must support (16KB). Later we could care - // more and send larger frames if the peer advertised it, but - // there's little point. Most headers are small anyway (so we - // generally won't have CONTINUATION frames), and extra frames - // only waste 9 bytes anyway. - const maxFrameSize = 16384 + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} - first := true - for len(headerBlock) > 0 { - frag := headerBlock - if len(frag) > maxFrameSize { - frag = frag[:maxFrameSize] - } - headerBlock = headerBlock[len(frag):] - endHeaders := len(headerBlock) == 0 - var err error - if first { - first = false - err = ctx.Framer().WriteHeaders(http2HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: frag, - EndStream: w.endStream, - EndHeaders: endHeaders, - }) - } else { - err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) - } - if err != nil { - return err - } +func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type http2writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *http2writePushPromise) staysWithinBuffer(max int) bool { + + return false +} + +func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + http2encKV(enc, ":method", w.method) + http2encKV(enc, ":scheme", w.url.Scheme) + http2encKV(enc, ":authority", w.url.Host) + http2encKV(enc, ":path", w.url.RequestURI()) + http2encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } - return nil } type http2write100ContinueHeadersFrame struct { @@ -6928,15 +7680,24 @@ func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) err }) } +func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + + return 9+2*(len(":status")+len("100")) <= max +} + type http2writeWindowUpdate struct { streamID uint32 // or 0 for conn-level n uint32 } +func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) } +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only only if k is in keys. func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { if keys == nil { sorter := http2sorterPool.Get().(*http2sorter) @@ -6966,14 +7727,53 @@ func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { } } -// frameWriteMsg is a request to write a frame. -type http2frameWriteMsg struct { +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type http2WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options http2OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority http2PriorityParam) + + // Push queues a frame in the scheduler. In most cases, this will not be + // called with wr.StreamID()!=0 unless that stream is currently open. The one + // exception is RST_STREAM frames, which may be sent on idle or closed streams. + Push(wr http2FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd. + Pop() (wr http2FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type http2OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type http2FrameWriteRequest struct { // write is the interface value that does the writing, once the - // writeScheduler (below) has decided to select this frame - // to write. The write functions are all defined in write.go. + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. write http2writeFramer - stream *http2stream // used for prioritization. nil for non-stream frames. + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + stream *http2stream // done, if non-nil, must be a buffered channel with space for // 1 message and is sent the return value from write (or an @@ -6981,247 +7781,644 @@ type http2frameWriteMsg struct { done chan error } -// for debugging only: -func (wm http2frameWriteMsg) String() string { - var streamID uint32 - if wm.stream != nil { - streamID = wm.stream.id +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr http2FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + if se, ok := wr.write.(http2StreamError); ok { + + return se.StreamID + } + return 0 + } + return wr.stream.id +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr http2FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*http2writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { + var empty http2FrameWriteRequest + + wd, ok := wr.write.(*http2writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + + endStream: false, + }, + + done: nil, + } + rest := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, + } + return consumed, rest, 2 } + + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 +} + +// String is for debugging only. +func (wr http2FrameWriteRequest) String() string { var des string - if s, ok := wm.write.(fmt.Stringer); ok { + if s, ok := wr.write.(fmt.Stringer); ok { des = s.String() } else { - des = fmt.Sprintf("%T", wm.write) + des = fmt.Sprintf("%T", wr.write) } - return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des) + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) } -// writeScheduler tracks pending frames to write, priorities, and decides -// the next one to use. It is not thread-safe. -type http2writeScheduler struct { - // zero are frames not associated with a specific stream. - // They're sent before any stream-specific freams. - zero http2writeQueue +// replyToWriter sends err to wr.done and panics if the send must block +// This does nothing if wr.done is nil. +func (wr *http2FrameWriteRequest) replyToWriter(err error) { + if wr.done == nil { + return + } + select { + case wr.done <- err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) + } + wr.write = nil +} - // maxFrameSize is the maximum size of a DATA frame - // we'll write. Must be non-zero and between 16K-16M. - maxFrameSize uint32 +// writeQueue is used by implementations of WriteScheduler. +type http2writeQueue struct { + s []http2FrameWriteRequest +} - // sq contains the stream-specific queues, keyed by stream ID. - // when a stream is idle, it's deleted from the map. - sq map[uint32]*http2writeQueue +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } - // canSend is a slice of memory that's reused between frame - // scheduling decisions to hold the list of writeQueues (from sq) - // which have enough flow control data to send. After canSend is - // built, the best is selected. - canSend []*http2writeQueue +func (q *http2writeQueue) push(wr http2FrameWriteRequest) { + q.s = append(q.s, wr) +} - // pool of empty queues for reuse. - queuePool []*http2writeQueue +func (q *http2writeQueue) shift() http2FrameWriteRequest { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wr := q.s[0] + + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr } -func (ws *http2writeScheduler) putEmptyQueue(q *http2writeQueue) { - if len(q.s) != 0 { - panic("queue must be empty") +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { + if len(q.s) == 0 { + return http2FrameWriteRequest{}, false } - ws.queuePool = append(ws.queuePool, q) + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return http2FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true } -func (ws *http2writeScheduler) getEmptyQueue() *http2writeQueue { - ln := len(ws.queuePool) +type http2writeQueuePool []*http2writeQueue + +// put inserts an unused writeQueue into the pool. +func (p *http2writeQueuePool) put(q *http2writeQueue) { + for i := range q.s { + q.s[i] = http2FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) +} + +// get returns an empty writeQueue. +func (p *http2writeQueuePool) get() *http2writeQueue { + ln := len(*p) if ln == 0 { return new(http2writeQueue) } - q := ws.queuePool[ln-1] - ws.queuePool = ws.queuePool[:ln-1] + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] return q } -func (ws *http2writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 } +// RFC 7540, Section 5.3.5: the default weight is 16. +const http2priorityDefaultWeight = 15 // 16 = 15 + 1 -func (ws *http2writeScheduler) add(wm http2frameWriteMsg) { - st := wm.stream - if st == nil { - ws.zero.push(wm) +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type http2PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7340 Section 5.3. +// If cfg is nil, default options are used. +func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { + if cfg == nil { + + cfg = &http2PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &http2priorityWriteScheduler{ + nodes: make(map[uint32]*http2priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 } else { - ws.streamQueue(st.id).push(wm) + ws.writeThrottleLimit = math.MaxInt32 } + return ws +} + +type http2priorityNodeState int + +const ( + http2priorityNodeOpen http2priorityNodeState = iota + http2priorityNodeClosed + http2priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type http2priorityNode struct { + q http2writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state http2priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *http2priorityNode + kids *http2priorityNode // start of the kids list + prev, next *http2priorityNode // doubly-linked list of siblings } -func (ws *http2writeScheduler) streamQueue(streamID uint32) *http2writeQueue { - if q, ok := ws.sq[streamID]; ok { - return q +func (n *http2priorityNode) setParent(parent *http2priorityNode) { + if n == parent { + panic("setParent to self") } - if ws.sq == nil { - ws.sq = make(map[uint32]*http2writeQueue) + if n.parent == parent { + return + } + + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n } - q := ws.getEmptyQueue() - ws.sq[streamID] = q - return q } -// take returns the most important frame to write and removes it from the scheduler. -// It is illegal to call this if the scheduler is empty or if there are no connection-level -// flow control bytes available. -func (ws *http2writeScheduler) take() (wm http2frameWriteMsg, ok bool) { - if ws.maxFrameSize == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") +func (n *http2priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b } +} - if !ws.zero.empty() { - return ws.zero.shift(), true +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this funcion returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true } - if len(ws.sq) == 0 { - return + if n.kids == nil { + return false + } + + if n.id != 0 { + openParent = openParent || (n.state == http2priorityNodeOpen) } - for id, q := range ws.sq { - if q.firstIsNoCost() { - return ws.takeFrom(id, q) + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break } } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false + } - if len(ws.canSend) != 0 { - panic("should be empty") + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) } - for _, q := range ws.sq { - if n := ws.streamWritableBytes(q); n > 0 { - ws.canSend = append(ws.canSend, q) + sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true } } - if len(ws.canSend) == 0 { - return + return false +} + +type http2sortPriorityNodeSiblings []*http2priorityNode + +func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } + +func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } + +func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { + + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk } - defer ws.zeroCanSend() + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} - q := ws.canSend[0] +type http2priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root http2priorityNode - return ws.takeFrom(q.streamID(), q) + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*http2priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 + + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*http2priorityNode + + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*http2priorityNode + + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// zeroCanSend is defered from take. -func (ws *http2writeScheduler) zeroCanSend() { - for i := range ws.canSend { - ws.canSend[i] = nil +func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != http2priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = http2priorityNodeOpen + return + } + + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID } - ws.canSend = ws.canSend[:0] } -// streamWritableBytes returns the number of DATA bytes we could write -// from the given queue's stream, if this stream/queue were -// selected. It is an error to call this if q's head isn't a -// *writeData. -func (ws *http2writeScheduler) streamWritableBytes(q *http2writeQueue) int32 { - wm := q.head() - ret := wm.stream.flow.available() - if ret == 0 { - return 0 +func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") } - if int32(ws.maxFrameSize) < ret { - ret = int32(ws.maxFrameSize) + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) } - if ret == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") + if ws.nodes[streamID].state != http2priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) } - wd := wm.write.(*http2writeData) - if len(wd.p) < int(ret) { - ret = int32(len(wd.p)) + + n := ws.nodes[streamID] + n.state = http2priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) } - return ret } -func (ws *http2writeScheduler) takeFrom(id uint32, q *http2writeQueue) (wm http2frameWriteMsg, ok bool) { - wm = q.head() - - if wd, ok := wm.write.(*http2writeData); ok && len(wd.p) > 0 { - allowed := wm.stream.flow.available() - if allowed == 0 { +func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } - return http2frameWriteMsg{}, false + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return } - if int32(ws.maxFrameSize) < allowed { - allowed = int32(ws.maxFrameSize) + ws.maxID = streamID + n = &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeIdle, } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } - if len(wd.p) > int(allowed) { - wm.stream.flow.take(allowed) - chunk := wd.p[:allowed] - wd.p = wd.p[allowed:] + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = http2priorityDefaultWeight + return + } - return http2frameWriteMsg{ - stream: wm.stream, - write: &http2writeData{ - streamID: wd.streamID, - p: chunk, + if n == parent { + return + } - endStream: false, - }, + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } - done: nil, - }, true + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next } - wm.stream.flow.take(int32(len(wd.p))) } - q.shift() - if q.empty() { - ws.putEmptyQueue(q) - delete(ws.sq, id) + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { + var n *http2priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + n = &ws.root + } } - return wm, true + n.q.push(wr) } -func (ws *http2writeScheduler) forgetStream(id uint32) { - q, ok := ws.sq[id] - if !ok { +func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { + if maxSize == 0 { return } - delete(ws.sq, id) + if len(*list) == maxSize { - for i := range q.s { - q.s[i] = http2frameWriteMsg{} + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] } - q.s = q.s[:0] - ws.putEmptyQueue(q) + *list = append(*list, n) } -type http2writeQueue struct { - s []http2frameWriteMsg +func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) } -// streamID returns the stream ID for a non-empty stream-specific queue. -func (q *http2writeQueue) streamID() uint32 { return q.s[0].stream.id } +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func http2NewRandomWriteScheduler() http2WriteScheduler { + return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +} -func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } +type http2randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero http2writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle or closed, it's deleted from the map. + sq map[uint32]*http2writeQueue -func (q *http2writeQueue) push(wm http2frameWriteMsg) { - q.s = append(q.s, wm) + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// head returns the next item that would be removed by shift. -func (q *http2writeQueue) head() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") - } - return q.s[0] +func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + } -func (q *http2writeQueue) shift() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") +func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return } - wm := q.s[0] + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = http2frameWriteMsg{} - q.s = q.s[:len(q.s)-1] - return wm } -func (q *http2writeQueue) firstIsNoCost() bool { - if df, ok := q.s[0].write.(*http2writeData); ok { - return len(df.p) == 0 +func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { + id := wr.StreamID() + if id == 0 { + ws.zero.push(wr) + return } - return true + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + + if !ws.zero.empty() { + return ws.zero.shift(), true + } + + for _, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + return wr, true + } + } + return http2FrameWriteRequest{}, false } |