"""GKLR Config module."""
from typing import Optional, Dict, Any
import os
import sys
import multiprocessing
import numpy as np
from .logger import *
__all__ = ["Config"]
def init_environment_variables(num_cores: Optional[int] = None):
"""Initializes the environment variables.
Args:
num_cores: The number of cores to use. If None, it uses all the cores
available. Default: None.
"""
os.environ["MKL_NUM_THREADS"] = str(num_cores)
os.environ["OMP_NUM_THREADS"] = str(num_cores)
[docs]class Config:
"""Configuration class for the GKLR package.
This class stores the configuration and hyperparameters for the GKLR package."""
def __init__(self):
"""Constructor."""
from gklr import __version__
self.info = {
"python_version": sys.version,
"GKLR_version": __version__,
"directory": os.getcwd(),
}
self.hyperparameters = {
"num_cores": multiprocessing.cpu_count(),
"kernel": "rbf",
"kernel_params": {"gamma": 1.0},
"nystrom": False,
"nystrom_sampling": "uniform",
"ridge_leverage_lambda": 1,
"compression": None,
}
init_environment_variables(self.hyperparameters["num_cores"])
def __str__(self):
rval = f"\nGKLR hyperparameters:\n---------------\n"
for key, val in self.hyperparameters.items():
rval += f" - {key:<24}: {val}\n"
rval += "\n"
return rval
def __getitem__(self, name: str) -> Any:
if name in self.hyperparameters:
return self.hyperparameters[name]
else:
return None
def __setitem__(self, name: str, val: Any) -> None:
if name in self.hyperparameters:
self.hyperparameters[name] = val
else:
raise NameError(f"Hyperparameter {name} is not a valid option.")
def __call__(self) -> Dict[str, Any]:
return self.hyperparameters
[docs] def set_hyperparameter(self, key: str, value: Any):
"""Helper method to set the hyperparameters of GKLR.
Args:
key: The hyperparameter to set.
value: The value to set the hyperparameter to.
"""
self.hyperparameters[key] = value
logger_debug(f"Set hyperparameter {key} = {value}")
[docs] def remove_hyperparameter(self, key: str):
"""Helper method to remove a hyperparameter from GKLR.
Args:
key: The hyperparameter to remove.
"""
if key in self.hyperparameters:
del self.hyperparameters[key]
logger_debug(f"Removed hyperparameter {key}")
else:
msg = f"Hyperparameter {key} not found"
logger_error(msg)
raise ValueError(msg)
[docs] def check_values(self):
"""Checks validity of hyperparameter values. Raises an error if any
of the hyperparameters is not valid.
"""
assert isinstance(self["num_cores"], (int, np.integer))
assert isinstance(self["kernel_params"]["gamma"], (float, np.floating))
assert isinstance(self["nystrom"], bool)
assert self["compression"] is None or isinstance(self["compression"], (float, np.floating)) \
or isinstance(self["compression"], (int, np.integer))
if self["compression"] > 1 and (self["compression"] != int(self["compression"])):
msg = ("When 'compression' hyperparameter is > 1, it must "
"be an integer representing the number of "
"Nyström components.")
logger_error(msg)
raise ValueError(msg)
elif self["compression"] <= 0:
msg = "'compression' hyperparameter must be a positive number."
logger_error(msg)
raise ValueError(msg)
# TODO: Assert kernel