@@ -3,6 +3,7 @@ mod response;
33mod gvl_helpers;
44mod grpc;
55
6+ use hyper_util:: server:: graceful:: GracefulShutdown ;
67use request:: { Request , GrpcRequest } ;
78use response:: { Response , GrpcResponse } ;
89use gvl_helpers:: nogvl;
@@ -11,11 +12,12 @@ use magnus::block::block_proc;
1112use magnus:: typed_data:: Obj ;
1213use magnus:: { function, method, prelude:: * , Error as MagnusError , IntoValue , Ruby , Value , RString } ;
1314use bytes:: Bytes ;
15+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
1416
1517use std:: cell:: RefCell ;
1618use std:: net:: SocketAddr ;
1719
18- use tokio:: net:: UnixListener ;
20+ use tokio:: net:: { TcpListener , UnixListener } ;
1921
2022use std:: sync:: Arc ;
2123use tokio:: sync:: { Mutex , oneshot} ;
@@ -31,18 +33,45 @@ use http_body_util::BodyExt;
3133
3234use jemallocator:: Jemalloc ;
3335
34- use log:: { debug, info, warn} ;
36+ use log:: { debug, info, warn, error } ;
3537
3638use env_logger;
3739use crate :: response:: BodyWithTrailers ;
3840use std:: sync:: Once ;
3941use tokio:: time:: timeout;
4042
43+ use std:: io;
44+
45+ use tokio:: sync:: broadcast;
46+
4147static LOGGER_INIT : Once = Once :: new ( ) ;
4248
4349#[ global_allocator]
4450static GLOBAL : Jemalloc = Jemalloc ;
4551
52+ trait AsyncStream : AsyncRead + AsyncWrite + Unpin + Send { }
53+ impl < T : AsyncRead + AsyncWrite + Unpin + Send > AsyncStream for T { }
54+
55+ enum Listener {
56+ Unix ( UnixListener ) ,
57+ Tcp ( TcpListener ) ,
58+ }
59+
60+ impl Listener {
61+ async fn accept ( & self ) -> io:: Result < ( Box < dyn AsyncStream > , SocketAddr ) > {
62+ match self {
63+ Listener :: Unix ( l) => {
64+ let ( stream, _) = l. accept ( ) . await ?;
65+ Ok ( ( Box :: new ( stream) , "0.0.0.0:0" . parse ( ) . unwrap ( ) ) )
66+ }
67+ Listener :: Tcp ( l) => {
68+ let ( stream, addr) = l. accept ( ) . await ?;
69+ Ok ( ( Box :: new ( stream) , addr) )
70+ }
71+ }
72+ }
73+ }
74+
4675#[ derive( Clone ) ]
4776struct ServerConfig {
4877 bind_address : String ,
@@ -75,18 +104,19 @@ struct Server {
75104 work_rx : RefCell < Option < crossbeam_channel:: Receiver < RequestWithCompletion > > > ,
76105 work_tx : RefCell < Option < Arc < crossbeam_channel:: Sender < RequestWithCompletion > > > > ,
77106 runtime : RefCell < Option < Arc < tokio:: runtime:: Runtime > > > ,
107+ shutdown : RefCell < Option < broadcast:: Sender < ( ) > > > ,
78108}
79109
80110impl Server {
81111 pub fn new ( ) -> Self {
82112 let ( work_tx, work_rx) = crossbeam_channel:: bounded ( 1000 ) ;
83-
84113 Self {
85114 server_handle : Arc :: new ( Mutex :: new ( None ) ) ,
86115 config : RefCell :: new ( ServerConfig :: new ( ) ) ,
87116 work_rx : RefCell :: new ( Some ( work_rx) ) ,
88117 work_tx : RefCell :: new ( Some ( Arc :: new ( work_tx) ) ) ,
89118 runtime : RefCell :: new ( None ) ,
119+ shutdown : RefCell :: new ( None ) ,
90120 }
91121 }
92122
@@ -211,6 +241,9 @@ impl Server {
211241 . ok_or_else ( || MagnusError :: new ( magnus:: exception:: runtime_error ( ) , "Work channel not initialized" ) ) ?
212242 . clone ( ) ;
213243
244+ let ( shutdown_tx, shutdown_rx) = broadcast:: channel ( 1 ) ;
245+ * self . shutdown . borrow_mut ( ) = Some ( shutdown_tx) ;
246+
214247 let mut rt_builder = tokio:: runtime:: Builder :: new_multi_thread ( ) ;
215248
216249 rt_builder. enable_all ( ) ;
@@ -225,65 +258,94 @@ impl Server {
225258
226259 * self . runtime . borrow_mut ( ) = Some ( rt. clone ( ) ) ;
227260
228- rt. block_on ( async {
229- let work_tx = work_tx. clone ( ) ;
230-
261+
262+ rt. block_on ( async move {
231263 let server_task = tokio:: spawn ( async move {
232264 let timer = hyper_util:: rt:: TokioTimer :: new ( ) ;
265+ let mut builder = auto:: Builder :: new ( hyper_util:: rt:: TokioExecutor :: new ( ) ) ;
266+ builder. http1 ( )
267+ . header_read_timeout ( std:: time:: Duration :: from_millis ( config. recv_timeout ) )
268+ . timer ( timer. clone ( ) ) ;
269+ builder. http2 ( )
270+ . keep_alive_interval ( std:: time:: Duration :: from_secs ( 10 ) )
271+ . timer ( timer) ;
272+
273+ let listener = if config. bind_address . starts_with ( "unix:" ) {
274+ Listener :: Unix ( UnixListener :: bind ( config. bind_address . trim_start_matches ( "unix:" ) ) . unwrap ( ) )
275+ } else {
276+ let addr: SocketAddr = config. bind_address . parse ( ) . expect ( "invalid address format" ) ;
277+ Listener :: Tcp ( TcpListener :: bind ( addr) . await . unwrap ( ) )
278+ } ;
233279
234- if config. bind_address . starts_with ( "unix:" ) {
235- let path = config. bind_address . trim_start_matches ( "unix:" ) ;
236- let listener = UnixListener :: bind ( path) . unwrap ( ) ;
237-
238- loop {
239- let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
240- let work_tx = work_tx. clone ( ) ;
241- let timer = timer. clone ( ) ;
242-
243- tokio:: task:: spawn ( async move {
244- handle_connection ( stream, work_tx, config. recv_timeout , timer) . await ;
245- } ) ;
280+ let graceful_shutdown = GracefulShutdown :: new ( ) ;
281+ let mut shutdown_rx = shutdown_rx;
282+
283+ loop {
284+ tokio:: select! {
285+ Ok ( ( stream, _) ) = listener. accept( ) => {
286+ info!( "New connection established" ) ;
287+
288+ let io = TokioIo :: new( stream) ;
289+
290+ debug!( "Setting up connection" ) ;
291+
292+ let builder = builder. clone( ) ;
293+ let work_tx = work_tx. clone( ) ;
294+ let conn = builder. serve_connection( io, service_fn( move |req: HyperRequest <Incoming >| {
295+ debug!( "Service handling request" ) ;
296+ handle_request( req, work_tx. clone( ) , config. recv_timeout)
297+ } ) ) ;
298+ let fut = graceful_shutdown. watch( conn. into_owned( ) ) ;
299+ tokio:: task:: spawn( async move {
300+ if let Err ( err) = fut. await {
301+ warn!( "Error serving connection: {:?}" , err) ;
302+ }
303+ } ) ;
304+ } ,
305+ _ = shutdown_rx. recv( ) => {
306+ debug!( "Graceful shutdown requested; shutting down" ) ;
307+ break ;
308+ }
246309 }
247- } else {
248- let addr: SocketAddr = config. bind_address . parse ( )
249- . expect ( "invalid address format" ) ;
250- let listener = tokio:: net:: TcpListener :: bind ( addr) . await . unwrap ( ) ;
251-
252- loop {
253- let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
254- let work_tx = work_tx. clone ( ) ;
255- let timer = timer. clone ( ) ;
256-
257- tokio:: task:: spawn ( async move {
258- handle_connection ( stream, work_tx, config. recv_timeout , timer) . await ;
259- } ) ;
310+ }
311+
312+ tokio:: select! {
313+ _ = graceful_shutdown. shutdown( ) => {
314+ debug!( "all connections gracefully closed" ) ;
315+ } ,
316+ _ = tokio:: time:: sleep( std:: time:: Duration :: from_secs( 10 ) ) => {
317+ error!( "timed out wait for all connections to close" ) ;
260318 }
261319 }
262320 } ) ;
263321
264322 let mut handle = self . server_handle . lock ( ) . await ;
265323 * handle = Some ( server_task) ;
266-
324+
267325 Ok :: < ( ) , MagnusError > ( ( ) )
268326 } ) ?;
269-
327+
270328 Ok ( ( ) )
271329 }
272330
273331 pub fn stop ( & self ) -> Result < ( ) , MagnusError > {
274- // Use the stored runtime instead of creating a new one
275332 if let Some ( rt) = self . runtime . borrow ( ) . as_ref ( ) {
333+ if let Some ( shutdown) = self . shutdown . borrow ( ) . as_ref ( ) {
334+ let _ = shutdown. send ( ( ) ) ;
335+ }
336+
276337 rt. block_on ( async {
277338 let mut handle = self . server_handle . lock ( ) . await ;
278339 if let Some ( task) = handle. take ( ) {
279- task. abort ( ) ;
340+ task. await . unwrap_or_else ( |e| warn ! ( "Server task failed: {:?}" , e ) ) ;
280341 }
281342 } ) ;
282343 }
283344
284345 // Drop the channel and runtime
285346 self . work_tx . borrow_mut ( ) . take ( ) ;
286347 self . runtime . borrow_mut ( ) . take ( ) ;
348+ self . shutdown . borrow_mut ( ) . take ( ) ;
287349
288350 let bind_address = self . config . borrow ( ) . bind_address . clone ( ) ;
289351 if bind_address. starts_with ( "unix:" ) {
@@ -371,41 +433,6 @@ fn create_timeout_response() -> HyperResponse<BodyWithTrailers> {
371433 . unwrap ( )
372434}
373435
374- async fn handle_connection (
375- stream : impl tokio:: io:: AsyncRead + tokio:: io:: AsyncWrite + Unpin + Send + ' static ,
376- work_tx : Arc < crossbeam_channel:: Sender < RequestWithCompletion > > ,
377- recv_timeout : u64 ,
378- timer : hyper_util:: rt:: TokioTimer ,
379- ) {
380- info ! ( "New connection established" ) ;
381-
382- let service = service_fn ( move |req : HyperRequest < Incoming > | {
383- debug ! ( "Service handling request" ) ;
384- let work_tx = work_tx. clone ( ) ;
385- handle_request ( req, work_tx, recv_timeout)
386- } ) ;
387-
388- let io = TokioIo :: new ( stream) ;
389-
390- debug ! ( "Setting up connection" ) ;
391- let mut builder = auto:: Builder :: new ( hyper_util:: rt:: TokioExecutor :: new ( ) ) ;
392-
393- builder. http1 ( )
394- . header_read_timeout ( std:: time:: Duration :: from_millis ( recv_timeout) )
395- . timer ( timer. clone ( ) ) ;
396-
397- builder. http2 ( )
398- . keep_alive_interval ( std:: time:: Duration :: from_secs ( 10 ) )
399- . timer ( timer) ;
400-
401- if let Err ( err) = builder
402- . serve_connection ( io, service)
403- . await
404- {
405- warn ! ( "Error serving connection: {:?}" , err) ;
406- }
407- }
408-
409436// Helper function to create error responses
410437fn create_error_response ( error_message : & str ) -> HyperResponse < BodyWithTrailers > {
411438 // For non-gRPC requests, return a plain HTTP error
0 commit comments