Skip to content

flgo.utils.fmodule

FModule

Bases: nn.Module

This module implements commonly used model-level operators like add, sub, and so on.

Example:

    >>> class TestModel(FModule):
    ...     def __init__(self):
    ...         self.mlp = torch.nn.Linear(2,2, bias=False)
    >>> m1 = TestModel()
    >>> m2 = TestModel()
    >>> m3 = m1+m2
    >>> (m1.mlp.weight+m2.mlp.weight)==m3.mlp.weight
Source code in flgo\utils\fmodule.py
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class FModule(nn.Module):
    r"""
    This module implements commonly used model-level operators like add, sub, and so on.

    Example:
    ```python
        >>> class TestModel(FModule):
        ...     def __init__(self):
        ...         self.mlp = torch.nn.Linear(2,2, bias=False)
        >>> m1 = TestModel()
        >>> m2 = TestModel()
        >>> m3 = m1+m2
        >>> (m1.mlp.weight+m2.mlp.weight)==m3.mlp.weight
    ```
    """
    def __init__(self):
        super().__init__()
        self.ingraph = False

    def __add__(self, other):
        if isinstance(other, int) and other == 0 : return self
        if not isinstance(other, FModule): raise TypeError
        return _model_add(self, other)

    def __radd__(self, other):
        return _model_add(self, other)

    def __sub__(self, other):
        if isinstance(other, int) and other == 0: return self
        if not isinstance(other, FModule): raise TypeError
        return _model_sub(self, other)

    def __mul__(self, other):
        return _model_scale(self, other)

    def __rmul__(self, other):
        return self*other

    def __truediv__(self, other):
        return self*(1.0/other)

    def __pow__(self, power, modulo=None):
        return _model_norm(self, power)

    def __neg__(self):
        return _model_scale(self, -1.0)

    def __sizeof__(self):
        if not hasattr(self, '__size'):
            param_size = 0
            param_sum = 0
            for param in self.parameters():
                param_size += param.nelement() * param.element_size()
                param_sum += param.nelement()
            buffer_size = 0
            buffer_sum = 0
            for buffer in self.buffers():
                buffer_size += buffer.nelement() * buffer.element_size()
                buffer_sum += buffer.nelement()
            self.__size = param_size + buffer_size
        return self.__size

    def norm(self, p=2):
        r"""
        Args:
            p (float): p-norm

        Returns:
            the scale value of the p-norm of vectorized model parameters
        """
        return self**p

    def zeros_like(self):
        r"""
        Returns:
             a new model with the same architecture and all the parameters being set zero
        """
        return self*0

    def dot(self, other):
        r"""
        Args:
            other (Fmodule): the model with the same architecture of self

        Returns:
            the dot value of the two vectorized models
        """
        return _model_dot(self, other)

    def cos_sim(self, other):
        r"""
        Args:
            other (Fmodule): the model with the same architecture of self

        Returns:
            the cosine similarity value of the two vectorized models
        """
        return _model_cossim(self, other)

    def op_with_graph(self):
        self.ingraph = True

    def op_without_graph(self):
        self.ingraph = False

    def load(self, other):
        r"""
        Set the values of model parameters the same as the values of another model
        Args:
            other (Fmodule): the model with the same architecture of self
        """
        self.op_without_graph()
        self.load_state_dict(other.state_dict())
        return

    def freeze_grad(self):
        r"""
        All the gradients of the model parameters won't be computed after calling this method
        """
        for p in self.parameters():
            p.requires_grad = False

    def enable_grad(self):
        r"""
        All the gradients of the model parameters will be computed after calling this method
        """
        for p in self.parameters():
            p.requires_grad = True

    def zero_dict(self):
        r"""
        Set all the values of model parameters to be zero
        """
        self.op_without_graph()
        for p in self.parameters():
            p.data.zero_()

    def normalize(self):
        r"""
        Normalize the parameters of self to enable self.norm(2)=1
        """
        self.op_without_graph()
        self.load_state_dict((self/(self**2)).state_dict())

    def has_nan(self):
        r"""
        Check whether there is nan value in model's parameters
        Returns:
            res (bool): True if there is nan value
        """
        for p in self.parameters():
            if torch.any(torch.isnan(p)).item():
                return True
        return False

    def get_device(self):
        r"""
        Returns:
            the device of the tensors of this model
        """
        return next(self.parameters()).device

    def count_parameters(self, output=True):
        r"""
        Count the parameters for this model

        Args:
            output (bool): whether to output the information to the stdin (i.e. console)
        Returns:
            the number of all the parameters in this model
        """
        # table = pt.PrettyTable(["Modules", "Parameters"])
        total_params = 0
        for name, parameter in self.named_parameters():
            if not parameter.requires_grad:
                # table.add_row([name, 0])
                continue
            params = parameter.numel()
            # table.add_row([name, params])
            total_params += params
        # if output:
        #     print(table)
        #     print(f"TotalTrainableParams: {total_params}")
        return total_params

cos_sim(other)

Parameters:

Name Type Description Default
other Fmodule

the model with the same architecture of self

required

Returns:

Type Description

the cosine similarity value of the two vectorized models

Source code in flgo\utils\fmodule.py
 93
 94
 95
 96
 97
 98
 99
100
101
def cos_sim(self, other):
    r"""
    Args:
        other (Fmodule): the model with the same architecture of self

    Returns:
        the cosine similarity value of the two vectorized models
    """
    return _model_cossim(self, other)

count_parameters(output=True)

Count the parameters for this model

Parameters:

Name Type Description Default
output bool

whether to output the information to the stdin (i.e. console)

True

Returns:

Type Description

the number of all the parameters in this model

Source code in flgo\utils\fmodule.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def count_parameters(self, output=True):
    r"""
    Count the parameters for this model

    Args:
        output (bool): whether to output the information to the stdin (i.e. console)
    Returns:
        the number of all the parameters in this model
    """
    # table = pt.PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in self.named_parameters():
        if not parameter.requires_grad:
            # table.add_row([name, 0])
            continue
        params = parameter.numel()
        # table.add_row([name, params])
        total_params += params
    # if output:
    #     print(table)
    #     print(f"TotalTrainableParams: {total_params}")
    return total_params

dot(other)

Parameters:

Name Type Description Default
other Fmodule

the model with the same architecture of self

required

Returns:

Type Description

the dot value of the two vectorized models

Source code in flgo\utils\fmodule.py
83
84
85
86
87
88
89
90
91
def dot(self, other):
    r"""
    Args:
        other (Fmodule): the model with the same architecture of self

    Returns:
        the dot value of the two vectorized models
    """
    return _model_dot(self, other)

enable_grad()

All the gradients of the model parameters will be computed after calling this method

Source code in flgo\utils\fmodule.py
126
127
128
129
130
131
def enable_grad(self):
    r"""
    All the gradients of the model parameters will be computed after calling this method
    """
    for p in self.parameters():
        p.requires_grad = True

freeze_grad()

All the gradients of the model parameters won't be computed after calling this method

Source code in flgo\utils\fmodule.py
119
120
121
122
123
124
def freeze_grad(self):
    r"""
    All the gradients of the model parameters won't be computed after calling this method
    """
    for p in self.parameters():
        p.requires_grad = False

get_device()

Returns:

Type Description

the device of the tensors of this model

Source code in flgo\utils\fmodule.py
159
160
161
162
163
164
def get_device(self):
    r"""
    Returns:
        the device of the tensors of this model
    """
    return next(self.parameters()).device

has_nan()

Check whether there is nan value in model's parameters

Returns:

Name Type Description
res bool

True if there is nan value

Source code in flgo\utils\fmodule.py
148
149
150
151
152
153
154
155
156
157
def has_nan(self):
    r"""
    Check whether there is nan value in model's parameters
    Returns:
        res (bool): True if there is nan value
    """
    for p in self.parameters():
        if torch.any(torch.isnan(p)).item():
            return True
    return False

load(other)

Set the values of model parameters the same as the values of another model

Parameters:

Name Type Description Default
other Fmodule

the model with the same architecture of self

required
Source code in flgo\utils\fmodule.py
109
110
111
112
113
114
115
116
117
def load(self, other):
    r"""
    Set the values of model parameters the same as the values of another model
    Args:
        other (Fmodule): the model with the same architecture of self
    """
    self.op_without_graph()
    self.load_state_dict(other.state_dict())
    return

norm(p=2)

Parameters:

Name Type Description Default
p float

p-norm

2

Returns:

Type Description

the scale value of the p-norm of vectorized model parameters

Source code in flgo\utils\fmodule.py
66
67
68
69
70
71
72
73
74
def norm(self, p=2):
    r"""
    Args:
        p (float): p-norm

    Returns:
        the scale value of the p-norm of vectorized model parameters
    """
    return self**p

normalize()

Normalize the parameters of self to enable self.norm(2)=1

Source code in flgo\utils\fmodule.py
141
142
143
144
145
146
def normalize(self):
    r"""
    Normalize the parameters of self to enable self.norm(2)=1
    """
    self.op_without_graph()
    self.load_state_dict((self/(self**2)).state_dict())

zero_dict()

Set all the values of model parameters to be zero

Source code in flgo\utils\fmodule.py
133
134
135
136
137
138
139
def zero_dict(self):
    r"""
    Set all the values of model parameters to be zero
    """
    self.op_without_graph()
    for p in self.parameters():
        p.data.zero_()

zeros_like()

Returns:

Type Description

a new model with the same architecture and all the parameters being set zero

Source code in flgo\utils\fmodule.py
76
77
78
79
80
81
def zeros_like(self):
    r"""
    Returns:
         a new model with the same architecture and all the parameters being set zero
    """
    return self*0

cos_sim(m1, m2)

The cosine similarity value of the two models res=m1·m2/(||m1||*||m2||)

Parameters:

Name Type Description Default
m1 FModule

model 1

required
m2 FModule

model 2

required

Returns:

Type Description

The cosine similarity value of the two models

Source code in flgo\utils\fmodule.py
214
215
216
217
218
219
220
221
222
223
224
225
def cos_sim(m1, m2):
    r"""
    The cosine similarity value of the two models res=m1·m2/(||m1||*||m2||)

    Args:
        m1 (FModule): model 1
        m2 (FModule): model 2

    Returns:
        The cosine similarity value of the two models
    """
    return m1.cos_sim(m2)

dot(m1, m2)

The dot value of the two models res = m1·m2

Parameters:

Name Type Description Default
m1 FModule

model 1

required
m2 FModule

model 2

required

Returns:

Type Description

The dot value of the two models

Source code in flgo\utils\fmodule.py
201
202
203
204
205
206
207
208
209
210
211
212
def dot(m1, m2):
    r"""
    The dot value of the two models res = m1·m2

    Args:
        m1 (FModule): model 1
        m2 (FModule): model 2

    Returns:
        The dot value of the two models
    """
    return m1.dot(m2)

element_wise_func(m, func)

The element-wise function on this model

Parameters:

Name Type Description Default
m FModule

the model

required
func

element-wise function

required

Returns:

Type Description

The new model whose parameters satisfy mi=func(mi)

Source code in flgo\utils\fmodule.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def element_wise_func(m, func):
    r"""
    The element-wise function on this model

    Args:
        m (FModule): the model
        func: element-wise function

    Returns:
        The new model whose parameters satisfy mi=func(mi)
    """
    if m is None: return None
    res = m.__class__().to(m.get_device())
    if m.ingraph:
        res.op_with_graph()
        ml = get_module_from_model(m)
        for md in ml:
            rd = _modeldict_element_wise(md._parameters, func)
            for l in md._parameters.keys():
                md._parameters[l] = rd[l]
    else:
        _modeldict_cp(res.state_dict(), _modeldict_element_wise(m.state_dict(), func))
    return res

exp(m)

The element-wise res=exp(m) where all the model parameters satisfy mi=exp(mi)

Parameters:

Name Type Description Default
m FModule

the model

required

Returns:

Type Description

The new model whose parameters satisfy mi=exp(mi)

Source code in flgo\utils\fmodule.py
227
228
229
230
231
232
233
234
235
236
237
def exp(m):
    r"""
    The element-wise res=exp(m) where all the model parameters satisfy mi=exp(mi)

    Args:
        m (FModule): the model

    Returns:
        The new model whose parameters satisfy mi=exp(mi)
    """
    return element_wise_func(m, torch.exp)

get_module_from_model(model, res=None)

Walk through all the sub modules of a model and return them as a list

Parameters:

Name Type Description Default
model FModule

model

required
res None

should be remained None

None

Returns:

Type Description

The list of all the sub-modules of a model

Source code in flgo\utils\fmodule.py
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
def get_module_from_model(model, res = None):
    r"""
    Walk through all the sub modules of a model and return them as a list

    Args:
        model (FModule): model
        res (None): should be remained None

    Returns:
        The list of all the sub-modules of a model
    """
    if res==None: res = []
    ch_names = [item[0] for item in model.named_children()]
    if ch_names==[]:
        if model._parameters:
            res.append(model)
    else:
        for name in ch_names:
            get_module_from_model(model.__getattr__(name), res)
    return res

log(m)

The element-wise res=log(m) where all the model parameters satisfy mi=log(mi)

Parameters:

Name Type Description Default
m FModule

the model

required

Returns:

Type Description

The new model whose parameters satisfy mi=log(mi)

Source code in flgo\utils\fmodule.py
239
240
241
242
243
244
245
246
247
248
249
def log(m):
    r"""
    The element-wise res=log(m) where all the model parameters satisfy mi=log(mi)

    Args:
        m (FModule): the model

    Returns:
        The new model whose parameters satisfy mi=log(mi)
    """
    return element_wise_func(m, torch.log)

normalize(m)

The new model that is the normalized version of the input model m=m/||m||_2

Parameters:

Name Type Description Default
m FModule

the model

required

Returns:

Type Description

The new model that is the normalized version of the input model

Source code in flgo\utils\fmodule.py
189
190
191
192
193
194
195
196
197
198
199
def normalize(m):
    r"""
    The new model that is the normalized version of the input model m=m/||m||_2

    Args:
        m (FModule): the model

    Returns:
        The new model that is the normalized version of the input model
    """
    return m/(m**2)

with_multi_gpus(func)

Decorate functions whose first parameter is model to carry out all the operations on the same device

Source code in flgo\utils\fmodule.py
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
def with_multi_gpus(func):
    r"""
    Decorate functions whose first parameter is model to carry out all the operations on the same device
    """
    def cal_on_personal_gpu(self, model, *args, **kargs):
        origin_device = model.get_device()
        # transfer to new device
        new_args = []
        new_kargs = {}
        for arg in args:
            narg = arg.to(self.device) if hasattr(arg, 'get_device') or hasattr(arg, 'device') else arg
            new_args.append(narg)
        for k,v in kargs.items():
            nv = v.to(self.device) if hasattr(v, 'get_device') or hasattr(v, 'device') else v
            new_kargs[k] = nv
        model.to(self.device)
        # calculating
        res = func(self, model, *tuple(new_args), **new_kargs)
        # transter to original device
        model.to(origin_device)
        if res is not None:
            if type(res)==dict:
                for k,v in res.items():
                    nv = v.to(origin_device) if hasattr(v, 'get_device') or hasattr(v, 'device') else v
                    res[k] = nv
            elif type(res)==tuple or type(res)==list:
                new_res = []
                for v in res:
                    nv = v.to(origin_device) if hasattr(v, 'get_device') or hasattr(v, 'device') else v
                    new_res.append(nv)
                if type(res)==tuple:
                    res = tuple(new_res)
            else:
                res = res.to(origin_device) if hasattr(res, 'get_device') or hasattr(res, 'device') else res
        return res
    return cal_on_personal_gpu