@@ -76,26 +76,26 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript
7676 tool_schemas = [],
7777 tools_kwargs = {},
7878 interaction_kwargs = {},
79- input_ids = [] ,
80- prompt_ids = [] ,
81- response_ids = [] ,
82- attention_mask = [] ,
83- prompt_attention_mask = [] ,
84- response_attention_mask = [] ,
85- position_ids = [] ,
86- prompt_position_ids = [] ,
87- response_position_ids = [] ,
88- loss_mask = [] ,
89- prompt_loss_mask = [] ,
90- response_loss_mask = [] ,
79+ input_ids = None ,
80+ prompt_ids = None ,
81+ response_ids = None ,
82+ attention_mask = None ,
83+ prompt_attention_mask = None ,
84+ response_attention_mask = None ,
85+ position_ids = None ,
86+ prompt_position_ids = None ,
87+ response_position_ids = None ,
88+ loss_mask = None ,
89+ prompt_loss_mask = None ,
90+ response_loss_mask = None ,
9191 reward_scores = {},
9292 max_prompt_len = 8192 ,
9393 max_response_len = 8192 ,
9494 max_model_len = 16384 ,
9595 metrics = {},
9696 use_inference_chat_template = True ,
9797 tokenization_sanity_check_mode = TokenizationSanityCheckModeEnum .STRICT ,
98- generation_prompt_ids = [] ,
98+ generation_prompt_ids = None ,
9999 base_conv_wo_gen_prompt_end_pos = 0 ,
100100 base_conv_with_gen_prompt_end_pos = 0 ,
101101 processing_class = processor ,
@@ -108,9 +108,9 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript
108108 continue
109109 _ = req .get_generation_prompt_ids (processor )
110110 req .add_assistant_message (processor , content = description_list [idx - 1 ])
111- before_tool_call_len = len ( req .input_ids )
111+ before_tool_call_len = req .input_ids . shape [ - 1 ]
112112 req .add_tool_response_messages (processor , [{"image" : [img ], "text" : "Here is the new image you requested: " }])
113- after_tool_call_len = len ( req .input_ids )
113+ after_tool_call_len = req .input_ids . shape [ - 1 ]
114114 if prev_generated_len == 0 :
115115 prev_generated_len = after_tool_call_len - before_tool_call_len
116116 else :
@@ -133,11 +133,11 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript
133133 return_dict = True ,
134134 )
135135 full_prompt_ids = full_prompt_info ["input_ids" ]
136- assert full_prompt_ids == req .input_ids
136+ assert full_prompt_ids . eq ( req .input_ids ). all ()
137137
138138 # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict
139139 # because np.array() only keeps the keys for BatchFeature.
140- full_prompt_multi_modal_inputs = dict ( full_prompt_info )
140+ full_prompt_multi_modal_inputs = full_prompt_info . copy ( )
141141 full_prompt_multi_modal_inputs .pop ("input_ids" , None )
142142 full_prompt_multi_modal_inputs .pop ("attention_mask" , None )
143143
0 commit comments