Interaction Modeling with Multiplex Attention

NeurIPS 2022

Multi-agent systems are often guided by a variety of interactions. In this work, we propose a forecasting model for multi-agent systems that infers interactions in the form of graphs, entirely through the task of predicting agent trajectories.


Modeling multi-agent systems requires understanding how agents interact. Such systems are often difficult to model because they can involve a variety of types of interactions that layer together to drive rich social behavioral dynamics. Leading approaches in modeling multi-agent systems use use graph neural networks (GNNs) to infer edge types for every pair of entities in the interacting systems. However, GNNs do not explicitly handle the multiple layers of interactions present in social multi-agent systems and, as shown empirically, has led to at least two shortcomings: reduced performance on long-term predictions and decreased interpretability.

We present Interaction Modeling with Multiplex Attention (IMMA), a forward prediction model with
  • multiplex latent graph to represent multiple independent types of interactions
  • attention to account for relations of different strengths
We also introduce Progressive Layer Training (PLT), a strategy of learning a set of good graph bases by learning the high-level and most consequential interactions first and then progressively growing the network to model lower-level and more intricate interactions. PLT decreases the information dependency between latent graph layers and does not suffer from decreased forecasting accuracy caused by adding a disentanglement loss.

Relational Prediction

Visualization of the latent graph and agent trajectories of the Social Navigation Environment. The leftmost column shows ground truth trajectories and the ground truth graph used to simulate those trajectories. The red agent's relational prediction is inaccurate with RFM---in the row highlighted by the arrow, the green agent is incorrectly given higher weight than the blue agent---and thus the predicted trajectories deviate from the ground truth, especially on long-horizon predictions.

Controlled Generation

To further understand how our model uses the inferred social graph, we ask our decoder to generate new trajectories conditioned on the new latent graph. We see that changing the leader of an agent clearly alters the predicted trajectory to target that new leader while keeping the predictions for other agents intact, whereas the generated trajectories of the baseline consist of unrealistic turns (red agent 0) and the predicted trajectories of other agents deteriorates at the same time.

Additional Visualizations


IMMA (ours)



IMMA (ours)

Prediction results visualized for PHASE dataset and the NBA dataset. Dotted lines: past trajectories. Solid lines: ground truth future trajectories. Circles: predicted future trajectories. Our model achieves the best performance in predicting future trajectories.

Please send any question to Fan-Yun Sun.