Shortcuts

Source code for torchts.nn.model

from abc import abstractmethod
from functools import partial

import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, TensorDataset


[docs]class TimeSeriesModel(LightningModule): """Base class for all TorchTS models. Args: optimizer (torch.optim.Optimizer): Optimizer opimizer_args (dict): Arguments for the optimizer criterion: Loss function criterion_args (dict): Arguments for the loss function scheduler (torch.optim.lr_scheduler): Learning rate scheduler scheduler_args (dict): Arguments for the scheduler scaler (torchts.utils.scaler.Scaler): Scaler """ def __init__( self, optimizer, optimizer_args=None, criterion=F.mse_loss, criterion_args=None, scheduler=None, scheduler_args=None, scaler=None, ): super().__init__() self.criterion = criterion self.criterion_args = criterion_args self.scaler = scaler if optimizer_args is not None: self.optimizer = partial(optimizer, **optimizer_args) else: self.optimizer = optimizer if scheduler is not None and scheduler_args is not None: self.scheduler = partial(scheduler, **scheduler_args) else: self.scheduler = scheduler
[docs] def fit(self, x, y, max_epochs=10, batch_size=128): """Fits model to the given data. Args: x (torch.Tensor): Input data y (torch.Tensor): Output data max_epochs (int): Number of training epochs batch_size (int): Batch size for torch.utils.data.DataLoader """ dataset = TensorDataset(x, y) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) trainer = Trainer(max_epochs=max_epochs) trainer.fit(self, loader)
def prepare_batch(self, batch): return batch def _step(self, batch, batch_idx, num_batches): """ Args: batch: Output of the torch.utils.data.DataLoader batch_idx: Integer displaying index of this batch dataset: Data set to use Returns: loss for the batch """ x, y = self.prepare_batch(batch) if self.training: batches_seen = batch_idx + self.current_epoch * num_batches else: batches_seen = batch_idx pred = self(x, y, batches_seen) if self.scaler is not None: y = self.scaler.inverse_transform(y) pred = self.scaler.inverse_transform(pred) if self.criterion_args is not None: loss = self.criterion(pred, y, **self.criterion_args) else: loss = self.criterion(pred, y) return loss
[docs] def training_step(self, batch, batch_idx): """Trains model for one step. Args: batch (torch.Tensor): Output of the torch.utils.data.DataLoader batch_idx (int): Integer displaying index of this batch """ train_loss = self._step(batch, batch_idx, len(self.trainer.train_dataloader)) self.log( "train_loss", train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return train_loss
[docs] def validation_step(self, batch, batch_idx): """Validates model for one step. Args: batch (torch.Tensor): Output of the torch.utils.data.DataLoader batch_idx (int): Integer displaying index of this batch """ val_loss = self._step(batch, batch_idx, len(self.trainer.val_dataloader)) self.log("val_loss", val_loss) return val_loss
[docs] def test_step(self, batch, batch_idx): """Tests model for one step. Args: batch (torch.Tensor): Output of the torch.utils.data.DataLoader batch_idx (int): Integer displaying index of this batch """ test_loss = self._step(batch, batch_idx, len(self.trainer.test_dataloader)) self.log("test_loss", test_loss) return test_loss
[docs] @abstractmethod def forward(self, x, y=None, batches_seen=None): """Forward pass. Args: x (torch.Tensor): Input data Returns: torch.Tensor: Predicted data """
[docs] def predict(self, x): """Runs model inference. Args: x (torch.Tensor): Input data Returns: torch.Tensor: Predicted data """ return self(x).detach()
[docs] def configure_optimizers(self): """Configure optimizer. Returns: torch.optim.Optimizer: Optimizer """ optimizer = self.optimizer(self.parameters()) if self.scheduler is not None: scheduler = self.scheduler(optimizer) return [optimizer], [scheduler] return optimizer

Was this helpful?
Thank you

© Copyright 2020-present UCSD STL Lab.

Built with Sphinx using a theme provided by Read the Docs.