Model¶
Pytorch Model¶
Model is inherited from dlex.torch.models.BaseModel.
-
class
dlex.torch.models.
BaseModel
(params: dlex.configs.MainConfig, dataset: dlex.datasets.torch.Dataset)¶ -
config_class
¶ alias of
dlex.configs.AttrDict
-
abstract
get_loss
(batch: dlex.torch.datatypes.Batch, output)¶ Return model loss to optimize
- Parameters
batch (Batch) –
output – Output of model forward
- Returns
A torch.FloatTensor with the loss value.
-
abstract
infer
(batch: dlex.torch.datatypes.Batch)¶ Infer from batch
- Parameters
batch (Batch) –
- Returns
tuple containing: pred: prediction ref: reference model_outputs others
- Return type
tuple
-
-
class
dlex.torch.models.
ClassificationModel
(params, dataset)¶ -
get_loss
(batch: dlex.torch.datatypes.Batch, output)¶ Return model loss to optimize
- Parameters
batch (Batch) –
output – Output of model forward
- Returns
A torch.FloatTensor with the loss value.
-
infer
(batch)¶ Infer from batch
- Parameters
batch (Batch) –
- Returns
tuple containing: pred: prediction ref: reference model_outputs others
- Return type
tuple
-