Spaces:
Runtime error
Runtime error
| from cmath import isnan | |
| import pytest | |
| import torch | |
| from mmcv import Config | |
| from risk_biased.models.nn_blocks import ( | |
| SequenceDecoderLSTM, | |
| SequenceDecoderMLP, | |
| SequenceEncoderMaskedLSTM, | |
| SequenceEncoderMLP, | |
| AttentionBlock, | |
| ) | |
| def params(): | |
| torch.manual_seed(0) | |
| cfg = Config() | |
| cfg.batch_size = 4 | |
| cfg.input_dim = 10 | |
| cfg.output_dim = 15 | |
| cfg.latent_dim = 3 | |
| cfg.h_dim = 32 | |
| cfg.num_attention_heads = 4 | |
| cfg.num_h_layers = 2 | |
| cfg.device = "cpu" | |
| return cfg | |
| def test_AttentionBlock(params): | |
| attention = AttentionBlock(params.h_dim, params.num_attention_heads) | |
| num_agents = 4 | |
| num_map_objects = 8 | |
| encoded_agents = torch.rand(params.batch_size, num_agents, params.h_dim) | |
| mask_agents = torch.rand(params.batch_size, num_agents) > 0.1 | |
| encoded_absolute_agents = torch.rand(params.batch_size, num_agents, params.h_dim) | |
| encoded_map = torch.rand(params.batch_size, num_map_objects, params.h_dim) | |
| mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1 | |
| output = attention( | |
| encoded_agents, mask_agents, encoded_absolute_agents, encoded_map, mask_map | |
| ) | |
| # check shape | |
| assert output.shape == (params.batch_size, num_agents, params.h_dim) | |
| assert not torch.isnan(output).any() | |
| def test_SequenceDecoder(params): | |
| decoder = SequenceDecoderLSTM(params.h_dim) | |
| num_agents = 8 | |
| sequence_length = 16 | |
| input = torch.rand(params.batch_size, num_agents, params.h_dim) | |
| output = decoder(input, sequence_length) | |
| assert output.shape == ( | |
| params.batch_size, | |
| num_agents, | |
| sequence_length, | |
| params.h_dim, | |
| ) | |
| assert not torch.isnan(output).any() | |
| def test_SequenceDecoderMLP(params): | |
| sequence_length = 16 | |
| decoder = SequenceDecoderMLP( | |
| params.h_dim, params.num_h_layers, sequence_length, True | |
| ) | |
| num_agents = 8 | |
| input = torch.rand(params.batch_size, num_agents, params.h_dim) | |
| output = decoder(input, sequence_length) | |
| assert output.shape == ( | |
| params.batch_size, | |
| num_agents, | |
| sequence_length, | |
| params.h_dim, | |
| ) | |
| assert not torch.isnan(output).any() | |
| def test_SequenceEncoder(params): | |
| encoder = SequenceEncoderMaskedLSTM(params.input_dim, params.h_dim) | |
| num_agents = 8 | |
| sequence_length = 16 | |
| input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim) | |
| mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1 | |
| output = encoder(input, mask_input) | |
| assert output.shape == (params.batch_size, num_agents, params.h_dim) | |
| assert not torch.isnan(output).any() | |
| def test_SequenceEncoderMLP(params): | |
| sequence_length = 16 | |
| num_agents = 8 | |
| encoder = SequenceEncoderMLP( | |
| params.input_dim, params.h_dim, params.num_h_layers, sequence_length, True | |
| ) | |
| input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim) | |
| mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1 | |
| output = encoder(input, mask_input) | |
| assert output.shape == (params.batch_size, num_agents, params.h_dim) | |
| assert not torch.isnan(output).any() | |