@@ -416,45 +416,124 @@ public static IAwaitQuery<TResult> AwaitCompletion<T, TTaskResult, TResult>(
416416
417417 return
418418 AwaitQuery . Create (
419- options => _ ( options . MaxConcurrency ?? int . MaxValue ,
419+ options => _ ( options . MaxConcurrency ,
420420 options . Scheduler ?? TaskScheduler . Default ,
421421 options . PreserveOrder ) ) ;
422422
423- IEnumerable < TResult > _ ( int maxConcurrency , TaskScheduler scheduler , bool ordered )
423+ IEnumerable < TResult > _ ( int ? maxConcurrency , TaskScheduler scheduler , bool ordered )
424424 {
425+ // A separate task will enumerate the source and launch tasks.
426+ // It will post all progress as notices to the collection below.
427+ // A notice is essentially a discriminated union like:
428+ //
429+ // type Notice<'a, 'b> =
430+ // | End
431+ // | Result of (int * 'a * Task<'b>)
432+ // | Error of ExceptionDispatchInfo
433+ //
434+ // Note that BlockingCollection.CompleteAdding is never used to
435+ // to mark the end (which its own notice above) because
436+ // BlockingCollection.Add throws if called after CompleteAdding
437+ // and we want to deliberately tolerate the race condition.
438+
425439 var notices = new BlockingCollection < ( Notice , ( int , T , Task < TTaskResult > ) , ExceptionDispatchInfo ) > ( ) ;
426- var cancellationTokenSource = new CancellationTokenSource ( ) ;
427- var cancellationToken = cancellationTokenSource . Token ;
428- var completed = false ;
429440
430- var enumerator =
431- source . Index ( )
432- . Select ( e => ( e . Key , Item : e . Value , Task : evaluator ( e . Value , cancellationToken ) ) )
433- . GetEnumerator ( ) ;
441+ var consumerCancellationTokenSource = new CancellationTokenSource ( ) ;
442+ ( Exception , Exception ) lastCriticalErrors = default ;
443+
444+ void PostNotice ( Notice notice ,
445+ ( int , T , Task < TTaskResult > ) item ,
446+ Exception error )
447+ {
448+ // If a notice fails to post then assume critical error
449+ // conditions (like low memory), capture the error without
450+ // further allocation of resources and trip the cancellation
451+ // token source used by the main loop waiting on notices.
452+ // Note that only the "last" critical error is reported
453+ // as maintaining a list would incur allocations. The idea
454+ // here is to make a best effort attempt to report any of
455+ // the error conditions that may be occuring, which is still
456+ // better than nothing.
457+
458+ try
459+ {
460+ var edi = error != null
461+ ? ExceptionDispatchInfo . Capture ( error )
462+ : null ;
463+ notices . Add ( ( notice , item , edi ) ) ;
464+ }
465+ catch ( Exception e )
466+ {
467+ // Don't use ExceptionDispatchInfo.Capture here to avoid
468+ // inducing allocations if already under low memory
469+ // conditions.
470+
471+ lastCriticalErrors = ( e , error ) ;
472+ consumerCancellationTokenSource . Cancel ( ) ;
473+ throw ;
474+ }
475+ }
476+
477+ var completed = false ;
478+ var cancellationTokenSource = new CancellationTokenSource ( ) ;
434479
480+ var enumerator = source . Index ( ) . GetEnumerator ( ) ;
435481 IDisposable disposable = enumerator ; // disables AccessToDisposedClosure warnings
436482
437483 try
438484 {
485+ var cancellationToken = cancellationTokenSource . Token ;
486+
487+ // Fire-up a parallel loop to iterate through the source and
488+ // launch tasks, posting a result-notice as each task
489+ // completes and another, an end-notice, when all tasks have
490+ // completed.
491+
439492 Task . Factory . StartNew (
440- ( ) =>
441- CollectToAsync (
442- enumerator ,
443- e => e . Task ,
444- notices ,
445- ( e , r ) => ( Notice . Result , ( e . Key , e . Item , e . Task ) , default ) ,
446- ex => ( Notice . Error , default , ExceptionDispatchInfo . Capture ( ex ) ) ,
447- ( Notice . End , default , default ) ,
448- maxConcurrency , cancellationTokenSource ) ,
493+ async ( ) =>
494+ {
495+ try
496+ {
497+ await enumerator . StartAsync (
498+ e => evaluator ( e . Value , cancellationToken ) ,
499+ ( e , r ) => PostNotice ( Notice . Result , ( e . Key , e . Value , r ) , default ) ,
500+ ( ) => PostNotice ( Notice . End , default , default ) ,
501+ maxConcurrency , cancellationToken ) ;
502+ }
503+ catch ( Exception e )
504+ {
505+ PostNotice ( Notice . Error , default , e ) ;
506+ }
507+ } ,
449508 CancellationToken . None ,
450509 TaskCreationOptions . DenyChildAttach ,
451510 scheduler ) ;
452511
512+ // Remainder here is the main loop that waits for and
513+ // processes notices.
514+
453515 var nextKey = 0 ;
454516 var holds = ordered ? new List < ( int , T , Task < TTaskResult > ) > ( ) : null ;
455517
456- foreach ( var ( kind , result , error ) in notices . GetConsumingEnumerable ( ) )
518+ using ( var notice = notices . GetConsumingEnumerable ( consumerCancellationTokenSource . Token )
519+ . GetEnumerator ( ) )
520+ while ( true )
457521 {
522+ try
523+ {
524+ if ( ! notice . MoveNext ( ) )
525+ break ;
526+ }
527+ catch ( OperationCanceledException e ) when ( e . CancellationToken == consumerCancellationTokenSource . Token )
528+ {
529+ var ( error1 , error2 ) = lastCriticalErrors ;
530+ throw new Exception ( "One or more critical errors have occurred." ,
531+ error2 != null ? new AggregateException ( error1 , error2 )
532+ : new AggregateException ( error1 ) ) ;
533+ }
534+
535+ var ( kind , result , error ) = notice . Current ;
536+
458537 if ( kind == Notice . Error )
459538 error . Throw ( ) ;
460539
@@ -531,149 +610,76 @@ IEnumerable<TResult> _(int maxConcurrency, TaskScheduler scheduler, bool ordered
531610 }
532611 }
533612
534- enum Notice { Result , Error , End }
535-
536- static async Task CollectToAsync < T , TResult , TNotice > (
537- this IEnumerator < T > e ,
538- Func < T , Task < TResult > > taskSelector ,
539- BlockingCollection < TNotice > collection ,
540- Func < T , Task < TResult > , TNotice > completionNoticeSelector ,
541- Func < Exception , TNotice > errorNoticeSelector ,
542- TNotice endNotice ,
543- int maxConcurrency ,
544- CancellationTokenSource cancellationTokenSource )
613+ enum Notice { End , Result , Error }
614+
615+ static async Task StartAsync < T , TResult > (
616+ this IEnumerator < T > enumerator ,
617+ Func < T , Task < TResult > > starter ,
618+ Action < T , Task < TResult > > onTaskCompletion ,
619+ Action onEnd ,
620+ int ? maxConcurrency ,
621+ CancellationToken cancellationToken )
545622 {
546- Reader < T > reader = null ;
623+ if ( enumerator == null ) throw new ArgumentNullException ( nameof ( enumerator ) ) ;
624+ if ( starter == null ) throw new ArgumentNullException ( nameof ( starter ) ) ;
625+ if ( onTaskCompletion == null ) throw new ArgumentNullException ( nameof ( onTaskCompletion ) ) ;
626+ if ( onEnd == null ) throw new ArgumentNullException ( nameof ( onEnd ) ) ;
627+ if ( maxConcurrency < 1 ) throw new ArgumentOutOfRangeException ( nameof ( maxConcurrency ) ) ;
547628
548- try
629+ using ( enumerator )
549630 {
550- reader = new Reader < T > ( e ) ;
551-
552- var cancellationToken = cancellationTokenSource . Token ;
553- var cancellationTaskSource = new TaskCompletionSource < bool > ( ) ;
554- cancellationToken . Register ( ( ) => cancellationTaskSource . TrySetResult ( true ) ) ;
631+ var pendingCount = 1 ; // terminator
555632
556- var tasks = new List < ( T Item , Task < TResult > Task ) > ( ) ;
557-
558- for ( var i = 0 ; i < maxConcurrency ; i ++ )
633+ void OnPendingCompleted ( )
559634 {
560- if ( ! reader . TryRead ( out var item ) )
561- break ;
562- tasks . Add ( ( item , taskSelector ( item ) ) ) ;
635+ if ( Interlocked . Decrement ( ref pendingCount ) == 0 )
636+ onEnd ( ) ;
563637 }
564638
565- while ( tasks . Count > 0 )
639+ var concurrencyGate = maxConcurrency is int count
640+ ? new ConcurrencyGate ( count )
641+ : ConcurrencyGate . Unbounded ;
642+
643+ while ( enumerator . MoveNext ( ) )
566644 {
567- // Task.WaitAny is synchronous and blocking but allows the
568- // waiting to be cancelled via a CancellationToken.
569- // Task.WhenAny can be awaited so it is better since the
570- // thread won't be blocked and can return to the pool.
571- // However, it doesn't support cancellation so instead a
572- // task is built on top of the CancellationToken that
573- // completes when the CancellationToken trips.
574- //
575- // Also, Task.WhenAny returns the task (Task) object that
576- // completed but task objects may not be unique due to
577- // caching, e.g.:
578- //
579- // async Task<bool> Foo() => true;
580- // async Task<bool> Bar() => true;
581- // var foo = Foo();
582- // var bar = Bar();
583- // var same = foo.Equals(bar); // == true
584- //
585- // In this case, the task returned by Task.WhenAny will
586- // match `foo` and `bar`:
587- //
588- // var done = Task.WhenAny(foo, bar);
589- //
590- // Logically speaking, the uniqueness of a task does not
591- // matter but here it does, especially when Await (the main
592- // user of CollectAsync) needs to return results ordered.
593- // Fortunately, we compose our own task on top of the
594- // original that links each item with the task result and as
595- // a consequence generate new and unique task objects.
596-
597- var completedTask = await
598- Task . WhenAny ( tasks . Select ( it => ( Task ) it . Task ) . Concat ( cancellationTaskSource . Task ) )
599- . ConfigureAwait ( continueOnCapturedContext : false ) ;
600-
601- if ( completedTask == cancellationTaskSource . Task )
645+ try
602646 {
603- // Cancellation during the wait means the enumeration
604- // has been stopped by the user so the results of the
605- // remaining tasks are no longer needed. Those tasks
606- // should cancel as a result of sharing the same
607- // cancellation token and provided that they passed it
608- // on to any downstream asynchronous operations. Either
609- // way, this loop is done so exit hard here.
610-
611- return ;
647+ await concurrencyGate . EnterAsync ( cancellationToken ) ;
612648 }
613-
614- var i = tasks . FindIndex ( it => it . Task . Equals ( completedTask ) ) ;
615-
649+ catch ( OperationCanceledException e ) when ( e . CancellationToken == cancellationToken )
616650 {
617- var ( item , task ) = tasks [ i ] ;
618- tasks . RemoveAt ( i ) ;
651+ return ;
652+ }
619653
620- // Await the task rather than using its result directly
621- // to avoid having the task's exception bubble up as
622- // AggregateException if the task failed.
654+ Interlocked . Increment ( ref pendingCount ) ;
623655
624- collection . Add ( completionNoticeSelector ( item , task ) ) ;
625- }
656+ var item = enumerator . Current ;
657+ var task = starter ( item ) ;
626658
627- {
628- if ( reader . TryRead ( out var item ) )
629- tasks . Add ( ( item , taskSelector ( item ) ) ) ;
630- }
631- }
659+ // Add a continutation that notifies completion of the task,
660+ // along with the necessary housekeeping, in case it
661+ // completes before maximum concurrency is reached.
632662
633- collection . Add ( endNotice ) ;
634- }
635- catch ( Exception ex )
636- {
637- cancellationTokenSource . Cancel ( ) ;
638- collection . Add ( errorNoticeSelector ( ex ) ) ;
639- }
640- finally
641- {
642- reader ? . Dispose ( ) ;
643- }
663+ #pragma warning disable 4014 // https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/compiler-messages/cs4014
644664
645- collection . CompleteAdding ( ) ;
646- }
665+ task . ContinueWith ( cancellationToken : cancellationToken ,
666+ continuationOptions : TaskContinuationOptions . ExecuteSynchronously ,
667+ scheduler : TaskScheduler . Current ,
668+ continuationAction : t =>
669+ {
670+ concurrencyGate . Exit ( ) ;
647671
648- sealed class Reader < T > : IDisposable
649- {
650- IEnumerator < T > _enumerator ;
672+ if ( cancellationToken . IsCancellationRequested )
673+ return ;
651674
652- public Reader ( IEnumerator < T > enumerator ) =>
653- _enumerator = enumerator ;
675+ onTaskCompletion ( item , t ) ;
676+ OnPendingCompleted ( ) ;
677+ } ) ;
654678
655- public bool TryRead ( out T item )
656- {
657- var ended = false ;
658- if ( _enumerator == null || ( ended = ! _enumerator . MoveNext ( ) ) )
659- {
660- if ( ended )
661- Dispose ( ) ;
662- item = default ;
663- return false ;
679+ #pragma warning restore 4014
664680 }
665681
666- item = _enumerator . Current ;
667- return true ;
668- }
669-
670- public void Dispose ( )
671- {
672- var e = _enumerator ;
673- if ( e == null )
674- return ;
675- _enumerator = null ;
676- e . Dispose ( ) ;
682+ OnPendingCompleted ( ) ;
677683 }
678684 }
679685
@@ -720,6 +726,53 @@ static class TupleComparer<T1, T2, T3>
720726 public static readonly IComparer < ( T1 , T2 , T3 ) > Item3 =
721727 Comparer < ( T1 , T2 , T3 ) > . Create ( ( x , y ) => Comparer < T3 > . Default . Compare ( x . Item3 , y . Item3 ) ) ;
722728 }
729+
730+ static class CompletedTask
731+ {
732+ #if NET451 || NETSTANDARD1_0
733+
734+ public static readonly Task Instance ;
735+
736+ static CompletedTask ( )
737+ {
738+ var tcs = new TaskCompletionSource < object > ( ) ;
739+ tcs . SetResult ( null ) ;
740+ Instance = tcs . Task ;
741+ }
742+
743+ #else
744+
745+ public static readonly Task Instance = Task . CompletedTask ;
746+
747+ #endif
748+ }
749+
750+ sealed class ConcurrencyGate
751+ {
752+ public static readonly ConcurrencyGate Unbounded = new ConcurrencyGate ( ) ;
753+
754+ readonly SemaphoreSlim _semaphore ;
755+
756+ ConcurrencyGate ( SemaphoreSlim semaphore = null ) =>
757+ _semaphore = semaphore ;
758+
759+ public ConcurrencyGate ( int max ) :
760+ this ( new SemaphoreSlim ( max , max ) ) { }
761+
762+ public Task EnterAsync ( CancellationToken token )
763+ {
764+ if ( _semaphore == null )
765+ {
766+ token . ThrowIfCancellationRequested ( ) ;
767+ return CompletedTask . Instance ;
768+ }
769+
770+ return _semaphore . WaitAsync ( token ) ;
771+ }
772+
773+ public void Exit ( ) =>
774+ _semaphore ? . Release ( ) ;
775+ }
723776 }
724777}
725778
0 commit comments