Model

Pytorch Model

Model is inherited from dlex.torch.models.BaseModel.

class dlex.torch.models.BaseModel(params: dlex.configs.MainConfig, dataset: dlex.datasets.torch.Dataset)
config_class

alias of dlex.configs.AttrDict

abstract get_loss(batch: dlex.torch.datatypes.Batch, output)

Return model loss to optimize

Parameters
  • batch (Batch) –

  • output – Output of model forward

Returns

A torch.FloatTensor with the loss value.

abstract infer(batch: dlex.torch.datatypes.Batch)

Infer from batch

Parameters

batch (Batch) –

Returns

tuple containing: pred: prediction ref: reference model_outputs others

Return type

tuple

class dlex.torch.models.ClassificationModel(params, dataset)
get_loss(batch: dlex.torch.datatypes.Batch, output)

Return model loss to optimize

Parameters
  • batch (Batch) –

  • output – Output of model forward

Returns

A torch.FloatTensor with the loss value.

infer(batch)

Infer from batch

Parameters

batch (Batch) –

Returns

tuple containing: pred: prediction ref: reference model_outputs others

Return type

tuple