Jax / Neural Tangents `linearize` induces CUDA_ERROR_OUT_OF_MEMORY - python

Cross-posting from GitHub: https://github.com/google/neural-tangents/issues/144
We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.
We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still the pre-linearized model to train. However, when we try using the linearized model, we get:
RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory
This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).
We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):
Their code:
def make_update_fn(*, apply_fn, accum_steps, lr_fn):
"""Returns update step for data parallel training."""
def update_fn(opt, step, batch, rng):
_, new_rng = jax.random.split(rng)
# Bind the rng key to the device id (which is unique across hosts)
# Note: This is only used for multi-host training (i.e. multiple computers
# each with multiple accelerators).
dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))
def cross_entropy_loss(*, logits, labels):
logp = jax.nn.log_softmax(logits)
return -jnp.mean(jnp.sum(logp * labels, axis=1))
def loss_fn(params, images, labels):
logits = apply_fn(
dict(params=params),
rngs=dict(dropout=dropout_rng),
inputs=images,
train=True)
return cross_entropy_loss(logits=logits, labels=labels)
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
accum_steps)
g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
l = jax.lax.pmean(l, axis_name='batch')
opt = opt.apply_gradient(g, learning_rate=lr_fn(step))
return opt, l, new_rng
return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))
That function is then called via:
# Check out train.make_update_fn in the editor on the right side for details.
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
update_fn_repl = train.make_update_fn(
apply_fn=vit_apply, accum_steps=accum_steps, lr_fn=lr_fn)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)
The training loop where the memory error arises:
losses = []
lrs = []
# Completes in ~20 min on the TPU runtime.
for step, batch in zip(
tqdm.trange(1, total_steps + 1),
ds_train.as_numpy_iterator(),
):
opt_repl, loss_repl, update_rng_repl = update_fn_repl(
opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) # ERROR IS HERE
losses.append(loss_repl[0])
lrs.append(lr_fn(step))
In order to linearize the ViT, we do the following:
def vit_apply(params, input):
return model.apply(dict(params=params), input, train=True)
f_lin = nt.linearize(vit_apply, params)

Related

PyTorch: "one of the variables needed for gradient computation has been modified by an inplace operation"

I'm training a PyTorch RNN on a text file of song lyrics to predict the next character given a character.
Here's how my RNN is defined:
import torch.nn as nn
import torch.optim
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
# from input, previous hidden state to new hidden state
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
# from input, previous hidden state to output
self.i2o = nn.Linear(input_size + hidden_size, output_size)
# softmax on output
self.softmax = nn.LogSoftmax(dim = 1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
#get new hidden state
hidden = self.i2h(combined)
#get output
output = self.i2o(combined)
#apply softmax
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
rnn = RNN(input_size = num_chars, hidden_size = 200, output_size = num_chars)
criterion = nn.NLLLoss()
lr = 0.01
optimizer = torch.optim.AdamW(rnn.parameters(), lr = lr)
Here's my training function:
def train(train, target):
hidden = rnn.initHidden()
loss = 0
for i in range(len(train)):
optimizer.zero_grad()
# get output, hidden state from rnn given input char, hidden state
output, hidden = rnn(train[i].unsqueeze(0), hidden)
#returns the index with '1' - indentifying the index of the right character
target_class = (target[i] == 1).nonzero(as_tuple=True)[0]
loss += criterion(output, target_class)
loss.backward(retain_graph = True)
optimizer.step()
print("done " + str(i) + " loop")
return output, loss.item() / train.size(0)
When I run my training function, I get this error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [274, 74]], which is output 0 of TBackward, is at version 5; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Interestingly, it makes it through two complete loops of the training function before giving me that error.
Now, when I remove the retain_graph = True from loss.backward(), I get this error:
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
It shouldn't be trying to go backward through the graph multiple times here. Perhaps the graph is not getting cleared between training loops?
The issue is you are accumulating your loss values (and at the same time, the computation graphs associated attached to them) on variable loss, here:
loss += criterion(output, target_class)
In turn, this means at every iteration you are trying to backpropagate through the current and previous loss values that were computed in previous inferences. In this particular instance where you are looping through your dataset, it isn't the right thing to do.
A simple fix is to accumulate loss's underlying value, i.e. the scalar value, not the tensor itself, using item. And, backpropagate on the current loss tensor:
total_loss = 0
for i in range(len(train)):
optimizer.zero_grad()
output, hidden = rnn(train[i].unsqueeze(0), hidden)
target_class = (target[i] == 1).nonzero(as_tuple=True)[0]
loss = criterion(output, target_class)
loss.backward()
total_loss += loss.item()
Since you are updating the model's parameter straight after having done the backpropagation, you don't need to retain the graph in memory.

Debugging Tensorflow 2.0: Printing in a tf.function that crashes

I am trying to debug a relatively complex custom training method using custom loss functions, etc. In particular I am trying to debug an issue in a custom training step, which is compiled into a Tensorflow #function and fitted as a Keras compiled model. I want to be able to print out an intermediate value of a tensor in a function call that is crashing. The difficulty is that since tensors inside an #function are graph values and arent evaluated immediately, and since the function crashes during evaluation, it seems like the values aren't actually calculated. Here is a simple example:
class debug_model(tf.keras.Model):
def __init__(self, width,depth,insize,outsize,batch_size):
super(debug_model, self).__init__()
self.width = width
self.depth = depth
self.insize = insize
self.outsize = outsize
self.net = tf.keras.models.Sequential()
self.net.add(tf.keras.Input(shape = (insize,)))
for i in range(depth):
self.net.add(tf.keras.layers.Dense(width,activation = 'swish'))
self.net.add(tf.keras.layers.Dense(outsize))
def call(self,ipts):
return self.net(ipts)
#tf.function
def train_step(self,data):
ipt, target = data
with tf.GradientTape(persistent=True) as tape_1:
tape_1.watch(ipt)
y = self(ipt)
tf.print('y:',y)
assert False
loss = tf.keras.losses.MAE(target,y)
trainable_vars = self.trainable_variables
loss_grad = tape_1.gradient(loss,trainable_vars)
self.optimizer.apply_gradients(zip(loss_grad, trainable_vars))
self.compiled_metrics.update_state(target, y)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
If you compile this model with some data of your choice and run it:
train_set = tf.data.Dataset.from_tensor_slices(data_tuple).batch(opt.batchSize)
train_set.shuffle(buffer_size = trainpoints)
model = debug_model(opt.width,opt.depth,in_size,out_size,batchSize)
optimizer = tf.keras.optimizers.Adam(learning_rate=opt.lr)
lr_sched = lambda epoch, lr: lr * 0.95**(1 / (8))
cb_scheduler = tf.keras.callbacks.LearningRateScheduler(schedule = lr_sched, verbose = 1)
model.build((None,1))
model.summary()
model.compile(optimizer=optimizer,
loss = tf.keras.losses.MeanAbsoluteError(),
)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(path,
verbose=2
),
cb_scheduler,
tf.keras.callbacks.CSVLogger(path+'log.csv')
]
hist = model.fit(train_set,epochs = opt.nEpochs,callbacks = callbacks)
If you load this up and run it you will see that it exits due to the assertion error without printing. Is there a way I can force this tensor to evaluate so I can print it?

Model behaves differently after saving and loading

I want to use torch.save() to save a trained model for inference. However, with either torch.load_state_dict() or torch.load(), I can't get the saved model. The loss computed by the loaded model is just different from the loss computed by the saved model.
The relevant Libraries:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F
The model:
class nn_block(nn.Module):
def __init__(self, feats_dim):
super(nn_block, self).__init__()
self.linear = nn.Linear(feats_dim, feats_dim)
self.bn = nn.BatchNorm1d(feats_dim)
self.softplus1 = nn.Softplus()
self.softplus2 = nn.Softplus()
def forward(self, rep_mat):
transformed_mat = self.linear(rep_mat)
transformed_mat = self.bn(transformed_mat)
transformed_mat = self.softplus1(transformed_mat)
transformed_mat = self.softplus2(transformed_mat + rep_mat)
return transformed_mat
class test_nn(nn.Module):
def __init__(self, in_feats, feats_dim, num_conv, num_classes):
super(test_nn, self).__init__()
self.linear1 = nn.Linear(in_feats, feats_dim)
self.convs = [nn_block(feats_dim) for _ in range(num_conv)]
self.linear2 = nn.Linear(feats_dim, num_classes)
self.softmax = nn.Softmax()
def forward(self, rep_mat):
h = self.linear1(rep_mat)
for conv_func in self.convs:
h = conv_func(h)
h = self.linear2(h)
h = self.softmax(h)
return h
Train, save, and reload a model:
# fake a classification task
num_classes = 2; input_dim = 8
one = np.random.multivariate_normal(np.zeros(input_dim),np.eye(input_dim),20)
two = np.random.multivariate_normal(np.ones(input_dim),np.eye(input_dim),20)
inputs = np.concatenate([one, two], axis=0)
labels = np.concatenate([np.zeros(20), np.ones(20)])
inputs = Variable(torch.Tensor(inputs))
labels = torch.LongTensor(labels)
# build a model
net = test_nn(input_dim, 5, 2, num_classes)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
net.train()
losses = []
best_score = 1e10
for epoch in range(25):
preds = net(inputs)
loss = F.cross_entropy(preds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
state_dict = {'state_dict': net.state_dict()}
if loss.item()-best_score<-1e-4:
# save only parameters
torch.save(state_dict, 'model_params.torch')
# save the whole model
torch.save(net, 'whole_model.torch')
best_score = np.min([best_score, loss.item()])
losses.append(loss.item())
net_params = test_nn(input_dim, 5, 2, num_classes)
net_params.load_state_dict(torch.load('model_params.torch')['state_dict'])
net_params.eval()
preds_params = net_params(inputs)
loss_params = F.cross_entropy(preds_params, labels)
print('reloaded params %.4f %.4f' % (loss_params.item(), np.min(losses)))
net_whole = torch.load('whole_model.torch')
net_whole.eval()
preds_whole = net_whole(inputs)
loss_whole = F.cross_entropy(preds_whole, labels)
print('reloaded whole %.4f %.4f' % (loss_whole.item(), np.min(losses)))
As you can see by running the code, the losses computed by the two loaded models are different, while the two loaded models are exactly the same. Not just the two losses are different, they are also different from the loss computed by the best model that was saved in the first place.
Why this can happen?
The state dict contains every parameter (nn.Parameter
) and buffer (similar to parameter, but which should not be trained/optimised) that has been registered on the module and all of its submodules. Everything else will not be included in that state dict.
Your test_nn module uses a list for convs, therefore it is not included in the state dict:
self.convs = [nn_block(feats_dim) for _ in range(num_conv)]
Not only are they not contained in the state dict, they are also not visible to net.parameters(), which means they are not trained/optimised at all.
To register the modules from the list you can wrap it in nn.ModuleList, which is a module that acts like a list, while correctly registering the modules it contains:
self.convs = nn.ModuleList([nn_block(feats_dim) for _ in range(num_conv)])
With that change both models produce the same result.
Since you are calling the convs modules sequentially in the for-loop (output of one module is the input of the next), you may consider using nn.Sequential, which you can call directly instead of having to use the for-loop. Sequencing is used a lot and it just makes it a little simpler, for example if you want to replace the sequence of modules with a single module, you don't need to change anything in the forward method.
Not just the two losses are different, they are also different from the loss computed by the best model that was saved in the first place.
When you are training, you calculate the loss for the current input (batch) and then you optimise the parameters based on that input. This means your parameters differ from the ones used to calculate the loss. Because you are saving the model after that, it will also have a different loss (the one that would occur in the next iteration).
preds = net(inputs)
# Calculating the loss of the current model
loss = F.cross_entropy(preds, labels)
optimizer.zero_grad()
loss.backward()
# Updating the model's parameters based on the loss
optimizer.step()
# State of the model after it has been updated
state_dict = {'state_dict': net.state_dict()}
# Comparing the loss from BEFORE the update
# But saving the model from AFTER the update
if loss.item()-best_score<-1e-4:
# save only parameters
torch.save(state_dict, 'model_params.torch')
# save the whole model
torch.save(net, 'whole_model.torch')
It's important to evaluate the model after the updates have been made. For this reason a validation set should be used, which is run after each epoch to assess the model's accuracy.

Memory utilization much higher than it should be

I'm using a simple method to extract descriptors from images and save them to disk into a .csv file. I have around 1M images and my network returns 512 features per image (float32).
Therefore, I estimate that at the end of the loop I would have 1e6 * 512 * 32/4 / 1e9 = 4.1GB. However, I observed that it is using more than twice the memory.
index is a string and class_id is a int64, so I don't think they are the culprit here.
I have already tried using gc.collect() without any success. Do you think my code is leaving references behind?
Here is the method:
def prepare_gallery(self, data_loader, TTA, pbar=False, dump_path=None):
'''Compute embeddings for a data_loader and store it in model.
This is required before predicting to a test set.
New entries should be removed from data before calling this function
to avoid inferring on useless images.
data_loader: A linear loader containing the database that test is
compared against.'''
self.set_mode('valid')
self.net.cuda()
n_iter = len(data_loader.dataset) / data_loader.batch_size
if pbar:
loader = tqdm(enumerate(data_loader), total=n_iter)
else:
loader = enumerate(data_loader)
# Run inference and get embeddings
feat_list = []
index_list = []
class_list = []
for i, (index, im, class_id) in loader:
with torch.no_grad():
feat = tta(self.net, im)
# Returns something like np.random.random((32, 512))
feat_list.extend(feat)
index_list.extend(index)
class_list.extend(class_id.item())
if dump_path is not None:
np.save(dump_path + '_ids', index_list)
np.save(dump_path + '_cls', class_list)
np.save(dump_path + '_feat', feat_list)
return np.asarray(index_list), np.asarray(feat_list), np.asarray(class_list)

How can I get the indices of the data used in every batch?

I need to save the indices of the data that are used in every mini-batch.
For example if my data is:
x = np.array([[1.1], [2.2], [3.3], [4.4]])
and the first mini-batch is [1.1] and [3.3], then I want to store 0 and 2 (since [1.1] is the 0th observations and [3.3] is the 2nd observation).
I am using tensorflow in eager execution with the keras.sequential APIs.
As far as I can tell from reading the source code, this information is not stored anywhere so I was unable to do this with a callback.
I am currently solving my problem by creating an object that stores the indices.
class IndexIterator(object):
def __init__(self, n, n_epochs, batch_size, shuffle=True):
data_ix = np.arange(n)
if shuffle:
np.random.shuffle(data_ix)
self.ix_batches = np.array_split(data_ix, np.ceil(n / batch_size))
self.batch_indices = []
def generate_arrays(self, x, y):
batch_ixs = np.arange(len(self.ix_batches))
while 1:
np.random.shuffle(batch_ixs)
for batch in batch_ixs:
self.batch_indices.append(self.ix_batches[batch])
yield (x[self.ix_batches[batch], :], y[self.ix_batches[batch], :])
data_gen = IndexIterator(n=32, n_epochs=100, batch_size=16)
dnn.fit_generator(data_gen.generate_arrays(x, y),
steps_per_epoch=2,
epochs=100)
# This is what I am looking for
print(data_gen.batch_indices)
Is there no way to do this using a tensorflow callback?
Not sure if this will be more efficient than your solution, but is certainly more general.
If you have training data with n indices you can create a secondary Dataset that contains only these indices and zip it with the "real" dataset.
I.E.
real_data = tf.data.Dataset ...
indices = tf.data.Dataset.from_tensor_slices(tf.range(data_set_length)))
total_dataset = tf.data.Dataset.zip((real_data, indices))
# Perform optional pre-processing ops.
iterator = total_dataset.make_one_shot_iterator()
# Next line yields `(original_data_element, index)`
item_and_index_tuple = iterator.get_next()
`

Categories

Resources