Skip to content

Commit 97d5fe6

Browse files
authored
Made asyncio TaskGroup work with eager task factories (#822)
Fixes #764.
1 parent 44405f4 commit 97d5fe6

6 files changed

Lines changed: 107 additions & 18 deletions

File tree

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ repos:
2222
- id: trailing-whitespace
2323

2424
- repo: https://github.com/astral-sh/ruff-pre-commit
25-
rev: v0.6.9
25+
rev: v0.8.1
2626
hooks:
2727
- id: ruff
2828
args: [--fix, --show-fixes]
2929
- id: ruff-format
3030

3131
- repo: https://github.com/pre-commit/mirrors-mypy
32-
rev: v1.11.2
32+
rev: v1.13.0
3333
hooks:
3434
- id: mypy
3535
additional_dependencies:

docs/versionhistory.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
55

66
**UNRELEASED**
77

8+
- Updated ``TaskGroup`` to work with asyncio's eager task factories
9+
(`#764 <https://github.com/agronholm/anyio/issues/764>`_)
810
- Fixed a misleading ``ValueError`` in the context of DNS failures
911
(`#815 <https://github.com/agronholm/anyio/issues/815>`_; PR by @graingert)
1012
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept

src/anyio/_backends/_asyncio.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
Collection,
2929
Coroutine,
3030
Iterable,
31+
Iterator,
32+
MutableMapping,
3133
Sequence,
3234
)
3335
from concurrent.futures import Future
@@ -351,8 +353,12 @@ def get_callable_name(func: Callable) -> str:
351353

352354
def _task_started(task: asyncio.Task) -> bool:
353355
"""Return ``True`` if the task has been started and has not finished."""
356+
# The task coro should never be None here, as we never add finished tasks to the
357+
# task list
358+
coro = task.get_coro()
359+
assert coro is not None
354360
try:
355-
return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
361+
return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
356362
except AttributeError:
357363
# task coro is async_genenerator_asend https://bugs.python.org/issue37771
358364
raise Exception(f"Cannot determine if task {task} has started or not") from None
@@ -409,8 +415,10 @@ def __enter__(self) -> CancelScope:
409415
self._parent_scope = task_state.cancel_scope
410416
task_state.cancel_scope = self
411417
if self._parent_scope is not None:
418+
# If using an eager task factory, the parent scope may not even contain
419+
# the host task
412420
self._parent_scope._child_scopes.add(self)
413-
self._parent_scope._tasks.remove(host_task)
421+
self._parent_scope._tasks.discard(host_task)
414422

415423
self._timeout()
416424
self._active = True
@@ -667,7 +675,45 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
667675
self.cancel_scope = cancel_scope
668676

669677

670-
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
678+
class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]):
679+
def __init__(self) -> None:
680+
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
681+
self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}
682+
683+
def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
684+
assert isinstance(key, asyncio.Task)
685+
try:
686+
return self._task_states[key]
687+
except KeyError:
688+
if coro := key.get_coro():
689+
if state := self._preliminary_task_states.get(coro):
690+
return state
691+
692+
raise KeyError(key)
693+
694+
def __setitem__(
695+
self, key: asyncio.Task | Awaitable[Any], value: TaskState, /
696+
) -> None:
697+
if isinstance(key, asyncio.Task):
698+
self._task_states[key] = value
699+
else:
700+
self._preliminary_task_states[key] = value
701+
702+
def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
703+
if isinstance(key, asyncio.Task):
704+
del self._task_states[key]
705+
else:
706+
del self._preliminary_task_states[key]
707+
708+
def __len__(self) -> int:
709+
return len(self._task_states) + len(self._preliminary_task_states)
710+
711+
def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
712+
yield from self._task_states
713+
yield from self._preliminary_task_states
714+
715+
716+
_task_states = TaskStateStore()
671717

672718

673719
#
@@ -787,7 +833,7 @@ def _spawn(
787833
task_status_future: asyncio.Future | None = None,
788834
) -> asyncio.Task:
789835
def task_done(_task: asyncio.Task) -> None:
790-
task_state = _task_states[_task]
836+
# task_state = _task_states[_task]
791837
assert task_state.cancel_scope is not None
792838
assert _task in task_state.cancel_scope._tasks
793839
task_state.cancel_scope._tasks.remove(_task)
@@ -844,16 +890,26 @@ def task_done(_task: asyncio.Task) -> None:
844890
f"the return value ({coro!r}) is not a coroutine object"
845891
)
846892

847-
name = get_callable_name(func) if name is None else str(name)
848-
task = create_task(coro, name=name)
849-
task.add_done_callback(task_done)
850-
851893
# Make the spawned task inherit the task group's cancel scope
852-
_task_states[task] = TaskState(
894+
_task_states[coro] = task_state = TaskState(
853895
parent_id=parent_id, cancel_scope=self.cancel_scope
854896
)
897+
name = get_callable_name(func) if name is None else str(name)
898+
try:
899+
task = create_task(coro, name=name)
900+
finally:
901+
del _task_states[coro]
902+
903+
_task_states[task] = task_state
855904
self.cancel_scope._tasks.add(task)
856905
self._tasks.add(task)
906+
907+
if task.done():
908+
# This can happen with eager task factories
909+
task_done(task)
910+
else:
911+
task.add_done_callback(task_done)
912+
857913
return task
858914

859915
def start_soon(
@@ -2086,7 +2142,9 @@ def __init__(self, task: asyncio.Task):
20862142
else:
20872143
parent_id = task_state.parent_id
20882144

2089-
super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
2145+
coro = task.get_coro()
2146+
assert coro is not None, "created TaskInfo from a completed Task"
2147+
super().__init__(id(task), parent_id, task.get_name(), coro)
20902148
self._task = weakref.ref(task)
20912149

20922150
def has_pending_cancellation(self) -> bool:
@@ -2339,10 +2397,11 @@ def create_cancel_scope(
23392397

23402398
@classmethod
23412399
def current_effective_deadline(cls) -> float:
2400+
if (task := current_task()) is None:
2401+
return math.inf
2402+
23422403
try:
2343-
cancel_scope = _task_states[
2344-
current_task() # type: ignore[index]
2345-
].cancel_scope
2404+
cancel_scope = _task_states[task].cancel_scope
23462405
except KeyError:
23472406
return math.inf
23482407

src/anyio/abc/_tasks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ class TaskGroup(metaclass=ABCMeta):
4040
4141
:ivar cancel_scope: the cancel scope inherited by all child tasks
4242
:vartype cancel_scope: CancelScope
43+
44+
.. note:: On asyncio, support for eager task factories is considered to be
45+
**experimental**. In particular, they don't follow the usual semantics of new
46+
tasks being scheduled on the next iteration of the event loop, and may thus
47+
cause unexpected behavior in code that wasn't written with such semantics in
48+
mind.
4349
"""
4450

4551
cancel_scope: CancelScope

tests/test_sockets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _identity(v: _T) -> _T:
149149
)
150150

151151

152-
@_ignore_win32_resource_warnings # type: ignore[operator]
152+
@_ignore_win32_resource_warnings
153153
class TestTCPStream:
154154
@pytest.fixture
155155
def server_sock(self, family: AnyIPAddressFamily) -> Iterator[socket.socket]:

tests/test_taskgroups.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from typing import Any, NoReturn, cast
1111

1212
import pytest
13-
from exceptiongroup import ExceptionGroup, catch
13+
from exceptiongroup import catch
14+
from pytest import FixtureRequest
1415
from pytest_mock import MockerFixture
1516

1617
import anyio
@@ -783,7 +784,7 @@ async def host_agen_fn() -> AsyncGenerator[None, None]:
783784
host_agen = host_agen_fn()
784785
try:
785786
loop = asyncio.get_running_loop()
786-
await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type]
787+
await loop.create_task(host_agen.__anext__())
787788
finally:
788789
await host_agen.aclose()
789790

@@ -1704,3 +1705,24 @@ async def typetest_optional_status(
17041705
task_status: TaskStatus[int] = TASK_STATUS_IGNORED,
17051706
) -> None:
17061707
task_status.started(1)
1708+
1709+
1710+
@pytest.mark.skipif(
1711+
sys.version_info < (3, 12),
1712+
reason="Eager task factories require Python 3.12",
1713+
)
1714+
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
1715+
async def test_eager_task_factory(request: FixtureRequest) -> None:
1716+
async def sync_coro() -> None:
1717+
# This should trigger fetching the task state
1718+
with CancelScope(): # noqa: ASYNC100
1719+
pass
1720+
1721+
loop = asyncio.get_running_loop()
1722+
old_task_factory = loop.get_task_factory()
1723+
loop.set_task_factory(asyncio.eager_task_factory)
1724+
request.addfinalizer(lambda: loop.set_task_factory(old_task_factory))
1725+
1726+
async with create_task_group() as tg:
1727+
tg.start_soon(sync_coro)
1728+
tg.cancel_scope.cancel()

0 commit comments

Comments
 (0)