Baselines module¶
-
class
activemri.baselines.
CVPR19Evaluator
(evaluator_path: str, device: torch.device, add_mask: bool = False)¶ Bases:
activemri.baselines.Policy
-
get_action
(obs: Dict[str, Any], **_kwargs) → List[int]¶ Returns a list of actions for a batch of observations.
-
-
class
activemri.baselines.
DDQN
(device: torch.device, memory: Optional[activemri.baselines.replay_buffer.ReplayMemory], opts: argparse.Namespace)¶ Bases:
torch.nn.modules.module.Module
,activemri.baselines.Policy
Implementation of Double DQN value network.
The configuration is given by the
opts
argument, which must contain the following fields:mask_embedding_dim(int): See
cvpr19_models.models.evaluator.EvaluatorNetwork
.gamma(float): Discount factor for target updates.
dqn_model_type(str): Describes the architecture of the neural net. Options are “simple_mlp” and “evaluator”, to use
SimpleMLP
andEvaluatorBasedValueNetwork
, respectively.budget(int): The environment’s budget.
image_width(int): The width of the input images.
- Parameters
device (
torch.device
) – Device to use.memory (optional(
replay_buffer.ReplayMemory
)) – Replay buffer to sample transitions from. Can beNone
, for example, if this is a target network.opts (
argparse.Namespace
) – Options for the algorithm as explained above.
-
forward
(x: torch.Tensor) → torch.Tensor¶ Predicts action values.
- Parameters
x (torch.Tensor) – The observation tensor.
- Returns
The predicted Q-values.
- Return type
Dictionary(torch.Tensor)
Note
Values corresponding to active k-space columns in the observation are manually set to
1e-10
.
-
get_action
(obs: Dict[str, Any], eps_threshold: float = 0.0) → List[int]¶ Returns an action sampled from an epsilon-greedy policy.
With probability epsilon sample a random k-space column (ignoring active columns), otherwise return the column with the highest estimated Q-value for the observation.
- Parameters
obs (torch.Tensor) – The observation for which an action is required.
eps_threshold (float) – The probability of sampling a random action instead of using a greedy action.
-
class
activemri.baselines.
DDQNTrainer
(options: argparse.Namespace, env: activemri.envs.envs.ActiveMRIEnv, device: torch.device)¶ Bases:
object
DDQN Trainer for active MRI acquisition.
Configuration for the trainer is provided by argument
options
. Must contain the following fields:“checkpoints_dir”(str): The directory where the model will be saved to (or loaded from).
dqn_batch_size(int): The batch size to use for updates.
dqn_burn_in(int): How many steps to do before starting updating parameters.
dqn_normalize(bool):
True
if running mean/st. deviation should be maintained for observations.dqn_only_test(bool):
True
if the model will not be trained, thus only will attempt to read from checkpoint and load only weights of the network (ignoring training related information).dqn_test_episode_freq(optional(int)): How frequently (in number of env steps) to perform test episodes.
freq_dqn_checkpoint_save(int): How often (in episodes) to save the model.
num_train_steps(int): How many environment steps to train for.
replay_buffer_size(int): The capacity of the replay buffer.
resume(bool): If true, will try to load weights from the checkpoints dir.
num_test_episodes(int): How many test episodes to periodically evaluate for.
seed(int): Sets the seed for the environment when running evaluation episodes.
reward_metric(str): Which of the
env.scores_keys()
is used as reward. Mainly used for logging purposes.target_net_update_freq(int): How often (in env’s steps) to update the target network.
- Parameters
options (
argparse.Namespace
) – Options for the trainer.env (
activemri.envs.ActiveMRIEnv
) – Env for which the policy is trained.device (
torch.device
) – Device to use.
-
class
activemri.baselines.
LowestIndexPolicy
(alternate_sides: bool, centered: bool = True)¶ Bases:
activemri.baselines.Policy
A policy that represents low-to-high frequency k-space selection.
- Parameters
alternate_sides (bool) – If
True
the indices of selected actions will alternate between the sides of the mask. For example, for an image with 100 columns, and non-centered k-space, the order will be 0, 99, 1, 98, 2, 97, …, etc. For the same size and centered, the order will be 49, 50, 48, 51, 47, 52, …, etc.centered (bool) – If
True
(default), low frequencies are in the center of the mask. Otherwise, they are in the edges of the mask.
-
get_action
(obs: Dict[str, Any], **_kwargs) → List[int]¶ Returns a random action without replacement.
- Parameters
obs (dict(str, any)) – As returned by
activemri.envs.ActiveMRIEnv
.- Returns
- A list of k-space column indices, one per batch element in
the observation, equal to the lowest non-active k-space column in their corresponding observation masks.
- Return type
list(int)
-
class
activemri.baselines.
OneStepGreedyOracle
(env: activemri.envs.envs.ActiveMRIEnv, metric: str, num_samples: Optional[int] = None, rng: Optional[numpy.random.mtrand.RandomState] = None)¶ Bases:
activemri.baselines.Policy
A policy that returns the k-space column leading to best reconstruction score.
- Parameters
env (
activemri.envs.ActiveMRIEnv
) – The environment for which the policy is computed for.metric (str) – The name of the score metric to use (must be in
env.score_keys()
).num_samples (optional(int)) – If given, only
num_samples
random actions will be tested. Defaults toNone
, which means that method will consider all actions.rng (
numpy.random.RandomState
) – A random number generator to use for sampling.
-
get_action
(obs: Dict[str, Any], **_kwargs) → List[int]¶ Returns a one-step greedy action maximizing reconstruction score.
- Parameters
obs (dict(str, any)) – As returned by
activemri.envs.ActiveMRIEnv
.- Returns
- A list of k-space column indices, one per batch element in
the observation, equal to the action that maximizes reconstruction score (e.g, SSIM or negative MSE).
- Return type
list(int)
-
class
activemri.baselines.
RandomLowBiasPolicy
(acceleration: float, centered: bool = True, seed: Optional[int] = None)¶ Bases:
activemri.baselines.Policy
-
get_action
(obs: Dict[str, Any], **_kwargs) → List[int]¶ Returns a list of actions for a batch of observations.
-
-
class
activemri.baselines.
RandomPolicy
(seed: Optional[int] = None)¶ Bases:
activemri.baselines.Policy
A policy representing random k-space selection.
Returns one of the valid actions uniformly at random.
- Parameters
seed (optional(int)) – The seed to use for the random number generator, which is based on
torch.Generator()
.
-
get_action
(obs: Dict[str, Any], **_kwargs) → List[int]¶ Returns a random action without replacement.
- Parameters
obs (dict(str, any)) – As returned by
activemri.envs.ActiveMRIEnv
.- Returns
- A list of random k-space column indices, one per batch element in
the observation. The indices are sampled from the set of inactive (0) columns on each batch element.
- Return type
list(int)