Adding your own reconstructor model

Configuring the RL environment to use your own reconstruction model involves three steps:

  1. Create a transform function to convert data loader outputs into an batched input for your reconstructor model.

  2. Refactor your reconstructor model to follow our Reconstructor interface.

  3. Modify the environment’s JSON configuration file accordingly.

We explain the first two steps below. The third step is explained in Environment’s JSON configuration.

Transform function

Communication between the fastMRI data loader and your reconstructor is done via a transform function, which should follow the signature of transform_template(). The environment will first load data from the fastMRI dataset, and collate it to meet the input format indicated in the documentation of transform_template(). The environment will then pass this input to your provided transform function (as separate keyword arguments), and subsequently pass the output of the transform to the reconstructor model. The complete sequence will roughly conform to the following pseudocode.

kspace, _, ground_truth, attrs, fname, slice_id = data_handler[item]
mask = get_current_active_mask()
reconstructor_input = transform(
    kspace=kspace,
    mask=mask,
    ground_truth=ground_truth,
    attrs=attrs,
    fname=fname,
    slice_id=slice_id
)
reconstructor_output = reconstructor(*reconstructor_input)

Some examples of transform functions are available:

Note

If your reconstructor only needs as input a zero-filled reconstruction (i.e, inverse Fourier transform for non-zero k-space columns), and perhaps mean and standard deviation to use for normalization, then fastmri_unet_transform_singlecoil() and fastmri_unet_transform_multicoil() should be a good place to start for your own transform.

Reconstructor interface

A reconstructor model is essentially a torch.nn.Module that must follow a few additional conventions. In terms of the class interface, besides the usual torch.nn.Module method, the reconstructor must also include a method init_from_checkpoint, which receives a model checkpoint as dictionary and initializes the model from this data.

class Reconstructor(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, *args, **kwargs) -> Dict[str, Any]:
        pass

    @abc.abstractmethod
    def init_from_checkpoint(self, checkpoint: Dict[str, Any]) -> Optional[Any]:
        pass

The other conventions concern the model intialization and the output format of the forward method. We explain these below.

Initializing reconstructor

The environment expects the reconstructor to be passed initialization arguments as keywords. The set of keywords and their values will be read from the environment’s configuration file. However if your checkpoint dictionary contains a key called "options", then this will take precedence, and the environment will emit a warning. The sequence will roughly conform to the following pseudocode:

reconstructor_cls, reconstructor_cfg_dict, checkpoint_path = read_from_env_config()
checkpoint_dict = torch.load(checkpoint_path)
reconstructor_cfg = override_if_options_key_present(checkpoint_dict)
reconstructor = reconstructor_cls(**reconstructor_cfg_dict)
reconstructor.init_from_checkpoint(checkpoint_dict)  # load weights, additional bookkeeping

Forward signature

The other important convention we adopt is that forward() will return a dictionary with the output of the model. This dictionary must contain key "reconstruction", whose value is the reconstructed image tensor. Note that the model can also return additional outputs, which will also included in the observation returned by the environment, as explained in the basic example.

Examples

Some examples are available at the models directory, in the repository. Note that no changes to the reconstructor model are required, and the coupling between environment and reconstructor can be done via short wrapper classes.