@@ -54,8 +54,10 @@ def __init__(
5454 exptr_problem_statement = exptr .problem_statement ()
5555
5656 if exptr_problem_statement .search_space .is_conditional :
57- raise ValueError ('Search space should not have conditional'
58- f' parameters { exptr_problem_statement } ' )
57+ raise ValueError (
58+ 'Search space should not have conditional'
59+ f' parameters { exptr_problem_statement } '
60+ )
5961 dimension = len (exptr_problem_statement .search_space .parameters )
6062 if dimension <= 0 :
6163 raise ValueError (f'Invalid dimension: { dimension } ' )
@@ -64,8 +66,8 @@ def __init__(
6466 self ._shift = np .broadcast_to (shift , (dimension ,))
6567 except ValueError as broadcast_err :
6668 raise ValueError (
67- f'Shift { shift } is not broadcastable for dim: { dimension } .'
68- ' \n ' ) from broadcast_err
69+ f'Shift { shift } is not broadcastable for dim: { dimension } .\n '
70+ ) from broadcast_err
6971
7072 # Converter should be in the underlying extpr space.
7173 self ._converter = converters .TrialToArrayConverter .from_study_config (
@@ -83,26 +85,28 @@ def __init__(
8385 ):
8486 if parameter .type != pyvizier .ParameterType .DOUBLE :
8587 raise ValueError (f'Non-double parameters { parameter } ' )
86- if (bounds := parameter .bounds ) is not None :
87- if abs (shift ) >= bounds [1 ] - bounds [0 ]:
88- raise ValueError (
89- f'Bounds { bounds } may need to be extended'
90- f'as shift { shift } is too large '
91- )
92- # Shift the bounds to maintain valid bounds.
93- if shift >= 0 :
94- new_bounds = (bounds [0 ] + shift , bounds [1 ])
95- else :
96- new_bounds = (bounds [0 ], bounds [1 ] + shift )
97- self ._problem_statement .search_space .add (
98- pyvizier .ParameterConfig .factory (
99- name = parameter .name ,
100- bounds = new_bounds ,
101- scale_type = parameter .scale_type ,
102- default_value = parameter .default_value ,
103- external_type = parameter .external_type ,
104- )
88+ if (bounds := parameter .bounds ) is None :
89+ raise ValueError (f'Parameter { parameter } has no bounds' )
90+
91+ if abs (shift ) >= bounds [1 ] - bounds [0 ]:
92+ raise ValueError (
93+ f'Bounds { bounds } may need to be extended'
94+ f'as shift { shift } is too large '
10595 )
96+ # Shift the bounds to maintain valid bounds.
97+ if shift >= 0 :
98+ new_bounds = (bounds [0 ] + shift , bounds [1 ])
99+ else :
100+ new_bounds = (bounds [0 ], bounds [1 ] + shift )
101+ self ._problem_statement .search_space .add (
102+ pyvizier .ParameterConfig .factory (
103+ name = parameter .name ,
104+ bounds = new_bounds ,
105+ scale_type = parameter .scale_type ,
106+ default_value = parameter .default_value ,
107+ external_type = parameter .external_type ,
108+ ),
109+ )
106110
107111 def problem_statement (self ) -> pyvizier .ProblemStatement :
108112 return copy .deepcopy (self ._problem_statement )
@@ -116,8 +120,9 @@ def evaluate(self, suggestions: Sequence[pyvizier.Trial]) -> None:
116120 for parameters , suggestion in zip (previous_parameters , suggestions ):
117121 suggestion .parameters = parameters
118122
119- def _offset (self , suggestions : Sequence [pyvizier .Trial ],
120- shift : np .ndarray ) -> None :
123+ def _offset (
124+ self , suggestions : Sequence [pyvizier .Trial ], shift : np .ndarray
125+ ) -> None :
121126 """Offsets parameter values (OOB values are clipped)."""
122127 for suggestion in suggestions :
123128 features = self ._converter .to_features ([suggestion ])
0 commit comments