Model¶
Pytorch Model¶
Model is inherited from dlex.torch.models.BaseModel.
-
class
dlex.torch.models.BaseModel(params: dlex.configs.MainConfig, dataset: dlex.datasets.torch.Dataset)¶ -
config_class¶ alias of
dlex.configs.AttrDict
-
abstract
get_loss(batch: dlex.torch.datatypes.Batch, output)¶ Return model loss to optimize
- Parameters
batch (Batch) –
output – Output of model forward
- Returns
A torch.FloatTensor with the loss value.
-
abstract
infer(batch: dlex.torch.datatypes.Batch)¶ Infer from batch
- Parameters
batch (Batch) –
- Returns
tuple containing: pred: prediction ref: reference model_outputs others
- Return type
tuple
-
-
class
dlex.torch.models.ClassificationModel(params, dataset)¶ -
get_loss(batch: dlex.torch.datatypes.Batch, output)¶ Return model loss to optimize
- Parameters
batch (Batch) –
output – Output of model forward
- Returns
A torch.FloatTensor with the loss value.
-
infer(batch)¶ Infer from batch
- Parameters
batch (Batch) –
- Returns
tuple containing: pred: prediction ref: reference model_outputs others
- Return type
tuple
-