discrepancy between Tensor flow model and Pyorch model - python
I trained a Unet model on the same dataset using both Tensor flow model and Pyorch. Both model showed a fine loss curve for training data, but the Pyorch validation data keeps zigzagging.
I think I'm good on Tensor flow but I'm new to Pyorch. Please, check if I made a mistake on Pyorch .
Below is the Tensor flow:
and this the Pyorch
The below code for Tensor flow :
class DataGen(keras.utils.Sequence):
def __init__(self, ids, path, batch_size=8, image_size=128):
self.ids = ids
self.path = path
self.batch_size = batch_size
self.image_size = image_size
def __load__(self, id_name):
## Path
#image_path = os.path.join(self.path, id_name, "images", id_name) + ".png"
#/content/drive/My Drive/mycolab/training/ patient0001 / patient0001
image_path = os.path.join(self.path, id_name, id_name,) + "_2CH_ED.mhd"
#mask_path = os.path.join(self.path, id_name, "masks/")
mask_path = os.path.join(self.path, id_name, id_name,) + "_2CH_ED_gt.mhd"
# not required all_masks = os.listdir(mask_path)
## Reading Image
#image = cv2.imread(image_path, 1)
my_img1= io.imread(image_path , plugin='simpleitk')
#--------------image = cv2.merge((image,image,image))
#image =convert_to_3_channel( cv2.resize(image, (self.image_size, self.image_size)))
image = cv2.resize(image, (self.image_size, self.image_size))
#image = cv2.merge((image,image,image,image))
#image = cv2.merge((image,image,image,image))
# same for mask
my_mask1= io.imread(mask_path , plugin='simpleitk')
mask= cv2.resize( mask, (self.image_size, self.image_size))
#one_hot_tensor= K.one_hot(K.cast( tf.convert_to_tensor(mask, dtype=tf.int32) , 'int32'), num_classes=4)
#mask=np.asarray(one_hot_tensor, np.float32)
#mask=np.asarray(one_hot_tensor, np.int32)
masks = [(mask == v) for v in range(4) ] #self.class_values]
mask = np.stack(masks, axis=-1).astype('float')
#masks = [(mask == v) for v in range(4)]#self.class_values]
#print("mask ttttttttttt5555555555:", type(masks))
#mask = np.stack(masks, axis=-1).astype('float')
# add background if mask is not binary
#if mask.shape[-1] != 1:
# #print("adding background if mask is not binary******************++++++++__________________$$")
# background = 1 - mask.sum(axis=-1, keepdims=True)
# mask = np.concatenate((mask, background), axis=-1)
#mask = np.zeros((self.image_size, self.image_size, 1))
## Reading Masks
#for name in all_masks:
# _mask_path = mask_path + name
# _mask_image = cv2.imread(_mask_path, -1)
# _mask_image = cv2.resize(_mask_image, (self.image_size, self.image_size)) #128x128
# _mask_image = np.expand_dims(_mask_image, axis=-1)
# mask = np.maximum(mask, _mask_image)
#print("image &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
#print("image &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
#print("image &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
## Normalizaing
#-------------image = image/255.0
#---- check if tis correct mask = mask/255.0
return image, mask# image.astype('float'), mask.astype('float')
def __getitem__(self, index):
#print("Index *************************************:", index )
if(index+1)*self.batch_size > len(self.ids):
self.batch_size = len(self.ids) - index*self.batch_size
files_batch = self.ids[index*self.batch_size : (index+1)*self.batch_size]
image = []
mask = []
for id_name in files_batch:
_img, _mask = self.__load__(id_name)
image.append(_img ) #.astype('float'))
mask.append(_mask ) #.astype('float'))
image = np.array(image)
mask = np.array(mask)
#print("image shape%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%",image.shape)
#print("mask shape%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%",mask.shape)
return image , mask# image.astype('float'), mask.astype('float')
def on_epoch_end(self):
def __len__(self):
return int(np.ceil(len(self.ids)/float(self.batch_size)))
# you may need to change variables names
image_size = 256
train_path ="path" #"dataset/stage1_train/"
epochs = 10 #70 # 5 # paper require 30
batch_size =1#32#1# 8
num_class = 4
print("train_ids length:", len(train_ids))
## Validation Data Size
val_data_size = 10
valid_ids = train_ids[:val_data_size]
my_slice_index=2 # class index that we are calauting
## Training Ids
train_ids = next(os.walk(train_path))[1]
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)
return c, p
def down_block_test(x, filters, kernel_size=(3, 3), padding="same", strides=1):
residual = x
print("down_block: residual size", residual.shape)
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
print("down_block: c size", c.shape)
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
print("down_block: c2 size", c.shape)
print("down_block: residual.shape[1]", residual.shape[1])
print("down_block: residual.shape[3]", residual.shape[3])
if residual.shape[3] != c.shape[3]:
residual = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(residual)
c += residual
p = keras.layers.MaxPool2D((2, 2), (2, 2))(c)
return c, p
def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
us = keras.layers.UpSampling2D((2, 2))(x)
concat = keras.layers.Concatenate()([us, skip])
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
return c
def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
c = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
return c
#------------------------------------ model in the paper as Unet 1 ----------------------------------------
def UNet_test():
f = [16, 32, 64, 128, 256]
#inputs = keras.layers.Input((image_size, image_size, 3))
inputs = keras.layers.Input((image_size, image_size, 1))
p0 = inputs
tf.print("********************************************p0: ")
c1, p1 = down_block(p0, f[1]) #264 -> 128
print("c1",c1.shape )
print("p1",p1.shape )
c2, p2 = down_block(p1, f[1]) #128 -> 64
print("c2",c2.shape )
print("p2",p2.shape )
c3, p3 = down_block(p2, f[2]) #64 -> 32
print("c3",c3.shape )
print("p3",p3.shape )
c4, p4 = down_block(p3, f[3]) #32->16
print("c4",c4.shape )
print("p4",p4.shape )
c5, p5 = down_block(p4, f[3]) #16->8
print("c5",c5.shape )
print("p5",p5.shape )
bn = bottleneck(p5, f[3]) # 8
u1 = up_block(bn, c5, f[3]) #8 -> 16
u2 = up_block(u1, c4, f[3]) #16 -> 32
u3 = up_block(u2, c3, f[2]) #32 -> 64
u4 = up_block(u3, c2, f[1]) #64 -> 128
u5 = up_block(u4, c1, f[0]) #128 -> 256
#outputs = keras.layers.Conv2D(1, (1, 1), padding="same", activation="sigmoid")(u4)
outputs = keras.layers.Conv2D(num_class, (1, 1), padding="same", activation="softmax")(u5)
#outputs = keras.layers.Conv2D(1, (1, 1), padding="same", activation="softmax")(u4)
model = keras.models.Model(inputs, outputs)
return model
import segmentation_models as sm
#---------------------model =new_model(Resmodel,'sigmoid')#model_standard() # UNet_1()
LR = 0.0001
optim = keras.optimizers.Adam(LR)
dice_loss_se2 = sm.losses.DiceLoss()
mae = tf.keras.losses.MeanAbsoluteError( )
metrics = [ mae,sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5) , dice_loss_se2]
model.compile(optimizer=optim,loss= dice_loss_se2,metrics= metrics)
train_gen = DataGen(train_ids, train_path, image_size=image_size, batch_size=batch_size)
valid_gen = DataGen(valid_ids, train_path, image_size=image_size, batch_size=batch_size)
train_steps = len(train_ids)//batch_size
valid_steps = len(valid_ids)//batch_size
history =model.fit_generator(train_gen, validation_data=valid_gen, steps_per_epoch=train_steps, validation_steps=valid_steps,
and below code for pytorch:
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
self.maxpool_conv = nn.Sequential(
DoubleConv(in_channels, out_channels)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet_standard(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet_standard, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes
def iou_score(output, target):
smooth = 1e-5
if torch.is_tensor(output):
output = torch.sigmoid(output).data.cpu().numpy()
if torch.is_tensor(target):
target = target.data.cpu().numpy()
output_ = output > 0.5
target_ = target > 0.5
intersection = (output_ & target_).sum()
union = (output_ | target_).sum()
return (intersection + smooth) / (union + smooth)
class easy_Synapse_dataset(Dataset):
def __init__(self, split, transform=None):
self.transform = transform # using transform in torch!
self.split = split
if self.split == "train":
self.sample_list =next(os.walk(use_path))[1] #open(os.path.join(list_dir, self.split+'.txt')).readlines()
def __len__(self):
return len(self.sample_list)
def __getitem__(self, idx):
if self.split == "train":
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
data_path = os.path.join(self.data_dir, slice_name+'.npz')
data = np.load(data_path)
image, label = data['image'], data['label']
vol_name = self.sample_list[idx].strip('\n')
filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
data = h5py.File(filepath)
image, label = data['image'][:], data['label'][:]
image_path = os.path.join(use_path, self.sample_list[idx], self.sample_list[idx],) + "_2CH_ED.mhd"
#mask_path = os.path.join(self.path, id_name, "masks/")
mask_path = os.path.join(use_path, self.sample_list[idx], self.sample_list[idx] ,) + "_2CH_ED_gt.mhd"
# not required all_masks = os.listdir(mask_path)
## Reading Image
#image = cv2.imread(image_path, 1)
my_img1= iio.imread(image_path , plugin='simpleitk')
#--------------image = cv2.merge((image,image,image))
#image =convert_to_3_channel( cv2.resize(image, (self.image_size, self.image_size)))
image = cv2.resize(image, (img_size, img_size))
#image = cv2.merge((image,image,image))
#image = np.moveaxis(image , 2, 0)
#image = cv2.merge((image,image,image,image))
# same for mask
my_mask1= iio.imread(mask_path , plugin='simpleitk')
mask= cv2.resize( mask, (img_size, img_size))
#one_hot_tensor= K.one_hot(K.cast( tf.convert_to_tensor(mask, dtype=tf.int32) , 'int32'), num_classes=4)
#mask=np.asarray(one_hot_tensor, np.float32)
#mask=np.asarray(one_hot_tensor, np.int32)
#masks = [(mask == v) for v in range(4) ] #self.class_values]
#mask = np.stack(masks, axis=-1).astype('float')
mask = torch.Tensor(mask)
mask=torch.nn.functional.one_hot(mask.to(torch.int64) , num_classes=4)
mask = mask.to(torch.float)
mask = mask.permute(2, 0, 1)
#print("image ", image.shape)
#print("mask ", mask.shape)
transform = transforms.Compose([
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
#if self.split != "train":
# image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
# label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
#print("sample[image].size() ", sample["image"].shape)
#print("sample[label].size() ", sample["label"].shape)
return sample# sample#transform(image), transform(mask).squeeze(0) #sample
img_size=256# 224
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
db_train = easy_Synapse_dataset( split="train",
transform=None) #transforms.Compose(
#[RandomGenerator(output_size=[img_size, img_size])]))
print("The length of train set is: {}".format(len(db_train)))
train_loader = DataLoader(db_train, batch_size=1, shuffle=True)#, num_workers=8, pin_memory=True,
db_test = easy_Synapse_dataset( split="test_vol")
val_loader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
# Now we can create a model and send it at once to the device
#----model = ManualLinearRegression().to(device)
config_vit = CONFIGS['R50-ViT-B_16']
config_vit.n_classes = 4#args.num_classes
config_vit.n_skip =3 # args.n_skip
if 'R50-ViT-B_16'.find('R50') != -1:
config_vit.patches.grid = (int(img_size / vit_patches_size), int(img_size / vit_patches_size))
model =UNet_standard( 1, 4).to(device)
# We can also inspect its parameters using its state_dict
lr =0.01# 1e-1
n_epochs = 10
loss_fn =DiceLoss(4)#num_classes) nn.MSELoss(reduction='mean')
optimizer = optim.SGD(model.parameters(), lr=lr)#optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
def make_train_step(model, loss_fn, optimizer):
# Builds function that performs a step in the train loop
def train_step(x, y):
# Sets model to TRAIN mode
# Makes predictions
yhat = model(x)
#d0, d1, d2, d3= model(x)
# Computes loss
#loss =muti_bce_loss_fusion2(d0, d1,d2, d3, y)# loss_fn(yhat, y, softmax=True) #loss_fn(y, yhat)
loss = loss_fn(yhat, y, softmax=True)
# Computes gradients
# Updates parameters and zeroes gradients
# Returns the loss
return loss.item()
# Returns the function that will be called inside the train loop
return train_step
# Creates the train_step function for our model, loss function and optimizer
train_step = make_train_step(model, loss_fn, optimizer)
y_val_average_loss = []
y_average_loss = []
x_epoch= []
# For each epoch...
for epoch in tqdm(range(n_epochs)):
losses = []
val_losses = []
iou_metric = []
#for x_batch, y_batch in train_loader:
for i_batch, sampled_batch in enumerate(train_loader):
x_batch, y_batch = sampled_batch['image'], sampled_batch['label']
#x_batch, y_batch =image_batch.cpu(), label_batch.cpu()
# the dataset "lives" in the CPU, so do our mini-batches
# therefore, we need to send those mini-batches to the
# device where the model "lives"
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
#print( x_batch.shape , " ",y_batch.shape )
loss = train_step(x_batch, y_batch)
#print("loss = ",loss )
#print('loss : %f' % (sum(losses) / len(losses) ))
avg=sum(losses) / len(losses)
y_average_loss .append (avg)
print('loss : %f' % (avg) )
losses = [] #clear
with torch.no_grad():
#for x_val, y_val in val_loader:
for i_batch, sampled_batch2 in enumerate(val_loader):
x_val, y_val = sampled_batch2['image'], sampled_batch2['label']
x_val = x_val.to(device)
y_val = y_val.to(device)
d3 = model(x_val)
#d0, d1, d2, d3 = model(x_val)
val_loss =loss_fn(d3, y_val, softmax=True) # loss_fn(y_val, yhat)
iou = iou_score(d3, y_val)
#print('Validation loss : %f' % (sum(val_losses) / len(val_losses) ))
print('Validation iou : %f' % (sum(iou_metric) / len(iou_metric) ))
val_avg=sum(val_losses) / len(val_losses)
y_val_average_loss .append (val_avg)
x_epoch.append (epoch)
print('Validation loss : %f' % ( val_avg ))
val_losses = [] # clear
iou_metric = []
# Checks model's parameters
I fixed the issue by using different Unet implementation from https://github.com/usuyama/pytorch-unet/blob/master/pytorch_unet.py
