@@ -63,6 +63,9 @@ func createApp(opts ...OptionFunc) *HTTPBin {
6363 DripDelay : 0 ,
6464 DripDuration : 100 * time .Millisecond ,
6565 DripNumBytes : 10 ,
66+ SSECount : 10 ,
67+ SSEDelay : 0 ,
68+ SSEDuration : 100 * time .Millisecond ,
6669 }),
6770 WithMaxBodySize (maxBodySize ),
6871 WithMaxDuration (maxDuration ),
@@ -2957,6 +2960,246 @@ func TestHostname(t *testing.T) {
29572960 })
29582961}
29592962
2963+ func TestSSE (t * testing.T ) {
2964+ t .Parallel ()
2965+
2966+ parseServerSentEvent := func (t * testing.T , buf * bufio.Reader ) (serverSentEvent , error ) {
2967+ t .Helper ()
2968+
2969+ // match "event: ping" line
2970+ eventLine , err := buf .ReadBytes ('\n' )
2971+ if err != nil {
2972+ return serverSentEvent {}, err
2973+ }
2974+ _ , eventType , _ := bytes .Cut (eventLine , []byte (":" ))
2975+ assert .Equal (t , string (bytes .TrimSpace (eventType )), "ping" , "unexpected event type" )
2976+
2977+ // match "data: {...}" line
2978+ dataLine , err := buf .ReadBytes ('\n' )
2979+ if err != nil {
2980+ return serverSentEvent {}, err
2981+ }
2982+ _ , data , _ := bytes .Cut (dataLine , []byte (":" ))
2983+ var event serverSentEvent
2984+ assert .NilError (t , json .Unmarshal (data , & event ))
2985+
2986+ // match newline after event data
2987+ b , err := buf .ReadByte ()
2988+ if err != nil && err != io .EOF {
2989+ assert .NilError (t , err )
2990+ }
2991+ if b != '\n' {
2992+ t .Fatalf ("expected newline after event data, got %q" , b )
2993+ }
2994+
2995+ return event , nil
2996+ }
2997+
2998+ parseServerSentEventStream := func (t * testing.T , resp * http.Response ) []serverSentEvent {
2999+ t .Helper ()
3000+ buf := bufio .NewReader (resp .Body )
3001+ var events []serverSentEvent
3002+ for {
3003+ event , err := parseServerSentEvent (t , buf )
3004+ if err == io .EOF {
3005+ break
3006+ }
3007+ assert .NilError (t , err )
3008+ events = append (events , event )
3009+ }
3010+ return events
3011+ }
3012+
3013+ okTests := []struct {
3014+ params * url.Values
3015+ duration time.Duration
3016+ count int
3017+ }{
3018+ // there are useful defaults for all values
3019+ {& url.Values {}, 0 , 10 },
3020+
3021+ // go-style durations are accepted
3022+ {& url.Values {"duration" : {"5ms" }}, 5 * time .Millisecond , 10 },
3023+ {& url.Values {"duration" : {"10ns" }}, 0 , 10 },
3024+ {& url.Values {"delay" : {"5ms" }}, 5 * time .Millisecond , 10 },
3025+ {& url.Values {"delay" : {"0h" }}, 0 , 10 },
3026+
3027+ // or floating point seconds
3028+ {& url.Values {"duration" : {"0.25" }}, 250 * time .Millisecond , 10 },
3029+ {& url.Values {"duration" : {"1" }}, 1 * time .Second , 10 },
3030+ {& url.Values {"delay" : {"0.25" }}, 250 * time .Millisecond , 10 },
3031+ {& url.Values {"delay" : {"0" }}, 0 , 10 },
3032+
3033+ {& url.Values {"count" : {"1" }}, 0 , 1 },
3034+ {& url.Values {"count" : {"011" }}, 0 , 11 },
3035+ {& url.Values {"count" : {fmt .Sprintf ("%d" , app .maxSSECount )}}, 0 , int (app .maxSSECount )},
3036+
3037+ {& url.Values {"duration" : {"250ms" }, "delay" : {"250ms" }}, 500 * time .Millisecond , 10 },
3038+ {& url.Values {"duration" : {"250ms" }, "delay" : {"0.25s" }}, 500 * time .Millisecond , 10 },
3039+ }
3040+ for _ , test := range okTests {
3041+ test := test
3042+ t .Run (fmt .Sprintf ("ok/%s" , test .params .Encode ()), func (t * testing.T ) {
3043+ t .Parallel ()
3044+
3045+ url := "/sse?" + test .params .Encode ()
3046+
3047+ start := time .Now ()
3048+ req := newTestRequest (t , "GET" , url )
3049+ resp := must .DoReq (t , client , req )
3050+ assert .StatusCode (t , resp , http .StatusOK )
3051+ events := parseServerSentEventStream (t , resp )
3052+
3053+ if elapsed := time .Since (start ); elapsed < test .duration {
3054+ t .Fatalf ("expected minimum duration of %s, request took %s" , test .duration , elapsed )
3055+ }
3056+ assert .ContentType (t , resp , sseContentType )
3057+ assert .DeepEqual (t , resp .TransferEncoding , []string {"chunked" }, "unexpected Transfer-Encoding header" )
3058+ assert .Equal (t , len (events ), test .count , "unexpected number of events" )
3059+ })
3060+ }
3061+
3062+ badTests := []struct {
3063+ params * url.Values
3064+ code int
3065+ }{
3066+ {& url.Values {"duration" : {"0" }}, http .StatusBadRequest },
3067+ {& url.Values {"duration" : {"0s" }}, http .StatusBadRequest },
3068+ {& url.Values {"duration" : {"1m" }}, http .StatusBadRequest },
3069+ {& url.Values {"duration" : {"-1ms" }}, http .StatusBadRequest },
3070+ {& url.Values {"duration" : {"1001" }}, http .StatusBadRequest },
3071+ {& url.Values {"duration" : {"-1" }}, http .StatusBadRequest },
3072+ {& url.Values {"duration" : {"foo" }}, http .StatusBadRequest },
3073+
3074+ {& url.Values {"delay" : {"1m" }}, http .StatusBadRequest },
3075+ {& url.Values {"delay" : {"-1ms" }}, http .StatusBadRequest },
3076+ {& url.Values {"delay" : {"1001" }}, http .StatusBadRequest },
3077+ {& url.Values {"delay" : {"-1" }}, http .StatusBadRequest },
3078+ {& url.Values {"delay" : {"foo" }}, http .StatusBadRequest },
3079+
3080+ {& url.Values {"count" : {"foo" }}, http .StatusBadRequest },
3081+ {& url.Values {"count" : {"0" }}, http .StatusBadRequest },
3082+ {& url.Values {"count" : {"-1" }}, http .StatusBadRequest },
3083+ {& url.Values {"count" : {"0xff" }}, http .StatusBadRequest },
3084+ {& url.Values {"count" : {fmt .Sprintf ("%d" , app .maxSSECount + 1 )}}, http .StatusBadRequest },
3085+
3086+ // request would take too long
3087+ {& url.Values {"duration" : {"750ms" }, "delay" : {"500ms" }}, http .StatusBadRequest },
3088+ }
3089+ for _ , test := range badTests {
3090+ test := test
3091+ t .Run (fmt .Sprintf ("bad/%s" , test .params .Encode ()), func (t * testing.T ) {
3092+ t .Parallel ()
3093+ url := "/sse?" + test .params .Encode ()
3094+ req := newTestRequest (t , "GET" , url )
3095+ resp := must .DoReq (t , client , req )
3096+ defer consumeAndCloseBody (resp )
3097+ assert .StatusCode (t , resp , test .code )
3098+ })
3099+ }
3100+
3101+ t .Run ("writes are actually incremmental" , func (t * testing.T ) {
3102+ t .Parallel ()
3103+
3104+ var (
3105+ duration = 100 * time .Millisecond
3106+ count = 3
3107+ endpoint = fmt .Sprintf ("/sse?duration=%s&count=%d" , duration , count )
3108+
3109+ // Match server logic for calculating the delay between writes
3110+ wantPauseBetweenWrites = duration / time .Duration (count - 1 )
3111+ )
3112+
3113+ req := newTestRequest (t , "GET" , endpoint )
3114+ resp := must .DoReq (t , client , req )
3115+ buf := bufio .NewReader (resp .Body )
3116+ eventCount := 0
3117+
3118+ // Here we read from the response one byte at a time, and ensure that
3119+ // at least the expected delay occurs for each read.
3120+ //
3121+ // The request above includes an initial delay equal to the expected
3122+ // wait between writes so that even the first iteration of this loop
3123+ // expects to wait the same amount of time for a read.
3124+ for i := 0 ; ; i ++ {
3125+ start := time .Now ()
3126+ event , err := parseServerSentEvent (t , buf )
3127+ if err == io .EOF {
3128+ break
3129+ }
3130+ assert .NilError (t , err )
3131+ gotPause := time .Since (start )
3132+
3133+ // We expect to read exactly one byte on each iteration. On the
3134+ // last iteration, we expct to hit EOF after reading the final
3135+ // byte, because the server does not pause after the last write.
3136+ assert .Equal (t , event .ID , i , "unexpected SSE event ID" )
3137+
3138+ // only ensure that we pause for the expected time between writes
3139+ // (allowing for minor mismatch in local timers and server timers)
3140+ // after the first byte.
3141+ if i > 0 {
3142+ assert .RoughDuration (t , gotPause , wantPauseBetweenWrites , 3 * time .Millisecond )
3143+ }
3144+
3145+ eventCount ++
3146+ }
3147+
3148+ assert .Equal (t , eventCount , count , "unexpected number of events" )
3149+ })
3150+
3151+ t .Run ("handle cancelation during initial delay" , func (t * testing.T ) {
3152+ t .Parallel ()
3153+
3154+ // For this test, we expect the client to time out and cancel the
3155+ // request after 10ms. The handler should still be in its intitial
3156+ // delay period, so this will result in a request error since no status
3157+ // code will be written before the cancelation.
3158+ ctx , cancel := context .WithTimeout (context .Background (), 25 * time .Millisecond )
3159+ defer cancel ()
3160+
3161+ req := newTestRequest (t , "GET" , "/sse?duration=500ms&delay=500ms" ).WithContext (ctx )
3162+ if _ , err := client .Do (req ); ! os .IsTimeout (err ) {
3163+ t .Fatalf ("expected timeout error, got %s" , err )
3164+ }
3165+ })
3166+
3167+ t .Run ("handle cancelation during stream" , func (t * testing.T ) {
3168+ t .Parallel ()
3169+
3170+ ctx , cancel := context .WithTimeout (context .Background (), 100 * time .Millisecond )
3171+ defer cancel ()
3172+
3173+ req := newTestRequest (t , "GET" , "/sse?duration=900ms&delay=0&count=2" ).WithContext (ctx )
3174+ resp := must .DoReq (t , client , req )
3175+ defer consumeAndCloseBody (resp )
3176+
3177+ // In this test, the server should have started an OK response before
3178+ // our client timeout cancels the request, so we should get an OK here.
3179+ assert .StatusCode (t , resp , http .StatusOK )
3180+
3181+ // But, we should time out while trying to read the whole response
3182+ // body.
3183+ body , err := io .ReadAll (resp .Body )
3184+ if ! os .IsTimeout (err ) {
3185+ t .Fatalf ("expected timeout reading body, got %s" , err )
3186+ }
3187+
3188+ // partial read should include the first whole event
3189+ event , err := parseServerSentEvent (t , bufio .NewReader (bytes .NewReader (body )))
3190+ assert .NilError (t , err )
3191+ assert .Equal (t , event .ID , 0 , "unexpected SSE event ID" )
3192+ })
3193+
3194+ t .Run ("ensure HEAD request works with streaming responses" , func (t * testing.T ) {
3195+ t .Parallel ()
3196+ req := newTestRequest (t , "HEAD" , "/sse?duration=900ms&delay=100ms" )
3197+ resp := must .DoReq (t , client , req )
3198+ assert .StatusCode (t , resp , http .StatusOK )
3199+ assert .BodySize (t , resp , 0 )
3200+ })
3201+ }
3202+
29603203func TestWebSocketEcho (t * testing.T ) {
29613204 // ========================================================================
29623205 // Note: Here we only test input validation for the websocket endpoint.
@@ -3028,6 +3271,7 @@ func TestWebSocketEcho(t *testing.T) {
30283271 })
30293272 }
30303273}
3274+
30313275func newTestServer (handler http.Handler ) (* httptest.Server , * http.Client ) {
30323276 srv := httptest .NewServer (handler )
30333277 client := srv .Client ()
0 commit comments