from datasets import load_dataset,load_dataset_builder
from nntrain.dataloaders import DataLoaders, hf_ds_collate_fn
Learner
Module containing helper functions and classes around the Learner
PublishEvents
PublishEvents (name)
Initialize self. See help(type(self)) for accurate signature.
CancelBatchException
Common base class for all non-exit exceptions.
CancelEpochException
Common base class for all non-exit exceptions.
CancelFitException
Common base class for all non-exit exceptions.
Learner
Learner (model, dls, loss_fn, optim_class, lr, subs)
Initialize self. See help(type(self)) for accurate signature.
Subscriber
Subscriber ()
Initialize self. See help(type(self)) for accurate signature.
MetricsS
MetricsS (**metrics)
Initialize self. See help(type(self)) for accurate signature.
DeviceS
DeviceS (device)
Initialize self. See help(type(self)) for accurate signature.
LRFindS
LRFindS (mult=1.25)
Initialize self. See help(type(self)) for accurate signature.
MomentumLearner
MomentumLearner (model, dls, loss_fn, optim_class, lr, subs, mom=0.85)
Initialize self. See help(type(self)) for accurate signature.
ProgressS
ProgressS (plot=False)
Initialize self. See help(type(self)) for accurate signature.
Example usage:
= "fashion_mnist"
name = load_dataset_builder(name)
ds_builder = load_dataset(name)
hf_dd
= 1024
bs = 28*28
n_in = 50
n_h = 10
n_out = 0.01
lr
= DataLoaders.from_hf_dd(hf_dd, batch_size=bs) dls
Reusing dataset fashion_mnist (/root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)
def get_model():
= [nn.Linear(n_in, n_h), nn.ReLU(), nn.Linear(n_h, n_out)]
layers return nn.Sequential(*layers)
= MetricsS(accuracy=tem.MulticlassAccuracy())
metrics = ProgressS(True)
progress = DeviceS(device)
device
= MomentumLearner(get_model(), dls, F.cross_entropy, torch.optim.SGD, lr, [metrics, progress, device])
l 5) l.fit(
epoch | mode | loss | accuracy |
---|---|---|---|
0 | train | 1.763 | 0.458 |
0 | eval | 1.151 | 0.647 |
1 | train | 0.949 | 0.669 |
1 | eval | 0.846 | 0.685 |
2 | train | 0.777 | 0.719 |
2 | eval | 0.748 | 0.725 |
3 | train | 0.697 | 0.757 |
3 | eval | 0.683 | 0.762 |
4 | train | 0.643 | 0.780 |
4 | eval | 0.641 | 0.778 |