Skip to content

translation

GeneralCalculator

Bases: BasicTaskCalculator

Source code in flgo\benchmark\toolkits\nlp\translation\__init__.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class GeneralCalculator(BasicTaskCalculator):
    def __init__(self, device, optimizer_name='sgd'):
        super(GeneralCalculator, self).__init__(device, optimizer_name)
        self.DataLoader = torch.utils.data.DataLoader

    def criterion(self, outputs, targets, ignore_index=-100):
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
        return loss_func(outputs[1:].view(-1, outputs.shape[-1]), targets[1:].view(-1))

    def compute_loss(self, model, data):
        """
        Args:
            model: the model to train
            data: the training dataset
        Returns: dict of train-one-step's result, which should at least contains the key 'loss'
        """
        sources, targets = self.to_device(data)
        outputs = model(sources, targets)
        loss = self.criterion(outputs, targets, model.ignore_index if hasattr(model, 'ignore_index') else -100)
        return {'loss': loss}

    @torch.no_grad()
    def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False):
        """
        Metric = [mean_accuracy, mean_loss]

        Args:
            model:
            dataset:
            batch_size:
        Returns: [mean_accuracy, mean_loss]
        """
        model.eval()
        if batch_size==-1:batch_size=len(dataset)
        data_loader = self.get_dataloader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
        total_loss = 0.0
        for batch_id, batch_data in enumerate(data_loader):
            batch_data = self.to_device(batch_data)
            outputs = model(batch_data[0], batch_data[1])
            batch_mean_loss = self.criterion(outputs, batch_data[1], model.ignore_index if hasattr(model, 'ignore_index') else -100).item()
            total_loss += batch_mean_loss * len(batch_data[-1])
        return {'loss':total_loss/len(dataset)}

    def to_device(self, data):
        return data[0].to(self.device), data[1].to(self.device)

    def get_dataloader(self, dataset, batch_size=64, shuffle=True, num_workers=0, pin_memory=False, drop_last=False):
        if self.DataLoader == None:
            raise NotImplementedError("DataLoader Not Found.")
        return self.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, collate_fn=self.collect_fn)

compute_loss(model, data)

Parameters:

Name Type Description Default
model

the model to train

required
data

the training dataset

required
Source code in flgo\benchmark\toolkits\nlp\translation\__init__.py
22
23
24
25
26
27
28
29
30
31
32
def compute_loss(self, model, data):
    """
    Args:
        model: the model to train
        data: the training dataset
    Returns: dict of train-one-step's result, which should at least contains the key 'loss'
    """
    sources, targets = self.to_device(data)
    outputs = model(sources, targets)
    loss = self.criterion(outputs, targets, model.ignore_index if hasattr(model, 'ignore_index') else -100)
    return {'loss': loss}

test(model, dataset, batch_size=64, num_workers=0, pin_memory=False)

Metric = [mean_accuracy, mean_loss]

Parameters:

Name Type Description Default
model required
dataset required
batch_size 64
Source code in flgo\benchmark\toolkits\nlp\translation\__init__.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@torch.no_grad()
def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False):
    """
    Metric = [mean_accuracy, mean_loss]

    Args:
        model:
        dataset:
        batch_size:
    Returns: [mean_accuracy, mean_loss]
    """
    model.eval()
    if batch_size==-1:batch_size=len(dataset)
    data_loader = self.get_dataloader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
    total_loss = 0.0
    for batch_id, batch_data in enumerate(data_loader):
        batch_data = self.to_device(batch_data)
        outputs = model(batch_data[0], batch_data[1])
        batch_mean_loss = self.criterion(outputs, batch_data[1], model.ignore_index if hasattr(model, 'ignore_index') else -100).item()
        total_loss += batch_mean_loss * len(batch_data[-1])
    return {'loss':total_loss/len(dataset)}