Source code for torchts.nn.model
from abc import abstractmethod
from functools import partial
import torch.nn.functional as F
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, TensorDataset
[docs]class TimeSeriesModel(LightningModule):
"""Base class for all TorchTS models.
Args:
optimizer (torch.optim.Optimizer): Optimizer
opimizer_args (dict): Arguments for the optimizer
criterion: Loss function
criterion_args (dict): Arguments for the loss function
scheduler (torch.optim.lr_scheduler): Learning rate scheduler
scheduler_args (dict): Arguments for the scheduler
scaler (torchts.utils.scaler.Scaler): Scaler
"""
def __init__(
self,
optimizer,
optimizer_args=None,
criterion=F.mse_loss,
criterion_args=None,
scheduler=None,
scheduler_args=None,
scaler=None,
):
super().__init__()
self.criterion = criterion
self.criterion_args = criterion_args
self.scaler = scaler
if optimizer_args is not None:
self.optimizer = partial(optimizer, **optimizer_args)
else:
self.optimizer = optimizer
if scheduler is not None and scheduler_args is not None:
self.scheduler = partial(scheduler, **scheduler_args)
else:
self.scheduler = scheduler
[docs] def fit(self, x, y, max_epochs=10, batch_size=128):
"""Fits model to the given data.
Args:
x (torch.Tensor): Input data
y (torch.Tensor): Output data
max_epochs (int): Number of training epochs
batch_size (int): Batch size for torch.utils.data.DataLoader
"""
dataset = TensorDataset(x, y)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
trainer = Trainer(max_epochs=max_epochs)
trainer.fit(self, loader)
def prepare_batch(self, batch):
return batch
def _step(self, batch, batch_idx, num_batches):
"""
Args:
batch: Output of the torch.utils.data.DataLoader
batch_idx: Integer displaying index of this batch
dataset: Data set to use
Returns: loss for the batch
"""
x, y = self.prepare_batch(batch)
if self.training:
batches_seen = batch_idx + self.current_epoch * num_batches
else:
batches_seen = batch_idx
pred = self(x, y, batches_seen)
if self.scaler is not None:
y = self.scaler.inverse_transform(y)
pred = self.scaler.inverse_transform(pred)
if self.criterion_args is not None:
loss = self.criterion(pred, y, **self.criterion_args)
else:
loss = self.criterion(pred, y)
return loss
[docs] def training_step(self, batch, batch_idx):
"""Trains model for one step.
Args:
batch (torch.Tensor): Output of the torch.utils.data.DataLoader
batch_idx (int): Integer displaying index of this batch
"""
train_loss = self._step(batch, batch_idx, len(self.trainer.train_dataloader))
self.log(
"train_loss",
train_loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return train_loss
[docs] def validation_step(self, batch, batch_idx):
"""Validates model for one step.
Args:
batch (torch.Tensor): Output of the torch.utils.data.DataLoader
batch_idx (int): Integer displaying index of this batch
"""
val_loss = self._step(batch, batch_idx, len(self.trainer.val_dataloader))
self.log("val_loss", val_loss)
return val_loss
[docs] def test_step(self, batch, batch_idx):
"""Tests model for one step.
Args:
batch (torch.Tensor): Output of the torch.utils.data.DataLoader
batch_idx (int): Integer displaying index of this batch
"""
test_loss = self._step(batch, batch_idx, len(self.trainer.test_dataloader))
self.log("test_loss", test_loss)
return test_loss
[docs] @abstractmethod
def forward(self, x, y=None, batches_seen=None):
"""Forward pass.
Args:
x (torch.Tensor): Input data
Returns:
torch.Tensor: Predicted data
"""
[docs] def predict(self, x):
"""Runs model inference.
Args:
x (torch.Tensor): Input data
Returns:
torch.Tensor: Predicted data
"""
return self(x).detach()