Skip to content

Commit 891c873

Browse files
authored
[sglang, rollout] refactor: use torch.Tensor in async rollout schemas (verl-project#2362)
1 parent 2a01b21 commit 891c873

4 files changed

Lines changed: 226 additions & 246 deletions

File tree

tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,26 +76,26 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript
7676
tool_schemas=[],
7777
tools_kwargs={},
7878
interaction_kwargs={},
79-
input_ids=[],
80-
prompt_ids=[],
81-
response_ids=[],
82-
attention_mask=[],
83-
prompt_attention_mask=[],
84-
response_attention_mask=[],
85-
position_ids=[],
86-
prompt_position_ids=[],
87-
response_position_ids=[],
88-
loss_mask=[],
89-
prompt_loss_mask=[],
90-
response_loss_mask=[],
79+
input_ids=None,
80+
prompt_ids=None,
81+
response_ids=None,
82+
attention_mask=None,
83+
prompt_attention_mask=None,
84+
response_attention_mask=None,
85+
position_ids=None,
86+
prompt_position_ids=None,
87+
response_position_ids=None,
88+
loss_mask=None,
89+
prompt_loss_mask=None,
90+
response_loss_mask=None,
9191
reward_scores={},
9292
max_prompt_len=8192,
9393
max_response_len=8192,
9494
max_model_len=16384,
9595
metrics={},
9696
use_inference_chat_template=True,
9797
tokenization_sanity_check_mode=TokenizationSanityCheckModeEnum.STRICT,
98-
generation_prompt_ids=[],
98+
generation_prompt_ids=None,
9999
base_conv_wo_gen_prompt_end_pos=0,
100100
base_conv_with_gen_prompt_end_pos=0,
101101
processing_class=processor,
@@ -108,9 +108,9 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript
108108
continue
109109
_ = req.get_generation_prompt_ids(processor)
110110
req.add_assistant_message(processor, content=description_list[idx - 1])
111-
before_tool_call_len = len(req.input_ids)
111+
before_tool_call_len = req.input_ids.shape[-1]
112112
req.add_tool_response_messages(processor, [{"image": [img], "text": "Here is the new image you requested: "}])
113-
after_tool_call_len = len(req.input_ids)
113+
after_tool_call_len = req.input_ids.shape[-1]
114114
if prev_generated_len == 0:
115115
prev_generated_len = after_tool_call_len - before_tool_call_len
116116
else:
@@ -133,11 +133,11 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript
133133
return_dict=True,
134134
)
135135
full_prompt_ids = full_prompt_info["input_ids"]
136-
assert full_prompt_ids == req.input_ids
136+
assert full_prompt_ids.eq(req.input_ids).all()
137137

138138
# We must use dict(full_prompt_info) to convert BatchFeature values to a new dict
139139
# because np.array() only keeps the keys for BatchFeature.
140-
full_prompt_multi_modal_inputs = dict(full_prompt_info)
140+
full_prompt_multi_modal_inputs = full_prompt_info.copy()
141141
full_prompt_multi_modal_inputs.pop("input_ids", None)
142142
full_prompt_multi_modal_inputs.pop("attention_mask", None)
143143

tests/workers/rollout/test_sglang_multi_interaction.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,18 @@ def test_interaction_selection_by_name(self):
193193
state=AsyncRolloutRequestStateEnum.INTERACTING,
194194
messages=[Message(role="user", content="test message")],
195195
interaction_kwargs={"name": "mock_agent2", "test_param": "value"},
196-
input_ids=[],
197-
prompt_ids=[],
198-
response_ids=[],
199-
attention_mask=[],
200-
prompt_attention_mask=[],
201-
response_attention_mask=[],
202-
position_ids=[],
203-
prompt_position_ids=[],
204-
response_position_ids=[],
205-
loss_mask=[],
206-
prompt_loss_mask=[],
207-
response_loss_mask=[],
196+
input_ids=None,
197+
prompt_ids=None,
198+
response_ids=None,
199+
attention_mask=None,
200+
prompt_attention_mask=None,
201+
response_attention_mask=None,
202+
position_ids=None,
203+
prompt_position_ids=None,
204+
response_position_ids=None,
205+
loss_mask=None,
206+
prompt_loss_mask=None,
207+
response_loss_mask=None,
208208
reward_scores={},
209209
max_prompt_len=32,
210210
max_response_len=16,

0 commit comments

Comments
 (0)