Baselines module

class activemri.baselines.CVPR19Evaluator(evaluator_path: str, device: torch.device, add_mask: bool = False)

Bases: activemri.baselines.Policy

get_action(obs: Dict[str, Any], **_kwargs) → List[int]

Returns a list of actions for a batch of observations.

class activemri.baselines.DDQN(device: torch.device, memory: Optional[activemri.baselines.replay_buffer.ReplayMemory], opts: argparse.Namespace)

Bases: torch.nn.modules.module.Module, activemri.baselines.Policy

Implementation of Double DQN value network.

The configuration is given by the opts argument, which must contain the following fields:

  • mask_embedding_dim(int): See cvpr19_models.models.evaluator.EvaluatorNetwork.

  • gamma(float): Discount factor for target updates.

  • dqn_model_type(str): Describes the architecture of the neural net. Options are “simple_mlp” and “evaluator”, to use SimpleMLP and EvaluatorBasedValueNetwork, respectively.

  • budget(int): The environment’s budget.

  • image_width(int): The width of the input images.

Parameters
  • device (torch.device) – Device to use.

  • memory (optional(replay_buffer.ReplayMemory)) – Replay buffer to sample transitions from. Can be None, for example, if this is a target network.

  • opts (argparse.Namespace) – Options for the algorithm as explained above.

forward(x: torch.Tensor) → torch.Tensor

Predicts action values.

Parameters

x (torch.Tensor) – The observation tensor.

Returns

The predicted Q-values.

Return type

Dictionary(torch.Tensor)

Note

Values corresponding to active k-space columns in the observation are manually set to 1e-10.

get_action(obs: Dict[str, Any], eps_threshold: float = 0.0) → List[int]

Returns an action sampled from an epsilon-greedy policy.

With probability epsilon sample a random k-space column (ignoring active columns), otherwise return the column with the highest estimated Q-value for the observation.

Parameters
  • obs (torch.Tensor) – The observation for which an action is required.

  • eps_threshold (float) – The probability of sampling a random action instead of using a greedy action.

class activemri.baselines.DDQNTrainer(options: argparse.Namespace, env: activemri.envs.envs.ActiveMRIEnv, device: torch.device)

Bases: object

DDQN Trainer for active MRI acquisition.

Configuration for the trainer is provided by argument options. Must contain the following fields:

  • “checkpoints_dir”(str): The directory where the model will be saved to (or loaded from).

  • dqn_batch_size(int): The batch size to use for updates.

  • dqn_burn_in(int): How many steps to do before starting updating parameters.

  • dqn_normalize(bool): True if running mean/st. deviation should be maintained for observations.

  • dqn_only_test(bool): True if the model will not be trained, thus only will attempt to read from checkpoint and load only weights of the network (ignoring training related information).

  • dqn_test_episode_freq(optional(int)): How frequently (in number of env steps) to perform test episodes.

  • freq_dqn_checkpoint_save(int): How often (in episodes) to save the model.

  • num_train_steps(int): How many environment steps to train for.

  • replay_buffer_size(int): The capacity of the replay buffer.

  • resume(bool): If true, will try to load weights from the checkpoints dir.

  • num_test_episodes(int): How many test episodes to periodically evaluate for.

  • seed(int): Sets the seed for the environment when running evaluation episodes.

  • reward_metric(str): Which of the env.scores_keys() is used as reward. Mainly used for logging purposes.

  • target_net_update_freq(int): How often (in env’s steps) to update the target network.

Parameters
  • options (argparse.Namespace) – Options for the trainer.

  • env (activemri.envs.ActiveMRIEnv) – Env for which the policy is trained.

  • device (torch.device) – Device to use.

class activemri.baselines.LowestIndexPolicy(alternate_sides: bool, centered: bool = True)

Bases: activemri.baselines.Policy

A policy that represents low-to-high frequency k-space selection.

Parameters
  • alternate_sides (bool) – If True the indices of selected actions will alternate between the sides of the mask. For example, for an image with 100 columns, and non-centered k-space, the order will be 0, 99, 1, 98, 2, 97, …, etc. For the same size and centered, the order will be 49, 50, 48, 51, 47, 52, …, etc.

  • centered (bool) – If True (default), low frequencies are in the center of the mask. Otherwise, they are in the edges of the mask.

get_action(obs: Dict[str, Any], **_kwargs) → List[int]

Returns a random action without replacement.

Parameters

obs (dict(str, any)) – As returned by activemri.envs.ActiveMRIEnv.

Returns

A list of k-space column indices, one per batch element in

the observation, equal to the lowest non-active k-space column in their corresponding observation masks.

Return type

list(int)

class activemri.baselines.OneStepGreedyOracle(env: activemri.envs.envs.ActiveMRIEnv, metric: str, num_samples: Optional[int] = None, rng: Optional[numpy.random.mtrand.RandomState] = None)

Bases: activemri.baselines.Policy

A policy that returns the k-space column leading to best reconstruction score.

Parameters
  • env (activemri.envs.ActiveMRIEnv) – The environment for which the policy is computed for.

  • metric (str) – The name of the score metric to use (must be in env.score_keys()).

  • num_samples (optional(int)) – If given, only num_samples random actions will be tested. Defaults to None, which means that method will consider all actions.

  • rng (numpy.random.RandomState) – A random number generator to use for sampling.

get_action(obs: Dict[str, Any], **_kwargs) → List[int]

Returns a one-step greedy action maximizing reconstruction score.

Parameters

obs (dict(str, any)) – As returned by activemri.envs.ActiveMRIEnv.

Returns

A list of k-space column indices, one per batch element in

the observation, equal to the action that maximizes reconstruction score (e.g, SSIM or negative MSE).

Return type

list(int)

class activemri.baselines.RandomLowBiasPolicy(acceleration: float, centered: bool = True, seed: Optional[int] = None)

Bases: activemri.baselines.Policy

get_action(obs: Dict[str, Any], **_kwargs) → List[int]

Returns a list of actions for a batch of observations.

class activemri.baselines.RandomPolicy(seed: Optional[int] = None)

Bases: activemri.baselines.Policy

A policy representing random k-space selection.

Returns one of the valid actions uniformly at random.

Parameters

seed (optional(int)) – The seed to use for the random number generator, which is based on torch.Generator().

get_action(obs: Dict[str, Any], **_kwargs) → List[int]

Returns a random action without replacement.

Parameters

obs (dict(str, any)) – As returned by activemri.envs.ActiveMRIEnv.

Returns

A list of random k-space column indices, one per batch element in

the observation. The indices are sampled from the set of inactive (0) columns on each batch element.

Return type

list(int)