-
Notifications
You must be signed in to change notification settings - Fork 19
/
objective.py
79 lines (65 loc) · 3.26 KB
/
objective.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from benchopt import BaseObjective, safe_import_context
# Protect the import with `safe_import_context()`. This allows:
# - skipping import to speed up autocompletion in CLI.
# - getting requirements info when all dependencies are not installed.
with safe_import_context() as import_ctx:
import numpy as np
# The benchmark objective must be named `Objective` and
# inherit from `BaseObjective` for `benchopt` to work properly.
class Objective(BaseObjective):
# Name to select the objective in the CLI and to display the results.
name = "Ordinary Least Squares"
# URL of the main repo for this benchmark.
url = "https://github.com/#ORG/#BENCHMARK_NAME"
# List of parameters for the objective. The benchmark will consider
# the cross product for each key in the dictionary.
# All parameters 'p' defined here are available as 'self.p'.
# This means the OLS objective will have a parameter `self.whiten_y`.
parameters = {
'whiten_y': [False, True],
}
# List of packages needed to run the benchmark.
# They are installed with conda; to use pip, use 'pip:packagename'. To
# install from a specific conda channel, use 'channelname:packagename'.
# Packages that are not necessary to the whole benchmark but only to some
# solvers or datasets should be declared in Dataset or Solver (see
# simulated.py and python-gd.py).
# Example syntax: requirements = ['numpy', 'pip:jax', 'pytorch:pytorch']
requirements = ["numpy"]
# Minimal version of benchopt required to run this benchmark.
# Bump it up if the benchmark depends on a new feature of benchopt.
min_benchopt_version = "1.5"
def set_data(self, X, y):
# The keyword arguments of this function are the keys of the dictionary
# returned by `Dataset.get_data`. This defines the benchmark's
# API to pass data. This is customizable for each benchmark.
self.X, self.y = X, y
# `set_data` can be used to preprocess the data. For instance,
# if `whiten_y` is True, remove the mean of `y`.
if self.whiten_y:
y -= y.mean(axis=0)
def evaluate_result(self, beta):
# The keyword arguments of this function are the keys of the
# dictionary returned by `Solver.get_result`. This defines the
# benchmark's API to pass solvers' result. This is customizable for
# each benchmark.
diff = self.y - self.X @ beta
# This method can return many metrics in a dictionary. One of these
# metrics needs to be `value` for convergence detection purposes.
return dict(
value=.5 * diff @ diff,
)
def get_one_result(self):
# Return one solution. The return value should be an object compatible
# with `self.evaluate_result`. This is mainly for testing purposes.
return dict(beta=np.zeros(self.X.shape[1]))
def get_objective(self):
# Define the information to pass to each solver to run the benchmark.
# The output of this function are the keyword arguments
# for `Solver.set_objective`. This defines the
# benchmark's API for passing the objective to the solver.
# It is customizable for each benchmark.
return dict(
X=self.X,
y=self.y,
)