|
28 | 28 | Collection, |
29 | 29 | Coroutine, |
30 | 30 | Iterable, |
| 31 | + Iterator, |
| 32 | + MutableMapping, |
31 | 33 | Sequence, |
32 | 34 | ) |
33 | 35 | from concurrent.futures import Future |
@@ -351,8 +353,12 @@ def get_callable_name(func: Callable) -> str: |
351 | 353 |
|
352 | 354 | def _task_started(task: asyncio.Task) -> bool: |
353 | 355 | """Return ``True`` if the task has been started and has not finished.""" |
| 356 | + # The task coro should never be None here, as we never add finished tasks to the |
| 357 | + # task list |
| 358 | + coro = task.get_coro() |
| 359 | + assert coro is not None |
354 | 360 | try: |
355 | | - return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED) |
| 361 | + return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) |
356 | 362 | except AttributeError: |
357 | 363 | # task coro is async_genenerator_asend https://bugs.python.org/issue37771 |
358 | 364 | raise Exception(f"Cannot determine if task {task} has started or not") from None |
@@ -409,8 +415,10 @@ def __enter__(self) -> CancelScope: |
409 | 415 | self._parent_scope = task_state.cancel_scope |
410 | 416 | task_state.cancel_scope = self |
411 | 417 | if self._parent_scope is not None: |
| 418 | + # If using an eager task factory, the parent scope may not even contain |
| 419 | + # the host task |
412 | 420 | self._parent_scope._child_scopes.add(self) |
413 | | - self._parent_scope._tasks.remove(host_task) |
| 421 | + self._parent_scope._tasks.discard(host_task) |
414 | 422 |
|
415 | 423 | self._timeout() |
416 | 424 | self._active = True |
@@ -667,7 +675,45 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): |
667 | 675 | self.cancel_scope = cancel_scope |
668 | 676 |
|
669 | 677 |
|
670 | | -_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary() |
| 678 | +class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]): |
| 679 | + def __init__(self) -> None: |
| 680 | + self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]() |
| 681 | + self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {} |
| 682 | + |
| 683 | + def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState: |
| 684 | + assert isinstance(key, asyncio.Task) |
| 685 | + try: |
| 686 | + return self._task_states[key] |
| 687 | + except KeyError: |
| 688 | + if coro := key.get_coro(): |
| 689 | + if state := self._preliminary_task_states.get(coro): |
| 690 | + return state |
| 691 | + |
| 692 | + raise KeyError(key) |
| 693 | + |
| 694 | + def __setitem__( |
| 695 | + self, key: asyncio.Task | Awaitable[Any], value: TaskState, / |
| 696 | + ) -> None: |
| 697 | + if isinstance(key, asyncio.Task): |
| 698 | + self._task_states[key] = value |
| 699 | + else: |
| 700 | + self._preliminary_task_states[key] = value |
| 701 | + |
| 702 | + def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None: |
| 703 | + if isinstance(key, asyncio.Task): |
| 704 | + del self._task_states[key] |
| 705 | + else: |
| 706 | + del self._preliminary_task_states[key] |
| 707 | + |
| 708 | + def __len__(self) -> int: |
| 709 | + return len(self._task_states) + len(self._preliminary_task_states) |
| 710 | + |
| 711 | + def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]: |
| 712 | + yield from self._task_states |
| 713 | + yield from self._preliminary_task_states |
| 714 | + |
| 715 | + |
| 716 | +_task_states = TaskStateStore() |
671 | 717 |
|
672 | 718 |
|
673 | 719 | # |
@@ -787,7 +833,7 @@ def _spawn( |
787 | 833 | task_status_future: asyncio.Future | None = None, |
788 | 834 | ) -> asyncio.Task: |
789 | 835 | def task_done(_task: asyncio.Task) -> None: |
790 | | - task_state = _task_states[_task] |
| 836 | + # task_state = _task_states[_task] |
791 | 837 | assert task_state.cancel_scope is not None |
792 | 838 | assert _task in task_state.cancel_scope._tasks |
793 | 839 | task_state.cancel_scope._tasks.remove(_task) |
@@ -844,16 +890,26 @@ def task_done(_task: asyncio.Task) -> None: |
844 | 890 | f"the return value ({coro!r}) is not a coroutine object" |
845 | 891 | ) |
846 | 892 |
|
847 | | - name = get_callable_name(func) if name is None else str(name) |
848 | | - task = create_task(coro, name=name) |
849 | | - task.add_done_callback(task_done) |
850 | | - |
851 | 893 | # Make the spawned task inherit the task group's cancel scope |
852 | | - _task_states[task] = TaskState( |
| 894 | + _task_states[coro] = task_state = TaskState( |
853 | 895 | parent_id=parent_id, cancel_scope=self.cancel_scope |
854 | 896 | ) |
| 897 | + name = get_callable_name(func) if name is None else str(name) |
| 898 | + try: |
| 899 | + task = create_task(coro, name=name) |
| 900 | + finally: |
| 901 | + del _task_states[coro] |
| 902 | + |
| 903 | + _task_states[task] = task_state |
855 | 904 | self.cancel_scope._tasks.add(task) |
856 | 905 | self._tasks.add(task) |
| 906 | + |
| 907 | + if task.done(): |
| 908 | + # This can happen with eager task factories |
| 909 | + task_done(task) |
| 910 | + else: |
| 911 | + task.add_done_callback(task_done) |
| 912 | + |
857 | 913 | return task |
858 | 914 |
|
859 | 915 | def start_soon( |
@@ -2086,7 +2142,9 @@ def __init__(self, task: asyncio.Task): |
2086 | 2142 | else: |
2087 | 2143 | parent_id = task_state.parent_id |
2088 | 2144 |
|
2089 | | - super().__init__(id(task), parent_id, task.get_name(), task.get_coro()) |
| 2145 | + coro = task.get_coro() |
| 2146 | + assert coro is not None, "created TaskInfo from a completed Task" |
| 2147 | + super().__init__(id(task), parent_id, task.get_name(), coro) |
2090 | 2148 | self._task = weakref.ref(task) |
2091 | 2149 |
|
2092 | 2150 | def has_pending_cancellation(self) -> bool: |
@@ -2339,10 +2397,11 @@ def create_cancel_scope( |
2339 | 2397 |
|
2340 | 2398 | @classmethod |
2341 | 2399 | def current_effective_deadline(cls) -> float: |
| 2400 | + if (task := current_task()) is None: |
| 2401 | + return math.inf |
| 2402 | + |
2342 | 2403 | try: |
2343 | | - cancel_scope = _task_states[ |
2344 | | - current_task() # type: ignore[index] |
2345 | | - ].cancel_scope |
| 2404 | + cancel_scope = _task_states[task].cancel_scope |
2346 | 2405 | except KeyError: |
2347 | 2406 | return math.inf |
2348 | 2407 |
|
|
0 commit comments