diff --git a/src/model/layers.py b/src/model/layers.py index 66ba056..3e2b87d 100644 --- a/src/model/layers.py +++ b/src/model/layers.py @@ -6,7 +6,7 @@ import torch.nn.functional as F class Phi(nn.Module): - def __init__(self, cnn, ff, norm): + def __init__(self, cnn, ff, norm=None): super(Phi, self).__init__() self.cnn = cnn self.ff = ff