Supervised learning
ptls.frames.supervised.SeqToTargetDataset
and ptls.frames.supervised.SequenceToTarget
for supervised learning
SeqToTargetDataset
Works similar as other datasets described in common patterns
Source data should have a scalar field with target value.
Example:
dataset = SeqToTargetDataset([{
'mcc_code': torch.randint(1, 10, (seq_len,)),
'amount': torch.randn(seq_len),
'event_time': torch.arange(seq_len), # shows order between transactions
'target': target,
} for seq_len, target in zip(
torch.randint(100, 200, (4,)),
[0, 0, 1, 1],
)], target_col_name='target')
dl = torch.utils.data.DataLoader(dataset, batch_size=10, collate_fn=dataset.collate_fn)
x, y = next(iter(dl))
torch.testing.assert_close(y, torch.LongTensor([0, 0, 1, 1]))
SequenceToTarget
in supervised mode
SequenceToTarget
is a lightning module for supervised training.
This module assumes a target for sequence.
There can be a some types of supervised task, like classification of regression.
SequenceToTarget
parameters allows to fit this module to your task.
SequenceToTarget
requires seq_encoder
, head
, loss
and metrics
.
Se en examples of usage in SequenceToTarget
docstring.
Layers from seq_encoder
, head
can be randomly initialized or pretrained.
SequenceToTarget
in inference mode
You may just provide pretrained seq_encoder
to SequenceToTarget
and
use trainer.predict
to get embeddings from pretrained seq_encoder
.
Classes
See docstrings for classes.
ptls.frames.supervised.SeqToTargetDataset
ptls.frames.supervised.SequenceToTarget