I begin by saying that I am totally new to this branch of programming, but i think that scipy optimization could be the solution.
I need to find the parameters that return the highest result in a function, but only if the result respect a condition.
The function is so long and takes more than 40 parameters, so it's impossible and too slow to brute-force them, the function returns 2 arrays of the same length in output.
constant = [1,2,3,4,4,3,5,6,7,8]
def fun(constant, length, period, multiplier, factor, ... ):
do long and complicated calculations
return array1, array2
Now, what I need is to find the parameters that returns the highest array1[-1]
valueif max(array2) < 40
(for example), and then print them.
All parameters (length, period, multiplier, factor) works in a range from 2 to 200. Instead constant
, obviously, should not be affected from the optimization.
I tried looping over all parameters by the ranges and execute calculation once by once, but it's very inefficent, intricate, and I think is not giving the best results.
How can I perform this type of parameter optimization?
CodePudding user response:
If you want to build from scratch, a simple completely random "good enough" solver might look like this.
The solver is the first function, the rest are your (user) functions.
You need
- your target long and complicated function
- a function that returns a score for a given generated result (or zero if the result is invalid)
- a dictionary of functions that return values for each argument to try.
import random
import time
def find_solution(
target_function,
score_solution,
param_generators,
max_iterations=10_000_000,
max_time=60,
):
best_solution = None
best_score = 0
start_time = time.time()
for i in range(max_iterations):
params = {param: gen() for param, gen in param_generators.items()}
solution = target_function(**params)
score = score_solution(solution)
if score > best_score:
best_score = score
best_solution = (params, solution)
print(f"{i} / New best solution: {best_solution}")
if time.time() - start_time > max_time:
print(f"{i} / Time limit reached")
break
return (best_solution, best_score)
def fun(constant, length, period, multiplier, factor):
a = constant * length * period * multiplier * factor
b = length * period
return (a, [b])
def sol_scorer(sol):
if max(sol[1]) < 40: # Invalid; return 0
return 0
return sol[0]
def main():
constants = [1, 2, 3, 4, 4, 3, 5, 6, 7, 8]
param_generators = {
"constant": lambda: random.choice(constants),
"length": lambda: random.randint(1, 100),
"period": lambda: random.randint(1, 100),
"multiplier": lambda: random.randint(1, 100),
"factor": lambda: random.randint(1, 100),
}
res = find_solution(
fun,
sol_scorer,
param_generators,
max_iterations=10_000_000,
max_time=10,
)
print(res)
if __name__ == "__main__":
main()
On my machine, this prints out e.g.
227838 / New best solution: ({'constant': 8, 'length': 98, 'period': 93, 'multiplier': 96, 'factor': 99}, (692955648, [9114]))
1085159 / New best solution: ({'constant': 8, 'length': 98, 'period': 99, 'multiplier': 91, 'factor': 100}, (706305600, [9702]))
1447216 / New best solution: ({'constant': 8, 'length': 99, 'period': 97, 'multiplier': 97, 'factor': 97}, (722837016, [9603]))
2325989 / Time limit reached
(({'constant': 8, 'length': 99, 'period': 97, 'multiplier': 97, 'factor': 97}, (722837016, [9603])), 722837016)
With sequential combinations
Adding an option to always try some combinations isn't much more code; see below.
import random
import time
from itertools import product
from typing import Any, Callable, Optional, Iterable
def find_solution(
target_function: Callable[..., Any],
score_solution: Callable[[Any], float],
param_generators: dict[str, Callable[[], Any]],
sequential_combination_generator: Optional[Iterable[dict]] = None,
max_iterations: int = 10_000_000,
max_time: float = 60.0,
) -> tuple[Any, float]:
best_solution = None
best_score = 0.0
start_time = time.time()
if sequential_combination_generator is None:
sequential_combination_generator = [{}]
try:
for sequential_combination in sequential_combination_generator:
print(f"Trying {max_iterations} w/: {sequential_combination}")
for i in range(max_iterations):
# Merge the sequential params with the randomly generated params
params = {
**sequential_combination,
**{param: gen() for param, gen in param_generators.items()},
}
solution = target_function(**params)
score = score_solution(solution)
if score > best_score:
best_score = score
best_solution = (params, solution)
print(f"Iteration {i}: New best solution: {best_solution}")
if time.time() - start_time > max_time:
raise TimeoutError(f"Time limit reached")
except TimeoutError as e:
print(e)
return (best_solution, best_score)
def generate_parameter_combinations(sequential_params: dict[str, list]) -> Iterable[dict]:
# Break the sequential_params dict into keys and values
keys, values = zip(*sequential_params.items())
# Yield each combination as a dict
for combination in product(*values):
yield dict(zip(keys, combination))
def fun(constant, length, period, multiplier, factor):
a = constant * length * period * multiplier * factor
b = length * period
return (a, [b])
def sol_scorer(sol):
if max(sol[1]) < 40: # Invalid; return 0
return 0
return sol[0]
def main():
constants = [1, 2, 3, 4, 4, 3, 5, 6, 7, 8]
# All of these combinations will exhaustively tried
sequential_params = generate_parameter_combinations(
{
"length": [10, 20, 30, 40],
"period": [40, 30, 20, 10],
}
)
# You can also pass in just a list of dicts, á la
# sequential_params = [
# {"length": 10, "period": 40},
# {"length": 20, "period": 30},
# ]
# These will be randomly generated
param_generators = {
"constant": lambda: random.choice(constants),
"multiplier": lambda: random.randint(1, 100),
"factor": lambda: random.randint(1, 100),
}
res = find_solution(
fun,
sol_scorer,
param_generators=param_generators,
sequential_combination_generator=sequential_params,
max_iterations=10_000, # Limit for each sequential combination
max_time=10, # Total time limit
)
print(res)
if __name__ == "__main__":
main()
This prints out e.g.
Trying 10000 w/: {'length': 10, 'period': 40}
Iteration 0: New best solution: ({'length': 10, 'period': 40, 'constant': 2, 'multiplier': 95, 'factor': 64}, (4864000, [400]))
Iteration 1: New best solution: ({'length': 10, 'period': 40, 'constant': 7, 'multiplier': 73, 'factor': 93}, (19009200, [400]))
Iteration 71: New best solution: ({'length': 10, 'period': 40, 'constant': 6, 'multiplier': 96, 'factor': 93}, (21427200, [400]))
Iteration 248: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 80, 'factor': 89}, (22784000, [400]))
Iteration 595: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 79, 'factor': 97}, (24521600, [400]))
Iteration 679: New best solution: ({'length': 10, 'period': 40, 'constant': 7, 'multiplier': 96, 'factor': 99}, (26611200, [400]))
Iteration 722: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 98, 'factor': 93}, (29164800, [400]))
Iteration 6065: New best solution: ({'length': 10, 'period': 40, 'constant': 8, 'multiplier': 98, 'factor': 94}, (29478400, [400]))
Trying 10000 w/: {'length': 10, 'period': 30}
Trying 10000 w/: {'length': 20, 'period': 40}
Iteration 0: New best solution: ({'length': 20, 'period': 40, 'constant': 7, 'multiplier': 70, 'factor': 79}, (30968000, [800]))
Iteration 26: New best solution: ({'length': 20, 'period': 40, 'constant': 8, 'multiplier': 96, 'factor': 63}, (38707200, [800]))
Iteration 54: New best solution: ({'length': 20, 'period': 40, 'constant': 8, 'multiplier': 81, 'factor': 78}, (40435200, [800]))
Iteration 80: New best solution: ({'length': 20, 'period': 40, 'constant': 8, 'multiplier': 80, 'factor': 97}, (49664000, [800]))
...
Iteration 4500: New best solution: ({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 94, 'factor': 93}, (111897600, [1600]))
Iteration 5638: New best solution: ({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 96, 'factor': 97}, (119193600, [1600]))
Iteration 6006: New best solution: ({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 99, 'factor': 99}, (125452800, [1600]))
Trying 10000 w/: {'length': 40, 'period': 30}
(({'length': 40, 'period': 40, 'constant': 8, 'multiplier': 99, 'factor': 99}, (125452800, [1600])), 125452800)