@@ -372,6 +372,193 @@ mod test {
372372 ( blocking_client, async_client)
373373 }
374374
375+ #[ cfg( feature = "async-ohttp" ) ]
376+ fn find_free_port ( ) -> u16 {
377+ let listener = std:: net:: TcpListener :: bind ( "0.0.0.0:0" ) . unwrap ( ) ;
378+ listener. local_addr ( ) . unwrap ( ) . port ( )
379+ }
380+
381+ #[ cfg( feature = "async-ohttp" ) ]
382+ async fn start_ohttp_relay (
383+ gateway_url : ohttp_relay:: GatewayUri ,
384+ ) -> (
385+ u16 ,
386+ tokio:: task:: JoinHandle < Result < ( ) , Box < dyn std:: error:: Error + std:: marker:: Send + Sync > > > ,
387+ ) {
388+ let port = find_free_port ( ) ;
389+ let relay = ohttp_relay:: listen_tcp ( port, gateway_url) . await . unwrap ( ) ;
390+
391+ ( port, relay)
392+ }
393+
394+ #[ cfg( feature = "async-ohttp" ) ]
395+ async fn start_ohttp_gateway ( ) -> ( u16 , tokio:: task:: JoinHandle < ( ) > ) {
396+ use http_body_util:: Full ;
397+ use hyper:: body:: Incoming ;
398+ use hyper:: service:: service_fn;
399+ use hyper:: Response ;
400+ use hyper:: { Method , Request } ;
401+ use hyper_util:: rt:: TokioIo ;
402+ use tokio:: net:: TcpListener ;
403+
404+ let port = find_free_port ( ) ;
405+ let listener = TcpListener :: bind ( format ! ( "0.0.0.0:{}" , port) )
406+ . await
407+ . unwrap ( ) ;
408+
409+ let handle = tokio:: spawn ( async move {
410+ let key_config = bitcoin_ohttp:: KeyConfig :: new (
411+ 0 ,
412+ bitcoin_ohttp:: hpke:: Kem :: K256Sha256 ,
413+ vec ! [ bitcoin_ohttp:: SymmetricSuite :: new(
414+ bitcoin_ohttp:: hpke:: Kdf :: HkdfSha256 ,
415+ bitcoin_ohttp:: hpke:: Aead :: ChaCha20Poly1305 ,
416+ ) ] ,
417+ )
418+ . expect ( "valid key config" ) ;
419+ let server = bitcoin_ohttp:: Server :: new ( key_config) . expect ( "valid server" ) ;
420+ let server = std:: sync:: Arc :: new ( server) ;
421+ loop {
422+ match listener. accept ( ) . await {
423+ Ok ( ( stream, _) ) => {
424+ let io = TokioIo :: new ( stream) ;
425+ let server = server. clone ( ) ;
426+ let service = service_fn ( move |req : Request < Incoming > | {
427+ let server = server. clone ( ) ;
428+ async move {
429+ let path = req. uri ( ) . path ( ) ;
430+ if path == "/.well-known/ohttp-gateway"
431+ && req. method ( ) == Method :: GET
432+ {
433+ let key_config = server. config ( ) . encode ( ) . unwrap ( ) ;
434+ Ok :: < _ , hyper:: Error > (
435+ Response :: builder ( )
436+ . status ( 200 )
437+ . header ( "content-type" , "application/ohttp-keys" )
438+ . body ( Full :: new ( hyper:: body:: Bytes :: from ( key_config) ) )
439+ . unwrap ( ) ,
440+ )
441+ } else if path == "/.well-known/ohttp-gateway"
442+ && req. method ( ) == Method :: POST
443+ {
444+ use http_body_util:: BodyExt ;
445+
446+ // Assert that the content-type header is set to "message/ohttp-req".
447+ let content_type_header = req
448+ . headers ( )
449+ . get ( "content-type" )
450+ . expect ( "content-type header should be set by the client" ) ;
451+ assert_eq ! ( content_type_header, "message/ohttp-req" ) ;
452+
453+ let bytes = req. collect ( ) . await ?. to_bytes ( ) ;
454+ let ( bhttp_body, response_ctx) =
455+ server. decapsulate ( bytes. iter ( ) . as_slice ( ) ) . unwrap ( ) ;
456+ // Reconstruct the inner HTTP message from the bhttp message.
457+ let mut r = std:: io:: Cursor :: new ( bhttp_body) ;
458+ let m: bhttp:: Message = bhttp:: Message :: read_bhttp ( & mut r)
459+ . expect ( "Should be valid bhttp message" ) ;
460+ let base_url = format ! (
461+ "http://{}" ,
462+ ELECTRSD . esplora_url. as_ref( ) . unwrap( )
463+ ) ;
464+ let path =
465+ String :: from_utf8 ( m. control ( ) . path ( ) . unwrap ( ) . to_vec ( ) )
466+ . unwrap ( ) ;
467+ let _ =
468+ Method :: from_bytes ( m. control ( ) . method ( ) . unwrap ( ) ) . unwrap ( ) ;
469+ // TODO: Use the actual method from the bhttp message
470+ // This will be refactored out to use bitreq
471+ let req = reqwest:: Request :: new (
472+ Method :: GET ,
473+ url:: Url :: parse ( & ( base_url + & path) ) . unwrap ( ) ,
474+ ) ;
475+ let mut req_builder = reqwest:: RequestBuilder :: from_parts (
476+ reqwest:: Client :: new ( ) ,
477+ req,
478+ ) ;
479+ for field in m. header ( ) . iter ( ) {
480+ req_builder =
481+ req_builder. header ( field. name ( ) , field. value ( ) ) ;
482+ }
483+
484+ let res = req_builder. send ( ) . await . unwrap ( ) ;
485+ // Convert HTTP response to bhttp response
486+ let mut m: bhttp:: Message = bhttp:: Message :: response (
487+ res. status ( ) . as_u16 ( ) . try_into ( ) . unwrap ( ) ,
488+ ) ;
489+ m. write_content ( res. bytes ( ) . await . unwrap ( ) ) ;
490+ let mut bhttp_res = vec ! [ ] ;
491+ m. write_bhttp ( bhttp:: Mode :: IndeterminateLength , & mut bhttp_res)
492+ . unwrap ( ) ;
493+ // Now we need to encapsulate the response
494+ let encapsulated_response =
495+ response_ctx. encapsulate ( & bhttp_res) . unwrap ( ) ;
496+
497+ Ok :: < _ , hyper:: Error > (
498+ Response :: builder ( )
499+ . status ( 200 )
500+ . header ( "content-type" , "message/ohttp-res" )
501+ . body ( Full :: new ( hyper:: body:: Bytes :: copy_from_slice (
502+ & encapsulated_response,
503+ ) ) )
504+ . unwrap ( ) ,
505+ )
506+ } else {
507+ Ok :: < _ , hyper:: Error > (
508+ Response :: builder ( )
509+ . status ( 404 )
510+ . body ( Full :: new ( hyper:: body:: Bytes :: from ( "Not Found" ) ) )
511+ . unwrap ( ) ,
512+ )
513+ }
514+ }
515+ } ) ;
516+
517+ tokio:: spawn ( async move {
518+ if let Err ( err) = hyper:: server:: conn:: http1:: Builder :: new ( )
519+ . serve_connection ( io, service)
520+ . await
521+ {
522+ eprintln ! ( "Error serving connection: {:?}" , err) ;
523+ }
524+ } ) ;
525+ }
526+ Err ( e) => {
527+ eprintln ! ( "Error accepting connection: {:?}" , e) ;
528+ break ;
529+ }
530+ }
531+ }
532+ } ) ;
533+ println ! ( "OHTTP gateway started on port {}" , port) ;
534+
535+ ( port, handle)
536+ }
537+ #[ cfg( feature = "async-ohttp" ) ]
538+ #[ tokio:: test]
539+ async fn test_ohttp_e2e ( ) {
540+ let ( _, async_client) = setup_clients ( ) . await ;
541+ let block_hash = async_client. get_block_hash ( 1 ) . await . unwrap ( ) ;
542+ let esplora_url = ELECTRSD . esplora_url . as_ref ( ) . unwrap ( ) ;
543+ let ( gateway_port, _) = start_ohttp_gateway ( ) . await ;
544+ let gateway_origin = format ! ( "http://localhost:{gateway_port}" ) ;
545+ let ( relay_port, _) =
546+ start_ohttp_relay ( gateway_origin. parse :: < ohttp_relay:: GatewayUri > ( ) . unwrap ( ) ) . await ;
547+ let gateway_url = format ! (
548+ "http://localhost:{}/.well-known/ohttp-gateway" ,
549+ gateway_port
550+ ) ;
551+ let relay_url = format ! ( "http://localhost:{}" , relay_port) ;
552+
553+ let ohttp_client = Builder :: new ( & format ! ( "http://{}" , esplora_url) )
554+ . build_async_with_ohttp ( & relay_url, & gateway_url)
555+ . await
556+ . unwrap ( ) ;
557+
558+ let res = ohttp_client. get_block_hash ( 1 ) . await . unwrap ( ) ;
559+ assert_eq ! ( res, block_hash) ;
560+ }
561+
375562 #[ cfg( all( feature = "blocking" , feature = "async" ) ) ]
376563 fn generate_blocks_and_wait ( num : usize ) {
377564 let cur_height = BITCOIND . client . get_block_count ( ) . unwrap ( ) . 0 ;
0 commit comments