diff --git a/curl.go b/curl.go index 291d38c..80aa9ba 100644 --- a/curl.go +++ b/curl.go @@ -44,3 +44,10 @@ func scheme(r *http.Request) string { func requestURL(r *http.Request) string { return fmt.Sprintf("%s://%s%s", scheme(r), r.Host, r.URL) } + +func bytesRead(r *countingReader) int64 { + if r == nil { + return 0 + } + return r.bytesRead +} diff --git a/middleware.go b/middleware.go index a31f3d8..2abc123 100644 --- a/middleware.go +++ b/middleware.go @@ -41,8 +41,16 @@ func RequestLogger(logger *slog.Logger, o *Options) func(http.Handler) http.Hand logReqBody := o.LogRequestBody != nil && o.LogRequestBody(r) logRespBody := o.LogResponseBody != nil && o.LogResponseBody(r) + hasReqBody := r.Body != nil && r.Body != http.NoBody + var bodyReader *countingReader + if hasReqBody { + bodyReader = &countingReader{reader: r.Body} + r.Body = bodyReader + } + + consumeBody := hasReqBody && (logReqBody || o.LogExtraAttrs != nil) var reqBody bytes.Buffer - if logReqBody || o.LogExtraAttrs != nil { + if consumeBody { r.Body = io.NopCloser(io.TeeReader(r.Body, &reqBody)) } @@ -130,6 +138,7 @@ func RequestLogger(logger *slog.Logger, o *Options) func(http.Handler) http.Hand slog.String(s.RequestProto, r.Proto), slog.Any(s.RequestHeaders, slog.GroupValue(getHeaderAttrs(r.Header, o.LogRequestHeaders)...)), slog.Int64(s.RequestBytes, r.ContentLength), + slog.Int64(s.RequestBytesRead, bytesRead(bodyReader)), slog.String(s.RequestUserAgent, r.UserAgent()), slog.String(s.RequestReferer, r.Referer()), slog.Any(s.ResponseHeaders, slog.GroupValue(getHeaderAttrs(ww.Header(), o.LogResponseHeaders)...)), @@ -142,7 +151,7 @@ func RequestLogger(logger *slog.Logger, o *Options) func(http.Handler) http.Hand logAttrs = appendAttrs(logAttrs, slog.Any(ErrorKey, ErrClientAborted), slog.String(s.ErrorType, "ClientAborted")) } - if logReqBody || o.LogExtraAttrs != nil { + if consumeBody { // Ensure the request body is fully read if the underlying HTTP handler didn't do so. n, _ := io.Copy(io.Discard, r.Body) if n > 0 { @@ -231,3 +240,18 @@ func logBody(body *bytes.Buffer, header http.Header, o *Options) string { } return fmt.Sprintf("[body redacted for Content-Type: %s]", contentType) } + +type countingReader struct { + reader io.ReadCloser + bytesRead int64 +} + +func (cr *countingReader) Read(p []byte) (int, error) { + n, err := cr.reader.Read(p) + cr.bytesRead += int64(n) + return n, err +} + +func (cr *countingReader) Close() error { + return cr.reader.Close() +} diff --git a/schema.go b/schema.go index ae33cb5..66c1b59 100644 --- a/schema.go +++ b/schema.go @@ -37,6 +37,7 @@ type Schema struct { RequestHeaders string // Selected request headers RequestBody string // Request body content, if logged. RequestBytes string // Size of request body in bytes + RequestBytesRead string // Read bytes in request body RequestBytesUnread string // Unread bytes in request body RequestUserAgent string // User-Agent header value RequestReferer string // Referer header value @@ -78,6 +79,7 @@ var ( RequestHeaders: "http.request.headers", RequestBody: "http.request.body.content", RequestBytes: "http.request.body.bytes", + RequestBytesRead: "http.request.body.read.bytes", RequestBytesUnread: "http.request.body.unread.bytes", RequestUserAgent: "user_agent.original", RequestReferer: "http.request.referrer", @@ -112,6 +114,7 @@ var ( RequestHeaders: "http.request.header", RequestBody: "http.request.body.content", RequestBytes: "http.request.body.size", + RequestBytesRead: "http.request.body.read.size", RequestBytesUnread: "http.request.body.unread.size", RequestUserAgent: "user_agent.original", RequestReferer: "http.request.header.referer", @@ -148,6 +151,7 @@ var ( RequestHeaders: "httpRequest:requestHeaders", RequestBody: "httpRequest:requestBody", RequestBytes: "httpRequest:requestSize", + RequestBytesRead: "httpRequest:requestReadSize", RequestBytesUnread: "httpRequest:requestUnreadSize", RequestUserAgent: "httpRequest:userAgent", RequestReferer: "httpRequest:referer",