Skip to content

AIPEAC/PyTorch-PredNet

Repository files navigation

PredNet PyTorch Implementation

References


Implementation

Directory

Training

  • File: mnist_train_all.py
  • Parameters:
    • num_epochs: 6
    • batch_size: 2
    • lr: 0.001
    • nt: 20 (sequence length)
    • n_train_seq: 7000
    • n_val_seq: 1000

Dataset

  • File: mnist_data.py
  • Parameters:
    • nt: 20 (frames per sequence)
    • image_size: (64, 64)
    • channels: 3 (RGB)

Model Hyperparameters

  • loss_mode: 'L_0' or 'L_all' (default: 'L_all')
  • peephole: False
  • lstm_tied_bias: False
  • gating_mode: 'mul' or 'sub' (default: 'mul')
  • A_channels & R_channels: (3, 48, 96, 192)

Output

Results

  • Example predicted frames: alt text

License

This project is licensed under the Apache 2.0 License.

Third-Party Notices

This project is derived from the following works (see NOTICE for full attribution):

About

Implement PredNet with PyTorch, and tested on Moving MNIST dataset

Topics

Resources

License

Stars

Watchers

Forks

Contributors

Languages