Error while loading model weights in pytorch - python

My model was in a .pth file and for loading the model i wrote this code.
model = torch.jit.load('/content/drive/MyDrive/fod.pth')
torch.save(model.state_dict(), 'weights.pt')
u2net = U2NETP()
u2net.eval()
u2net.load_state_dict(torch.load('/content/weights.pt'), strict = False)
U2NETP is the network architecture, but the problem here is that I am getting an error which goes like this
_IncompatibleKeys(missing_keys=['stage1.rebnconvin.bn_s1.running_mean', 'stage1.rebnconvin.bn_s1.running_var', 'stage1.rebnconv1.bn_s1.running_mean', 'stage1.rebnconv1.bn_s1.running_var', 'stage1.rebnconv2.bn_s1.running_mean', 'stage1.rebnconv2.bn_s1.running_var', 'stage1.rebnconv3.bn_s1.running_mean', 'stage1.rebnconv3.bn_s1.running_var', 'stage1.rebnconv4.bn_s1.running_mean', 'stage1.rebnconv4.bn_s1.running_var', 'stage1.rebnconv5.bn_s1.running_mean', 'stage1.rebnconv5.bn_s1.running_var', 'stage1.rebnconv6.bn_s1.running_mean', 'stage1.rebnconv6.bn_s1.running_var', 'stage1.rebnconv7.bn_s1.running_mean', 'stage1.rebnconv7.bn_s1.running_var', 'stage1.rebnconv6d.bn_s1.running_mean', 'stage1.rebnconv6d.bn_s1.running_var', 'stage1.rebnconv5d.bn_s1.running_mean', 'stage1.rebnconv5d.bn_s1.running_var', 'stage1.rebnconv4d.bn_s1.running_mean', 'stage1.rebnconv4d.bn_s1.running_var', 'stage1.rebnconv3d.bn_s1.running_mean', 'stage1.rebnconv3d.bn_s1.running_var', 'stage1.rebnconv2d.bn_s1.running_mean', 'stage1.rebnconv2d.bn_s1.running_var', 'stage1.rebnconv1d.bn_s1.running_mean', 'stage1.rebnconv1d.bn_s1.running_var', 'stage2.rebnconvin.bn_s1.running_mean', 'stage2.rebnconvin.bn_s1.running_var', 'stage2.rebnconv1.bn_s1.running_mean', 'stage2.rebnconv1.bn_s1.running_var', 'stage2.rebnconv2.bn_s1.running_mean', 'stage2.rebnconv2.bn_s1.running_var', 'stage2.rebnconv3.bn_s1.running_mean', 'stage2.rebnconv3.bn_s1.running_var', 'stage2.rebnconv4.bn_s1.running_mean', 'stage2.rebnconv4.bn_s1.running_var', 'stage2.rebnconv5.bn_s1.running_mean', 'stage2.rebnconv5.bn_s1.running_var', 'stage2.rebnconv6.bn_s1.running_mean', 'stage2.rebnconv6.bn_s1.running_var', 'stage2.rebnconv5d.bn_s1.running_mean', 'stage2.rebnconv5d.bn_s1.running_var', 'stage2.rebnconv4d.bn_s1.running_mean', 'stage2.rebnconv4d.bn_s1.running_var', 'stage2.rebnconv3d.bn_s1.running_mean', 'stage2.rebnconv3d.bn_s1.running_var', 'stage2.rebnconv2d.bn_s1.running_mean', 'stage2.rebnconv2d.bn_s1.running_var', 'stage2.rebnconv1d.bn_s1.running_mean', 'stage2.rebnconv1d.bn_s1.running_var', 'stage3.rebnconvin.bn_s1.running_mean', 'stage3.rebnconvin.bn_s1.running_var', 'stage3.rebnconv1.bn_s1.running_mean', 'stage3.rebnconv1.bn_s1.running_var', 'stage3.rebnconv2.bn_s1.running_mean', 'stage3.rebnconv2.bn_s1.running_var', 'stage3.rebnconv3.bn_s1.running_mean', 'stage3.rebnconv3.bn_s1.running_var', 'stage3.rebnconv4.bn_s1.running_mean', 'stage3.rebnconv4.bn_s1.running_var', 'stage3.rebnconv5.bn_s1.running_mean', 'stage3.rebnconv5.bn_s1.running_var', 'stage3.rebnconv4d.bn_s1.running_mean', 'stage3.rebnconv4d.bn_s1.running_var', 'stage3.rebnconv3d.bn_s1.running_mean', 'stage3.rebnconv3d.bn_s1.running_var', 'stage3.rebnconv2d.bn_s1.running_mean', 'stage3.rebnconv2d.bn_s1.running_var', 'stage3.rebnconv1d.bn_s1.running_mean', 'stage3.rebnconv1d.bn_s1.running_var', 'stage4.rebnconvin.bn_s1.running_mean', 'stage4.rebnconvin.bn_s1.running_var', 'stage4.rebnconv1.bn_s1.running_mean', 'stage4.rebnconv1.bn_s1.running_var', 'stage4.rebnconv2.bn_s1.running_mean', 'stage4.rebnconv2.bn_s1.running_var', 'stage4.rebnconv3.bn_s1.running_mean', 'stage4.rebnconv3.bn_s1.running_var', 'stage4.rebnconv4.bn_s1.running_mean', 'stage4.rebnconv4.bn_s1.running_var', 'stage4.rebnconv3d.bn_s1.running_mean', 'stage4.rebnconv3d.bn_s1.running_var', 'stage4.rebnconv2d.bn_s1.running_mean', 'stage4.rebnconv2d.bn_s1.running_var', 'stage4.rebnconv1d.bn_s1.running_mean', 'stage4.rebnconv1d.bn_s1.running_var', 'stage5.rebnconvin.bn_s1.running_mean', 'stage5.rebnconvin.bn_s1.running_var', 'stage5.rebnconv1.bn_s1.running_mean', 'stage5.rebnconv1.bn_s1.running_var', 'stage5.rebnconv2.bn_s1.running_mean', 'stage5.rebnconv2.bn_s1.running_var', 'stage5.rebnconv3.bn_s1.running_mean', 'stage5.rebnconv3.bn_s1.running_var', 'stage5.rebnconv4.bn_s1.running_mean', 'stage5.rebnconv4.bn_s1.running_var', 'stage5.rebnconv3d.bn_s1.running_mean', 'stage5.rebnconv3d.bn_s1.running_var', 'stage5.rebnconv2d.bn_s1.running_mean', 'stage5.rebnconv2d.bn_s1.running_var', 'stage5.rebnconv1d.bn_s1.running_mean', 'stage5.rebnconv1d.bn_s1.running_var', 'stage6.rebnconvin.bn_s1.running_mean', 'stage6.rebnconvin.bn_s1.running_var', 'stage6.rebnconv1.bn_s1.running_mean', 'stage6.rebnconv1.bn_s1.running_var', 'stage6.rebnconv2.bn_s1.running_mean', 'stage6.rebnconv2.bn_s1.running_var', 'stage6.rebnconv3.bn_s1.running_mean', 'stage6.rebnconv3.bn_s1.running_var', 'stage6.rebnconv4.bn_s1.running_mean', 'stage6.rebnconv4.bn_s1.running_var', 'stage6.rebnconv3d.bn_s1.running_mean', 'stage6.rebnconv3d.bn_s1.running_var', 'stage6.rebnconv2d.bn_s1.running_mean', 'stage6.rebnconv2d.bn_s1.running_var', 'stage6.rebnconv1d.bn_s1.running_mean', 'stage6.rebnconv1d.bn_s1.running_var', 'stage5d.rebnconvin.bn_s1.running_mean', 'stage5d.rebnconvin.bn_s1.running_var', 'stage5d.rebnconv1.bn_s1.running_mean', 'stage5d.rebnconv1.bn_s1.running_var', 'stage5d.rebnconv2.bn_s1.running_mean', 'stage5d.rebnconv2.bn_s1.running_var', 'stage5d.rebnconv3.bn_s1.running_mean', 'stage5d.rebnconv3.bn_s1.running_var', 'stage5d.rebnconv4.bn_s1.running_mean', 'stage5d.rebnconv4.bn_s1.running_var', 'stage5d.rebnconv3d.bn_s1.running_mean', 'stage5d.rebnconv3d.bn_s1.running_var', 'stage5d.rebnconv2d.bn_s1.running_mean', 'stage5d.rebnconv2d.bn_s1.running_var', 'stage5d.rebnconv1d.bn_s1.running_mean', 'stage5d.rebnconv1d.bn_s1.running_var', 'stage4d.rebnconvin.bn_s1.running_mean', 'stage4d.rebnconvin.bn_s1.running_var', 'stage4d.rebnconv1.bn_s1.running_mean', 'stage4d.rebnconv1.bn_s1.running_var', 'stage4d.rebnconv2.bn_s1.running_mean', 'stage4d.rebnconv2.bn_s1.running_var', 'stage4d.rebnconv3.bn_s1.running_mean', 'stage4d.rebnconv3.bn_s1.running_var', 'stage4d.rebnconv4.bn_s1.running_mean', 'stage4d.rebnconv4.bn_s1.running_var', 'stage4d.rebnconv3d.bn_s1.running_mean', 'stage4d.rebnconv3d.bn_s1.running_var', 'stage4d.rebnconv2d.bn_s1.running_mean', 'stage4d.rebnconv2d.bn_s1.running_var', 'stage4d.rebnconv1d.bn_s1.running_mean', 'stage4d.rebnconv1d.bn_s1.running_var', 'stage3d.rebnconvin.bn_s1.running_mean', 'stage3d.rebnconvin.bn_s1.running_var', 'stage3d.rebnconv1.bn_s1.running_mean', 'stage3d.rebnconv1.bn_s1.running_var', 'stage3d.rebnconv2.bn_s1.running_mean', 'stage3d.rebnconv2.bn_s1.running_var', 'stage3d.rebnconv3.bn_s1.running_mean', 'stage3d.rebnconv3.bn_s1.running_var', 'stage3d.rebnconv4.bn_s1.running_mean', 'stage3d.rebnconv4.bn_s1.running_var', 'stage3d.rebnconv5.bn_s1.running_mean', 'stage3d.rebnconv5.bn_s1.running_var', 'stage3d.rebnconv4d.bn_s1.running_mean', 'stage3d.rebnconv4d.bn_s1.running_var', 'stage3d.rebnconv3d.bn_s1.running_mean', 'stage3d.rebnconv3d.bn_s1.running_var', 'stage3d.rebnconv2d.bn_s1.running_mean', 'stage3d.rebnconv2d.bn_s1.running_var', 'stage3d.rebnconv1d.bn_s1.running_mean', 'stage3d.rebnconv1d.bn_s1.running_var', 'stage2d.rebnconvin.bn_s1.running_mean', 'stage2d.rebnconvin.bn_s1.running_var', 'stage2d.rebnconv1.bn_s1.running_mean', 'stage2d.rebnconv1.bn_s1.running_var', 'stage2d.rebnconv2.bn_s1.running_mean', 'stage2d.rebnconv2.bn_s1.running_var', 'stage2d.rebnconv3.bn_s1.running_mean', 'stage2d.rebnconv3.bn_s1.running_var', 'stage2d.rebnconv4.bn_s1.running_mean', 'stage2d.rebnconv4.bn_s1.running_var', 'stage2d.rebnconv5.bn_s1.running_mean', 'stage2d.rebnconv5.bn_s1.running_var', 'stage2d.rebnconv6.bn_s1.running_mean', 'stage2d.rebnconv6.bn_s1.running_var', 'stage2d.rebnconv5d.bn_s1.running_mean', 'stage2d.rebnconv5d.bn_s1.running_var', 'stage2d.rebnconv4d.bn_s1.running_mean', 'stage2d.rebnconv4d.bn_s1.running_var', 'stage2d.rebnconv3d.bn_s1.running_mean', 'stage2d.rebnconv3d.bn_s1.running_var', 'stage2d.rebnconv2d.bn_s1.running_mean', 'stage2d.rebnconv2d.bn_s1.running_var', 'stage2d.rebnconv1d.bn_s1.running_mean', 'stage2d.rebnconv1d.bn_s1.running_var', 'stage1d.rebnconvin.bn_s1.running_mean', 'stage1d.rebnconvin.bn_s1.running_var', 'stage1d.rebnconv1.bn_s1.running_mean', 'stage1d.rebnconv1.bn_s1.running_var', 'stage1d.rebnconv2.bn_s1.running_mean', 'stage1d.rebnconv2.bn_s1.running_var', 'stage1d.rebnconv3.bn_s1.running_mean', 'stage1d.rebnconv3.bn_s1.running_var', 'stage1d.rebnconv4.bn_s1.running_mean', 'stage1d.rebnconv4.bn_s1.running_var', 'stage1d.rebnconv5.bn_s1.running_mean', 'stage1d.rebnconv5.bn_s1.running_var', 'stage1d.rebnconv6.bn_s1.running_mean', 'stage1d.rebnconv6.bn_s1.running_var', 'stage1d.rebnconv7.bn_s1.running_mean', 'stage1d.rebnconv7.bn_s1.running_var', 'stage1d.rebnconv6d.bn_s1.running_mean', 'stage1d.rebnconv6d.bn_s1.running_var', 'stage1d.rebnconv5d.bn_s1.running_mean', 'stage1d.rebnconv5d.bn_s1.running_var', 'stage1d.rebnconv4d.bn_s1.running_mean', 'stage1d.rebnconv4d.bn_s1.running_var', 'stage1d.rebnconv3d.bn_s1.running_mean', 'stage1d.rebnconv3d.bn_s1.running_var', 'stage1d.rebnconv2d.bn_s1.running_mean', 'stage1d.rebnconv2d.bn_s1.running_var', 'stage1d.rebnconv1d.bn_s1.running_mean', 'stage1d.rebnconv1d.bn_s1.running_var'], unexpected_keys=[])
for param_tensor in sd:
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
I used this code for printing the weights. Seems like it contains weights and bias keys but not running mean/ variance

Related

How to add all standard special tokens to my hugging face tokenizer and model?

I want all special tokens to always be available. How do I do this?
My first attempt to give it to my tokenizer:
def does_t5_have_sep_token():
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained('t5-small')
assert isinstance(tokenizer, PreTrainedTokenizerFast)
print(tokenizer)
print(f'{len(tokenizer)=}')
# print(f'{tokenizer.all_special_tokens=}')
print(f'{tokenizer.sep_token=}')
print(f'{tokenizer.eos_token=}')
print(f'{tokenizer.all_special_tokens=}')
special_tokens_dict = {'additional_special_tokens': ['<bos>', '<cls>', '<s>'] + tokenizer.all_special_tokens }
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print(f'{tokenizer.sep_token=}')
print(f'{tokenizer.eos_token=}')
print(f'{tokenizer.all_special_tokens=}')
if __name__ == '__main__':
does_t5_have_sep_token()
print('Done\a')
but feels hacky.
refs:
https://github.com/huggingface/tokenizers/issues/247
https://discuss.huggingface.co/t/how-to-add-all-standard-special-tokens-to-my-tokenizer-and-model/21529
seems useful: https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
I want to add standard tokens by adding the right "standard tokens" the solution provided didn't work for me since the .bos_token is still None. See:
tokenizer.bos_token=None
tokenizer.cls_token=None
tokenizer.sep_token=None
tokenizer.mask_token=None
tokenizer.eos_token='</s>'
tokenizer.unk_token='<unk>'
tokenizer.bos_token_id=None
tokenizer.cls_token_id=None
tokenizer.sep_token_id=None
tokenizer.mask_token_id=None
tokenizer.eos_token_id=1
tokenizer.unk_token_id=2
tokenizer.all_special_tokens=['</s>', '<unk>', '<pad>', '<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>', '<extra_id_46>', '<extra_id_47>', '<extra_id_48>', '<extra_id_49>', '<extra_id_50>', '<extra_id_51>', '<extra_id_52>', '<extra_id_53>', '<extra_id_54>', '<extra_id_55>', '<extra_id_56>', '<extra_id_57>', '<extra_id_58>', '<extra_id_59>', '<extra_id_60>', '<extra_id_61>', '<extra_id_62>', '<extra_id_63>', '<extra_id_64>', '<extra_id_65>', '<extra_id_66>', '<extra_id_67>', '<extra_id_68>', '<extra_id_69>', '<extra_id_70>', '<extra_id_71>', '<extra_id_72>', '<extra_id_73>', '<extra_id_74>', '<extra_id_75>', '<extra_id_76>', '<extra_id_77>', '<extra_id_78>', '<extra_id_79>', '<extra_id_80>', '<extra_id_81>', '<extra_id_82>', '<extra_id_83>', '<extra_id_84>', '<extra_id_85>', '<extra_id_86>', '<extra_id_87>', '<extra_id_88>', '<extra_id_89>', '<extra_id_90>', '<extra_id_91>', '<extra_id_92>', '<extra_id_93>', '<extra_id_94>', '<extra_id_95>', '<extra_id_96>', '<extra_id_97>', '<extra_id_98>', '<extra_id_99>']
Using bos_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using sep_token, but it is not set yet.
Using mask_token, but it is not set yet.
code:
def does_t5_have_sep_token():
"""
https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
"""
import torch
from transformers import AutoModelForSeq2SeqLM
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained('t5-small')
assert isinstance(tokenizer, PreTrainedTokenizerFast)
print(tokenizer)
print(f'{len(tokenizer)=}')
print()
print(f'{tokenizer.sep_token=}')
print(f'{tokenizer.eos_token=}')
print(f'{tokenizer.all_special_tokens=}')
print()
# special_tokens_dict = {'additional_special_tokens': ['<bos>', '<cls>', '<s>'] + tokenizer.all_special_tokens}
# num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
tokenizer.add_tokens([f"_{n}" for n in range(1, 100)], special_tokens=True)
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
assert isinstance(model, torch.nn.Module)
model.resize_token_embeddings(len(tokenizer))
# tokenizer.save_pretrained('pathToExtendedTokenizer/')
# tokenizer = T5Tokenizer.from_pretrained("sandbox/t5_models/pretrained/tokenizer/")
print()
print(f'{tokenizer.bos_token=}')
print(f'{tokenizer.cls_token=}')
print(f'{tokenizer.sep_token=}')
print(f'{tokenizer.mask_token=}')
print(f'{tokenizer.eos_token=}')
print(f'{tokenizer.unk_token=}')
print(f'{tokenizer.bos_token_id=}')
print(f'{tokenizer.cls_token_id=}')
print(f'{tokenizer.sep_token_id=}')
print(f'{tokenizer.mask_token_id=}')
print(f'{tokenizer.eos_token_id=}')
print(f'{tokenizer.unk_token_id=}')
print(f'{tokenizer.all_special_tokens=}')
print()
if __name__ == '__main__':
does_t5_have_sep_token()
print('Done\a')
I do not entirely understand what you're trying to accomplish, but here are some notes that might help:
T5 documentation shows that T5 has only three special tokens (</s>, <unk> and <pad>). You can also see this in the T5Tokenizer class definition. I am confident this is because the original T5 model was trained only with these special tokens (no BOS, no MASK, no CLS).
Running, e.g.,
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-small')
print(tokenizer.all_special_tokens)
will show you these three tokens as well as the <extra_id_*> tokens.
Is there a reason you want the other tokens like BOS?
(Edit - to answer your comments):
(I really think you would benefit from reading the linked documentation at huggingface. The point of a pretrained model is to take advantage of what has already been done. T5 does not use BOS nor CLS in the way you seem to be imagining. Maybe you can get it to work, but IMO it makes more sense to adapt the task you want to solve to the T5 approach)
</s> is the sep token and is already available.
As I understand, for the T5 model, masking (for the sake of ignoring loss) is implemented using attention_mask. On the other hand, if you want to "fill in the blank" then <extra_id> is used to indicate to the model that it should predict the missing token (this is how semi-supervised pretraining is done). See the section on training in the documentation.
BOS is similar - T5 is not trained to use a BOS token. (E.g. (again from documentation),
Note that T5 uses the pad_token_id as the decoder_start_token_id, so
when doing generation without using generate(), make sure you start it
with the pad_token_id.
t5 does not use the CLS token. If you want to do classification, you should finetune a new task (or find a corresponding one done in pretraining), finetuning the model to generate a word (or words) that correspond to the classifications you want.
(again from documentation:)
Build model inputs from a sequence or a pair of sequence for sequence
classification tasks by concatenating and adding special tokens. A
sequence has the following format:
I think this is correct. Please correct me if I'm wrong:
def add_special_all_special_tokens(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
"""
special_tokens_dict = {"cls_token": "<CLS>"}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
assert tokenizer.cls_token == "<CLS>"
"""
original_len: int = len(tokenizer)
num_added_toks: dict = {}
if tokenizer.bos_token is None:
num_added_toks['bos_token'] = "<bos>"
if tokenizer.bos_token is None:
num_added_toks['cls_token'] = "<cls>"
if tokenizer.bos_token is None:
num_added_toks['sep_token'] = "<s>"
if tokenizer.bos_token is None:
num_added_toks['mask_token'] = "<mask>"
# num_added_toks = {"bos_token": "<bos>", "cls_token": "<cls>", "sep_token": "<s>", "mask_token": "<mask>"}
# special_tokens_dict = {'additional_special_tokens': new_special_tokens + tokenizer.all_special_tokens}
num_new_tokens: int = tokenizer.add_special_tokens(num_added_toks)
assert tokenizer.bos_token == "<bos>"
assert tokenizer.cls_token == "<cls>"
assert tokenizer.sep_token == "<s>"
assert tokenizer.mask_token == "<mask>"
msg = f"Error, not equal: {len(tokenizer)=}, {original_len + num_new_tokens=}"
assert len(tokenizer) == original_len + num_new_tokens, msg
left comment from doc that inspired my answer:
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
"""
Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
current vocabulary).
Note,None When adding new tokens to the vocabulary, you should make sure to also resize the token embedding
matrix of the model so that its embedding matrix matches the tokenizer.
In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
Using `add_special_tokens` will ensure your special tokens can be used in several ways:
- Special tokens are carefully handled by the tokenizer (they are never split).
- You can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This
makes it easy to develop model-agnostic training and fine-tuning scripts.
When possible, special tokens are already registered for provided pretrained models (for instance
[`BertTokenizer`] `cls_token` is already registered to be :obj*'[CLS]'* and XLM's one is also registered to be
`'</s>'`).
Args:
special_tokens_dict (dictionary *str* to *str* or `tokenizers.AddedToken`):
Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`,
`sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`].
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
assign the index of the `unk_token` to them).
Returns:
`int`: Number of tokens added to the vocabulary.
Examples:
```python
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")
special_tokens_dict = {"cls_token": "<CLS>"}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
assert tokenizer.cls_token == "<CLS>"
```"""
it was in hf's tokenization_utils_base.py
I think the right answer is here: https://stackoverflow.com/a/73361984/1601580
Links can be bad answers so here is the code:
def add_special_all_special_tokens(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
"""
special_tokens_dict = {"cls_token": "<CLS>"}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
assert tokenizer.cls_token == "<CLS>"
"""
original_len: int = len(tokenizer)
num_added_toks: dict = {}
if tokenizer.bos_token is None:
num_added_toks['bos_token'] = "<bos>"
if tokenizer.bos_token is None:
num_added_toks['cls_token'] = "<cls>"
if tokenizer.bos_token is None:
num_added_toks['sep_token'] = "<s>"
if tokenizer.bos_token is None:
num_added_toks['mask_token'] = "<mask>"
# num_added_toks = {"bos_token": "<bos>", "cls_token": "<cls>", "sep_token": "<s>", "mask_token": "<mask>"}
# special_tokens_dict = {'additional_special_tokens': new_special_tokens + tokenizer.all_special_tokens}
num_new_tokens: int = tokenizer.add_special_tokens(num_added_toks)
assert tokenizer.bos_token == "<bos>"
assert tokenizer.cls_token == "<cls>"
assert tokenizer.sep_token == "<s>"
assert tokenizer.mask_token == "<mask>"
err_msg = f"Error, not equal: {len(tokenizer)=}, {original_len + num_new_tokens=}"
assert len(tokenizer) == original_len + num_new_tokens, err_msg
Feedback is always welcomed.

huggingface fine tuning distilbert-base-uncased and then using pipeline throws error KeyError: 'logits'

I am referring to the model. I am fine tuning that model. Before fine tuning I could use pipeline as below
from transformers import AutoTokenizer
model_check = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_check)
from transformers import pipeline
from transformers import AutoModelForSequenceClassification
modelz = (AutoModelForSequenceClassification.from_pretrained(model_check))
classifier = pipeline("text-classification", model=modelz,tokenizer=tokenizer)
custom_tweet = "I saw a movie today and it was really good."
preds = classifier(custom_tweet, return_all_scores=True)
preds
[[{'label': 'LABEL_0', 'score': 0.5338158011436462},
{'label': 'LABEL_1', 'score': 0.46618416905403137}]]
type(modelz)
transformers.models.distilbert.modeling_distilbert.DistilBertForSequenceClassification
then I am using from transformers import Trainer to train the model and saving it. But after that I havent been able to use pipeline :(
saved_model_path='/zyz'
trainer.save_model(saved_model_path)
model_saved=AutoModel.from_pretrained(saved_model_path)
tokenizer_saved=AutoTokenizer.from_pretrained(saved_model_path)
type(model_saved)
#transformers.models.distilbert.modeling_distilbert.DistilBertModel
from transformers import pipeline
classifier = pipeline("text-classification", model=model_saved,tokenizer=tokenizer_saved)
custom_tweet = "I saw a movie today and it was really good."
preds = classifier(custom_tweet)#, return_all_scores=True)
The model 'DistilBertModel' is not supported for text-classification. Supported models are ['FNetForSequenceClassification', 'GPTJForSequenceClassification', 'LayoutLMv2ForSequenceClassification', 'RemBertForSequenceClassification', 'CanineForSequenceClassification', 'RoFormerForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BigBirdForSequenceClassification', 'ConvBertForSequenceClassification', 'LEDForSequenceClassification', 'DistilBertForSequenceClassification', 'AlbertForSequenceClassification', 'CamembertForSequenceClassification', 'XLMRobertaForSequenceClassification', 'MBartForSequenceClassification', 'BartForSequenceClassification', 'LongformerForSequenceClassification', 'RobertaForSequenceClassification', 'SqueezeBertForSequenceClassification', 'LayoutLMForSequenceClassification', 'BertForSequenceClassification', 'XLNetForSequenceClassification', 'MegatronBertForSequenceClassification', 'MobileBertForSequenceClassification', 'FlaubertForSequenceClassification', 'XLMForSequenceClassification', 'ElectraForSequenceClassification', 'FunnelForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'GPT2ForSequenceClassification', 'GPTNeoForSequenceClassification', 'OpenAIGPTForSequenceClassification', 'ReformerForSequenceClassification', 'CTRLForSequenceClassification', 'TransfoXLForSequenceClassification', 'MPNetForSequenceClassification', 'TapasForSequenceClassification', 'IBertForSequenceClassification'].
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-92-cf10c8e85589> in <module>()
4
5 custom_tweet = "I saw a movie today and it was really good."
----> 6 preds = classifier(custom_tweet)#, return_all_scores=True)
5 frames
/usr/local/lib/python3.7/dist-packages/transformers/file_utils.py in __getitem__(self, k)
1952 if isinstance(k, str):
1953 inner_dict = {k: v for (k, v) in self.items()}
-> 1954 return inner_dict[k]
1955 else:
1956 return self.to_tuple()[k]
KeyError: 'logits'
I am confused about 2 things -
why does it say The model 'DistilBertModel' is not supported for text-classification. Supported models are in the second part of the code when the first part runs without any problems. I tried type(modelz) and type(model_saved) and I see that the trainer is being saved differently. Why is that? How should I save the trianer so that it's class doesnt change?
Why does it throw error KeyError: 'logits'
update 1================================found answer
I fixed it by changing as below
model_saved=AutoModelForSequenceClassification.from_pretrained(saved_model_path)

How to access the embeddings using tensorflow hub.module?

I am using the following code to access the embeddings using TF Hub Universal Sentence encoder.
import tensorflow as tf
import tensorflow_hub as hub
model = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
def embed(input):
return model(input)
messages = ["There is no hard limit on how long the paragraph is. Roughly, the longer the more 'diluted' the embedding will be."]
message_embeddings = embed(messages)
How can I access the actual vectors now?
Actual Embedding Vectors can be accessed from the Variable, message_embeddings.
message_embeddings is a Vector of shape=(1, 512), meaning, the Dimensionality of the Vector returned by USE-4 is 512.
In other words, Every Sentence is encoded into 512 Columned Vector.
Output of the code,
print(message_embeddings)
is
tf.Tensor(
[[-0.00366504 -0.00703163 -0.0061244 0.02026021 -0.09436475 0.00027828
0.05004153 -0.01591516 0.088241 0.07551358 -0.01868021 0.04386544
0.00105771 0.03730893 -0.05554571 0.02852311 0.01709696 0.08152976
-0.03092775 0.00683713 -0.08059237 0.042355 -0.07580714 -0.00443942
-0.03430099 0.03240041 -0.05212452 -0.04247908 -0.05534476 -0.02328587
-0.0438301 -0.03972115 0.01639873 0.00163302 0.07708091 -0.02310511
0.01288455 0.04831124 0.0089498 -0.02632253 -0.01840279 0.02118563
0.03758964 0.08740229 0.02880297 -0.00486817 0.0115555 -0.00451289
-0.00162866 0.01446948 0.00189139 -0.07941346 -0.0216493 -0.02580371
-0.00930381 -0.00526039 -0.01272183 0.02215818 0.04742621 0.02226813
0.0110765 -0.01790449 0.01739751 -0.08388933 0.05826297 -0.05230762
-0.07484917 0.06905693 0.01646299 0.00850342 -0.0022191 -0.07555264
0.01601691 0.06028103 0.00524664 0.03776945 -0.05246941 0.03556651
0.06253887 -0.04647287 -0.03415112 -0.03473583 0.04833042 -0.01264609
0.01788526 -0.07143527 -0.02432756 0.04081429 -0.0524265 -0.05402376
-0.02753968 0.06558003 0.01936845 -0.08112626 0.0157347 0.05620547
-0.06219236 -0.03654391 0.03936478 -0.01247254 -0.03957544 0.07394353
-0.06131149 -0.0550663 0.08301188 -0.01699291 0.03726438 0.00248359
-0.00569713 0.04109528 -0.05154289 0.05428214 -0.06594346 0.06009263
0.02753788 0.01492724 -0.01422153 0.02779302 0.02881143 -0.01985389
0.05809831 -0.02661227 -0.06907296 0.01192496 -0.03630216 0.03146286
-0.02979902 0.05192203 -0.0479207 0.03564131 0.05351846 0.02681697
0.02597373 -0.03392426 -0.05286925 -0.05110073 0.01331552 -0.00612995
-0.04932296 -0.0185418 -0.0841584 0.02415963 -0.01051812 0.05603031
-0.0083728 -0.05966095 0.0321536 -0.03968453 0.03799454 -0.05958865
-0.07585841 0.04390398 -0.03674331 0.01918785 0.03446485 -0.04106916
-0.05183128 0.02947152 -0.03531763 0.03698466 0.06261521 -0.00646621
0.01130813 -0.02275244 -0.04280937 0.01955702 -0.03919312 0.00476116
0.01887495 -0.00195181 -0.02401051 -0.06942239 -0.06978329 0.06458326
0.00362934 0.03588834 0.04921037 -0.03195003 0.02806171 -0.0193333
0.00994556 -0.02342404 0.10165592 -0.02853323 0.04147425 0.00914851
0.00497671 0.00073764 -0.00318258 0.03595887 -0.01817959 0.01496308
-0.03551586 0.02536247 -0.07170779 -0.03153825 -0.04042004 -0.01769615
0.00958568 0.00038516 0.00799816 0.04089458 0.02171035 -0.08852603
-0.06747856 0.05664572 -0.06597329 0.02299296 0.03397151 -0.03845559
0.00395073 0.00314357 0.01119022 0.05957965 -0.05583638 0.02908287
0.0112076 0.07695369 -0.03935304 -0.02383705 -0.04208985 -0.00359387
0.06851663 -0.05395376 -0.00246254 -0.01888378 -0.01391678 -0.07573339
0.05811501 0.02059502 -0.00418438 -0.01210096 -0.06286791 -0.07645103
-0.02463043 -0.03153505 0.05593796 -0.02202086 -0.00274707 0.04458077
-0.06263509 0.06126784 -0.04235342 0.00322403 0.02189728 -0.06388599
-0.03919036 -0.00010863 0.02531325 0.02581233 -0.01304512 -0.03001025
-0.02754986 0.0531372 -0.02369525 -0.04376267 0.0641819 0.09532097
-0.06730784 0.04478338 0.02004733 0.05244097 -0.01885018 -0.06137342
-0.08407518 -0.00084469 -0.02145135 -0.0091182 -0.06907462 0.06986497
0.0600312 -0.04390564 -0.00131028 0.06390417 0.03533437 0.03813365
0.04030495 -0.01402102 -0.06857175 -0.06571147 0.01421791 -0.0381003
-0.04138157 0.05040992 -0.05724671 0.01490439 -0.07905842 -0.03806996
-0.01071311 -0.01229521 -0.00771822 -0.03641455 -0.04578875 0.00925799
0.0403841 0.00132017 0.031641 0.01162737 0.0101506 -0.01761867
0.0579349 0.03595775 -0.01147426 -0.01525036 0.05006553 0.03747585
-0.05307707 -0.08915938 0.02942844 -0.05546442 -0.0128964 0.04225868
-0.01534053 -0.04580414 0.01088955 -0.03184818 0.02326705 -0.08861458
-0.07253686 -0.02572111 -0.03711193 0.0474383 -0.05628109 -0.01391787
0.00941848 -0.06177152 -0.06071901 -0.0092127 -0.10220838 -0.01376523
0.03162379 0.03983926 0.00640659 -0.00418033 -0.01612685 0.01891562
-0.04313575 0.01139805 -0.00378637 0.08349139 0.08300766 -0.0494319
-0.03658734 0.00325003 -0.05251636 -0.04457545 -0.079386 -0.05799922
-0.01254137 0.02311826 -0.00766293 -0.06729192 -0.03971054 -0.0663051
0.08720677 0.04582898 -0.08557201 -0.01054355 -0.02762848 0.06243869
-0.08848279 0.02289506 0.05723204 -0.01221769 -0.0393519 -0.00582338
0.02841124 -0.03293297 -0.03143778 -0.00352248 0.0073043 0.01209227
-0.00148794 0.03695554 0.03136331 -0.03311655 -0.0221175 -0.07959055
-0.04138357 -0.00950083 -0.01173625 0.01499144 -0.0121095 0.00823302
0.07642982 0.05198056 0.05955188 0.03240911 0.09211077 -0.05317325
-0.06024589 0.00489183 0.04719653 0.02498623 0.03750401 -0.02352423
0.05042319 -0.01633615 -0.02236294 0.04443104 0.02694818 0.00881322
0.02469178 -0.06206469 -0.00215397 -0.02641553 0.00405129 -0.07184313
-0.02841844 0.0309756 0.02459977 -0.03155032 0.01407542 0.00524732
-0.01893367 0.0102607 -0.00333736 0.02885202 -0.03275619 -0.08507563
0.02076722 -0.02471628 -0.00449985 0.0004644 -0.0923043 0.02101186
0.0352884 0.03790538 -0.00372656 0.06751391 0.02638355 0.01678842
0.03843728 0.10451197 -0.06375936 -0.05324562 0.03276567 -0.01112294
-0.0082361 -0.01735083 -0.03767544 -0.04266915 -0.04767371 0.07573947
-0.01247379 -0.01048137 -0.02308911 -0.01484709 -0.00733855 0.06788232
-0.08163249 -0.01530467 -0.01805264 -0.07910046 -0.06530869 0.07402557
0.06713054 -0.01659747 -0.00980262 0.05586078 0.03396358 -0.06102567
-0.06640005 0.02269907 0.03265672 -0.01353668 -0.08313932 -0.02356159
-0.03383274 0.05942128 -0.08610516 -0.08445066 -0.01306568 -0.05279852
0.00986506 0.00461306 0.08119206 0.00604 0.10107437 0.00191085
-0.05926891 0.01157635 0.0284292 -0.08671403 0.01851062 0.05745851
-0.06798992 0.02700593 0.00208116 -0.00829788 0.08901995 -0.00418414
-0.06217562 -0.07832154 0.02027107 0.06713033 0.04617893 0.05885412
-0.04505047 0.09581003 0.033753 -0.00888314 -0.07608356 -0.03729891
0.02724086 0.02371461 -0.01081131 -0.00809431 -0.04376922 -0.04656423
0.00886904 0.01995739]], shape=(1, 512), dtype=float32)
Hope this helps. Happy Learning!

Tensorflow Object Detection Jupyter Notebook no detection

I tried to run the Jupyter Notebook example for the object detection of tensorflow (tutorial) but there are no detections. I printed the scores and it's seems to work but the results are very bad. Does anyone have an idea what I might have done wrong.
print(scores[0]):
[ 0.03587309 0.02224856 0.01864638 0.01096715 0.0100315
0.0065446
0.00633551 0.00534311 0.00495995 0.00410238 0.00362363 0.00339175
0.00308251 0.0030337 0.00293387 0.00277085 0.00269581 0.00266825
0.00263924 0.00263331 0.00258721 0.00240822 0.00225823 0.00186966
0.00184308 0.00180467 0.00177474 0.00173643 0.0017281 0.00171935
0.00171891 0.00170284 0.00163754 0.00162967 0.00160267 0.00156545
0.00153614 0.00140936 0.00132406 0.00131524 0.00131041 0.00129431
0.00125819 0.0012553 0.00122365 0.00119179 0.00115673 0.00115186
0.00112368 0.00107096 0.00105803 0.00104337 0.00102719 0.00102337
0.00100349 0.00097767 0.0009685 0.00092741 0.00088506 0.00087696
0.0008734 0.00084825 0.00084135 0.00083512 0.00083396 0.00082068
0.00080583 0.00078979 0.00078059 0.00077475 0.00075449 0.00074426
0.00074421 0.00070195 0.00068741 0.00068138 0.00067261 0.00067125
0.00067032 0.00066041 0.0006473 0.00064205 0.00061964 0.00061793
0.00060834 0.00060468 0.00059547 0.00059478 0.00059461 0.00059436
0.00059426 0.00059411 0.00059406 0.00059392 0.00059365 0.00059351
0.00059191 0.00058798 0.00058682 0.00058148]
[ 0.01044157 0.00982138 0.00942336 0.00846517 0.00613665 0.00398568
0.00357755 0.00300539 0.00255862 0.00236576 0.00232631 0.00220291
0.00185227 0.00163544 0.00159791 0.00145071 0.0014366 0.0014137
0.00122685 0.00118978 0.00108457 0.00104252 0.00099215 0.00096401
0.0008708 0.00084774 0.00080484 0.00078507 0.00078379 0.00076875
0.00072774 0.00071732 0.00071343 0.00070812 0.00069253 0.0006762
0.00067269 0.00059905 0.00059367 0.000588 0.00056114 0.0005504
0.00051472 0.00051055 0.00050973 0.00048484 0.00047297 0.00046204
0.00044787 0.00043259 0.00042987 0.00042673 0.00041978 0.00040494
0.00040087 0.00039576 0.00039059 0.00037274 0.00036828 0.00036417
0.0003612 0.00034645 0.00034479 0.00034078 0.00033771 0.00033605
0.0003333 0.0003304 0.0003294 0.00032325 0.00031787 0.00031773
0.00031748 0.00031741 0.00031732 0.00031729 0.00031724 0.00031722
0.00031717 0.00031708 0.00031702 0.00031579 0.00030416 0.00030222
0.00029739 0.00029726 0.00028289 0.00026527 0.00026325 0.00024584
0.00024221 0.00024156 0.0002391 0.00023335 0.00021617 0.0002001
0.00019127 0.00018342 0.00017271 0.00015507]
I'm running the example with tensorflow 1.4, python 3.5 and I tested the installation as suggested.
I had the same issue. I found in a post that you have to change:
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_08'
To:
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
and it worked fine.
Original answer: https://stackoverflow.com/a/47332228/8954260

R DTW multivariate series with asymmetric step fails to compute alignment

I'm using the DTW implementation found in R along with the python bindings in order to verify the effects of changing different parameters(like local constraint, local distance function and others) for my data. The data represents feature vectors that an audio processing frontend outputs(MFCC). Because of this I am dealing with multivariate time series, each feature vector has a size of 8. The problem I'm facing is when I try to use certain local constraints ( or step patterns ) I get the following error:
Error in if (is.na(gcm$distance)) { : argument is of length zero
Traceback (most recent call last):
File "r_dtw_simplified.py", line 32, in <module>
alignment = R.dtw(canDist, rNull, "Euclidean", stepPattern, "none", True, Fa
lse, True, False )
File "D:\Python27\lib\site-packages\rpy2\robjects\functions.py", line 86, in _
_call__
return super(SignatureTranslatedFunction, self).__call__(*args, **kwargs)
File "D:\Python27\lib\site-packages\rpy2\robjects\functions.py", line 35, in _
_call__
res = super(Function, self).__call__(*new_args, **new_kwargs)
rpy2.rinterface.RRuntimeError: Error in if (is.na(gcm$distance)) { : argument is
of length zero
Because the process of generating and adapting the input data is complicated I only made a simplified script to ilustrate the error i'm receiving.
#data works
#reference = [[-0.126678, -1.541763, 0.29985, 1.719757, 0.755798, -3.594681, -1.492798, 3.493042], [-0.110596, -1.638184, 0.128174, 1.638947, 0.721085, -3.247696, -0.920013, 3.763977], [-0.022415, -1.643539, -0.130692, 1.441742, 1.022064, -2.882172, -0.952225, 3.662842], [0.071259, -2.030411, -0.531891, 0.835114, 1.320419, -2.432281, -0.469116, 3.871094], [0.070526, -2.056702, -0.688293, 0.530396, 1.962128, -1.681915, -0.368973, 4.542419], [0.047745, -2.005127, -0.798203, 0.616028, 2.146988, -1.895874, 0.371597, 4.090881], [0.013962, -2.162796, -1.008545, 0.363495, 2.062866, -0.856613, 0.543884, 4.043335], [0.066757, -2.152969, -1.087097, 0.257263, 2.592697, -0.422424, -0.280533, 3.327576], [0.123123, -2.061035, -1.012863, 0.389282, 2.50206, 0.078186, -0.887711, 2.828247], [0.157455, -2.060425, -0.790344, 0.210419, 2.542114, 0.016983, -0.959274, 1.916504], [0.029648, -2.128204, -1.047318, 0.116547, 2.44899, 0.166534, -0.677551, 2.49231], [0.158554, -1.821365, -1.045044, 0.374207, 2.426712, 0.406952, -1.055084, 2.543762], [0.077026, -1.863235, -1.14827, 0.277069, 2.669067, 0.362549, -1.294342, 1.66748], [0.101822, -1.800293, -1.126801, 0.364594, 2.503815, 0.294846, -0.881302, 1.281616], [0.166138, -1.627762, -0.866013, 0.494476, 2.450668, 0.569, -1.392868, 0.651184], [0.225006, -1.596069, -1.07634, 0.550049, 2.167435, 0.554123, -1.432983, 1.166931], [0.114777, -1.462769, -0.793167, 0.565704, 2.183792, 0.345978, -1.410919, 0.708679], [0.144028, -1.444458, -0.831985, 0.536652, 2.222366, 0.330368, -0.715149, 0.517212], [0.147888, -1.450577, -0.809372, 0.479584, 2.271378, 0.250763, -0.540359, -0.036072], [0.090714, -1.485474, -0.888153, 0.268768, 2.001221, 0.412537, -0.698868, 0.17157], [0.11972, -1.382767, -0.890457, 0.218414, 1.666519, 0.659592, -0.069641, 0.914307], [0.189774, -1.18428, -0.785797, 0.106659, 1.429977, 0.195236, 0.627029, 0.503296], [0.194702, -1.098068, -0.956818, 0.020386, 1.369247, 0.10437, 0.641724, 0.410767], [0.215134, -1.069092, -1.11644, 0.283234, 1.313507, 0.110962, 0.600861, 0.752869], [0.216766, -1.065338, -1.047974, 0.080231, 1.500702, -0.113388, 0.712646, 0.914307], [0.259933, -0.964386, -0.981369, 0.092224, 1.480667, -0.00238, 0.896255, 0.665344], [0.265991, -0.935257, -0.93779, 0.214966, 1.235275, 0.104782, 1.33754, 0.599487], [0.266098, -0.62619, -0.905792, 0.131409, 0.402908, 0.103363, 1.352814, 1.554688], [0.273468, -0.354691, -0.709579, 0.228027, 0.315125, -0.15564, 0.942123, 1.024292], [0.246429, -0.272522, -0.609924, 0.318604, -0.007355, -0.165756, 1.07019, 1.087708], [0.248596, -0.232468, -0.524887, 0.53009, -0.476334, -0.184479, 1.088089, 0.667358], [0.074478, -0.200455, -0.058411, 0.662811, -0.111923, -0.686462, 1.205154, 1.271912], [0.063065, -0.080765, 0.065552, 0.79071, -0.569946, -0.899506, 0.875687, 0.095215], [0.117706, -0.270584, -0.021027, 0.723694, -0.200073, -0.365158, 0.892624, -0.152466], [0.00148, -0.075348, 0.017761, 0.757507, 0.719299, -0.355362, 0.749329, 0.315247], [0.035034, -0.110794, 0.038559, 0.949677, 0.478699, 0.005951, 0.097305, -0.388245], [-0.101944, -0.392487, 0.401886, 1.154938, 0.199127, 0.117371, -0.070007, -0.562439], [-0.083282, -0.388657, 0.449066, 1.505951, 0.46405, -0.566208, 0.216293, -0.528076], [-0.152054, -0.100113, 0.833054, 1.746857, 0.085861, -1.314102, 0.294632, -0.470947], [-0.166672, -0.183777, 0.988373, 1.925262, -0.202057, -0.961441, 0.15242, 0.594421], [-0.234573, -0.227707, 1.102112, 1.802002, -0.382492, -1.153336, 0.29335, 0.074036], [-0.336426, 0.042435, 1.255096, 1.804535, -0.610153, -0.810745, 1.308441, 0.599854], [-0.359344, 0.007248, 1.344543, 1.441559, -0.758286, -0.800079, 1.0233, 0.668213], [-0.321823, 0.027618, 1.1521, 1.509827, -0.708267, -0.668152, 1.05722, 0.710571], [-0.265335, 0.012344, 1.491501, 1.844971, -0.584137, -1.042419, -0.449188, 0.5354], [-0.302399, 0.049698, 1.440643, 1.674866, -0.626633, -1.158554, -0.906937, 0.405579], [-0.330276, 0.466675, 1.444153, 0.855499, -0.645447, -0.352158, 0.730423, 0.429932], [-0.354721, 0.540207, 1.570786, 0.626648, -0.897446, -0.007416, 0.174042, 0.100525], [-0.239609, 0.669983, 0.978851, 0.85321, -0.156784, 0.107986, 0.915054, 0.114197], [-0.189346, 0.930756, 0.824295, 0.516083, -0.339767, -0.206314, 0.744049, -0.36377]]
#query = [[0.387268, -1.21701, -0.432266, -1.394104, -0.458984, -1.469788, 0.12764, 2.310059], [0.418091, -1.389526, -0.150146, -0.759155, -0.578003, -2.123199, 0.276001, 3.022339], [0.264694, -1.526886, -0.238907, -0.511108, -0.90683, -2.699249, 0.692032, 2.849854], [0.246628, -1.675171, -0.533432, 0.070007, -0.392151, -1.739227, 0.534485, 2.744019], [0.099335, -1.983826, -0.985291, 0.428833, 0.26535, -1.285583, -0.234451, 2.4729], [0.055893, -2.108063, -0.401825, 0.860413, 0.724106, -1.959137, -1.360458, 2.350708], [-0.131592, -1.928314, -0.056213, 0.577698, 0.859146, -1.812286, -1.21669, 2.2052], [-0.162796, -2.149933, 0.467239, 0.524231, 0.74913, -1.829498, -0.741913, 1.616577], [-0.282745, -1.971008, 0.837616, 0.56427, 0.198288, -1.826935, -0.118027, 1.599731], [-0.497223, -1.578705, 1.277298, 0.682983, 0.055084, -2.032562, 0.64151, 1.719238], [-0.634232, -1.433258, 1.760513, 0.550415, -0.053787, -2.188568, 1.666687, 1.611938], [-0.607498, -1.302826, 1.960556, 1.331726, 0.417633, -2.271973, 2.095001, 0.9823], [-0.952957, -0.222076, 0.772064, 2.062256, -0.295258, -1.255371, 3.450974, -0.047607], [-1.210587, 1.00061, 0.036392, 1.952209, 0.470123, 0.231628, 2.670502, -0.608276], [-1.213287, 0.927002, -0.414825, 2.104065, 1.160126, 0.088898, 1.32959, -0.018311], [-1.081558, 1.007751, -0.337509, 1.7146, 0.653687, 0.297089, 1.916733, -0.772461], [-1.064804, 1.284302, -0.393585, 2.150635, 0.132294, 0.443298, 1.967575, 0.775513], [-0.972366, 1.039734, -0.588135, 1.413818, 0.423813, 0.781494, 1.977509, -0.556274], [-0.556381, 0.591309, -0.678314, 1.025635, 1.094284, 2.234711, 1.504013, -1.71875], [-0.063477, 0.626129, 0.360489, 0.149902, 0.92804, 0.936493, 1.203018, 0.264282], [0.162003, 0.577698, 0.956863, -0.477051, 1.081161, 0.817749, 0.660843, -0.428711], [-0.049515, 0.423615, 0.82489, 0.446228, 1.323853, 0.562775, -0.144196, 1.145386], [-0.146851, 0.171906, 0.304871, 0.320435, 1.378937, 0.673004, 0.188416, 0.208618], [0.33992, -2.072418, -0.447968, 0.526794, -0.175858, -1.400299, -0.452454, 1.396606], [0.226089, -2.183441, -0.301071, -0.475159, 0.834961, -2.191864, -1.092361, 2.434814], [0.279556, -2.073181, -0.517639, -0.766479, 0.974808, -2.070374, -2.003891, 2.706421], [0.237961, -1.9245, -0.708435, -0.582153, 1.285934, -1.75882, -2.146164, 2.369995], [0.149658, -1.703705, -0.539749, -0.215332, 1.369705, -1.484802, -1.506256, 1.04126], [0.078735, -1.719543, 0.157013, 0.382385, 1.100998, -0.223755, 0.021683, -0.545654], [0.106003, -1.404358, 0.372345, 1.881165, -0.292511, -0.263855, 1.579529, -1.426025], [0.047729, -1.198608, 0.600769, 1.901123, -1.106949, 0.128815, 1.293701, -1.364258], [0.110748, -0.894348, 0.712601, 1.728699, -1.250381, 0.674377, 0.812302, -1.428833], [0.085754, -0.662903, 0.794312, 1.102844, -1.234283, 1.084442, 0.986938, -1.10022], [0.140823, -0.300323, 0.673508, 0.669983, -0.551453, 1.213074, 1.449326, -1.567261], [0.03743, 0.550293, 0.400909, -0.174622, 0.355301, 1.325867, 0.875854, 0.126953], [-0.084885, 1.128906, 0.292099, -0.248779, 0.722961, 0.873871, -0.409515, 0.470581], [0.019684, 0.947754, 0.19931, -0.306274, 0.176849, 1.431702, 1.091507, 0.701416], [-0.094162, 0.895203, 0.687378, -0.229065, 0.549088, 1.376953, 0.892303, -0.642334], [-0.727692, 0.626495, 0.848877, 0.521362, 1.521912, -0.443481, 1.247238, 0.197388], [-0.82048, 0.117279, 0.975174, 1.487244, 1.085281, -0.567993, 0.776093, -0.381592], [-0.009827, -0.553009, -0.213135, 0.837341, 0.482712, -0.939423, 0.140884, 0.330566], [-0.018127, -1.362335, -0.199265, 1.260742, 0.005188, -1.445068, -1.159653, 1.220825], [0.186172, -1.727814, -0.246552, 1.544128, 0.285416, 0.081848, -1.634003, -0.47522], [0.193649, -1.144043, -0.334854, 1.220276, 1.241302, 1.554382, 0.57048, -1.334961], [0.344604, -1.647461, -0.720749, 0.993774, 0.585709, 0.953522, -0.493042, -1.845703], [0.37471, -1.989471, -0.518555, 0.555908, -0.025787, 0.148132, -1.463425, -0.844849], [0.34523, -1.821625, -0.809418, 0.59137, -0.577927, 0.037903, -2.067764, -0.519531], [0.413193, -1.503876, -0.752243, 0.280396, -0.236206, 0.429932, -1.684097, -0.724731], [0.331299, -1.349243, -0.890121, -0.178589, -0.285721, 0.809875, -2.012329, -0.157227], [0.278946, -1.090057, -0.670441, -0.477539, -0.267105, 0.446045, -1.95668, 0.501343], [0.127304, -0.977112, -0.660324, -1.011658, -0.547409, 0.349182, -1.357574, 1.045654], [0.217728, -0.793182, -0.496262, -1.259949, -0.128937, 0.38855, -1.513306, 1.863647], [0.240143, -0.891541, -0.619995, -1.478577, -0.361481, 0.258362, -1.630585, 1.841064], [0.241547, -0.758453, -0.515442, -1.370605, -0.428238, 0.23996, -1.469406, 1.307617], [0.289948, -0.714661, -0.533798, -1.574036, 0.017929, -0.368317, -1.290283, 0.851563], [0.304916, -0.783752, -0.459915, -1.523621, -0.107651, -0.027649, -1.089905, 0.969238], [0.27179, -0.795593, -0.352432, -1.597656, -0.001678, -0.06189, -1.072495, 0.637329], [0.301956, -0.823578, -0.152115, -1.637634, 0.2034, -0.214508, -1.315491, 0.773071], [0.282486, -0.853271, -0.162094, -1.561096, 0.15686, -0.289307, -1.076874, 0.673706], [0.299881, -0.97052, -0.051086, -1.431152, -0.074692, -0.32428, -1.385452, 0.684326], [0.220886, -1.072266, -0.269531, -1.038269, 0.140533, -0.711273, -1.7453, 1.090332], [0.177628, -1.229126, -0.274292, -0.943481, 0.483246, -1.214447, -2.026321, 0.719971], [0.176987, -1.137543, -0.007645, -0.794861, 0.965118, -1.084717, -2.37677, 0.598267], [0.135727, -1.36795, 0.09462, -0.776367, 0.946655, -1.157959, -2.794403, 0.226074], [0.067337, -1.648987, 0.535721, -0.665833, 1.506119, -1.348755, -3.092728, 0.281616], [-0.038101, -1.437347, 0.983917, -0.280762, 1.880722, -1.351318, -3.002258, -0.599609], [-0.152573, -1.146027, 0.717545, -0.60321, 2.126541, -0.59198, -2.282028, -1.048584], [-0.113525, -0.629669, 0.925323, 0.465393, 2.368698, -0.352661, -1.969391, -0.915161], [-0.140121, -0.311951, 0.884262, 0.809021, 1.557693, -0.552429, -1.776062, -0.925537], [-0.189423, -0.117767, 0.975174, 1.595032, 1.284485, -0.698639, -2.007202, -1.307251], [-0.048874, -0.176941, 0.820679, 1.306519, 0.584259, -0.913147, -0.658066, -0.630981], [-0.127594, 0.33313, 0.791336, 1.400696, 0.685577, -1.500275, -0.657959, -0.207642], [-0.044128, 0.653351, 0.615326, 0.476685, 1.099625, -0.902893, -0.154449, 0.325073], [-0.150223, 1.059845, 1.208405, -0.038635, 0.758667, 0.458038, -0.178909, -0.998657], [-0.099854, 1.127197, 0.789871, -0.013611, 0.452805, 0.736176, 0.948273, -0.236328], [-0.250275, 1.188568, 0.935989, 0.34314, 0.130463, 0.879913, 1.669037, 0.12793], [-0.122818, 1.441223, 0.670029, 0.389526, -0.15274, 1.293549, 1.22908, -1.132568]]
#this one doesn't
reference = [[-0.453598, -2.439209, 0.973587, 1.362091, -0.073654, -1.755112, 1.090057, 4.246765], [-0.448502, -2.621201, 0.723282, 1.257324, 0.26619, -1.375351, 1.328735, 4.46991], [-0.481247, -2.29718, 0.612854, 1.078033, 0.309708, -2.037506, 1.056305, 3.181702], [-0.42482, -2.306702, 0.436157, 1.529907, 0.50708, -1.930069, 0.653198, 3.561768], [-0.39032, -2.361343, 0.589294, 1.965607, 0.611801, -2.417084, 0.035675, 3.381104], [-0.233444, -2.281525, 0.703171, 2.17868, 0.519257, -2.474442, -0.502808, 3.569153], [-0.174652, -1.924591, 0.180267, 2.127075, 0.250626, -2.208527, -0.396591, 2.565552], [-0.121078, -1.53801, 0.234344, 2.221039, 0.845367, -1.516205, -0.174149, 1.298645], [-0.18631, -1.047806, 0.629654, 2.073303, 0.775024, -1.931076, 0.382706, 2.278442], [-0.160477, -0.78743, 0.694214, 1.917572, 0.834885, -1.574707, 0.780045, 2.370422], [-0.203659, -0.427246, 0.726486, 1.548767, 0.465698, -1.185379, 0.555206, 2.619629], [-0.208298, -0.393707, 0.771881, 1.646484, 0.612946, -0.996277, 0.658539, 2.499146], [-0.180679, -0.166656, 0.689209, 1.205994, 0.3918, -1.051483, 0.771072, 1.854553], [-0.1978, 0.082764, 0.723541, 1.019104, 0.165405, -0.127533, 1.0522, 0.552368], [-0.171127, 0.168533, 0.529541, 0.584839, 0.702011, -0.36525, 0.711792, 1.029114], [-0.224243, 0.38765, 0.916031, 0.45108, 0.708923, -0.059326, 1.016312, 0.437561], [-0.217072, -0.981766, 1.67363, 1.864014, 0.050812, -2.572815, -0.22937, 0.757996], [-0.284714, -0.784927, 1.720383, 1.782379, -0.093414, -2.492111, 0.623398, 0.629028], [-0.261169, -0.427979, 1.680038, 1.585358, 0.067093, -1.8181, 1.276291, 0.838989], [-0.183075, -0.08197, 1.094147, 1.120392, -0.117752, -0.86142, 1.94194, 0.966858], [-0.188919, 0.121521, 1.277664, 0.90979, 0.114288, -0.880875, 1.920517, 0.95752], [-0.226868, 0.338455, 0.78067, 0.803009, 0.347092, -0.387955, 0.641296, 0.374634], [-0.206329, 0.768158, 0.759537, 0.264099, 0.15979, 0.152618, 0.911636, -0.011597], [-0.230453, 0.495941, 0.547165, 0.137604, 0.36377, 0.594406, 1.168839, 0.125916], [0.340851, -0.382736, -1.060455, -0.267792, 1.1306, 0.595047, -1.544922, -1.6828], [0.341492, -0.325836, -1.07164, -0.215607, 0.895645, 0.400177, -0.773956, -1.827515], [0.392075, -0.305389, -0.885422, -0.293427, 0.993225, 0.66655, -1.061218, -1.730713], [0.30191, -0.339005, -0.877853, 0.153992, 0.986588, 0.711823, -1.100525, -1.648376], [0.303574, -0.491241, -1.000183, 0.075378, 0.686295, 0.752792, -1.192123, -1.744568], [0.315781, -0.629456, -0.996063, 0.224731, 1.074173, 0.757736, -1.170807, -2.08313], [0.313675, -0.804688, -1.00325, 0.431641, 0.685883, 0.538879, -0.988373, -2.421326], [0.267181, -0.790329, -0.726974, 0.853027, 1.369629, -0.213638, -1.708023, -1.977844], [0.304459, -0.935257, -0.778061, 1.042633, 1.391861, -0.296768, -1.562164, -2.014099], [0.169754, -0.792953, -0.481842, 1.404236, 0.766983, -0.29805, -1.587265, -1.25531], [0.15918, -0.9814, -0.197662, 1.748718, 0.888367, -0.880234, -1.64949, -1.359802], [0.028244, -0.772934, -0.186172, 1.594238, 0.863571, -1.224701, -1.153183, -0.292664], [-0.020401, -0.461578, 0.368088, 1.000366, 1.079636, -0.389603, -0.144409, 0.651733], [0.018555, -0.725418, 0.632599, 1.707336, 0.535049, -1.783859, -0.916122, 1.557007], [-0.038971, -0.797668, 0.820419, 1.483093, 0.350494, -1.465073, -0.786453, 1.370361], [-0.244888, -0.469513, 1.067978, 1.028809, 0.4879, -1.796585, -0.77887, 1.888977], [-0.260193, -0.226593, 1.141754, 1.21228, 0.214005, -1.200943, -0.441177, 0.532715], [-0.165283, 0.016129, 1.263016, 0.745514, -0.211288, -0.802368, 0.215698, 0.316406], [-0.353134, 0.053787, 1.544189, 0.21106, -0.469086, -0.485367, 0.767761, 0.849548], [-0.330215, 0.162704, 1.570053, 0.304718, -0.561172, -0.410294, 0.895126, 0.858093], [-0.333847, 0.173904, 1.56958, 0.075531, -0.5569, -0.259552, 1.276764, 0.749084], [-0.347107, 0.206665, 1.389832, 0.50473, -0.721664, -0.56955, 1.542618, 0.817444], [-0.299057, 0.140244, 1.402924, 0.215363, -0.62767, -0.550461, 1.60788, 0.506958], [-0.292084, 0.052063, 1.463348, 0.290497, -0.462875, -0.497452, 1.280609, 0.261841], [-0.279877, 0.183548, 1.308609, 0.305756, -0.6483, -0.374771, 1.647781, 0.161865], [-0.28389, 0.27916, 1.148636, 0.466736, -0.724442, -0.21991, 1.819901, -0.218872], [-0.275528, 0.309753, 1.192856, 0.398163, -0.828781, -0.268066, 1.763672, 0.116089], [-0.275284, 0.160019, 1.200623, 0.718628, -0.925552, -0.026596, 1.367447, 0.174866], [-0.302795, 0.383438, 1.10556, 0.441833, -0.968323, -0.137375, 1.851791, 0.357971], [-0.317078, 0.22876, 1.272217, 0.462219, -0.855789, -0.294296, 1.593994, 0.127502], [-0.304932, 0.207718, 1.156189, 0.481506, -0.866776, -0.340027, 1.670105, 0.657837], [-0.257217, 0.155655, 1.041428, 0.717926, -0.761597, -0.17244, 1.114151, 0.653503], [-0.321426, 0.292358, 0.73848, 0.422607, -0.850754, -0.057907, 1.462357, 0.697754], [-0.34642, 0.361526, 0.69722, 0.585175, -0.464508, -0.26651, 1.860596, 0.106201], [-0.339844, 0.584229, 0.542603, 0.184937, -0.341263, 0.085648, 1.837311, 0.160461], [-0.32338, 0.661224, 0.512833, 0.319702, -0.195572, 0.004028, 1.046799, 0.233704], [-0.346329, 0.572388, 0.385986, 0.118988, 0.057556, 0.039001, 1.255081, -0.18573], [-0.383392, 0.558395, 0.553391, -0.358612, 0.443573, -0.086014, 0.652878, 0.829956], [-0.420395, 0.668991, 0.64856, -0.021271, 0.511475, 0.639221, 0.860474, 0.463196], [-0.359039, 0.748672, 0.522964, -0.308899, 0.717194, 0.218811, 0.681396, 0.606812], [-0.323914, 0.942627, 0.249069, -0.418365, 0.673599, 0.797974, 0.162674, 0.120361], [-0.411301, 0.92775, 0.493332, -0.286346, 0.165054, 0.63446, 1.085571, 0.120789], [-0.346191, 0.632309, 0.635056, -0.402496, 0.143814, 0.785614, 0.952164, 0.482727], [-0.203812, 0.789261, 0.240433, -0.47699, -0.12912, 0.91832, 1.145493, 0.052002], [-0.048203, 0.632095, 0.009583, -0.53833, 0.232727, 1.293045, 0.308151, 0.188904], [-0.062393, 0.732315, 0.06694, -0.697144, 0.126221, 0.864578, 0.581635, -0.088379]]
query = [[-0.113144, -3.316223, -1.101563, -2.128418, 1.853867, 3.61972, 1.218185, 1.71228], [-0.128952, -3.37915, -1.152237, -2.033081, 1.860199, 4.008179, 0.445938, 1.665894], [-0.0392, -2.976654, -0.888245, -1.613953, 1.638641, 3.849518, 0.034073, 0.768188], [-0.146042, -2.980713, -1.044113, -1.44397, 0.954514, 3.20929, -0.232422, 1.050781], [-0.155029, -2.997192, -1.064438, -1.369873, 0.67688, 2.570709, -0.855347, 1.523438], [-0.102341, -2.686401, -1.029648, -1.00531, 0.950089, 1.933228, -0.526367, 1.598633], [-0.060272, -2.538727, -1.278259, -0.65332, 0.630875, 1.459717, -0.264038, 1.872925], [0.064087, -2.592682, -1.112823, -0.775024, 0.848618, 0.810883, 0.298965, 2.312134], [0.111557, -2.815277, -1.203506, -1.173584, 0.54863, 0.46756, -0.023071, 3.029053], [0.266068, -2.624786, -1.089066, -0.864136, 0.055389, 0.619446, -0.160965, 2.928589], [0.181488, -2.31073, -1.307785, -0.720276, 0.001297, 0.534668, 0.495499, 2.989502], [0.216202, -2.25354, -1.288193, -0.902039, -0.152283, -0.060791, 0.566315, 2.911621], [0.430084, -2.0289, -1.099594, -1.091736, -0.302505, -0.087799, 0.955963, 2.677002], [0.484253, -1.412842, -0.881882, -1.087158, -1.064072, -0.145935, 1.437683, 2.606567], [0.339081, -1.277222, -1.24498, -1.048279, -0.219498, 0.448517, 1.168625, 0.563843], [0.105728, 0.138275, -1.01413, -0.489868, 1.319275, 1.604645, 1.634003, -0.94812], [-0.209061, 1.025665, 0.180405, 0.955566, 1.527405, 0.91745, 1.951233, -0.40686], [-0.136993, 1.332275, 0.639862, 1.277832, 1.277313, 0.361267, 0.390717, -0.728394], [-0.217758, 1.416718, 1.080002, 0.816101, 0.343933, -0.154175, 1.10347, -0.568848]]
reference = np.array( reference )
query = np.array( query )
rpy2.robjects.numpy2ri.activate()
# Set up our R namespaces
R = rpy2.robjects.r
rNull = R("NULL")
rprint = rpy2.robjects.globalenv.get("print")
rplot = rpy2.robjects.r('plot')
distConstr = rpy2.robjects.r('proxy::dist')
DTW = importr('dtw')
stepName = "asymmetricP05"
stepPattern = rpy2.robjects.r( stepName )
canDist = distConstr( reference, query, "Euclidean" ) #
alignment = R.dtw(canDist, rNull, "Euclidean", stepPattern, "none", True, False, True, False )
For some series the script doesn't generate the error but there are some which do. See the commented lines for examples. It is worth noting that for the classic constraint this error does not appear. I am thinking that perhaps I have not set-up something correct but I am no expert in python nor in R so that is why I was hoping that others who have used the R DTW can help me on this. I am sorry for the long lines for reference and query (the data is from outputting the MFCC's of a 2 second wav file).
One of the two series is too short to be compatible with the fancy step pattern you chose. Use the common symmetric2 pattern, which does not restrict slopes, before the more exotic ones.

Categories

Resources