Approximators
MushroomRL exposes the high-level class Regressor
that can manage any type of
function regressor. This class is a wrapper for any kind of function
approximator, e.g. a scikit-learn approximator or a pytorch neural network.
Regressor
- class Regressor(approximator, input_shape, output_shape=None, n_actions=None, n_models=None, **params)[source]
Bases:
Serializable
This class implements the function to manage a function approximator. This class selects the appropriate kind of regressor to implement according to the parameters provided by the user; this makes this class the only one to use for each kind of task that has to be performed. The inference of the implementation to choose is done checking the provided values of parameters
n_actions
. Ifn_actions
is provided, it means that the user wants to implement an approximator of the Q-function: if the value ofn_actions
is equal to theoutput_shape
then aQRegressor
is created, else (output_shape
should be (1,)) anActionRegressor
is created. Otherwise aGenericRegressor
is created. AnEnsemble
model can be used for all the previous implementations listed before simply providing an_models
parameter greater than 1.- __init__(approximator, input_shape, output_shape=None, n_actions=None, n_models=None, **params)[source]
Constructor.
- Parameters:
approximator (class) – the approximator class to use to create the model;
input_shape (tuple) – the shape of the input of the model;
output_shape (tuple, None) – the shape of the output of the model;
n_actions (int, None) – number of actions considered to create a
QRegressor
or anActionRegressor
;n_models (int, 1) – number of models to create;
**params – other parameters to create each model.
- fit(*z, **fit_params)[source]
Fit the model.
- Parameters:
*z – list of input of the model;
**fit_params – parameters to use to fit the model.
- predict(*z, **predict_params)[source]
Predict the output of the model given an input.
- Parameters:
*z – list of input of the model;
**predict_params – parameters to use to predict with the model.
- Returns:
The model prediction.
- property model
Returns: The model object.
- property input_shape
Returns: The shape of the input of the model.
- property output_shape
Returns: The shape of the output of the model.
- property weights_size
Returns: The shape of the weights of the model.
- _add_save_attr(**attr_dict)
Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
- Parameters:
**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.
- _post_load()
This method can be overwritten to implement logic that is executed after the loading of the agent.
- copy()
- Returns:
A deepcopy of the agent.
- classmethod load(path)
Load and deserialize the agent from the given location on disk.
- Parameters:
path (Path, string) – Relative or absolute path to the agents save location.
- Returns:
The loaded agent.
- save(path, full_save=False)
Serialize and save the object to the given path on disk.
- Parameters:
path (Path, str) – Relative or absolute path to the object save location;
full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
- save_zip(zip_file, full_save, folder='')
Serialize and save the agent to the given path on disk.
- Parameters:
zip_file (ZipFile) – ZipFile where te object needs to be saved;
full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
folder (string, '') – subfolder to be used by the save method.
Approximator
Linear
- class LinearApproximator(weights=None, input_shape=None, output_shape=(1,), **kwargs)[source]
Bases:
Serializable
This class implements a linear approximator.
- __init__(weights=None, input_shape=None, output_shape=(1,), **kwargs)[source]
Constructor.
- Parameters:
weights (np.ndarray) – array of weights to initialize the weights of the approximator;
input_shape (np.ndarray, None) – the shape of the input of the model;
output_shape (np.ndarray, (1,)) – the shape of the output of the model;
**kwargs – other params of the approximator.
- fit(x, y, **fit_params)[source]
Fit the model.
- Parameters:
x (np.ndarray) – input;
y (np.ndarray) – target;
**fit_params – other parameters used by the fit method of the regressor.
- predict(x, **predict_params)[source]
Predict.
- Parameters:
x (np.ndarray) – input;
**predict_params – other parameters used by the predict method the regressor.
- Returns:
The predictions of the model.
- property weights_size
Returns: The size of the array of weights.
- _add_save_attr(**attr_dict)
Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
- Parameters:
**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.
- _post_load()
This method can be overwritten to implement logic that is executed after the loading of the agent.
- copy()
- Returns:
A deepcopy of the agent.
- diff(state, action=None)[source]
Compute the derivative of the output w.r.t.
state
, andaction
if provided.- Parameters:
state (np.ndarray) – the state;
action (np.ndarray, None) – the action.
- Returns:
The derivative of the output w.r.t.
state
, andaction
if provided.
- classmethod load(path)
Load and deserialize the agent from the given location on disk.
- Parameters:
path (Path, string) – Relative or absolute path to the agents save location.
- Returns:
The loaded agent.
- save(path, full_save=False)
Serialize and save the object to the given path on disk.
- Parameters:
path (Path, str) – Relative or absolute path to the object save location;
full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
- save_zip(zip_file, full_save, folder='')
Serialize and save the agent to the given path on disk.
- Parameters:
zip_file (ZipFile) – ZipFile where te object needs to be saved;
full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
folder (string, '') – subfolder to be used by the save method.
CMAC
- class CMAC(tilings, weights=None, output_shape=(1,), **kwargs)[source]
Bases:
LinearApproximator
This class implements a Cerebellar Model Arithmetic Computer.
- __init__(tilings, weights=None, output_shape=(1,), **kwargs)[source]
Constructor.
- Parameters:
tilings (list) – list of tilings to discretize the input space.
weights (np.ndarray) – array of weights to initialize the weights of the approximator;
input_shape (np.ndarray, None) – the shape of the input of the model;
output_shape (np.ndarray, (1,)) – the shape of the output of the model;
**kwargs – other params of the approximator.
- fit(x, y, alpha=1.0, **kwargs)[source]
Fit the model.
- Parameters:
x (np.ndarray) – input;
y (np.ndarray) – target;
alpha (float) – learning rate;
**kwargs – other parameters used by the fit method of the regressor.
- predict(x, **predict_params)[source]
Predict.
- Parameters:
x (np.ndarray) – input;
**predict_params – other parameters used by the predict method the regressor.
- Returns:
The predictions of the model.
- diff(state, action=None)[source]
Compute the derivative of the output w.r.t.
state
, andaction
if provided.- Parameters:
state (np.ndarray) – the state;
action (np.ndarray, None) – the action.
- Returns:
The derivative of the output w.r.t.
state
, andaction
if provided.
- _add_save_attr(**attr_dict)
Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
- Parameters:
**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.
- _post_load()
This method can be overwritten to implement logic that is executed after the loading of the agent.
- copy()
- Returns:
A deepcopy of the agent.
- get_weights()
Getter.
- Returns:
The set of weights of the approximator.
- classmethod load(path)
Load and deserialize the agent from the given location on disk.
- Parameters:
path (Path, string) – Relative or absolute path to the agents save location.
- Returns:
The loaded agent.
- save(path, full_save=False)
Serialize and save the object to the given path on disk.
- Parameters:
path (Path, str) – Relative or absolute path to the object save location;
full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
- save_zip(zip_file, full_save, folder='')
Serialize and save the agent to the given path on disk.
- Parameters:
zip_file (ZipFile) – ZipFile where te object needs to be saved;
full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
folder (string, '') – subfolder to be used by the save method.
- set_weights(w)
Setter.
- Parameters:
w (np.ndarray) – the set of weights to set.
- property weights_size
Returns: The size of the array of weights.
Torch Approximator
- class TorchApproximator(input_shape, output_shape, network, optimizer=None, loss=None, batch_size=0, n_fit_targets=1, use_cuda=False, reinitialize=False, dropout=False, quiet=True, **params)[source]
Bases:
Serializable
Class to interface a pytorch model to the mushroom Regressor interface. This class implements all is needed to use a generic pytorch model and train it using a specified optimizer and objective function. This class supports also minibatches.
- __init__(input_shape, output_shape, network, optimizer=None, loss=None, batch_size=0, n_fit_targets=1, use_cuda=False, reinitialize=False, dropout=False, quiet=True, **params)[source]
Constructor.
- Parameters:
input_shape (tuple) – shape of the input of the network;
output_shape (tuple) – shape of the output of the network;
network (torch.nn.Module) – the network class to use;
optimizer (dict) – the optimizer used for every fit step;
loss (torch.nn.functional) – the loss function to optimize in the fit method;
batch_size (int, 0) – the size of each minibatch. If 0, the whole dataset is fed to the optimizer at each epoch;
n_fit_targets (int, 1) – the number of fit targets used by the fit method of the network;
use_cuda (bool, False) – if True, runs the network on the GPU;
reinitialize (bool, False) – if True, the approximator is re initialized at every fit call. To perform the initialization, the weights_init method must be defined properly for the selected model network.
dropout (bool, False) – if True, dropout is applied only during train;
quiet (bool, True) – if False, shows two progress bars, one for epochs and one for the minibatches;
**params – dictionary of parameters needed to construct the network.
- predict(*args, output_tensor=False, **kwargs)[source]
Predict.
- Parameters:
*args – input;
output_tensor (bool, False) – whether to return the output as tensor or not;
**kwargs – other parameters used by the predict method the regressor.
- Returns:
The predictions of the model.
- fit(*args, n_epochs=None, weights=None, epsilon=None, patience=1, validation_split=1.0, **kwargs)[source]
Fit the model.
- Parameters:
*args – input, where the last
n_fit_targets
elements are considered as the target, while the others are considered as input;n_epochs (int, None) – the number of training epochs;
weights (np.ndarray, None) – the weights of each sample in the computation of the loss;
epsilon (float, None) – the coefficient used for early stopping;
patience (float, 1.) – the number of epochs to wait until stop the learning if not improving;
validation_split (float, 1.) – the percentage of the dataset to use as training set;
**kwargs – other parameters used by the fit method of the regressor.
- property weights_size
Returns: The size of the array of weights.
- diff(*args, **kwargs)[source]
Compute the derivative of the output w.r.t.
state
, andaction
if provided.- Parameters:
state (np.ndarray) – the state;
action (np.ndarray, None) – the action.
- Returns:
The derivative of the output w.r.t.
state
, andaction
if provided.
- property loss_fit
Returns: The average loss of the last epoch of the last fit call.
- _post_load()[source]
This method can be overwritten to implement logic that is executed after the loading of the agent.
- _add_save_attr(**attr_dict)
Add attributes that should be saved for an agent. For every attribute, it is necessary to specify the method to be used to save and load. Available methods are: numpy, mushroom, torch, json, pickle, primitive and none. The primitive method can be used to store primitive attributes, while the none method always skip the attribute, but ensure that it is initialized to None after the load. The mushroom method can be used with classes that implement the Serializable interface. All the other methods use the library named. If a “!” character is added at the end of the method, the field will be saved only if full_save is set to True.
- Parameters:
**attr_dict – dictionary of attributes mapped to the method that should be used to save and load them.
- copy()
- Returns:
A deepcopy of the agent.
- classmethod load(path)
Load and deserialize the agent from the given location on disk.
- Parameters:
path (Path, string) – Relative or absolute path to the agents save location.
- Returns:
The loaded agent.
- save(path, full_save=False)
Serialize and save the object to the given path on disk.
- Parameters:
path (Path, str) – Relative or absolute path to the object save location;
full_save (bool) – Flag to specify the amount of data to save for MushroomRL data structures.
- save_zip(zip_file, full_save, folder='')
Serialize and save the agent to the given path on disk.
- Parameters:
zip_file (ZipFile) – ZipFile where te object needs to be saved;
full_save (bool) – flag to specify the amount of data to save for MushroomRL data structures;
folder (string, '') – subfolder to be used by the save method.