models
Useful model tools
models.loss
Sereval loss functions and high level loss function get'er.
WeightedLosses Objects
class WeightedLosses()
Class: Weighted loss depending on the forecast horizon.
__init__
def __init__(decay_rate: Optional[int] = None, forecast_length: int = 6)
Want to set up the MSE loss function so the weights only have to be calculated once.
Arguments:
decay_rate
- The weights exponentially decay depending on the 'decay_rate'.forecast_length
- The forecast length is needed to make sure the weights sum to 1
get_mse_exp
def get_mse_exp(output, target)
Loss function weighted MSE
get_mae_exp
def get_mae_exp(output, target)
Loss function weighted MAE
GradientDifferenceLoss Objects
class GradientDifferenceLoss(nn.Module)
Gradient Difference Loss that penalizes blurry images more than MSE.
__init__
def __init__(alpha: int = 2)
Initalize the Loss Class.
Arguments:
alpha
- #TODO
forward
def forward(x: torch.Tensor, y: torch.Tensor)
Calculate the Gradient Difference Loss.
Arguments:
x
- vector one-
y
- vector two -
Returns
- the Gradient Difference Loss value
GridCellLoss Objects
class GridCellLoss(nn.Module)
Grid Cell Regularizer loss from Skillful Nowcasting,
see https://arxiv.org/pdf/2104.00954.pdf.
__init__
def __init__(weight_fn=None)
Initialize the model.
Arguments:
weight_fn
- the weight function the be called when #TODO?
forward
def forward(generated_images, targets)
Calculates the grid cell regularizer value.
This assumes generated images are the mean predictions from 6 calls to the generater (Monte Carlo estimation of the expectations for the latent variable)
Arguments:
generated_images
- Mean generated images from the generatortargets
- Ground truth future frames
Returns:
Grid Cell Regularizer term
NowcastingLoss Objects
class NowcastingLoss(nn.Module)
Loss described in Skillful-Nowcasting GAN, see https://arxiv.org/pdf/2104.00954.pdf.
__init__
def __init__()
Initialize the model.
forward
def forward(x, real_flag)
Forward step.
Arguments:
x
- the data to work with-
real_flag
- boolean if its real or not -
Returns
- #TODO
get_loss
def get_loss(loss: str = "mse", **kwargs) -> torch.nn.Module
Function to get different losses easily.
Arguments:
loss
- name of the loss, or torch.nn.Module, if a Module, returns that Module**kwargs
- kwargs to pass to the loss function
Returns:
torch.nn.Module
models.metrics
Metrics Used for different forecast horizons
mse_each_forecast_horizon
def mse_each_forecast_horizon(output: torch.Tensor,
target: torch.Tensor) -> torch.Tensor
Get MSE for each forecast horizon
Arguments:
output
- The model estimate of size (batch_size, forecast_length)-
target
- The truth of size (batch_size, forecast_length) -
Returns
- A tensor of size (forecast_length)
mae_each_forecast_horizon
def mae_each_forecast_horizon(output: torch.Tensor,
target: torch.Tensor) -> torch.Tensor
Get MAE for each forecast horizon
Arguments:
output
- The model estimate of size (batch_size, forecast_length)-
target
- The truth of size (batch_size, forecast_length) -
Returns
- A tensor of size (forecast_length)
models.hub
Originally Taken from https://github.com/rwightman/
https://github.com/rwightman/pytorch-image-models/ blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py
get_cache_dir
def get_cache_dir(child_dir="")
Returns the location of the directory where models are cached (and creates it if necessary).
has_hf_hub
def has_hf_hub(necessary: bool = False) -> bool
Determines if HuggingFace hub is available
Arguments:
necessary
- Whether having HuggingFace access is required
Returns:
Whether HuggingFace is available
hf_split
def hf_split(hf_id: str) -> Tuple[str, str]
Splits the string of the HuggingFace ID to give the model ID and revision
Arguments:
hf_id
- ID for HuggingFace
Returns:
Tuple consisting of the HuggingFace model ID and the revision
load_cfg_from_json
def load_cfg_from_json(json_file: Union[str, os.PathLike]) -> dict
Load the configuration from a JSON file
Arguments:
json_file
- The JSON file which contains the configuration file
Returns:
Dictionary containing the model configuration
load_model_config_from_hf
def load_model_config_from_hf(model_id: str) -> Tuple[dict, str]
Downloads and loads the model configuration from HuggingFace
Arguments:
model_id
- The HuggingFace model ID
Returns:
A tuple consisting of the default configuration for the model as well as the model name
load_state_dict_from_hf
def load_state_dict_from_hf(model_id: str) -> dict
Load the state dict of the model from HuggingFace.
Arguments:
model_id
- The HuggingFace model ID
Returns:
The model's state_dict
cache_file_from_hf
def cache_file_from_hf(model_id: str) -> Union[str, os.PathLike]
Caches the model from HuggingFace and returns the path
Arguments:
model_id
- HuggingFace model ID
Returns:
The path to the cached file
load_pretrained
def load_pretrained(
model,
default_cfg: Optional[dict] = None,
in_chans: int = 12,
strict: bool = True
) -> Union[torch.nn.Module, pytorch_lightning.LightningModule]
Load pretrained checkpoint
Taken from https://github.com/rwightman/pytorch-image-models/ blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/helpers.py
Arguments:
model (nn.Module) : PyTorch model module, or LightningModule
- default_cfg
Optional[Dict] - default configuration for pretrained weights / target dataset
- in_chans
int - in_chans for model
- strict
bool - strict load of checkpoint
NowcastingModelHubMixin Objects
class NowcastingModelHubMixin(ModelHubMixin)
HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models
__init__
def __init__(*args, **kwargs)
Mixin for pl.LightningModule and Hugging Face
Mix this class with your pl.LightningModule class to easily push / download the model via the Hugging Face Hub
Example::
from nowcasting_utils.models.hub import NowcastingModelHubMixin
class MyModel(nn.Module, NowcastingModelHubMixin): ... def init(self, **kwargs): ... super().init() ... self.layer = ... ... def forward(self, ...) ... return ...
model = MyModel() model.push_to_hub("mymodel") # Pushing model-weights to hf-hub
Downloading weights from hf-hub & model will be initialized from those weights
model = MyModel.from_pretrained("username/mymodel")
models.normalization
File containing different normalization methods
metnet_normalization
def metnet_normalization(data: np.ndarray) -> np.ndarray
Perform normalization from the MetNet paper
This involves subtracting by the median, dividing by the interquartile range, then squashing to [-1,1] with the hyperbolic tangent
Arguments:
data
- input image data
Returns:
Normalized image data
standard_normalization
def standard_normalization(data: np.ndarray, std: np.ndarray,
mean: np.ndarray) -> np.ndarray
Performs standard normalization to get values with a mean of 0 and standard deviation of 1
Arguments:
data
- The data to normalizestd
- Standard deviation of each channelmean
- Mean of each channel
Returns:
The normalized data
models.base
Base model class for all ML models.
Useful things like - Same validation set - Interface with HuggingFace
register_model
def register_model(cls: Type[pl.LightningModule])
Register model
Arguments:
-
cls
- the model to be registered -
Returns
- the registered model
get_model
def get_model(name: str) -> Type[pl.LightningModule]
Get model from registered models
list_models
def list_models()
List of the registered models
split_model_name
def split_model_name(model_name)
Split model name with ':'
Arguments:
-
model_name
- the original model name -
Returns
- source name, and the model name
safe_model_name
def safe_model_name(model_name, remove_source=True)
Make a safe model name
Arguments:
model_name
- the original model name-
remove_source
- flag if to remove the source or not -
Returns
- the new model name
create_model
def create_model(model_name, pretrained=False, checkpoint_path=None, **kwargs)
Create a model
Almost entirely taken from timm https://github.com/rwightman/pytorch-image-models
Arguments:
model_name
str - name of model to instantiatepretrained
bool - load pretrained ImageNet-1k weights if truecheckpoint_path
str - path of checkpoint to load after model is initialized
Arguments:
drop_rate
float - dropout rate for training (default: 0.0)global_pool
str - global pool type (default: 'avg')input_channels
int - number of input channels (default: 12)forecast_steps
int - number of steps to forecast (default: 48)lr
float - learning rate (default: 0.001)**
- other kwargs are model specific
BaseModel Objects
class BaseModel(pl.LightningModule, NowcastingModelHubMixin)
Base Model for ML models
__init__
def __init__(pretrained: bool = False,
forecast_steps: int = 48,
input_channels: int = 12,
output_channels: int = 12,
lr: float = 0.001,
visualize: bool = False)
Setup the base model class.
Arguments:
pretrained
- flag is thie model is pretrained or notforecast_steps
- the number of forecasts stepsinput_channels
- the number of input channelsoutput_channels
- the number of output channelslr
- the learning ratevisualize
- if to visualize the resutls or not
from_config
@classmethod
def from_config(cls, config)
Get the model from a config file.
Arguments:
-
config
- config file -
Returns
- Error, as the model needs to implement this method
training_step
def training_step(batch, batch_idx)
The training step.
Arguments:
batch
- the batch data-
batch_idx
- the batch index -
Returns
- The model outputs
validation_step
def validation_step(batch, batch_idx)
Validation step
Arguments:
batch
- the batch data-
batch_idx
- the batch index -
Returns
- The model outputs
forward
def forward(x, **kwargs) -> Any
Forward method for the model.
Arguments:
x
- the input data-
**kwargs
- other input needed -
Returns
- the model outputs
visualize_step
def visualize_step(x: torch.Tensor, y: torch.Tensor, y_hat: torch.Tensor,
batch_idx: int, step: str) -> None
Visualization Step
Arguments:
x
- input datay
- the truthy_hat
- the predictionsbatch_idx
- what batch index this isstep
- what step number this is
models.losses
Many Loss functions to be used across different models
models.losses.TotalVariationLoss
Implementation of Total Variation Loss
(https://en.wikipedia.org/wiki/Total_variation_denoising) copied and slightly modified from the original Apache License 2.0 traiNNer Authors https://github.com/victorca25/traiNNer/tree/master
Copyright 2021 traiNNer Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
get_outnorm
def get_outnorm(x: torch.Tensor, out_norm: str = "") -> torch.Tensor
Common function to get a loss normalization value.
Can normalize by either the - batch size ('b'), - the number of channels ('c'), - the image size ('i') - or combinations ('bi', 'bci', etc)
Arguments:
x
- the tensor to be normalized-
out_norm
- the string dimension to be normalized -
Returns
- the normalized tensor
get_4dim_image_gradients
def get_4dim_image_gradients(image: torch.Tensor)
Returns image gradients (dy, dx) for each color channel
This uses the finite-difference approximation. Similar to get_image_gradients(), but additionally calculates the gradients in the two diagonal directions: 'dp' (the positive diagonal: bottom left to top right) and 'dn' (the negative diagonal: top left to bottom right). Only 1-step finite difference has been tested and is available.
Arguments:
-
image
- Tensor with shape [b, c, h, w]. -
Returns
- tensors (dy, dx, dp, dn) holding the vertical, horizontal and diagonal image gradients (1-step finite difference). dx will always have zeros in the last column, dy will always have zeros in the last row, dp will always have zeros in the last row.
get_image_gradients
def get_image_gradients(image: torch.Tensor, step: int = 1)
Returns image gradients (dy, dx) for each color channel,
This use the finite-difference approximation. Places the gradient [ie. I(x+1,y) - I(x,y)] on the base pixel (x, y). Both output tensors have the same shape as the input: [b, c, h, w].
Arguments:
image
- Tensor with shape [b, c, h, w].-
step
- the size of the step for the finite difference -
Returns
- Pair of tensors (dy, dx) holding the vertical and horizontal image gradients (ie. 1-step finite difference). To match the original size image, for example with step=1, dy will always have zeros in the last row, and dx will always have zeros in the last column.
TVLoss Objects
class TVLoss(nn.Module)
Calculate the L1 or L2 total variation regularization.
Also can calculate experimental 4D directional total variation. Ref: Mahendran et al. https://arxiv.org/pdf/1412.0035.pdf
__init__
def __init__(tv_type: str = "tv",
p=2,
reduction: str = "mean",
out_norm: str = "b",
beta: int = 2) -> None
Init
Arguments:
tv_type
- regular 'tv' or 4D 'dtv'p
- use the absolute values '1' or Euclidean distance '2' to calculate the tv. (alt names: 'l1' and 'l2')reduction
- aggregate results per image either by their 'mean' or by the total 'sum'. Note: typically, 'sum' should be normalized with out_norm: 'bci', while 'mean' needs only 'b'.out_norm
- normalizes the TV loss by either the batch size ('b'), the number of channels ('c'), the image size ('i') or combinations ('bi', 'bci', etc).beta
- β factor to control the balance between sharp edges (1<β<2) and washed out results (penalizing edges) with β >= 2.
forward
def forward(x: torch.Tensor) -> torch.Tensor
Forward method
Arguments:
-
x
- data -
Returns
- model outputs
models.losses.FocalLoss
Focal Loss - https://arxiv.org/abs/1708.02002
FocalLoss Objects
class FocalLoss(nn.Module)
Focal Loss
__init__
def __init__(gamma: Union[int, float, List] = 0,
alpha: Optional[Union[int, float, List]] = None,
size_average: bool = True)
Focal loss is described in https://arxiv.org/abs/1708.02002
Copied from: https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
Courtesy of carwin, MIT License
Arguments:
alpha
- (tensor, float, or list of floats) The scalar factor for this criteriongamma
- (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more focus on hard misclassified examplesize_average
- (bool, optional) By default, the losses are averaged over each loss element in the batch.
forward
def forward(x: torch.Tensor, target: torch.Tensor)
Forward model
Arguments:
x
- prediction-
target
- truth -
Returns
- loss value
models.losses.StructuralSimilarity
This file contains various versions of losses using the Structural Similarity Index Measure.
See https://en.wikipedia.org/wiki/Structural_similarity for more details
SSIMLoss Objects
class SSIMLoss(nn.Module)
SSIM Loss, optionally converting input range from [-1,1] to [0,1]
__init__
def __init__(convert_range: bool = False, **kwargs)
Init
Arguments:
convert_range
- Convert input from -1,1 to 0,1 range**kwargs
- Kwargs to pass through to SSIM
forward
def forward(x: torch.Tensor, y: torch.Tensor)
Forward method
Arguments:
x
- one tensor-
y
- second tensor -
Returns
- SSIM loss
MS_SSIMLoss Objects
class MS_SSIMLoss(nn.Module)
Multi-Scale SSIM Loss, optionally converting input range from [-1,1] to [0,1]
__init__
def __init__(convert_range: bool = False, **kwargs)
Initialize
Arguments:
convert_range
- Convert input from -1,1 to 0,1 range**kwargs
- Kwargs to pass through to MS_SSIM
forward
def forward(x: torch.Tensor, y: torch.Tensor)
Forward method
Arguments:
x
- tensor oney
- tensor two
Returns:M S SSIM Loss
SSIMLossDynamic Objects
class SSIMLossDynamic(nn.Module)
SSIM Loss on only dynamic part of the images
Optionally converting input range from [-1,1] to [0,1]
In Mathieu et al. to stop SSIM regressing towards the mean and predicting only the background, they only run SSIM on the dynamic parts of the image. We can accomplish that by subtracting the current image from the future ones
__init__
def __init__(convert_range: bool = False, **kwargs)
Initialize
Arguments:
convert_range
- Whether to convert from -1,1 to 0,1 as required for SSIM**kwargs
- Kwargs for the ssim_module
forward
def forward(current_image: torch.Tensor, x: torch.Tensor, y: torch.Tensor)
Forward method
Arguments:
current_image
- The last 'real' image given to the modex
- The target future sequencey
- The predicted future sequence
Returns:
The SSIM loss computed only for the parts of the image that has changed